From a60b0c15ca38cbdac274d5e5ea21d7c6218e1a92 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Tue, 24 Mar 2026 21:01:58 -0700 Subject: [PATCH] multichannel support + colormap inherit --- backend/data_types.py | 6 + backend/execution.py | 13 +- backend/nodes/analysis.py | 1 + backend/nodes/display.py | 13 +- backend/nodes/io.py | 483 ++++++++++++++++++++-------------- backend/server.py | 17 ++ frontend/src/App.jsx | 51 ++++ frontend/src/CustomNode.jsx | 5 + frontend/src/api.js | 6 + frontend/src/styles.css | 14 +- frontend/vite.config.js | 1 + tests/test_nodes.py | 499 +++++++++++++++++++++++++++++++++++- 12 files changed, 889 insertions(+), 220 deletions(-) diff --git a/backend/data_types.py b/backend/data_types.py index f1bdbba..1fff322 100644 --- a/backend/data_types.py +++ b/backend/data_types.py @@ -16,6 +16,9 @@ from dataclasses import dataclass, field import numpy as np +COLORMAPS = ("viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain", + "cividis", "magma", "copper", "afmhot") + @dataclass class DataField: data: np.ndarray # shape (yres, xres), dtype float64 @@ -28,6 +31,7 @@ class DataField: si_unit_xy: str = "m" si_unit_z: str = "m" domain: str = "spatial" # "spatial" or "frequency" + colormap: str = "viridis" def __post_init__(self) -> None: self.data = np.asarray(self.data, dtype=np.float64) @@ -48,6 +52,7 @@ class DataField: si_unit_xy=self.si_unit_xy, si_unit_z=self.si_unit_z, domain=self.domain, + colormap=self.colormap, ) def replace(self, **kwargs) -> "DataField": @@ -63,6 +68,7 @@ class DataField: "si_unit_xy": self.si_unit_xy, "si_unit_z": self.si_unit_z, "domain": self.domain, + "colormap": self.colormap, } base.update(kwargs) return DataField(**base) diff --git a/backend/execution.py b/backend/execution.py index b9a52af..4172dfb 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -50,6 +50,7 @@ class ExecutionEngine: on_table: Callable[[str, list], None] | None = None, on_mesh: Callable[[str, dict], None] | None = None, on_overlay: Callable[[str, str], None] | None = None, + on_warning: Callable[[str, str], None] | None = None, ) -> dict[str, tuple]: """ Execute the workflow described by `prompt`. @@ -62,6 +63,7 @@ class ExecutionEngine: on_preview : called with (node_id, data_uri) when a display node runs on_table : called with (node_id, table_list) when PrintTable runs on_overlay : called with (node_id, data_uri) for interactive overlays + on_warning : called with (node_id, message) for node warnings Returns ------- @@ -71,7 +73,7 @@ class ExecutionEngine: node_outputs: dict[str, tuple] = {} # Inject display callbacks before execution - self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay) + self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_warning) for node_id in order: node_def = prompt[node_id] @@ -174,12 +176,13 @@ class ExecutionEngine: on_table: Callable | None, on_mesh: Callable | None = None, on_overlay: Callable | None = None, + on_warning: Callable | None = None, ) -> None: """Wire up broadcast callbacks on display node classes.""" from backend.nodes.display import PreviewImage, PrintTable, View3D from backend.nodes.analysis import CrossSection, LineCursors from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine - from backend.nodes.io import SaveImage + from backend.nodes.io import SaveImage, LoadFile PreviewImage._broadcast_fn = on_preview ThresholdMask._broadcast_fn = on_preview @@ -190,6 +193,7 @@ class ExecutionEngine: PrintTable._broadcast_table_fn = on_table CrossSection._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay + LoadFile._broadcast_warning_fn = on_warning SaveImage._broadcast_preview = ( (lambda data_uri: on_preview("save", data_uri)) if on_preview else None ) @@ -199,8 +203,9 @@ class ExecutionEngine: from backend.nodes.display import PreviewImage, PrintTable, View3D from backend.nodes.analysis import CrossSection, LineCursors from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine + from backend.nodes.io import LoadFile if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors, - ThresholdMask, MaskMorphology, MaskInvert, MaskCombine): + ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile): cls._current_node_id = node_id def _auto_preview( @@ -232,7 +237,7 @@ class ExecutionEngine: value = result[slot] if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview: - arr = datafield_to_uint8(value, "viridis") + arr = datafield_to_uint8(value, value.colormap) on_preview(node_id, encode_preview(arr)) return # one preview per node is enough diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index 60c6b30..fc9e6a1 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -326,6 +326,7 @@ class FFT2D: si_unit_xy="1/m", si_unit_z=z_unit, domain="frequency", + colormap=field.colormap, ) return (out_field,) diff --git a/backend/nodes/display.py b/backend/nodes/display.py index c12c7ab..c6a1a0f 100644 --- a/backend/nodes/display.py +++ b/backend/nodes/display.py @@ -9,7 +9,7 @@ before execution begins. from __future__ import annotations import numpy as np from backend.node_registry import register_node -from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview +from backend.data_types import DataField, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview @register_node(display_name="Preview") @@ -18,7 +18,7 @@ class PreviewImage: def INPUT_TYPES(cls): return { "required": { - "colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],), + "colormap": (["auto"] + list(COLORMAPS),), }, "optional": { "image": ("IMAGE",), @@ -36,6 +36,10 @@ class PreviewImage: _current_node_id: str = "" def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple: + # Resolve "auto" — use field's colormap if available, else fall back to gray + if colormap == "auto": + colormap = field.colormap if field is not None else "gray" + # Prefer field if both are connected; accept whichever is provided if field is not None: arr_u8 = datafield_to_uint8(field, colormap) @@ -73,7 +77,7 @@ class View3D: return { "required": { "field": ("DATA_FIELD",), - "colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],), + "colormap": (["auto"] + list(COLORMAPS),), "z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}), "resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}), } @@ -114,7 +118,8 @@ class View3D: else: z_norm = np.zeros_like(z) - cmap = cm.get_cmap(colormap) + cmap_name = field.colormap if colormap == "auto" else colormap + cmap = cm.get_cmap(cmap_name) rgba = cmap(z_norm) # (ny, nx, 4) float [0,1] colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8) diff --git a/backend/nodes/io.py b/backend/nodes/io.py index 7a44f46..3d432c6 100644 --- a/backend/nodes/io.py +++ b/backend/nodes/io.py @@ -8,7 +8,7 @@ import numpy as np from pathlib import Path from backend.node_registry import register_node -from backend.data_types import DataField, encode_preview, image_to_uint8 +from backend.data_types import DataField, COLORMAPS, encode_preview, image_to_uint8 from backend.runtime_paths import demo_dir, input_dir, output_dir # Resolved at server startup so nodes know where to look @@ -19,112 +19,293 @@ OUTPUT_DIR = output_dir() _DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz", ".gwy", ".sxm", ".ibw"} +_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"} +_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"} +_ARRAY_EXTENSIONS = {".npy", ".npz"} + # --------------------------------------------------------------------------- -# LoadImage +# Channel listing helper (used by the /channels endpoint) # --------------------------------------------------------------------------- -@register_node(display_name="Load Image") -class LoadImage: +def _resolve_path(filepath: str) -> Path: + path = Path(filepath) + if path.is_absolute(): + return path + # Try input dir first, then demo dir + candidate = INPUT_DIR / filepath + if candidate.exists(): + return candidate + candidate = DEMO_DIR / filepath + if candidate.exists(): + return candidate + # Fall back to input dir (will trigger FileNotFoundError later) + return INPUT_DIR / filepath + + +def list_channels(filepath: str) -> list[dict]: + """Return available channel info for a file. + + Returns a list of {"name": str, "type": "DATA_FIELD"} dicts. + For SPM formats this inspects the file header. + For images / arrays, returns a single unnamed channel. + """ + path = _resolve_path(filepath) + if not path.exists(): + return [{"name": "field", "type": "DATA_FIELD"}] + + ext = path.suffix.lower() + + if ext == ".gwy": + try: + import gwyfile + obj = gwyfile.load(str(path)) + channels = gwyfile.util.get_datafields(obj) + if channels: + return [{"name": k, "type": "DATA_FIELD"} for k in channels] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + if ext == ".sxm": + try: + import nanonispy as nap + sxm = nap.read.Scan(str(path)) + if sxm.signals: + return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + if ext == ".ibw": + try: + from igor.binarywave import load as load_ibw + wave = load_ibw(str(path)) + raw = wave["wave"]["wData"] + labels = wave["wave"].get("labels", None) + if raw.ndim >= 3 and labels: + dim_idx = min(2, len(labels) - 1) + if dim_idx >= 0 and labels[dim_idx]: + decoded = [] + for lbl in labels[dim_idx]: + if lbl: + name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip() + if name: + decoded.append(name) + if decoded: + return [{"name": n, "type": "DATA_FIELD"} for n in decoded] + # Multi-channel without labels — use numeric names + if raw.ndim >= 3 and raw.shape[2] > 1: + return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + # Image or array — single channel + return [{"name": "field", "type": "DATA_FIELD"}] + + +# --------------------------------------------------------------------------- +# LoadFile (unified loader — replaces LoadImage + LoadSPM) +# --------------------------------------------------------------------------- + +@register_node(display_name="Load File") +class LoadFile: @classmethod def INPUT_TYPES(cls): return { "required": { "filename": ("FILE_PICKER", {"default": ""}), + "colormap": (list(COLORMAPS),), } } - RETURN_TYPES = ("IMAGE", "DATA_FIELD") - RETURN_NAMES = ("image", "field") + # Default outputs — overridden dynamically by the frontend for multi-channel files + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("field",) FUNCTION = "load" CATEGORY = "io" - DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD." - - def load(self, filename: str): - # Accept absolute paths or filenames relative to input/ - path = Path(filename) - if not path.is_absolute(): - path = INPUT_DIR / filename - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") - - ext = path.suffix.lower() - if ext in (".npy",): - arr = np.load(str(path)).astype(np.float64) - elif ext in (".npz",): - npz = np.load(str(path)) - key = list(npz.files)[0] - arr = npz[key].astype(np.float64) - else: - from PIL import Image - img = Image.open(str(path)) - arr = np.array(img) - if arr.dtype != np.uint8: - arr = arr.astype(np.float64) - - # Convert to float64 grayscale for the DATA_FIELD output - if arr.ndim == 3: - gray = np.mean(arr.astype(np.float64), axis=2) - else: - gray = arr.astype(np.float64) - - field = DataField(data=gray) - return (arr, field) - - -# --------------------------------------------------------------------------- -# LoadDemo -# --------------------------------------------------------------------------- - -def _list_demo_files() -> list[str]: - """Return sorted list of demo filenames available in the demo/ directory.""" - if not DEMO_DIR.exists(): - return [] - return sorted( - f.name for f in DEMO_DIR.iterdir() - if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS + DESCRIPTION = ( + "Load any supported file. " + "SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; " + "each channel gets its own output. " + "Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields." ) + # Set by execution engine for warning broadcast + _broadcast_warning_fn = None + _current_node_id = None -@register_node(display_name="Load Demo Image") -class LoadDemo: - @classmethod - def INPUT_TYPES(cls): - choices = _list_demo_files() or ["(no demo images found)"] - return { - "required": { - "name": (choices,), - } - } - - RETURN_TYPES = ("IMAGE", "DATA_FIELD") - RETURN_NAMES = ("image", "field") - FUNCTION = "load" - CATEGORY = "io" - DESCRIPTION = "Load a bundled demo image so you can try the app without providing your own data." - - def load(self, name: str): - path = DEMO_DIR / name + def load(self, filename: str, colormap: str = "viridis"): + if not filename or not filename.strip(): + raise ValueError("No file selected — use Browse to pick a file.") + path = _resolve_path(filename) if not path.exists(): - raise FileNotFoundError(f"Demo image not found: {name}") + raise FileNotFoundError(f"File not found: {path}") + if path.is_dir(): + raise IsADirectoryError(f"Expected a file, got a directory: {path}") ext = path.suffix.lower() - # SPM formats → delegate to LoadSPM-style loading, return as IMAGE + DATA_FIELD - if ext == ".gwy": - field = LoadSPM()._load_gwy(path, "Z") - arr = field.data - return (arr, field) - elif ext == ".sxm": - field = LoadSPM()._load_sxm(path, "Z") - arr = field.data - return (arr, field) - elif ext == ".ibw": - field = LoadSPM()._load_ibw(path) - arr = field.data - return (arr, field) + if ext in _SPM_EXTENSIONS: + fields = self._load_spm_all(path, ext) + for f in fields: + f.colormap = colormap + return tuple(fields) - # npy / npz + # Image or array — uncalibrated, single output + field = self._load_image_or_array(path, ext) + field.colormap = colormap + self._send_warning("Uncalibrated data — no physical dimensions.") + return (field,) + + def _send_warning(self, message: str): + fn = LoadFile._broadcast_warning_fn + nid = LoadFile._current_node_id + if fn and nid: + fn(nid, message) + + # -- SPM: load all channels --------------------------------------------- + + def _load_spm_all(self, path: Path, ext: str) -> list[DataField]: + if ext == ".gwy": + return self._load_gwy_all(path) + elif ext == ".sxm": + return self._load_sxm_all(path) + elif ext == ".ibw": + return self._load_ibw_all(path) + else: + raise ValueError(f"Unsupported SPM format: {ext}") + + # -- GWY ---------------------------------------------------------------- + + def _load_gwy_all(self, path: Path) -> list[DataField]: + try: + import gwyfile + except ImportError: + raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile") + + obj = gwyfile.load(str(path)) + channels = gwyfile.util.get_datafields(obj) + if not channels: + raise ValueError(f"No data channels found in {path.name}") + + fields = [] + for ch in channels.values(): + data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) + fields.append(DataField( + data=data, + xreal=float(ch.xreal), + yreal=float(ch.yreal), + xoff=float(getattr(ch, "xoff", 0.0)), + yoff=float(getattr(ch, "yoff", 0.0)), + si_unit_xy="m", + si_unit_z="m", + )) + return fields + + # -- SXM ---------------------------------------------------------------- + + def _load_sxm_all(self, path: Path) -> list[DataField]: + try: + import nanonispy as nap + except ImportError: + raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy") + + sxm = nap.read.Scan(str(path)) + signals = sxm.signals + if not signals: + raise ValueError(f"No signals found in {path.name}") + + header = sxm.header + scan_range = header.get("scan_range", [1e-6, 1e-6]) + + fields = [] + for sig in signals.values(): + data = sig.get("forward", list(sig.values())[0]) + data = np.asarray(data, dtype=np.float64) + if data.ndim != 2: + data = data.reshape(data.shape[-2], data.shape[-1]) + fields.append(DataField( + data=data, + xreal=float(scan_range[0]), + yreal=float(scan_range[1]), + si_unit_xy="m", + si_unit_z="m", + )) + return fields + + # -- IBW ---------------------------------------------------------------- + + def _load_ibw_all(self, path: Path) -> list[DataField]: + try: + from igor.binarywave import load as load_ibw + except ImportError: + raise ImportError("Install 'igor' package to load .ibw files: pip install igor") + + wave = load_ibw(str(path)) + wdata = wave["wave"] + header = wdata["wave_header"] + raw = wdata["wData"] + + n_channels = raw.shape[2] if raw.ndim >= 3 else 1 + + # Physical scaling + sfA = header.get("sfA", None) + + def _decode_unit(raw_unit): + if raw_unit is None: + return "m" + if isinstance(raw_unit, bytes): + return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" + if isinstance(raw_unit, np.ndarray): + return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" + return str(raw_unit).strip() or "m" + + dim_units_raw = header.get("dimUnits", None) + data_units_raw = header.get("dataUnits", None) + + if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2: + si_unit_xy = _decode_unit(dim_units_raw[0]) + elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0: + si_unit_xy = _decode_unit(dim_units_raw[0]) + else: + si_unit_xy = _decode_unit(dim_units_raw) + + si_unit_z = _decode_unit(data_units_raw) + + fields = [] + for ch_idx in range(n_channels): + if raw.ndim >= 3: + ch_data = raw[:, :, ch_idx] + elif raw.ndim == 1: + ch_data = raw.reshape(-1, 1) + else: + ch_data = raw + + # Transpose from (xres, yres) Igor order to (yres, xres) DataField order, + # then flip vertically to match gwyddion + data = np.flipud(ch_data.T).astype(np.float64) + yres, xres = data.shape + + if sfA is not None and len(sfA) >= 2: + xreal = abs(float(sfA[0]) * xres) or 1e-6 + yreal = abs(float(sfA[1]) * yres) or 1e-6 + else: + hsA = header.get("hsA", 0.0) + xreal = abs(float(hsA) * xres) or 1e-6 + yreal = xreal * (yres / xres) if xres else 1e-6 + + fields.append(DataField( + data=data, xreal=xreal, yreal=yreal, + si_unit_xy=si_unit_xy, si_unit_z=si_unit_z, + )) + + return fields + + # -- Image / array (uncalibrated) -------------------------------------- + + def _load_image_or_array(self, path: Path, ext: str) -> DataField: if ext == ".npy": arr = np.load(str(path)).astype(np.float64) elif ext == ".npz": @@ -143,22 +324,32 @@ class LoadDemo: else: gray = arr.astype(np.float64) - field = DataField(data=gray) - return (arr, field) + return DataField(data=gray) # --------------------------------------------------------------------------- -# LoadSPM +# LoadDemo # --------------------------------------------------------------------------- -@register_node(display_name="Load SPM File") -class LoadSPM: +def _list_demo_files() -> list[str]: + """Return sorted list of demo filenames available in the demo/ directory.""" + if not DEMO_DIR.exists(): + return [] + return sorted( + f.name for f in DEMO_DIR.iterdir() + if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS + ) + + +@register_node(display_name="Load Demo File") +class LoadDemo: @classmethod def INPUT_TYPES(cls): + choices = _list_demo_files() or ["(no demo files found)"] return { "required": { - "filename": ("FILE_PICKER", {"default": ""}), - "channel": ("STRING", {"default": "Z"}), + "name": (choices,), + "colormap": (list(COLORMAPS),), } } @@ -166,111 +357,15 @@ class LoadSPM: RETURN_NAMES = ("field",) FUNCTION = "load" CATEGORY = "io" - DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField." + DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data." - def load(self, filename: str, channel: str = "Z"): - path = Path(filename) - if not path.is_absolute(): - path = INPUT_DIR / filename + def load(self, name: str, colormap: str = "viridis"): + path = DEMO_DIR / name if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") + raise FileNotFoundError(f"Demo file not found: {name}") - ext = path.suffix.lower() - - if ext == ".gwy": - return (self._load_gwy(path, channel),) - elif ext == ".sxm": - return (self._load_sxm(path, channel),) - elif ext in (".ibw",): - return (self._load_ibw(path),) - elif ext in (".npy",): - data = np.load(str(path)).astype(np.float64) - return (DataField(data=data),) - elif ext in (".npz",): - npz = np.load(str(path)) - key = list(npz.files)[0] - return (DataField(data=npz[key].astype(np.float64)),) - else: - raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz") - - def _load_gwy(self, path: Path, channel: str) -> DataField: - try: - import gwyfile - except ImportError: - raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile") - - obj = gwyfile.load(str(path)) - channels = gwyfile.util.get_datafields(obj) - if not channels: - raise ValueError(f"No data channels found in {path.name}") - - # Try requested channel name, fall back to first available - ch = None - for key, df in channels.items(): - if channel.lower() in key.lower(): - ch = df - break - if ch is None: - ch = next(iter(channels.values())) - - data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) - return DataField( - data=data, - xreal=float(ch.xreal), - yreal=float(ch.yreal), - xoff=float(getattr(ch, "xoff", 0.0)), - yoff=float(getattr(ch, "yoff", 0.0)), - si_unit_xy="m", - si_unit_z="m", - ) - - def _load_sxm(self, path: Path, channel: str) -> DataField: - try: - import nanonispy as nap - except ImportError: - raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy") - - sxm = nap.read.Scan(str(path)) - signals = sxm.signals - - # Pick channel - ch_key = None - for key in signals: - if channel.upper() in key.upper(): - ch_key = key - break - if ch_key is None: - ch_key = next(iter(signals)) - - data = signals[ch_key].get("forward", list(signals[ch_key].values())[0]) - data = np.asarray(data, dtype=np.float64) - if data.ndim != 2: - data = data.reshape(data.shape[-2], data.shape[-1]) - - header = sxm.header - scan_range = header.get("scan_range", [1e-6, 1e-6]) - return DataField( - data=data, - xreal=float(scan_range[0]), - yreal=float(scan_range[1]), - si_unit_xy="m", - si_unit_z="m", - ) - - def _load_ibw(self, path: Path) -> DataField: - try: - import igor.igorpy as igorpy - wave = igorpy.load(str(path)) - data = wave.wave["wData"].squeeze().astype(np.float64) - except ImportError: - raise ImportError("Install 'igor' package to load .ibw files: pip install igor") - - if data.ndim == 1: - data = data.reshape(1, -1) - elif data.ndim != 2: - data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1) - - return DataField(data=data, si_unit_xy="m", si_unit_z="m") + loader = LoadFile() + return loader.load(filename=str(path), colormap=colormap) # --------------------------------------------------------------------------- diff --git a/backend/server.py b/backend/server.py index 87f9acd..7fb50d4 100644 --- a/backend/server.py +++ b/backend/server.py @@ -101,6 +101,9 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: def on_overlay(node_id: str, overlay_data) -> None: broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) + def on_warning(node_id: str, message: str) -> None: + broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}}) + # ------------------------------------------------------------------ # Route handlers # ------------------------------------------------------------------ @@ -193,6 +196,18 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: }, ) + async def get_channels(request: web.Request) -> web.Response: + """Return available channels for a given file path.""" + from backend.nodes.io import list_channels + filepath = request.query.get("file", "") + if not filepath: + return web.Response( + text=_dumps([{"name": "field", "type": "DATA_FIELD"}]), + content_type="application/json", + ) + channels = await loop.run_in_executor(None, list_channels, filepath) + return web.Response(text=_dumps(channels), content_type="application/json") + async def submit_prompt(request: web.Request) -> web.Response: body = await request.json() prompt = body.get("prompt") @@ -218,6 +233,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: on_table=on_table, on_mesh=on_mesh, on_overlay=on_overlay, + on_warning=on_warning, ), ) broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}}) @@ -262,6 +278,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: app.router.add_get("/browse", browse_dir) app.router.add_post("/upload", upload_file) app.router.add_post("/download", download_file) + app.router.add_get("/channels", get_channels) app.router.add_post("/prompt", submit_prompt) app.router.add_get("/ws", websocket_handler) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index d22d342..cbf6075 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -436,6 +436,9 @@ function Flow() { case 'overlay': updateNodeData(msg.data.node_id, { overlay: msg.data.overlay }); break; + case 'node_warning': + updateNodeData(msg.data.node_id, { warning: msg.data.message }); + break; } }); api.initWS(); @@ -500,9 +503,36 @@ function Flow() { data: { ...n.data, widgetValues: { ...n.data.widgetValues, [name]: value }, + // Clear warning when user changes a value + warning: null, }, }; })); + + // If this is a filename/name change on a LoadFile/LoadDemo node, fetch channels + if ((name === 'filename' || name === 'name') && value) { + const node = reactFlow.getNode(nodeId); + if (node && (node.data.className === 'LoadFile' || node.data.className === 'LoadDemo')) { + api.getChannels(value).then((channels) => { + setNodes((prev) => prev.map((n) => { + if (n.id !== nodeId) return n; + return { + ...n, + data: { + ...n.data, + definition: { + ...n.data.definition, + output: channels.map((c) => c.type), + output_name: channels.map((c) => c.name), + }, + }, + }; + })); + reactFlow.updateNodeInternals(nodeId); + }); + } + } + scheduleAutoRun(); }, [setNodes]); // scheduleAutoRun is stable (no deps) @@ -568,6 +598,27 @@ function Flow() { setNodes((ns) => [...ns, newNode]); + // For LoadFile/LoadDemo, auto-fetch channels for the default value + if (className === 'LoadDemo' && widgetValues.name) { + api.getChannels(widgetValues.name).then((channels) => { + setNodes((prev) => prev.map((n) => { + if (n.id !== newNodeId) return n; + return { + ...n, + data: { + ...n.data, + definition: { + ...n.data.definition, + output: channels.map((c) => c.type), + output_name: channels.map((c) => c.name), + }, + }, + }; + })); + reactFlow.updateNodeInternals(newNodeId); + }); + } + // Auto-connect if this was triggered by dropping a connection on blank space if (contextMenu.pendingHandleId) { const filterType = contextMenu.filterType; diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index b26eef4..21ab15b 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -211,6 +211,11 @@ function CustomNode({ id, data }) { ); })} + {/* Warning notification */} + {data.warning && ( +
{data.warning}
+ )} + {/* Widget rows */} {widgets.map((w) => (
diff --git a/frontend/src/api.js b/frontend/src/api.js index b90acae..39ede09 100644 --- a/frontend/src/api.js +++ b/frontend/src/api.js @@ -34,6 +34,12 @@ export async function uploadFile(file) { return r.json(); } +export async function getChannels(filepath) { + const r = await fetch(`/channels?file=${encodeURIComponent(filepath)}`); + if (!r.ok) return [{ name: 'field', type: 'DATA_FIELD' }]; + return r.json(); +} + export async function runPrompt(prompt) { const r = await fetch('/prompt', { method: 'POST', diff --git a/frontend/src/styles.css b/frontend/src/styles.css index c79483e..41ee658 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -82,10 +82,7 @@ html, body, #root { padding: 4px 10px; border-radius: 4px; font-size: 11px; - max-width: 400px; - white-space: nowrap; - overflow: hidden; - text-overflow: ellipsis; + max-width: 60%; flex-shrink: 1; } .status-bar.info { color: #90caf9; } @@ -155,6 +152,15 @@ html, body, #root { padding: 4px 0; } +.node-warning { + padding: 3px 10px; + font-size: 10px; + color: #fbbf24; + background: rgba(251, 191, 36, 0.1); + border-top: 1px solid rgba(251, 191, 36, 0.2); + border-bottom: 1px solid rgba(251, 191, 36, 0.2); +} + /* ── I/O rows ──────────────────────────────────────────────────────── */ .io-row { display: flex; diff --git a/frontend/vite.config.js b/frontend/vite.config.js index 63e6370..eac01c1 100644 --- a/frontend/vite.config.js +++ b/frontend/vite.config.js @@ -9,6 +9,7 @@ export default defineConfig({ '/nodes': 'http://127.0.0.1:8188', '/files': 'http://127.0.0.1:8188', '/browse': 'http://127.0.0.1:8188', + '/channels': 'http://127.0.0.1:8188', '/upload': 'http://127.0.0.1:8188', '/download': 'http://127.0.0.1:8188', '/prompt': 'http://127.0.0.1:8188', diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 541ca60..f81efa0 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -523,41 +523,42 @@ def test_particle_analysis(): # I/O # ========================================================================= -def test_load_image(): - print("=== Test: LoadImage ===") - from backend.nodes.io import LoadImage +def test_load_file(): + print("=== Test: LoadFile ===") + from backend.nodes.io import LoadFile from PIL import Image - node = LoadImage() + node = LoadFile() with tempfile.TemporaryDirectory() as tmpdir: - # Test loading a grayscale PNG + # Test loading a grayscale PNG → single DataField output arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8) img = Image.fromarray(arr, mode="L") path = os.path.join(tmpdir, "test_gray.png") img.save(path) - image, field = node.load(filename=path) - assert image.shape == (48, 64) + result = node.load(filename=path) + assert len(result) == 1 + field = result[0] assert field.data.shape == (48, 64) assert field.data.dtype == np.float64 - # Test loading an RGB PNG (should average to grayscale for field) + # Test loading an RGB PNG (should average to grayscale) arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8) img_rgb = Image.fromarray(arr_rgb, mode="RGB") path_rgb = os.path.join(tmpdir, "test_rgb.png") img_rgb.save(path_rgb) - image_rgb, field_rgb = node.load(filename=path_rgb) - assert image_rgb.shape == (32, 32, 3) - assert field_rgb.data.shape == (32, 32) + result_rgb = node.load(filename=path_rgb) + assert len(result_rgb) == 1 + assert result_rgb[0].data.shape == (32, 32) # Test loading a .npy file data_npy = np.random.default_rng(3).standard_normal((50, 60)) path_npy = os.path.join(tmpdir, "test.npy") np.save(path_npy, data_npy) - image_npy, field_npy = node.load(filename=path_npy) - assert np.allclose(field_npy.data, data_npy) + result_npy = node.load(filename=path_npy) + assert np.allclose(result_npy[0].data, data_npy) print(" PASS\n") @@ -641,6 +642,464 @@ def test_print_table(): print(" PASS\n") +# ========================================================================= +# I/O — IBW multi-channel loading +# ========================================================================= + +def test_load_file_ibw(): + print("=== Test: LoadFile IBW multi-channel ===") + from backend.nodes.io import LoadFile + + node = LoadFile() + ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw") + ibw_path = os.path.abspath(ibw_path) + if not os.path.exists(ibw_path): + print(" SKIP (demo IBW file not found)\n") + return + + result = node.load(filename=ibw_path) + + # BR_New20012.ibw has 4 channels + assert len(result) == 4, f"Expected 4 channels, got {len(result)}" + + for i, field in enumerate(result): + assert isinstance(field, DataField), f"Channel {i} is not a DataField" + assert field.data.shape == (512, 1024), f"Channel {i} shape: {field.data.shape}" + assert field.data.dtype == np.float64 + # Physical dimensions should be populated (not default 1e-6) + assert field.xreal > 1e-8, f"Channel {i} xreal too small: {field.xreal}" + assert field.yreal > 1e-8, f"Channel {i} yreal too small: {field.yreal}" + assert field.si_unit_xy == "m" + assert field.si_unit_z == "m" + + # All channels should share the same physical dimensions + assert result[0].xreal == result[1].xreal + assert result[0].yreal == result[1].yreal + + # Different channels should have different data + assert not np.array_equal(result[0].data, result[1].data) + + print(" PASS\n") + + +def test_load_file_npz(): + print("=== Test: LoadFile .npz ===") + from backend.nodes.io import LoadFile + + node = LoadFile() + with tempfile.TemporaryDirectory() as tmpdir: + data = np.random.default_rng(99).standard_normal((30, 40)) + path = os.path.join(tmpdir, "test.npz") + np.savez(path, my_array=data) + + result = node.load(filename=path) + assert len(result) == 1 + assert np.allclose(result[0].data, data) + + print(" PASS\n") + + +def test_load_file_not_found(): + print("=== Test: LoadFile not found ===") + from backend.nodes.io import LoadFile + + node = LoadFile() + try: + node.load(filename="/nonexistent/path/file.png") + assert False, "Should have raised FileNotFoundError" + except FileNotFoundError: + pass + + print(" PASS\n") + + +def test_load_file_unsupported(): + print("=== Test: LoadFile unsupported format ===") + from backend.nodes.io import LoadFile + + node = LoadFile() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "test.xyz") + with open(path, "w") as f: + f.write("hello") + try: + node.load(filename=path) + assert False, "Should have raised an error for .xyz" + except Exception: + pass + + print(" PASS\n") + + +def test_load_file_warning(): + print("=== Test: LoadFile warning for uncalibrated data ===") + from backend.nodes.io import LoadFile + from PIL import Image + + node = LoadFile() + warnings = [] + LoadFile._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) + LoadFile._current_node_id = "test" + + with tempfile.TemporaryDirectory() as tmpdir: + arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8) + img = Image.fromarray(arr) + path = os.path.join(tmpdir, "test.png") + img.save(path) + + result = node.load(filename=path) + assert len(result) == 1 + assert len(warnings) == 1 + assert "Uncalibrated" in warnings[0] + + LoadFile._broadcast_warning_fn = None + print(" PASS\n") + + +# ========================================================================= +# I/O — list_channels helper +# ========================================================================= + +def test_list_channels(): + print("=== Test: list_channels ===") + from backend.nodes.io import list_channels + + # Non-existent file → default + ch = list_channels("/nonexistent/file.ibw") + assert len(ch) == 1 + assert ch[0]["name"] == "field" + + # IBW with channels + ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")) + if os.path.exists(ibw_path): + ch = list_channels(ibw_path) + assert len(ch) == 4 + names = [c["name"] for c in ch] + assert "HeightRetrace" in names + assert "AmplitudeRetrace" in names + assert all(c["type"] == "DATA_FIELD" for c in ch) + + # Plain image → single default channel + with tempfile.TemporaryDirectory() as tmpdir: + from PIL import Image + img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8)) + path = os.path.join(tmpdir, "test.png") + img.save(path) + + ch = list_channels(path) + assert len(ch) == 1 + assert ch[0]["name"] == "field" + + # .npy → single default channel + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "test.npy") + np.save(path, np.zeros((4, 4))) + + ch = list_channels(path) + assert len(ch) == 1 + + print(" PASS\n") + + +# ========================================================================= +# I/O — LoadDemo +# ========================================================================= + +def test_load_demo(): + print("=== Test: LoadDemo ===") + from backend.nodes.io import LoadDemo + + node = LoadDemo() + + # Should be able to load a demo file by name + result = node.load(name="nanoparticles.npy") + assert len(result) >= 1 + assert isinstance(result[0], DataField) + assert result[0].data.ndim == 2 + + # IBW demo should return multiple channels + result_ibw = node.load(name="whiskers.ibw") + assert len(result_ibw) == 4 + for field in result_ibw: + assert isinstance(field, DataField) + + # Non-existent demo should raise + try: + node.load(name="nonexistent_file.png") + assert False, "Should have raised FileNotFoundError" + except FileNotFoundError: + pass + + print(" PASS\n") + + +# ========================================================================= +# I/O — Coordinate +# ========================================================================= + +def test_coordinate(): + print("=== Test: Coordinate ===") + from backend.nodes.io import Coordinate + + node = Coordinate() + + result = node.process(x=0.3, y=0.7) + assert len(result) == 1 + assert result[0] == (0.3, 0.7) + + # Edge values + result_zero = node.process(x=0.0, y=0.0) + assert result_zero[0] == (0.0, 0.0) + + result_one = node.process(x=1.0, y=1.0) + assert result_one[0] == (1.0, 1.0) + + print(" PASS\n") + + +# ========================================================================= +# Analysis — LineCursors +# ========================================================================= + +def test_line_cursors(): + print("=== Test: LineCursors ===") + from backend.nodes.analysis import LineCursors + + node = LineCursors() + + # Create a simple linear ramp + line = np.linspace(0, 10, 100).astype(np.float64) + + # Capture overlay + overlays = [] + LineCursors._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + LineCursors._current_node_id = "test" + + table, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5) + + # Should produce a 6-row table + assert len(table) == 6 + quantities = {row["quantity"] for row in table} + assert "A position" in quantities + assert "B position" in quantities + assert "delta X" in quantities + assert "delta Y" in quantities + + # B should be at a later position than A + a_pos = next(r["value"] for r in table if r["quantity"] == "A position") + b_pos = next(r["value"] for r in table if r["quantity"] == "B position") + assert b_pos > a_pos + + # Delta Y should reflect the height difference along the ramp + dy = next(r["value"] for r in table if r["quantity"] == "delta Y") + assert dy > 0 # ramp goes upward + + # Overlay should have been broadcast + assert len(overlays) == 1 + assert "image" in overlays[0] + assert overlays[0]["image"].startswith("data:image/png;base64,") + + # With x_axis provided + x_axis = np.linspace(0, 1, 100).astype(np.float64) + table2, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5, x_axis=x_axis) + assert len(table2) == 6 + + LineCursors._broadcast_overlay_fn = None + print(" PASS\n") + + +# ========================================================================= +# Analysis — FFT2D +# ========================================================================= + +def test_fft2d(): + print("=== Test: FFT2D ===") + from backend.nodes.analysis import FFT2D + + node = FFT2D() + + # Pure single-frequency signal: peak should appear at the right location + N = 64 + y, x = np.mgrid[0:N, 0:N] / N + freq = 5 + data = np.sin(2 * np.pi * freq * x) + field = make_field(data=data, xreal=1e-6, yreal=1e-6) + + # log_magnitude + spectrum, = node.process(field, windowing="none", level="none", output="log_magnitude") + assert spectrum.data.shape == (N, N) + assert spectrum.domain == "frequency" + assert spectrum.si_unit_xy == "1/m" + # Peak should be symmetric about centre + centre = N // 2 + row = spectrum.data[centre, :] + peak_idx = np.argmax(row[centre + 1:]) + centre + 1 + assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}" + + # magnitude output + spec_mag, = node.process(field, windowing="hann", level="mean", output="magnitude") + assert spec_mag.data.shape == (N, N) + assert np.all(spec_mag.data >= 0) + + # phase output + spec_phase, = node.process(field, windowing="none", level="none", output="phase") + assert spec_phase.data.shape == (N, N) + assert spec_phase.data.min() >= -np.pi - 0.01 + assert spec_phase.data.max() <= np.pi + 0.01 + + # psdf output — units should reflect PSDF calibration + spec_psdf, = node.process(field, windowing="hamming", level="plane", output="psdf") + assert spec_psdf.data.shape == (N, N) + assert np.all(spec_psdf.data >= 0) + assert "^2" in spec_psdf.si_unit_z + + # Constant field should have all energy at DC + const_field = make_field(data=np.ones((32, 32)) * 3.0) + spec_const, = node.process(const_field, windowing="none", level="none", output="magnitude") + centre32 = 16 + dc_val = spec_const.data[centre32, centre32] + assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field" + + # Blackman windowing should also work without error + spec_bk, = node.process(field, windowing="blackman", level="none", output="log_magnitude") + assert spec_bk.data.shape == (N, N) + + print(" PASS\n") + + +# ========================================================================= +# Analysis — LineMath +# ========================================================================= + +def test_line_math(): + print("=== Test: LineMath ===") + from backend.nodes.analysis import LineMath + + node = LineMath() + line = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + + # Basic stats + table, = node.process(line, operation="min") + assert table[0]["value"] == 1.0 + + table, = node.process(line, operation="max") + assert table[0]["value"] == 5.0 + + table, = node.process(line, operation="mean") + assert table[0]["value"] == 3.0 + + table, = node.process(line, operation="median") + assert table[0]["value"] == 3.0 + + table, = node.process(line, operation="sum") + assert table[0]["value"] == 15.0 + + table, = node.process(line, operation="range") + assert table[0]["value"] == 4.0 + + table, = node.process(line, operation="length") + assert table[0]["value"] == 5.0 + + # RMS of [1,2,3,4,5] + table, = node.process(line, operation="rms") + expected_rms = np.sqrt(np.mean(line ** 2)) + assert abs(table[0]["value"] - expected_rms) < 1e-10 + + # Roughness parameters + table, = node.process(line, operation="Ra") + d = line - line.mean() + expected_ra = float(np.mean(np.abs(d))) + assert abs(table[0]["value"] - expected_ra) < 1e-10 + + table, = node.process(line, operation="Rq") + expected_rq = float(np.sqrt(np.mean(d ** 2))) + assert abs(table[0]["value"] - expected_rq) < 1e-10 + + # Rp = max of (z - mean) + table, = node.process(line, operation="Rp") + assert abs(table[0]["value"] - d.max()) < 1e-10 + + # Rv = -(min of (z - mean)) + table, = node.process(line, operation="Rv") + assert abs(table[0]["value"] - (-d.min())) < 1e-10 + + # Rt = Rp + Rv = range of (z - mean) + table, = node.process(line, operation="Rt") + assert abs(table[0]["value"] - (d.max() - d.min())) < 1e-10 + + # Constant line: roughness parameters should all be zero + const_line = np.ones(10) * 7.0 + table, = node.process(const_line, operation="Ra") + assert table[0]["value"] == 0.0 + table, = node.process(const_line, operation="Rq") + assert table[0]["value"] == 0.0 + table, = node.process(const_line, operation="Rsk") + assert table[0]["value"] == 0.0 + table, = node.process(const_line, operation="Rku") + assert table[0]["value"] == 0.0 + + # Slope-based: Dq and Da + table, = node.process(line, operation="Dq") + dz = np.diff(line) + expected_dq = float(np.sqrt(np.mean(dz * dz))) + assert abs(table[0]["value"] - expected_dq) < 1e-10 + + table, = node.process(line, operation="Da") + expected_da = float(np.mean(np.abs(dz))) + assert abs(table[0]["value"] - expected_da) < 1e-10 + + print(" PASS\n") + + +# ========================================================================= +# Display — View3D +# ========================================================================= + +def test_view3d(): + print("=== Test: View3D ===") + from backend.nodes.display import View3D + + node = View3D() + field = make_field() + + captured = [] + View3D._broadcast_mesh_fn = lambda nid, mesh: captured.append(mesh) + View3D._current_node_id = "test" + + result = node.render(field, colormap="viridis", z_scale=2.0, resolution=64) + assert result == () + assert len(captured) == 1 + + mesh = captured[0] + assert "width" in mesh + assert "height" in mesh + assert "z_data" in mesh + assert "colors" in mesh + assert mesh["z_scale"] == 2.0 + assert mesh["width"] <= 64 + assert mesh["height"] <= 64 + # z_min < z_max for non-constant data + assert mesh["z_min"] < mesh["z_max"] + + # Verify base64 data can be decoded + import base64 + z_bytes = base64.b64decode(mesh["z_data"]) + assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 # float32 + + colors_bytes = base64.b64decode(mesh["colors"]) + assert len(colors_bytes) == mesh["width"] * mesh["height"] * 3 # uint8 RGB + + # High-res input should be downsampled + big_field = make_field(shape=(256, 256)) + captured.clear() + node.render(big_field, colormap="hot", z_scale=1.0, resolution=64) + assert captured[0]["width"] <= 64 + assert captured[0]["height"] <= 64 + + View3D._broadcast_mesh_fn = None + print(" PASS\n") + + # ========================================================================= # Run all tests # ========================================================================= @@ -662,6 +1121,9 @@ if __name__ == "__main__": test_statistics() test_height_histogram() test_cross_section() + test_line_cursors() + test_fft2d() + test_line_math() # Mask test_threshold_mask() @@ -673,11 +1135,20 @@ if __name__ == "__main__": test_particle_analysis() # I/O - test_load_image() + test_load_file() + test_load_file_ibw() + test_load_file_npz() + test_load_file_not_found() + test_load_file_unsupported() + test_load_file_warning() + test_list_channels() + test_load_demo() + test_coordinate() test_save_image() # Display test_preview_image() test_print_table() + test_view3d() print("All tests passed!")