multichannel support + colormap inherit

This commit is contained in:
2026-03-24 21:01:58 -07:00
parent 53e2fc7746
commit a60b0c15ca
12 changed files with 889 additions and 220 deletions

View File

@@ -16,6 +16,9 @@ from dataclasses import dataclass, field
import numpy as np
COLORMAPS = ("viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain",
"cividis", "magma", "copper", "afmhot")
@dataclass
class DataField:
data: np.ndarray # shape (yres, xres), dtype float64
@@ -28,6 +31,7 @@ class DataField:
si_unit_xy: str = "m"
si_unit_z: str = "m"
domain: str = "spatial" # "spatial" or "frequency"
colormap: str = "viridis"
def __post_init__(self) -> None:
self.data = np.asarray(self.data, dtype=np.float64)
@@ -48,6 +52,7 @@ class DataField:
si_unit_xy=self.si_unit_xy,
si_unit_z=self.si_unit_z,
domain=self.domain,
colormap=self.colormap,
)
def replace(self, **kwargs) -> "DataField":
@@ -63,6 +68,7 @@ class DataField:
"si_unit_xy": self.si_unit_xy,
"si_unit_z": self.si_unit_z,
"domain": self.domain,
"colormap": self.colormap,
}
base.update(kwargs)
return DataField(**base)

View File

@@ -50,6 +50,7 @@ class ExecutionEngine:
on_table: Callable[[str, list], None] | None = None,
on_mesh: Callable[[str, dict], None] | None = None,
on_overlay: Callable[[str, str], None] | None = None,
on_warning: Callable[[str, str], None] | None = None,
) -> dict[str, tuple]:
"""
Execute the workflow described by `prompt`.
@@ -62,6 +63,7 @@ class ExecutionEngine:
on_preview : called with (node_id, data_uri) when a display node runs
on_table : called with (node_id, table_list) when PrintTable runs
on_overlay : called with (node_id, data_uri) for interactive overlays
on_warning : called with (node_id, message) for node warnings
Returns
-------
@@ -71,7 +73,7 @@ class ExecutionEngine:
node_outputs: dict[str, tuple] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay)
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_warning)
for node_id in order:
node_def = prompt[node_id]
@@ -174,12 +176,13 @@ class ExecutionEngine:
on_table: Callable | None,
on_mesh: Callable | None = None,
on_overlay: Callable | None = None,
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import SaveImage
from backend.nodes.io import SaveImage, LoadFile
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
@@ -190,6 +193,7 @@ class ExecutionEngine:
PrintTable._broadcast_table_fn = on_table
CrossSection._broadcast_overlay_fn = on_overlay
LineCursors._broadcast_overlay_fn = on_overlay
LoadFile._broadcast_warning_fn = on_warning
SaveImage._broadcast_preview = (
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
)
@@ -199,8 +203,9 @@ class ExecutionEngine:
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import LoadFile
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine):
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile):
cls._current_node_id = node_id
def _auto_preview(
@@ -232,7 +237,7 @@ class ExecutionEngine:
value = result[slot]
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
arr = datafield_to_uint8(value, "viridis")
arr = datafield_to_uint8(value, value.colormap)
on_preview(node_id, encode_preview(arr))
return # one preview per node is enough

View File

@@ -326,6 +326,7 @@ class FFT2D:
si_unit_xy="1/m",
si_unit_z=z_unit,
domain="frequency",
colormap=field.colormap,
)
return (out_field,)

View File

@@ -9,7 +9,7 @@ before execution begins.
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview
from backend.data_types import DataField, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview
@register_node(display_name="Preview")
@@ -18,7 +18,7 @@ class PreviewImage:
def INPUT_TYPES(cls):
return {
"required": {
"colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],),
"colormap": (["auto"] + list(COLORMAPS),),
},
"optional": {
"image": ("IMAGE",),
@@ -36,6 +36,10 @@ class PreviewImage:
_current_node_id: str = ""
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
# Resolve "auto" — use field's colormap if available, else fall back to gray
if colormap == "auto":
colormap = field.colormap if field is not None else "gray"
# Prefer field if both are connected; accept whichever is provided
if field is not None:
arr_u8 = datafield_to_uint8(field, colormap)
@@ -73,7 +77,7 @@ class View3D:
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],),
"colormap": (["auto"] + list(COLORMAPS),),
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
}
@@ -114,7 +118,8 @@ class View3D:
else:
z_norm = np.zeros_like(z)
cmap = cm.get_cmap(colormap)
cmap_name = field.colormap if colormap == "auto" else colormap
cmap = cm.get_cmap(cmap_name)
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)

View File

@@ -8,7 +8,7 @@ import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview, image_to_uint8
from backend.data_types import DataField, COLORMAPS, encode_preview, image_to_uint8
from backend.runtime_paths import demo_dir, input_dir, output_dir
# Resolved at server startup so nodes know where to look
@@ -19,112 +19,293 @@ OUTPUT_DIR = output_dir()
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
# ---------------------------------------------------------------------------
# LoadImage
# Channel listing helper (used by the /channels endpoint)
# ---------------------------------------------------------------------------
@register_node(display_name="Load Image")
class LoadImage:
def _resolve_path(filepath: str) -> Path:
path = Path(filepath)
if path.is_absolute():
return path
# Try input dir first, then demo dir
candidate = INPUT_DIR / filepath
if candidate.exists():
return candidate
candidate = DEMO_DIR / filepath
if candidate.exists():
return candidate
# Fall back to input dir (will trigger FileNotFoundError later)
return INPUT_DIR / filepath
def list_channels(filepath: str) -> list[dict]:
"""Return available channel info for a file.
Returns a list of {"name": str, "type": "DATA_FIELD"} dicts.
For SPM formats this inspects the file header.
For images / arrays, returns a single unnamed channel.
"""
path = _resolve_path(filepath)
if not path.exists():
return [{"name": "field", "type": "DATA_FIELD"}]
ext = path.suffix.lower()
if ext == ".gwy":
try:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".sxm":
try:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
if sxm.signals:
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".ibw":
try:
from igor.binarywave import load as load_ibw
wave = load_ibw(str(path))
raw = wave["wave"]["wData"]
labels = wave["wave"].get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
# Multi-channel without labels — use numeric names
if raw.ndim >= 3 and raw.shape[2] > 1:
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
# Image or array — single channel
return [{"name": "field", "type": "DATA_FIELD"}]
# ---------------------------------------------------------------------------
# LoadFile (unified loader — replaces LoadImage + LoadSPM)
# ---------------------------------------------------------------------------
@register_node(display_name="Load File")
class LoadFile:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
"colormap": (list(COLORMAPS),),
}
}
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
RETURN_NAMES = ("image", "field")
# Default outputs — overridden dynamically by the frontend for multi-channel files
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD."
def load(self, filename: str):
# Accept absolute paths or filenames relative to input/
path = Path(filename)
if not path.is_absolute():
path = INPUT_DIR / filename
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
ext = path.suffix.lower()
if ext in (".npy",):
arr = np.load(str(path)).astype(np.float64)
elif ext in (".npz",):
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image
img = Image.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
# Convert to float64 grayscale for the DATA_FIELD output
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
field = DataField(data=gray)
return (arr, field)
# ---------------------------------------------------------------------------
# LoadDemo
# ---------------------------------------------------------------------------
def _list_demo_files() -> list[str]:
"""Return sorted list of demo filenames available in the demo/ directory."""
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
DESCRIPTION = (
"Load any supported file. "
"SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; "
"each channel gets its own output. "
"Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields."
)
# Set by execution engine for warning broadcast
_broadcast_warning_fn = None
_current_node_id = None
@register_node(display_name="Load Demo Image")
class LoadDemo:
@classmethod
def INPUT_TYPES(cls):
choices = _list_demo_files() or ["(no demo images found)"]
return {
"required": {
"name": (choices,),
}
}
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
RETURN_NAMES = ("image", "field")
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load a bundled demo image so you can try the app without providing your own data."
def load(self, name: str):
path = DEMO_DIR / name
def load(self, filename: str, colormap: str = "viridis"):
if not filename or not filename.strip():
raise ValueError("No file selected — use Browse to pick a file.")
path = _resolve_path(filename)
if not path.exists():
raise FileNotFoundError(f"Demo image not found: {name}")
raise FileNotFoundError(f"File not found: {path}")
if path.is_dir():
raise IsADirectoryError(f"Expected a file, got a directory: {path}")
ext = path.suffix.lower()
# SPM formats → delegate to LoadSPM-style loading, return as IMAGE + DATA_FIELD
if ext == ".gwy":
field = LoadSPM()._load_gwy(path, "Z")
arr = field.data
return (arr, field)
elif ext == ".sxm":
field = LoadSPM()._load_sxm(path, "Z")
arr = field.data
return (arr, field)
elif ext == ".ibw":
field = LoadSPM()._load_ibw(path)
arr = field.data
return (arr, field)
if ext in _SPM_EXTENSIONS:
fields = self._load_spm_all(path, ext)
for f in fields:
f.colormap = colormap
return tuple(fields)
# npy / npz
# Image or array — uncalibrated, single output
field = self._load_image_or_array(path, ext)
field.colormap = colormap
self._send_warning("Uncalibrated data — no physical dimensions.")
return (field,)
def _send_warning(self, message: str):
fn = LoadFile._broadcast_warning_fn
nid = LoadFile._current_node_id
if fn and nid:
fn(nid, message)
# -- SPM: load all channels ---------------------------------------------
def _load_spm_all(self, path: Path, ext: str) -> list[DataField]:
if ext == ".gwy":
return self._load_gwy_all(path)
elif ext == ".sxm":
return self._load_sxm_all(path)
elif ext == ".ibw":
return self._load_ibw_all(path)
else:
raise ValueError(f"Unsupported SPM format: {ext}")
# -- GWY ----------------------------------------------------------------
def _load_gwy_all(self, path: Path) -> list[DataField]:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
fields = []
for ch in channels.values():
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
fields.append(DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
))
return fields
# -- SXM ----------------------------------------------------------------
def _load_sxm_all(self, path: Path) -> list[DataField]:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
if not signals:
raise ValueError(f"No signals found in {path.name}")
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
fields = []
for sig in signals.values():
data = sig.get("forward", list(sig.values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
fields.append(DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
))
return fields
# -- IBW ----------------------------------------------------------------
def _load_ibw_all(self, path: Path) -> list[DataField]:
try:
from igor.binarywave import load as load_ibw
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
wave = load_ibw(str(path))
wdata = wave["wave"]
header = wdata["wave_header"]
raw = wdata["wData"]
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
# Physical scaling
sfA = header.get("sfA", None)
def _decode_unit(raw_unit):
if raw_unit is None:
return "m"
if isinstance(raw_unit, bytes):
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
if isinstance(raw_unit, np.ndarray):
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
return str(raw_unit).strip() or "m"
dim_units_raw = header.get("dimUnits", None)
data_units_raw = header.get("dataUnits", None)
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
si_unit_xy = _decode_unit(dim_units_raw[0])
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
si_unit_xy = _decode_unit(dim_units_raw[0])
else:
si_unit_xy = _decode_unit(dim_units_raw)
si_unit_z = _decode_unit(data_units_raw)
fields = []
for ch_idx in range(n_channels):
if raw.ndim >= 3:
ch_data = raw[:, :, ch_idx]
elif raw.ndim == 1:
ch_data = raw.reshape(-1, 1)
else:
ch_data = raw
# Transpose from (xres, yres) Igor order to (yres, xres) DataField order,
# then flip vertically to match gwyddion
data = np.flipud(ch_data.T).astype(np.float64)
yres, xres = data.shape
if sfA is not None and len(sfA) >= 2:
xreal = abs(float(sfA[0]) * xres) or 1e-6
yreal = abs(float(sfA[1]) * yres) or 1e-6
else:
hsA = header.get("hsA", 0.0)
xreal = abs(float(hsA) * xres) or 1e-6
yreal = xreal * (yres / xres) if xres else 1e-6
fields.append(DataField(
data=data, xreal=xreal, yreal=yreal,
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
))
return fields
# -- Image / array (uncalibrated) --------------------------------------
def _load_image_or_array(self, path: Path, ext: str) -> DataField:
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
@@ -143,22 +324,32 @@ class LoadDemo:
else:
gray = arr.astype(np.float64)
field = DataField(data=gray)
return (arr, field)
return DataField(data=gray)
# ---------------------------------------------------------------------------
# LoadSPM
# LoadDemo
# ---------------------------------------------------------------------------
@register_node(display_name="Load SPM File")
class LoadSPM:
def _list_demo_files() -> list[str]:
"""Return sorted list of demo filenames available in the demo/ directory."""
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
)
@register_node(display_name="Load Demo File")
class LoadDemo:
@classmethod
def INPUT_TYPES(cls):
choices = _list_demo_files() or ["(no demo files found)"]
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
"channel": ("STRING", {"default": "Z"}),
"name": (choices,),
"colormap": (list(COLORMAPS),),
}
}
@@ -166,111 +357,15 @@ class LoadSPM:
RETURN_NAMES = ("field",)
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField."
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
def load(self, filename: str, channel: str = "Z"):
path = Path(filename)
if not path.is_absolute():
path = INPUT_DIR / filename
def load(self, name: str, colormap: str = "viridis"):
path = DEMO_DIR / name
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
raise FileNotFoundError(f"Demo file not found: {name}")
ext = path.suffix.lower()
if ext == ".gwy":
return (self._load_gwy(path, channel),)
elif ext == ".sxm":
return (self._load_sxm(path, channel),)
elif ext in (".ibw",):
return (self._load_ibw(path),)
elif ext in (".npy",):
data = np.load(str(path)).astype(np.float64)
return (DataField(data=data),)
elif ext in (".npz",):
npz = np.load(str(path))
key = list(npz.files)[0]
return (DataField(data=npz[key].astype(np.float64)),)
else:
raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz")
def _load_gwy(self, path: Path, channel: str) -> DataField:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
# Try requested channel name, fall back to first available
ch = None
for key, df in channels.items():
if channel.lower() in key.lower():
ch = df
break
if ch is None:
ch = next(iter(channels.values()))
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
return DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
)
def _load_sxm(self, path: Path, channel: str) -> DataField:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
# Pick channel
ch_key = None
for key in signals:
if channel.upper() in key.upper():
ch_key = key
break
if ch_key is None:
ch_key = next(iter(signals))
data = signals[ch_key].get("forward", list(signals[ch_key].values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
return DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
)
def _load_ibw(self, path: Path) -> DataField:
try:
import igor.igorpy as igorpy
wave = igorpy.load(str(path))
data = wave.wave["wData"].squeeze().astype(np.float64)
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
if data.ndim == 1:
data = data.reshape(1, -1)
elif data.ndim != 2:
data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1)
return DataField(data=data, si_unit_xy="m", si_unit_z="m")
loader = LoadFile()
return loader.load(filename=str(path), colormap=colormap)
# ---------------------------------------------------------------------------

View File

@@ -101,6 +101,9 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
def on_overlay(node_id: str, overlay_data) -> None:
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
def on_warning(node_id: str, message: str) -> None:
broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}})
# ------------------------------------------------------------------
# Route handlers
# ------------------------------------------------------------------
@@ -193,6 +196,18 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
},
)
async def get_channels(request: web.Request) -> web.Response:
"""Return available channels for a given file path."""
from backend.nodes.io import list_channels
filepath = request.query.get("file", "")
if not filepath:
return web.Response(
text=_dumps([{"name": "field", "type": "DATA_FIELD"}]),
content_type="application/json",
)
channels = await loop.run_in_executor(None, list_channels, filepath)
return web.Response(text=_dumps(channels), content_type="application/json")
async def submit_prompt(request: web.Request) -> web.Response:
body = await request.json()
prompt = body.get("prompt")
@@ -218,6 +233,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
on_table=on_table,
on_mesh=on_mesh,
on_overlay=on_overlay,
on_warning=on_warning,
),
)
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}})
@@ -262,6 +278,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
app.router.add_get("/browse", browse_dir)
app.router.add_post("/upload", upload_file)
app.router.add_post("/download", download_file)
app.router.add_get("/channels", get_channels)
app.router.add_post("/prompt", submit_prompt)
app.router.add_get("/ws", websocket_handler)

View File

@@ -436,6 +436,9 @@ function Flow() {
case 'overlay':
updateNodeData(msg.data.node_id, { overlay: msg.data.overlay });
break;
case 'node_warning':
updateNodeData(msg.data.node_id, { warning: msg.data.message });
break;
}
});
api.initWS();
@@ -500,9 +503,36 @@ function Flow() {
data: {
...n.data,
widgetValues: { ...n.data.widgetValues, [name]: value },
// Clear warning when user changes a value
warning: null,
},
};
}));
// If this is a filename/name change on a LoadFile/LoadDemo node, fetch channels
if ((name === 'filename' || name === 'name') && value) {
const node = reactFlow.getNode(nodeId);
if (node && (node.data.className === 'LoadFile' || node.data.className === 'LoadDemo')) {
api.getChannels(value).then((channels) => {
setNodes((prev) => prev.map((n) => {
if (n.id !== nodeId) return n;
return {
...n,
data: {
...n.data,
definition: {
...n.data.definition,
output: channels.map((c) => c.type),
output_name: channels.map((c) => c.name),
},
},
};
}));
reactFlow.updateNodeInternals(nodeId);
});
}
}
scheduleAutoRun();
}, [setNodes]); // scheduleAutoRun is stable (no deps)
@@ -568,6 +598,27 @@ function Flow() {
setNodes((ns) => [...ns, newNode]);
// For LoadFile/LoadDemo, auto-fetch channels for the default value
if (className === 'LoadDemo' && widgetValues.name) {
api.getChannels(widgetValues.name).then((channels) => {
setNodes((prev) => prev.map((n) => {
if (n.id !== newNodeId) return n;
return {
...n,
data: {
...n.data,
definition: {
...n.data.definition,
output: channels.map((c) => c.type),
output_name: channels.map((c) => c.name),
},
},
};
}));
reactFlow.updateNodeInternals(newNodeId);
});
}
// Auto-connect if this was triggered by dropping a connection on blank space
if (contextMenu.pendingHandleId) {
const filterType = contextMenu.filterType;

View File

@@ -211,6 +211,11 @@ function CustomNode({ id, data }) {
);
})}
{/* Warning notification */}
{data.warning && (
<div className="node-warning">{data.warning}</div>
)}
{/* Widget rows */}
{widgets.map((w) => (
<div className="widget-row" key={w.name}>

View File

@@ -34,6 +34,12 @@ export async function uploadFile(file) {
return r.json();
}
export async function getChannels(filepath) {
const r = await fetch(`/channels?file=${encodeURIComponent(filepath)}`);
if (!r.ok) return [{ name: 'field', type: 'DATA_FIELD' }];
return r.json();
}
export async function runPrompt(prompt) {
const r = await fetch('/prompt', {
method: 'POST',

View File

@@ -82,10 +82,7 @@ html, body, #root {
padding: 4px 10px;
border-radius: 4px;
font-size: 11px;
max-width: 400px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
max-width: 60%;
flex-shrink: 1;
}
.status-bar.info { color: #90caf9; }
@@ -155,6 +152,15 @@ html, body, #root {
padding: 4px 0;
}
.node-warning {
padding: 3px 10px;
font-size: 10px;
color: #fbbf24;
background: rgba(251, 191, 36, 0.1);
border-top: 1px solid rgba(251, 191, 36, 0.2);
border-bottom: 1px solid rgba(251, 191, 36, 0.2);
}
/* ── I/O rows ──────────────────────────────────────────────────────── */
.io-row {
display: flex;

View File

@@ -9,6 +9,7 @@ export default defineConfig({
'/nodes': 'http://127.0.0.1:8188',
'/files': 'http://127.0.0.1:8188',
'/browse': 'http://127.0.0.1:8188',
'/channels': 'http://127.0.0.1:8188',
'/upload': 'http://127.0.0.1:8188',
'/download': 'http://127.0.0.1:8188',
'/prompt': 'http://127.0.0.1:8188',

View File

@@ -523,41 +523,42 @@ def test_particle_analysis():
# I/O
# =========================================================================
def test_load_image():
print("=== Test: LoadImage ===")
from backend.nodes.io import LoadImage
def test_load_file():
print("=== Test: LoadFile ===")
from backend.nodes.io import LoadFile
from PIL import Image
node = LoadImage()
node = LoadFile()
with tempfile.TemporaryDirectory() as tmpdir:
# Test loading a grayscale PNG
# Test loading a grayscale PNG → single DataField output
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
img = Image.fromarray(arr, mode="L")
path = os.path.join(tmpdir, "test_gray.png")
img.save(path)
image, field = node.load(filename=path)
assert image.shape == (48, 64)
result = node.load(filename=path)
assert len(result) == 1
field = result[0]
assert field.data.shape == (48, 64)
assert field.data.dtype == np.float64
# Test loading an RGB PNG (should average to grayscale for field)
# Test loading an RGB PNG (should average to grayscale)
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
path_rgb = os.path.join(tmpdir, "test_rgb.png")
img_rgb.save(path_rgb)
image_rgb, field_rgb = node.load(filename=path_rgb)
assert image_rgb.shape == (32, 32, 3)
assert field_rgb.data.shape == (32, 32)
result_rgb = node.load(filename=path_rgb)
assert len(result_rgb) == 1
assert result_rgb[0].data.shape == (32, 32)
# Test loading a .npy file
data_npy = np.random.default_rng(3).standard_normal((50, 60))
path_npy = os.path.join(tmpdir, "test.npy")
np.save(path_npy, data_npy)
image_npy, field_npy = node.load(filename=path_npy)
assert np.allclose(field_npy.data, data_npy)
result_npy = node.load(filename=path_npy)
assert np.allclose(result_npy[0].data, data_npy)
print(" PASS\n")
@@ -641,6 +642,464 @@ def test_print_table():
print(" PASS\n")
# =========================================================================
# I/O — IBW multi-channel loading
# =========================================================================
def test_load_file_ibw():
print("=== Test: LoadFile IBW multi-channel ===")
from backend.nodes.io import LoadFile
node = LoadFile()
ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")
ibw_path = os.path.abspath(ibw_path)
if not os.path.exists(ibw_path):
print(" SKIP (demo IBW file not found)\n")
return
result = node.load(filename=ibw_path)
# BR_New20012.ibw has 4 channels
assert len(result) == 4, f"Expected 4 channels, got {len(result)}"
for i, field in enumerate(result):
assert isinstance(field, DataField), f"Channel {i} is not a DataField"
assert field.data.shape == (512, 1024), f"Channel {i} shape: {field.data.shape}"
assert field.data.dtype == np.float64
# Physical dimensions should be populated (not default 1e-6)
assert field.xreal > 1e-8, f"Channel {i} xreal too small: {field.xreal}"
assert field.yreal > 1e-8, f"Channel {i} yreal too small: {field.yreal}"
assert field.si_unit_xy == "m"
assert field.si_unit_z == "m"
# All channels should share the same physical dimensions
assert result[0].xreal == result[1].xreal
assert result[0].yreal == result[1].yreal
# Different channels should have different data
assert not np.array_equal(result[0].data, result[1].data)
print(" PASS\n")
def test_load_file_npz():
print("=== Test: LoadFile .npz ===")
from backend.nodes.io import LoadFile
node = LoadFile()
with tempfile.TemporaryDirectory() as tmpdir:
data = np.random.default_rng(99).standard_normal((30, 40))
path = os.path.join(tmpdir, "test.npz")
np.savez(path, my_array=data)
result = node.load(filename=path)
assert len(result) == 1
assert np.allclose(result[0].data, data)
print(" PASS\n")
def test_load_file_not_found():
print("=== Test: LoadFile not found ===")
from backend.nodes.io import LoadFile
node = LoadFile()
try:
node.load(filename="/nonexistent/path/file.png")
assert False, "Should have raised FileNotFoundError"
except FileNotFoundError:
pass
print(" PASS\n")
def test_load_file_unsupported():
print("=== Test: LoadFile unsupported format ===")
from backend.nodes.io import LoadFile
node = LoadFile()
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "test.xyz")
with open(path, "w") as f:
f.write("hello")
try:
node.load(filename=path)
assert False, "Should have raised an error for .xyz"
except Exception:
pass
print(" PASS\n")
def test_load_file_warning():
print("=== Test: LoadFile warning for uncalibrated data ===")
from backend.nodes.io import LoadFile
from PIL import Image
node = LoadFile()
warnings = []
LoadFile._broadcast_warning_fn = lambda nid, msg: warnings.append(msg)
LoadFile._current_node_id = "test"
with tempfile.TemporaryDirectory() as tmpdir:
arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8)
img = Image.fromarray(arr)
path = os.path.join(tmpdir, "test.png")
img.save(path)
result = node.load(filename=path)
assert len(result) == 1
assert len(warnings) == 1
assert "Uncalibrated" in warnings[0]
LoadFile._broadcast_warning_fn = None
print(" PASS\n")
# =========================================================================
# I/O — list_channels helper
# =========================================================================
def test_list_channels():
print("=== Test: list_channels ===")
from backend.nodes.io import list_channels
# Non-existent file → default
ch = list_channels("/nonexistent/file.ibw")
assert len(ch) == 1
assert ch[0]["name"] == "field"
# IBW with channels
ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw"))
if os.path.exists(ibw_path):
ch = list_channels(ibw_path)
assert len(ch) == 4
names = [c["name"] for c in ch]
assert "HeightRetrace" in names
assert "AmplitudeRetrace" in names
assert all(c["type"] == "DATA_FIELD" for c in ch)
# Plain image → single default channel
with tempfile.TemporaryDirectory() as tmpdir:
from PIL import Image
img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8))
path = os.path.join(tmpdir, "test.png")
img.save(path)
ch = list_channels(path)
assert len(ch) == 1
assert ch[0]["name"] == "field"
# .npy → single default channel
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "test.npy")
np.save(path, np.zeros((4, 4)))
ch = list_channels(path)
assert len(ch) == 1
print(" PASS\n")
# =========================================================================
# I/O — LoadDemo
# =========================================================================
def test_load_demo():
print("=== Test: LoadDemo ===")
from backend.nodes.io import LoadDemo
node = LoadDemo()
# Should be able to load a demo file by name
result = node.load(name="nanoparticles.npy")
assert len(result) >= 1
assert isinstance(result[0], DataField)
assert result[0].data.ndim == 2
# IBW demo should return multiple channels
result_ibw = node.load(name="whiskers.ibw")
assert len(result_ibw) == 4
for field in result_ibw:
assert isinstance(field, DataField)
# Non-existent demo should raise
try:
node.load(name="nonexistent_file.png")
assert False, "Should have raised FileNotFoundError"
except FileNotFoundError:
pass
print(" PASS\n")
# =========================================================================
# I/O — Coordinate
# =========================================================================
def test_coordinate():
print("=== Test: Coordinate ===")
from backend.nodes.io import Coordinate
node = Coordinate()
result = node.process(x=0.3, y=0.7)
assert len(result) == 1
assert result[0] == (0.3, 0.7)
# Edge values
result_zero = node.process(x=0.0, y=0.0)
assert result_zero[0] == (0.0, 0.0)
result_one = node.process(x=1.0, y=1.0)
assert result_one[0] == (1.0, 1.0)
print(" PASS\n")
# =========================================================================
# Analysis — LineCursors
# =========================================================================
def test_line_cursors():
print("=== Test: LineCursors ===")
from backend.nodes.analysis import LineCursors
node = LineCursors()
# Create a simple linear ramp
line = np.linspace(0, 10, 100).astype(np.float64)
# Capture overlay
overlays = []
LineCursors._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
LineCursors._current_node_id = "test"
table, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
# Should produce a 6-row table
assert len(table) == 6
quantities = {row["quantity"] for row in table}
assert "A position" in quantities
assert "B position" in quantities
assert "delta X" in quantities
assert "delta Y" in quantities
# B should be at a later position than A
a_pos = next(r["value"] for r in table if r["quantity"] == "A position")
b_pos = next(r["value"] for r in table if r["quantity"] == "B position")
assert b_pos > a_pos
# Delta Y should reflect the height difference along the ramp
dy = next(r["value"] for r in table if r["quantity"] == "delta Y")
assert dy > 0 # ramp goes upward
# Overlay should have been broadcast
assert len(overlays) == 1
assert "image" in overlays[0]
assert overlays[0]["image"].startswith("data:image/png;base64,")
# With x_axis provided
x_axis = np.linspace(0, 1, 100).astype(np.float64)
table2, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5, x_axis=x_axis)
assert len(table2) == 6
LineCursors._broadcast_overlay_fn = None
print(" PASS\n")
# =========================================================================
# Analysis — FFT2D
# =========================================================================
def test_fft2d():
print("=== Test: FFT2D ===")
from backend.nodes.analysis import FFT2D
node = FFT2D()
# Pure single-frequency signal: peak should appear at the right location
N = 64
y, x = np.mgrid[0:N, 0:N] / N
freq = 5
data = np.sin(2 * np.pi * freq * x)
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
# log_magnitude
spectrum, = node.process(field, windowing="none", level="none", output="log_magnitude")
assert spectrum.data.shape == (N, N)
assert spectrum.domain == "frequency"
assert spectrum.si_unit_xy == "1/m"
# Peak should be symmetric about centre
centre = N // 2
row = spectrum.data[centre, :]
peak_idx = np.argmax(row[centre + 1:]) + centre + 1
assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}"
# magnitude output
spec_mag, = node.process(field, windowing="hann", level="mean", output="magnitude")
assert spec_mag.data.shape == (N, N)
assert np.all(spec_mag.data >= 0)
# phase output
spec_phase, = node.process(field, windowing="none", level="none", output="phase")
assert spec_phase.data.shape == (N, N)
assert spec_phase.data.min() >= -np.pi - 0.01
assert spec_phase.data.max() <= np.pi + 0.01
# psdf output — units should reflect PSDF calibration
spec_psdf, = node.process(field, windowing="hamming", level="plane", output="psdf")
assert spec_psdf.data.shape == (N, N)
assert np.all(spec_psdf.data >= 0)
assert "^2" in spec_psdf.si_unit_z
# Constant field should have all energy at DC
const_field = make_field(data=np.ones((32, 32)) * 3.0)
spec_const, = node.process(const_field, windowing="none", level="none", output="magnitude")
centre32 = 16
dc_val = spec_const.data[centre32, centre32]
assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field"
# Blackman windowing should also work without error
spec_bk, = node.process(field, windowing="blackman", level="none", output="log_magnitude")
assert spec_bk.data.shape == (N, N)
print(" PASS\n")
# =========================================================================
# Analysis — LineMath
# =========================================================================
def test_line_math():
print("=== Test: LineMath ===")
from backend.nodes.analysis import LineMath
node = LineMath()
line = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
# Basic stats
table, = node.process(line, operation="min")
assert table[0]["value"] == 1.0
table, = node.process(line, operation="max")
assert table[0]["value"] == 5.0
table, = node.process(line, operation="mean")
assert table[0]["value"] == 3.0
table, = node.process(line, operation="median")
assert table[0]["value"] == 3.0
table, = node.process(line, operation="sum")
assert table[0]["value"] == 15.0
table, = node.process(line, operation="range")
assert table[0]["value"] == 4.0
table, = node.process(line, operation="length")
assert table[0]["value"] == 5.0
# RMS of [1,2,3,4,5]
table, = node.process(line, operation="rms")
expected_rms = np.sqrt(np.mean(line ** 2))
assert abs(table[0]["value"] - expected_rms) < 1e-10
# Roughness parameters
table, = node.process(line, operation="Ra")
d = line - line.mean()
expected_ra = float(np.mean(np.abs(d)))
assert abs(table[0]["value"] - expected_ra) < 1e-10
table, = node.process(line, operation="Rq")
expected_rq = float(np.sqrt(np.mean(d ** 2)))
assert abs(table[0]["value"] - expected_rq) < 1e-10
# Rp = max of (z - mean)
table, = node.process(line, operation="Rp")
assert abs(table[0]["value"] - d.max()) < 1e-10
# Rv = -(min of (z - mean))
table, = node.process(line, operation="Rv")
assert abs(table[0]["value"] - (-d.min())) < 1e-10
# Rt = Rp + Rv = range of (z - mean)
table, = node.process(line, operation="Rt")
assert abs(table[0]["value"] - (d.max() - d.min())) < 1e-10
# Constant line: roughness parameters should all be zero
const_line = np.ones(10) * 7.0
table, = node.process(const_line, operation="Ra")
assert table[0]["value"] == 0.0
table, = node.process(const_line, operation="Rq")
assert table[0]["value"] == 0.0
table, = node.process(const_line, operation="Rsk")
assert table[0]["value"] == 0.0
table, = node.process(const_line, operation="Rku")
assert table[0]["value"] == 0.0
# Slope-based: Dq and Da
table, = node.process(line, operation="Dq")
dz = np.diff(line)
expected_dq = float(np.sqrt(np.mean(dz * dz)))
assert abs(table[0]["value"] - expected_dq) < 1e-10
table, = node.process(line, operation="Da")
expected_da = float(np.mean(np.abs(dz)))
assert abs(table[0]["value"] - expected_da) < 1e-10
print(" PASS\n")
# =========================================================================
# Display — View3D
# =========================================================================
def test_view3d():
print("=== Test: View3D ===")
from backend.nodes.display import View3D
node = View3D()
field = make_field()
captured = []
View3D._broadcast_mesh_fn = lambda nid, mesh: captured.append(mesh)
View3D._current_node_id = "test"
result = node.render(field, colormap="viridis", z_scale=2.0, resolution=64)
assert result == ()
assert len(captured) == 1
mesh = captured[0]
assert "width" in mesh
assert "height" in mesh
assert "z_data" in mesh
assert "colors" in mesh
assert mesh["z_scale"] == 2.0
assert mesh["width"] <= 64
assert mesh["height"] <= 64
# z_min < z_max for non-constant data
assert mesh["z_min"] < mesh["z_max"]
# Verify base64 data can be decoded
import base64
z_bytes = base64.b64decode(mesh["z_data"])
assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 # float32
colors_bytes = base64.b64decode(mesh["colors"])
assert len(colors_bytes) == mesh["width"] * mesh["height"] * 3 # uint8 RGB
# High-res input should be downsampled
big_field = make_field(shape=(256, 256))
captured.clear()
node.render(big_field, colormap="hot", z_scale=1.0, resolution=64)
assert captured[0]["width"] <= 64
assert captured[0]["height"] <= 64
View3D._broadcast_mesh_fn = None
print(" PASS\n")
# =========================================================================
# Run all tests
# =========================================================================
@@ -662,6 +1121,9 @@ if __name__ == "__main__":
test_statistics()
test_height_histogram()
test_cross_section()
test_line_cursors()
test_fft2d()
test_line_math()
# Mask
test_threshold_mask()
@@ -673,11 +1135,20 @@ if __name__ == "__main__":
test_particle_analysis()
# I/O
test_load_image()
test_load_file()
test_load_file_ibw()
test_load_file_npz()
test_load_file_not_found()
test_load_file_unsupported()
test_load_file_warning()
test_list_channels()
test_load_demo()
test_coordinate()
test_save_image()
# Display
test_preview_image()
test_print_table()
test_view3d()
print("All tests passed!")