diff --git a/backend/data_types.py b/backend/data_types.py index f1bdbba..1fff322 100644 --- a/backend/data_types.py +++ b/backend/data_types.py @@ -16,6 +16,9 @@ from dataclasses import dataclass, field import numpy as np +COLORMAPS = ("viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain", + "cividis", "magma", "copper", "afmhot") + @dataclass class DataField: data: np.ndarray # shape (yres, xres), dtype float64 @@ -28,6 +31,7 @@ class DataField: si_unit_xy: str = "m" si_unit_z: str = "m" domain: str = "spatial" # "spatial" or "frequency" + colormap: str = "viridis" def __post_init__(self) -> None: self.data = np.asarray(self.data, dtype=np.float64) @@ -48,6 +52,7 @@ class DataField: si_unit_xy=self.si_unit_xy, si_unit_z=self.si_unit_z, domain=self.domain, + colormap=self.colormap, ) def replace(self, **kwargs) -> "DataField": @@ -63,6 +68,7 @@ class DataField: "si_unit_xy": self.si_unit_xy, "si_unit_z": self.si_unit_z, "domain": self.domain, + "colormap": self.colormap, } base.update(kwargs) return DataField(**base) diff --git a/backend/execution.py b/backend/execution.py index b9a52af..4172dfb 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -50,6 +50,7 @@ class ExecutionEngine: on_table: Callable[[str, list], None] | None = None, on_mesh: Callable[[str, dict], None] | None = None, on_overlay: Callable[[str, str], None] | None = None, + on_warning: Callable[[str, str], None] | None = None, ) -> dict[str, tuple]: """ Execute the workflow described by `prompt`. @@ -62,6 +63,7 @@ class ExecutionEngine: on_preview : called with (node_id, data_uri) when a display node runs on_table : called with (node_id, table_list) when PrintTable runs on_overlay : called with (node_id, data_uri) for interactive overlays + on_warning : called with (node_id, message) for node warnings Returns ------- @@ -71,7 +73,7 @@ class ExecutionEngine: node_outputs: dict[str, tuple] = {} # Inject display callbacks before execution - self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay) + self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_warning) for node_id in order: node_def = prompt[node_id] @@ -174,12 +176,13 @@ class ExecutionEngine: on_table: Callable | None, on_mesh: Callable | None = None, on_overlay: Callable | None = None, + on_warning: Callable | None = None, ) -> None: """Wire up broadcast callbacks on display node classes.""" from backend.nodes.display import PreviewImage, PrintTable, View3D from backend.nodes.analysis import CrossSection, LineCursors from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine - from backend.nodes.io import SaveImage + from backend.nodes.io import SaveImage, LoadFile PreviewImage._broadcast_fn = on_preview ThresholdMask._broadcast_fn = on_preview @@ -190,6 +193,7 @@ class ExecutionEngine: PrintTable._broadcast_table_fn = on_table CrossSection._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay + LoadFile._broadcast_warning_fn = on_warning SaveImage._broadcast_preview = ( (lambda data_uri: on_preview("save", data_uri)) if on_preview else None ) @@ -199,8 +203,9 @@ class ExecutionEngine: from backend.nodes.display import PreviewImage, PrintTable, View3D from backend.nodes.analysis import CrossSection, LineCursors from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine + from backend.nodes.io import LoadFile if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors, - ThresholdMask, MaskMorphology, MaskInvert, MaskCombine): + ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile): cls._current_node_id = node_id def _auto_preview( @@ -232,7 +237,7 @@ class ExecutionEngine: value = result[slot] if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview: - arr = datafield_to_uint8(value, "viridis") + arr = datafield_to_uint8(value, value.colormap) on_preview(node_id, encode_preview(arr)) return # one preview per node is enough diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index 60c6b30..fc9e6a1 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -326,6 +326,7 @@ class FFT2D: si_unit_xy="1/m", si_unit_z=z_unit, domain="frequency", + colormap=field.colormap, ) return (out_field,) diff --git a/backend/nodes/display.py b/backend/nodes/display.py index c12c7ab..c6a1a0f 100644 --- a/backend/nodes/display.py +++ b/backend/nodes/display.py @@ -9,7 +9,7 @@ before execution begins. from __future__ import annotations import numpy as np from backend.node_registry import register_node -from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview +from backend.data_types import DataField, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview @register_node(display_name="Preview") @@ -18,7 +18,7 @@ class PreviewImage: def INPUT_TYPES(cls): return { "required": { - "colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],), + "colormap": (["auto"] + list(COLORMAPS),), }, "optional": { "image": ("IMAGE",), @@ -36,6 +36,10 @@ class PreviewImage: _current_node_id: str = "" def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple: + # Resolve "auto" — use field's colormap if available, else fall back to gray + if colormap == "auto": + colormap = field.colormap if field is not None else "gray" + # Prefer field if both are connected; accept whichever is provided if field is not None: arr_u8 = datafield_to_uint8(field, colormap) @@ -73,7 +77,7 @@ class View3D: return { "required": { "field": ("DATA_FIELD",), - "colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],), + "colormap": (["auto"] + list(COLORMAPS),), "z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}), "resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}), } @@ -114,7 +118,8 @@ class View3D: else: z_norm = np.zeros_like(z) - cmap = cm.get_cmap(colormap) + cmap_name = field.colormap if colormap == "auto" else colormap + cmap = cm.get_cmap(cmap_name) rgba = cmap(z_norm) # (ny, nx, 4) float [0,1] colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8) diff --git a/backend/nodes/io.py b/backend/nodes/io.py index 7a44f46..3d432c6 100644 --- a/backend/nodes/io.py +++ b/backend/nodes/io.py @@ -8,7 +8,7 @@ import numpy as np from pathlib import Path from backend.node_registry import register_node -from backend.data_types import DataField, encode_preview, image_to_uint8 +from backend.data_types import DataField, COLORMAPS, encode_preview, image_to_uint8 from backend.runtime_paths import demo_dir, input_dir, output_dir # Resolved at server startup so nodes know where to look @@ -19,112 +19,293 @@ OUTPUT_DIR = output_dir() _DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz", ".gwy", ".sxm", ".ibw"} +_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"} +_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"} +_ARRAY_EXTENSIONS = {".npy", ".npz"} + # --------------------------------------------------------------------------- -# LoadImage +# Channel listing helper (used by the /channels endpoint) # --------------------------------------------------------------------------- -@register_node(display_name="Load Image") -class LoadImage: +def _resolve_path(filepath: str) -> Path: + path = Path(filepath) + if path.is_absolute(): + return path + # Try input dir first, then demo dir + candidate = INPUT_DIR / filepath + if candidate.exists(): + return candidate + candidate = DEMO_DIR / filepath + if candidate.exists(): + return candidate + # Fall back to input dir (will trigger FileNotFoundError later) + return INPUT_DIR / filepath + + +def list_channels(filepath: str) -> list[dict]: + """Return available channel info for a file. + + Returns a list of {"name": str, "type": "DATA_FIELD"} dicts. + For SPM formats this inspects the file header. + For images / arrays, returns a single unnamed channel. + """ + path = _resolve_path(filepath) + if not path.exists(): + return [{"name": "field", "type": "DATA_FIELD"}] + + ext = path.suffix.lower() + + if ext == ".gwy": + try: + import gwyfile + obj = gwyfile.load(str(path)) + channels = gwyfile.util.get_datafields(obj) + if channels: + return [{"name": k, "type": "DATA_FIELD"} for k in channels] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + if ext == ".sxm": + try: + import nanonispy as nap + sxm = nap.read.Scan(str(path)) + if sxm.signals: + return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + if ext == ".ibw": + try: + from igor.binarywave import load as load_ibw + wave = load_ibw(str(path)) + raw = wave["wave"]["wData"] + labels = wave["wave"].get("labels", None) + if raw.ndim >= 3 and labels: + dim_idx = min(2, len(labels) - 1) + if dim_idx >= 0 and labels[dim_idx]: + decoded = [] + for lbl in labels[dim_idx]: + if lbl: + name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip() + if name: + decoded.append(name) + if decoded: + return [{"name": n, "type": "DATA_FIELD"} for n in decoded] + # Multi-channel without labels — use numeric names + if raw.ndim >= 3 and raw.shape[2] > 1: + return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + # Image or array — single channel + return [{"name": "field", "type": "DATA_FIELD"}] + + +# --------------------------------------------------------------------------- +# LoadFile (unified loader — replaces LoadImage + LoadSPM) +# --------------------------------------------------------------------------- + +@register_node(display_name="Load File") +class LoadFile: @classmethod def INPUT_TYPES(cls): return { "required": { "filename": ("FILE_PICKER", {"default": ""}), + "colormap": (list(COLORMAPS),), } } - RETURN_TYPES = ("IMAGE", "DATA_FIELD") - RETURN_NAMES = ("image", "field") + # Default outputs — overridden dynamically by the frontend for multi-channel files + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("field",) FUNCTION = "load" CATEGORY = "io" - DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD." - - def load(self, filename: str): - # Accept absolute paths or filenames relative to input/ - path = Path(filename) - if not path.is_absolute(): - path = INPUT_DIR / filename - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") - - ext = path.suffix.lower() - if ext in (".npy",): - arr = np.load(str(path)).astype(np.float64) - elif ext in (".npz",): - npz = np.load(str(path)) - key = list(npz.files)[0] - arr = npz[key].astype(np.float64) - else: - from PIL import Image - img = Image.open(str(path)) - arr = np.array(img) - if arr.dtype != np.uint8: - arr = arr.astype(np.float64) - - # Convert to float64 grayscale for the DATA_FIELD output - if arr.ndim == 3: - gray = np.mean(arr.astype(np.float64), axis=2) - else: - gray = arr.astype(np.float64) - - field = DataField(data=gray) - return (arr, field) - - -# --------------------------------------------------------------------------- -# LoadDemo -# --------------------------------------------------------------------------- - -def _list_demo_files() -> list[str]: - """Return sorted list of demo filenames available in the demo/ directory.""" - if not DEMO_DIR.exists(): - return [] - return sorted( - f.name for f in DEMO_DIR.iterdir() - if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS + DESCRIPTION = ( + "Load any supported file. " + "SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; " + "each channel gets its own output. " + "Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields." ) + # Set by execution engine for warning broadcast + _broadcast_warning_fn = None + _current_node_id = None -@register_node(display_name="Load Demo Image") -class LoadDemo: - @classmethod - def INPUT_TYPES(cls): - choices = _list_demo_files() or ["(no demo images found)"] - return { - "required": { - "name": (choices,), - } - } - - RETURN_TYPES = ("IMAGE", "DATA_FIELD") - RETURN_NAMES = ("image", "field") - FUNCTION = "load" - CATEGORY = "io" - DESCRIPTION = "Load a bundled demo image so you can try the app without providing your own data." - - def load(self, name: str): - path = DEMO_DIR / name + def load(self, filename: str, colormap: str = "viridis"): + if not filename or not filename.strip(): + raise ValueError("No file selected — use Browse to pick a file.") + path = _resolve_path(filename) if not path.exists(): - raise FileNotFoundError(f"Demo image not found: {name}") + raise FileNotFoundError(f"File not found: {path}") + if path.is_dir(): + raise IsADirectoryError(f"Expected a file, got a directory: {path}") ext = path.suffix.lower() - # SPM formats → delegate to LoadSPM-style loading, return as IMAGE + DATA_FIELD - if ext == ".gwy": - field = LoadSPM()._load_gwy(path, "Z") - arr = field.data - return (arr, field) - elif ext == ".sxm": - field = LoadSPM()._load_sxm(path, "Z") - arr = field.data - return (arr, field) - elif ext == ".ibw": - field = LoadSPM()._load_ibw(path) - arr = field.data - return (arr, field) + if ext in _SPM_EXTENSIONS: + fields = self._load_spm_all(path, ext) + for f in fields: + f.colormap = colormap + return tuple(fields) - # npy / npz + # Image or array — uncalibrated, single output + field = self._load_image_or_array(path, ext) + field.colormap = colormap + self._send_warning("Uncalibrated data — no physical dimensions.") + return (field,) + + def _send_warning(self, message: str): + fn = LoadFile._broadcast_warning_fn + nid = LoadFile._current_node_id + if fn and nid: + fn(nid, message) + + # -- SPM: load all channels --------------------------------------------- + + def _load_spm_all(self, path: Path, ext: str) -> list[DataField]: + if ext == ".gwy": + return self._load_gwy_all(path) + elif ext == ".sxm": + return self._load_sxm_all(path) + elif ext == ".ibw": + return self._load_ibw_all(path) + else: + raise ValueError(f"Unsupported SPM format: {ext}") + + # -- GWY ---------------------------------------------------------------- + + def _load_gwy_all(self, path: Path) -> list[DataField]: + try: + import gwyfile + except ImportError: + raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile") + + obj = gwyfile.load(str(path)) + channels = gwyfile.util.get_datafields(obj) + if not channels: + raise ValueError(f"No data channels found in {path.name}") + + fields = [] + for ch in channels.values(): + data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) + fields.append(DataField( + data=data, + xreal=float(ch.xreal), + yreal=float(ch.yreal), + xoff=float(getattr(ch, "xoff", 0.0)), + yoff=float(getattr(ch, "yoff", 0.0)), + si_unit_xy="m", + si_unit_z="m", + )) + return fields + + # -- SXM ---------------------------------------------------------------- + + def _load_sxm_all(self, path: Path) -> list[DataField]: + try: + import nanonispy as nap + except ImportError: + raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy") + + sxm = nap.read.Scan(str(path)) + signals = sxm.signals + if not signals: + raise ValueError(f"No signals found in {path.name}") + + header = sxm.header + scan_range = header.get("scan_range", [1e-6, 1e-6]) + + fields = [] + for sig in signals.values(): + data = sig.get("forward", list(sig.values())[0]) + data = np.asarray(data, dtype=np.float64) + if data.ndim != 2: + data = data.reshape(data.shape[-2], data.shape[-1]) + fields.append(DataField( + data=data, + xreal=float(scan_range[0]), + yreal=float(scan_range[1]), + si_unit_xy="m", + si_unit_z="m", + )) + return fields + + # -- IBW ---------------------------------------------------------------- + + def _load_ibw_all(self, path: Path) -> list[DataField]: + try: + from igor.binarywave import load as load_ibw + except ImportError: + raise ImportError("Install 'igor' package to load .ibw files: pip install igor") + + wave = load_ibw(str(path)) + wdata = wave["wave"] + header = wdata["wave_header"] + raw = wdata["wData"] + + n_channels = raw.shape[2] if raw.ndim >= 3 else 1 + + # Physical scaling + sfA = header.get("sfA", None) + + def _decode_unit(raw_unit): + if raw_unit is None: + return "m" + if isinstance(raw_unit, bytes): + return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" + if isinstance(raw_unit, np.ndarray): + return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" + return str(raw_unit).strip() or "m" + + dim_units_raw = header.get("dimUnits", None) + data_units_raw = header.get("dataUnits", None) + + if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2: + si_unit_xy = _decode_unit(dim_units_raw[0]) + elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0: + si_unit_xy = _decode_unit(dim_units_raw[0]) + else: + si_unit_xy = _decode_unit(dim_units_raw) + + si_unit_z = _decode_unit(data_units_raw) + + fields = [] + for ch_idx in range(n_channels): + if raw.ndim >= 3: + ch_data = raw[:, :, ch_idx] + elif raw.ndim == 1: + ch_data = raw.reshape(-1, 1) + else: + ch_data = raw + + # Transpose from (xres, yres) Igor order to (yres, xres) DataField order, + # then flip vertically to match gwyddion + data = np.flipud(ch_data.T).astype(np.float64) + yres, xres = data.shape + + if sfA is not None and len(sfA) >= 2: + xreal = abs(float(sfA[0]) * xres) or 1e-6 + yreal = abs(float(sfA[1]) * yres) or 1e-6 + else: + hsA = header.get("hsA", 0.0) + xreal = abs(float(hsA) * xres) or 1e-6 + yreal = xreal * (yres / xres) if xres else 1e-6 + + fields.append(DataField( + data=data, xreal=xreal, yreal=yreal, + si_unit_xy=si_unit_xy, si_unit_z=si_unit_z, + )) + + return fields + + # -- Image / array (uncalibrated) -------------------------------------- + + def _load_image_or_array(self, path: Path, ext: str) -> DataField: if ext == ".npy": arr = np.load(str(path)).astype(np.float64) elif ext == ".npz": @@ -143,22 +324,32 @@ class LoadDemo: else: gray = arr.astype(np.float64) - field = DataField(data=gray) - return (arr, field) + return DataField(data=gray) # --------------------------------------------------------------------------- -# LoadSPM +# LoadDemo # --------------------------------------------------------------------------- -@register_node(display_name="Load SPM File") -class LoadSPM: +def _list_demo_files() -> list[str]: + """Return sorted list of demo filenames available in the demo/ directory.""" + if not DEMO_DIR.exists(): + return [] + return sorted( + f.name for f in DEMO_DIR.iterdir() + if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS + ) + + +@register_node(display_name="Load Demo File") +class LoadDemo: @classmethod def INPUT_TYPES(cls): + choices = _list_demo_files() or ["(no demo files found)"] return { "required": { - "filename": ("FILE_PICKER", {"default": ""}), - "channel": ("STRING", {"default": "Z"}), + "name": (choices,), + "colormap": (list(COLORMAPS),), } } @@ -166,111 +357,15 @@ class LoadSPM: RETURN_NAMES = ("field",) FUNCTION = "load" CATEGORY = "io" - DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField." + DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data." - def load(self, filename: str, channel: str = "Z"): - path = Path(filename) - if not path.is_absolute(): - path = INPUT_DIR / filename + def load(self, name: str, colormap: str = "viridis"): + path = DEMO_DIR / name if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") + raise FileNotFoundError(f"Demo file not found: {name}") - ext = path.suffix.lower() - - if ext == ".gwy": - return (self._load_gwy(path, channel),) - elif ext == ".sxm": - return (self._load_sxm(path, channel),) - elif ext in (".ibw",): - return (self._load_ibw(path),) - elif ext in (".npy",): - data = np.load(str(path)).astype(np.float64) - return (DataField(data=data),) - elif ext in (".npz",): - npz = np.load(str(path)) - key = list(npz.files)[0] - return (DataField(data=npz[key].astype(np.float64)),) - else: - raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz") - - def _load_gwy(self, path: Path, channel: str) -> DataField: - try: - import gwyfile - except ImportError: - raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile") - - obj = gwyfile.load(str(path)) - channels = gwyfile.util.get_datafields(obj) - if not channels: - raise ValueError(f"No data channels found in {path.name}") - - # Try requested channel name, fall back to first available - ch = None - for key, df in channels.items(): - if channel.lower() in key.lower(): - ch = df - break - if ch is None: - ch = next(iter(channels.values())) - - data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) - return DataField( - data=data, - xreal=float(ch.xreal), - yreal=float(ch.yreal), - xoff=float(getattr(ch, "xoff", 0.0)), - yoff=float(getattr(ch, "yoff", 0.0)), - si_unit_xy="m", - si_unit_z="m", - ) - - def _load_sxm(self, path: Path, channel: str) -> DataField: - try: - import nanonispy as nap - except ImportError: - raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy") - - sxm = nap.read.Scan(str(path)) - signals = sxm.signals - - # Pick channel - ch_key = None - for key in signals: - if channel.upper() in key.upper(): - ch_key = key - break - if ch_key is None: - ch_key = next(iter(signals)) - - data = signals[ch_key].get("forward", list(signals[ch_key].values())[0]) - data = np.asarray(data, dtype=np.float64) - if data.ndim != 2: - data = data.reshape(data.shape[-2], data.shape[-1]) - - header = sxm.header - scan_range = header.get("scan_range", [1e-6, 1e-6]) - return DataField( - data=data, - xreal=float(scan_range[0]), - yreal=float(scan_range[1]), - si_unit_xy="m", - si_unit_z="m", - ) - - def _load_ibw(self, path: Path) -> DataField: - try: - import igor.igorpy as igorpy - wave = igorpy.load(str(path)) - data = wave.wave["wData"].squeeze().astype(np.float64) - except ImportError: - raise ImportError("Install 'igor' package to load .ibw files: pip install igor") - - if data.ndim == 1: - data = data.reshape(1, -1) - elif data.ndim != 2: - data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1) - - return DataField(data=data, si_unit_xy="m", si_unit_z="m") + loader = LoadFile() + return loader.load(filename=str(path), colormap=colormap) # --------------------------------------------------------------------------- diff --git a/backend/server.py b/backend/server.py index 87f9acd..7fb50d4 100644 --- a/backend/server.py +++ b/backend/server.py @@ -101,6 +101,9 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: def on_overlay(node_id: str, overlay_data) -> None: broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) + def on_warning(node_id: str, message: str) -> None: + broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}}) + # ------------------------------------------------------------------ # Route handlers # ------------------------------------------------------------------ @@ -193,6 +196,18 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: }, ) + async def get_channels(request: web.Request) -> web.Response: + """Return available channels for a given file path.""" + from backend.nodes.io import list_channels + filepath = request.query.get("file", "") + if not filepath: + return web.Response( + text=_dumps([{"name": "field", "type": "DATA_FIELD"}]), + content_type="application/json", + ) + channels = await loop.run_in_executor(None, list_channels, filepath) + return web.Response(text=_dumps(channels), content_type="application/json") + async def submit_prompt(request: web.Request) -> web.Response: body = await request.json() prompt = body.get("prompt") @@ -218,6 +233,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: on_table=on_table, on_mesh=on_mesh, on_overlay=on_overlay, + on_warning=on_warning, ), ) broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}}) @@ -262,6 +278,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: app.router.add_get("/browse", browse_dir) app.router.add_post("/upload", upload_file) app.router.add_post("/download", download_file) + app.router.add_get("/channels", get_channels) app.router.add_post("/prompt", submit_prompt) app.router.add_get("/ws", websocket_handler) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index d22d342..cbf6075 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -436,6 +436,9 @@ function Flow() { case 'overlay': updateNodeData(msg.data.node_id, { overlay: msg.data.overlay }); break; + case 'node_warning': + updateNodeData(msg.data.node_id, { warning: msg.data.message }); + break; } }); api.initWS(); @@ -500,9 +503,36 @@ function Flow() { data: { ...n.data, widgetValues: { ...n.data.widgetValues, [name]: value }, + // Clear warning when user changes a value + warning: null, }, }; })); + + // If this is a filename/name change on a LoadFile/LoadDemo node, fetch channels + if ((name === 'filename' || name === 'name') && value) { + const node = reactFlow.getNode(nodeId); + if (node && (node.data.className === 'LoadFile' || node.data.className === 'LoadDemo')) { + api.getChannels(value).then((channels) => { + setNodes((prev) => prev.map((n) => { + if (n.id !== nodeId) return n; + return { + ...n, + data: { + ...n.data, + definition: { + ...n.data.definition, + output: channels.map((c) => c.type), + output_name: channels.map((c) => c.name), + }, + }, + }; + })); + reactFlow.updateNodeInternals(nodeId); + }); + } + } + scheduleAutoRun(); }, [setNodes]); // scheduleAutoRun is stable (no deps) @@ -568,6 +598,27 @@ function Flow() { setNodes((ns) => [...ns, newNode]); + // For LoadFile/LoadDemo, auto-fetch channels for the default value + if (className === 'LoadDemo' && widgetValues.name) { + api.getChannels(widgetValues.name).then((channels) => { + setNodes((prev) => prev.map((n) => { + if (n.id !== newNodeId) return n; + return { + ...n, + data: { + ...n.data, + definition: { + ...n.data.definition, + output: channels.map((c) => c.type), + output_name: channels.map((c) => c.name), + }, + }, + }; + })); + reactFlow.updateNodeInternals(newNodeId); + }); + } + // Auto-connect if this was triggered by dropping a connection on blank space if (contextMenu.pendingHandleId) { const filterType = contextMenu.filterType; diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index b26eef4..21ab15b 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -211,6 +211,11 @@ function CustomNode({ id, data }) { ); })} + {/* Warning notification */} + {data.warning && ( +