diff --git a/.gitignore b/.gitignore index f4cd701..d0515b3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,6 @@ desktop-build/ desktop-dist/ frontend/node_modules/ frontend/dist/ -.venv/ \ No newline at end of file +.venv/ +sessions/ +.*/ \ No newline at end of file diff --git a/README.md b/README.md index 6a1d8bb..e72148d 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ http://127.0.0.1:5173 Notes: - The frontend dev server proxies API and WebSocket requests to the backend. +- `npm run dev` now clears Vite's local cache and stale Python bytecode first, then starts Vite with `--force`. - If you open the backend directly in a browser instead of the Vite dev server, argonode now refreshes `frontend/dist` automatically when checked-out frontend sources are newer, such as after a `git pull`. - If you want the frontend accessible from other devices on your LAN, run: @@ -95,14 +96,9 @@ npm run dev -- --host 0.0.0.0 ## Running the Local Desktop Version The desktop launcher starts the Python server internally and opens a native window with `pywebview`. +`npm run desktop` now rebuilds the frontend first so the native app always uses a fresh `frontend/dist`. -Build the frontend first: - -```powershell -npm run build -``` - -Then launch the desktop app from source: +Launch the desktop app from source: ```powershell npm run desktop @@ -110,8 +106,7 @@ npm run desktop Notes: -- `npm run desktop` uses the built frontend from `frontend/dist`. -- If you change frontend code, run `npm run build` again before starting the desktop version. +- `npm run build` clears stale frontend output, Vite cache, and Python bytecode before producing `frontend/dist`. ## Building the Windows `.exe` diff --git a/backend/execution.py b/backend/execution.py index 0d14ed0..c7f74c0 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -32,6 +32,7 @@ from time import perf_counter from typing import Any, Callable from backend.node_registry import NODE_CLASS_MAPPINGS +from backend.execution_context import active_node, execution_callbacks def _is_link(value: Any) -> bool: @@ -85,63 +86,66 @@ class ExecutionEngine: node_outputs: dict[str, tuple] = {} node_output_signatures: dict[str, tuple[str, ...]] = {} - # Inject display callbacks before execution - self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning) + with execution_callbacks( + preview=on_preview, + table=on_table, + mesh=on_mesh, + overlay=on_overlay, + value=on_value, + warning=on_warning, + ): + for node_id in order: + node_def = prompt[node_id] + class_name = node_def["class_type"] - for node_id in order: - node_def = prompt[node_id] - class_name = node_def["class_type"] + if class_name not in NODE_CLASS_MAPPINGS: + raise ValueError(f"Unknown node type: '{class_name}'") - if class_name not in NODE_CLASS_MAPPINGS: - raise ValueError(f"Unknown node type: '{class_name}'") + cls = NODE_CLASS_MAPPINGS[class_name] + raw_inputs = node_def.get("inputs", {}) + input_types = cls.INPUT_TYPES() + inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types) + input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures) - cls = NODE_CLASS_MAPPINGS[class_name] - raw_inputs = node_def.get("inputs", {}) - input_types = cls.INPUT_TYPES() - inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types) - input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures) + cache_entry = self._get_cached_entry(node_id, class_name, input_signature) + if cache_entry is not None: + result = self._clone_cached_outputs(cache_entry["outputs"]) + elapsed_ms = 0.0 + else: + if on_node_start: + on_node_start(node_id) - # Let display nodes know their node_id so they can tag WS messages - self._set_node_id_on_display(cls, node_id) + instance = cls() + func = getattr(instance, cls.FUNCTION) + start_time = perf_counter() + with active_node(node_id): + result = func(**inputs) + elapsed_ms = (perf_counter() - start_time) * 1000.0 - cache_entry = self._get_cached_entry(node_id, class_name, input_signature) - if cache_entry is not None: - result = self._clone_cached_outputs(cache_entry["outputs"]) - elapsed_ms = 0.0 - else: - if on_node_start: - on_node_start(node_id) + # Nodes must return a tuple; coerce single values just in case + if not isinstance(result, tuple): + result = (result,) - instance = cls() - func = getattr(instance, cls.FUNCTION) - start_time = perf_counter() - result = func(**inputs) - elapsed_ms = (perf_counter() - start_time) * 1000.0 + node_outputs[node_id] = result + output_signatures = tuple(self._fingerprint_value(value) for value in result) + node_output_signatures[node_id] = output_signatures - # Nodes must return a tuple; coerce single values just in case - if not isinstance(result, tuple): - result = (result,) + if cache_entry is None and self._node_cacheable(cls): + self._store_cache_entry( + node_id=node_id, + class_name=class_name, + input_signature=input_signature, + output_signatures=output_signatures, + outputs=self._clone_cached_outputs(result), + ) - node_outputs[node_id] = result - output_signatures = tuple(self._fingerprint_value(value) for value in result) - node_output_signatures[node_id] = output_signatures + # Auto-preview: broadcast a thumbnail for any DATA_FIELD, + # IMAGE, or table-like output so every node shows its result. + if on_preview or on_table: + self._auto_preview(cls, node_id, result, on_preview, on_table, inputs) - if cache_entry is None and self._node_cacheable(cls): - self._store_cache_entry( - node_id=node_id, - class_name=class_name, - input_signature=input_signature, - output_signatures=output_signatures, - outputs=self._clone_cached_outputs(result), - ) - - # Auto-preview: broadcast a thumbnail for any DATA_FIELD, - # IMAGE, or table-like output so every node shows its result. - if on_preview or on_table: - self._auto_preview(cls, node_id, result, on_preview, on_table, inputs) - - if on_node_done: - on_node_done(node_id, elapsed_ms) + if on_node_done: + on_node_done(node_id, elapsed_ms) return node_outputs @@ -421,88 +425,6 @@ class ExecutionEngine: return deepcopy(value) return value - def _inject_display_callbacks( - self, - on_preview: Callable | None, - on_table: Callable | None, - on_mesh: Callable | None = None, - on_overlay: Callable | None = None, - on_value: Callable | None = None, - on_warning: Callable | None = None, - ) -> None: - """Wire up broadcast callbacks on display node classes.""" - from backend.nodes.preview_image import PreviewImage - from backend.nodes.print_table import PrintTable - from backend.nodes.view_3d import View3D - from backend.nodes.annotations import Annotations - from backend.nodes.value_display import ValueDisplay - from backend.nodes.markup import Markup - from backend.nodes.cross_section import CrossSection - from backend.nodes.cursors import Cursors - from backend.nodes.stats import Stats - from backend.nodes.histogram import Histogram - from backend.nodes.crop_resize_field import CropResizeField - from backend.nodes.rotate_field import RotateField - from backend.nodes.threshold_mask import ThresholdMask - from backend.nodes.mask_morphology import MaskMorphology - from backend.nodes.mask_invert import MaskInvert - from backend.nodes.mask_combine import MaskCombine - from backend.nodes.draw_mask import DrawMask - from backend.nodes.save import Save - from backend.nodes.save_image import SaveImage - from backend.nodes.image import Image - from backend.nodes.image_demo import ImageDemo - - PreviewImage._broadcast_fn = on_preview - ThresholdMask._broadcast_fn = on_preview - MaskMorphology._broadcast_fn = on_preview - MaskInvert._broadcast_fn = on_preview - MaskCombine._broadcast_fn = on_preview - DrawMask._broadcast_overlay_fn = on_overlay - View3D._broadcast_mesh_fn = on_mesh - Annotations._broadcast_warning_fn = on_warning - PrintTable._broadcast_table_fn = on_table - ValueDisplay._broadcast_value_fn = on_value - Stats._broadcast_value_fn = on_value - Histogram._broadcast_overlay_fn = on_overlay - CrossSection._broadcast_overlay_fn = on_overlay - Cursors._broadcast_overlay_fn = on_overlay - CropResizeField._broadcast_overlay_fn = on_overlay - RotateField._broadcast_warning_fn = on_warning - Markup._broadcast_overlay_fn = on_overlay - Image._broadcast_warning_fn = on_warning - ImageDemo._broadcast_warning_fn = on_warning - Save._broadcast_warning_fn = on_warning - SaveImage._broadcast_warning_fn = on_warning - - def _set_node_id_on_display(self, cls: type, node_id: str) -> None: - """Inform display nodes of their current node_id for WS tagging.""" - from backend.nodes.preview_image import PreviewImage - from backend.nodes.print_table import PrintTable - from backend.nodes.view_3d import View3D - from backend.nodes.annotations import Annotations - from backend.nodes.value_display import ValueDisplay - from backend.nodes.markup import Markup - from backend.nodes.cross_section import CrossSection - from backend.nodes.cursors import Cursors - from backend.nodes.stats import Stats - from backend.nodes.histogram import Histogram - from backend.nodes.crop_resize_field import CropResizeField - from backend.nodes.rotate_field import RotateField - from backend.nodes.threshold_mask import ThresholdMask - from backend.nodes.mask_morphology import MaskMorphology - from backend.nodes.mask_invert import MaskInvert - from backend.nodes.mask_combine import MaskCombine - from backend.nodes.draw_mask import DrawMask - from backend.nodes.image import Image - from backend.nodes.image_demo import ImageDemo - from backend.nodes.save import Save - from backend.nodes.save_image import SaveImage - if cls in (PreviewImage, PrintTable, View3D, Annotations, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup, - ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask, - Image, ImageDemo, Save, SaveImage): - cls._current_node_id = node_id - def _auto_preview( self, cls: type, diff --git a/backend/execution_context.py b/backend/execution_context.py new file mode 100644 index 0000000..ca19cfb --- /dev/null +++ b/backend/execution_context.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Callable + +Callback = Callable[[str, Any], None] + +_callbacks_var: ContextVar[dict[str, Callback | None]] = ContextVar( + "argonode_execution_callbacks", + default={}, +) +_node_id_var: ContextVar[str | None] = ContextVar("argonode_execution_node_id", default=None) + + +@contextmanager +def execution_callbacks( + *, + preview: Callback | None = None, + table: Callback | None = None, + mesh: Callback | None = None, + overlay: Callback | None = None, + value: Callback | None = None, + warning: Callback | None = None, +): + token = _callbacks_var.set({ + "preview": preview, + "table": table, + "mesh": mesh, + "overlay": overlay, + "value": value, + "warning": warning, + }) + try: + yield + finally: + _callbacks_var.reset(token) + + +@contextmanager +def active_node(node_id: str): + token = _node_id_var.set(str(node_id)) + try: + yield + finally: + _node_id_var.reset(token) + + +def current_node_id() -> str | None: + return _node_id_var.get() + + +def _emit(kind: str, payload: Any) -> None: + callbacks = _callbacks_var.get() + callback = callbacks.get(kind) + node_id = current_node_id() + if callback is not None and node_id: + callback(node_id, payload) + + +def emit_preview(payload: Any) -> None: + _emit("preview", payload) + + +def emit_table(rows: list) -> None: + _emit("table", rows) + + +def emit_mesh(mesh: dict) -> None: + _emit("mesh", mesh) + + +def emit_overlay(overlay: dict) -> None: + _emit("overlay", overlay) + + +def emit_value(payload: Any) -> None: + _emit("value", payload) + + +def emit_warning(message: str) -> None: + _emit("warning", message) diff --git a/backend/nodes/annotations.py b/backend/nodes/annotations.py index 8e74c3e..4deb199 100644 --- a/backend/nodes/annotations.py +++ b/backend/nodes/annotations.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import ( COLORMAPS, DataField, @@ -120,7 +121,4 @@ class Annotations: return (ImageData(annotated, metadata={"annotation_context": context}),) def _send_warning(self, message: str): - fn = Annotations._broadcast_warning_fn - nid = Annotations._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) diff --git a/backend/nodes/crop_resize_field.py b/backend/nodes/crop_resize_field.py index 3c69a21..d9d85b8 100644 --- a/backend/nodes/crop_resize_field.py +++ b/backend/nodes/crop_resize_field.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, datafield_to_uint8, encode_preview @@ -61,20 +62,16 @@ class CropResizeField: x2 = float(np.clip(x2, 0.0, 1.0)) y2 = float(np.clip(y2, 0.0, 1.0)) - if CropResizeField._broadcast_overlay_fn is not None: - CropResizeField._broadcast_overlay_fn( - CropResizeField._current_node_id, - { - "kind": "crop_box", - "image": encode_preview(datafield_to_uint8(field, field.colormap)), - "x1": x1, - "y1": y1, - "x2": x2, - "y2": y2, - "a_locked": corner_a is not None, - "b_locked": corner_b is not None, - }, - ) + emit_overlay({ + "kind": "crop_box", + "image": encode_preview(datafield_to_uint8(field, field.colormap)), + "x1": x1, + "y1": y1, + "x2": x2, + "y2": y2, + "a_locked": corner_a is not None, + "b_locked": corner_b is not None, + }) left = min(x1, x2) right = max(x1, x2) diff --git a/backend/nodes/cross_section.py b/backend/nodes/cross_section.py index e63a74c..b3f28cd 100644 --- a/backend/nodes/cross_section.py +++ b/backend/nodes/cross_section.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview from backend.nodes.helpers import _extend_to_edges @@ -73,19 +74,14 @@ class CrossSection: profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest") - if CrossSection._broadcast_overlay_fn is not None: - image_uri = encode_preview(datafield_to_uint8(field, field.colormap)) - - CrossSection._broadcast_overlay_fn( - CrossSection._current_node_id, - { - "image": image_uri, - "x1": marker_x1, "y1": marker_y1, - "x2": marker_x2, "y2": marker_y2, - "a_locked": marker_pair is not None, - "b_locked": marker_pair is not None, - }, - ) + image_uri = encode_preview(datafield_to_uint8(field, field.colormap)) + emit_overlay({ + "image": image_uri, + "x1": marker_x1, "y1": marker_y1, + "x2": marker_x2, "y2": marker_y2, + "a_locked": marker_pair is not None, + "b_locked": marker_pair is not None, + }) dx_real = (x2 - x1) * field.xreal dy_real = (y2 - y1) * field.yreal diff --git a/backend/nodes/cursors.py b/backend/nodes/cursors.py index c4655a1..5488d70 100644 --- a/backend/nodes/cursors.py +++ b/backend/nodes/cursors.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, LineData, MeasureTable, encode_preview, render_datafield_preview @@ -87,22 +88,18 @@ class Cursors: xa, ya = float(x[idx_a]), float(y[idx_a]) xb, yb = float(x[idx_b]), float(y[idx_b]) - if Cursors._broadcast_overlay_fn is not None: - Cursors._broadcast_overlay_fn( - Cursors._current_node_id, - { - "kind": "line_plot", - "section_title": "Cursors", - "line": y.tolist(), - "x_axis": x.tolist(), - "x1": x1, - "x2": x2, - "y1": float(y1), - "y2": float(y2), - "a_locked": locked, - "b_locked": locked, - }, - ) + emit_overlay({ + "kind": "line_plot", + "section_title": "Cursors", + "line": y.tolist(), + "x_axis": x.tolist(), + "x1": x1, + "x2": x2, + "y1": float(y1), + "y2": float(y2), + "a_locked": locked, + "b_locked": locked, + }) table = MeasureTable([ {"quantity": "A x", "value": xa, "unit": x_unit}, @@ -143,21 +140,17 @@ class Cursors: bx = float(field.xoff + x2 * field.xreal) by = float(field.yoff + y2 * field.yreal) - if Cursors._broadcast_overlay_fn is not None: - Cursors._broadcast_overlay_fn( - Cursors._current_node_id, - { - "kind": "cursor_points", - "section_title": "Cursors", - "image": encode_preview(render_datafield_preview(field, field.colormap)), - "x1": x1, - "y1": y1, - "x2": x2, - "y2": y2, - "a_locked": locked, - "b_locked": locked, - }, - ) + emit_overlay({ + "kind": "cursor_points", + "section_title": "Cursors", + "image": encode_preview(render_datafield_preview(field, field.colormap)), + "x1": x1, + "y1": y1, + "x2": x2, + "y2": y2, + "a_locked": locked, + "b_locked": locked, + }) table = MeasureTable([ {"quantity": "A x", "value": ax, "unit": field.si_unit_xy}, diff --git a/backend/nodes/draw_mask.py b/backend/nodes/draw_mask.py index e39a888..0a8a9d9 100644 --- a/backend/nodes/draw_mask.py +++ b/backend/nodes/draw_mask.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, datafield_to_uint8, encode_preview from backend.nodes.helpers import _parse_mask_strokes, _rasterize_mask @@ -40,17 +41,13 @@ class DrawMask: if invert: mask = np.where(mask > 127, np.uint8(0), np.uint8(255)) - if DrawMask._broadcast_overlay_fn is not None: - DrawMask._broadcast_overlay_fn( - DrawMask._current_node_id, - { - "kind": "mask_paint", - "section_title": "Mask", - "image": encode_preview(datafield_to_uint8(field, "gray")), - "image_width": field.xres, - "image_height": field.yres, - "invert": bool(invert), - }, - ) + emit_overlay({ + "kind": "mask_paint", + "section_title": "Mask", + "image": encode_preview(datafield_to_uint8(field, "gray")), + "image_width": field.xres, + "image_height": field.yres, + "invert": bool(invert), + }) return (mask,) diff --git a/backend/nodes/helpers.py b/backend/nodes/helpers.py index 756bc46..72b9333 100644 --- a/backend/nodes/helpers.py +++ b/backend/nodes/helpers.py @@ -180,6 +180,20 @@ def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int] return text_image +def _import_ibw_loader(): + """Import igor's binary wave loader with NumPy 2 compatibility.""" + if not hasattr(np, "complex"): + # igor 0.3 still references np.complex at import time. + setattr(np, "complex", complex) + + try: + from igor.binarywave import load as load_ibw + except ImportError: + raise ImportError("Install 'igor' package to load .ibw files: pip install igor") + + return load_ibw + + # --------------------------------------------------------------------------- # Markup helpers (from display.py — used by Markup) # --------------------------------------------------------------------------- @@ -508,7 +522,7 @@ def list_channels(filepath: str) -> list[dict]: if ext == ".ibw": try: - from igor.binarywave import load as load_ibw + load_ibw = _import_ibw_loader() wave = load_ibw(str(path)) raw = wave["wave"]["wData"] labels = wave["wave"].get("labels", None) diff --git a/backend/nodes/histogram.py b/backend/nodes/histogram.py index 4888aa4..9199f33 100644 --- a/backend/nodes/histogram.py +++ b/backend/nodes/histogram.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, MeasureTable @@ -72,22 +73,18 @@ class Histogram: yb = float(counts[idx_b]) if len(counts) else 0.0 count_unit = "count" if y_scale == "linear" else "log10(1+count)" - if Histogram._broadcast_overlay_fn is not None: - Histogram._broadcast_overlay_fn( - Histogram._current_node_id, - { - "kind": "line_plot", - "section_title": "Histogram", - "line": counts.tolist(), - "x_axis": bin_centers.astype(np.float64).tolist(), - "x1": float(np.clip(x1, 0.0, 1.0)), - "x2": float(np.clip(x2, 0.0, 1.0)), - "y1": float(y1), - "y2": float(y2), - "a_locked": False, - "b_locked": False, - }, - ) + emit_overlay({ + "kind": "line_plot", + "section_title": "Histogram", + "line": counts.tolist(), + "x_axis": bin_centers.astype(np.float64).tolist(), + "x1": float(np.clip(x1, 0.0, 1.0)), + "x2": float(np.clip(x2, 0.0, 1.0)), + "y1": float(y1), + "y2": float(y2), + "a_locked": False, + "b_locked": False, + }) table = MeasureTable([ {"quantity": "A position", "value": xa, "unit": field.si_unit_z}, diff --git a/backend/nodes/image.py b/backend/nodes/image.py index 4cf8e18..8aa068e 100644 --- a/backend/nodes/image.py +++ b/backend/nodes/image.py @@ -4,8 +4,9 @@ import numpy as np from pathlib import Path from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import COLORMAPS, DataField, resolve_colormap_input -from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS +from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS, _import_ibw_loader @register_node(display_name="Image") @@ -66,10 +67,7 @@ class Image: return fields def _send_warning(self, message: str): - fn = Image._broadcast_warning_fn - nid = Image._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) @staticmethod @lru_cache(maxsize=32) @@ -149,11 +147,7 @@ class Image: @staticmethod def _load_ibw_all(path: Path) -> list[DataField]: - try: - from igor.binarywave import load as load_ibw - except ImportError: - raise ImportError("Install 'igor' package to load .ibw files: pip install igor") - + load_ibw = _import_ibw_loader() wave = load_ibw(str(path)) wdata = wave["wave"] header = wdata["wave_header"] diff --git a/backend/nodes/markup.py b/backend/nodes/markup.py index 3190946..f85ee25 100644 --- a/backend/nodes/markup.py +++ b/backend/nodes/markup.py @@ -1,5 +1,6 @@ from __future__ import annotations from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import ( DataField, ImageData, @@ -70,17 +71,13 @@ class Markup: metadata=image_metadata(input), ) - if Markup._broadcast_overlay_fn is not None: - Markup._broadcast_overlay_fn( - Markup._current_node_id, - { - "kind": "markup", - "section_title": "Markup", - "image": encode_preview(preview_base), - "shape": str(shape), - "stroke_color": _normalize_markup_color(stroke_color), - "stroke_width": max(1, int(stroke_width)), - }, - ) + emit_overlay({ + "kind": "markup", + "section_title": "Markup", + "image": encode_preview(preview_base), + "shape": str(shape), + "stroke_color": _normalize_markup_color(stroke_color), + "stroke_width": max(1, int(stroke_width)), + }) return (out,) diff --git a/backend/nodes/mask_combine.py b/backend/nodes/mask_combine.py index 28d7fae..ab25489 100644 --- a/backend/nodes/mask_combine.py +++ b/backend/nodes/mask_combine.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import DataField, encode_preview from backend.nodes.helpers import _mask_overlay @@ -53,10 +54,8 @@ class MaskCombine: out = result.astype(np.uint8) * 255 - if field is not None and MaskCombine._broadcast_fn is not None: + if field is not None: overlay = _mask_overlay(field, out) - MaskCombine._broadcast_fn( - MaskCombine._current_node_id, encode_preview(overlay), - ) + emit_preview(encode_preview(overlay)) return (out,) diff --git a/backend/nodes/mask_invert.py b/backend/nodes/mask_invert.py index 1cf3529..4b90e0b 100644 --- a/backend/nodes/mask_invert.py +++ b/backend/nodes/mask_invert.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import DataField, encode_preview from backend.nodes.helpers import _mask_overlay @@ -32,10 +33,8 @@ class MaskInvert: def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple: out = np.where(mask > 127, np.uint8(0), np.uint8(255)) - if field is not None and MaskInvert._broadcast_fn is not None: + if field is not None: overlay = _mask_overlay(field, out) - MaskInvert._broadcast_fn( - MaskInvert._current_node_id, encode_preview(overlay), - ) + emit_preview(encode_preview(overlay)) return (out,) diff --git a/backend/nodes/mask_morphology.py b/backend/nodes/mask_morphology.py index e9603f0..351cb34 100644 --- a/backend/nodes/mask_morphology.py +++ b/backend/nodes/mask_morphology.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import DataField, encode_preview from backend.nodes.helpers import _mask_overlay, _mask_structure @@ -62,10 +63,8 @@ class MaskMorphology: out = result.astype(np.uint8) * 255 - if field is not None and MaskMorphology._broadcast_fn is not None: + if field is not None: overlay = _mask_overlay(field, out) - MaskMorphology._broadcast_fn( - MaskMorphology._current_node_id, encode_preview(overlay), - ) + emit_preview(encode_preview(overlay)) return (out,) diff --git a/backend/nodes/preview_image.py b/backend/nodes/preview_image.py index a955458..ad4d268 100644 --- a/backend/nodes/preview_image.py +++ b/backend/nodes/preview_image.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import ( COLORMAPS, colormap_to_uint8, @@ -68,7 +69,6 @@ class PreviewImage: data_uri = encode_preview(arr_u8) - if PreviewImage._broadcast_fn is not None: - PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri) + emit_preview(data_uri) return () diff --git a/backend/nodes/print_table.py b/backend/nodes/print_table.py index f53f93b..72f1c44 100644 --- a/backend/nodes/print_table.py +++ b/backend/nodes/print_table.py @@ -1,5 +1,6 @@ from __future__ import annotations from backend.node_registry import register_node +from backend.execution_context import emit_table @register_node(display_name="Print Table") @@ -22,6 +23,5 @@ class PrintTable: _current_node_id: str = "" def print_table(self, table: list) -> tuple: - if PrintTable._broadcast_table_fn is not None: - PrintTable._broadcast_table_fn(PrintTable._current_node_id, table) + emit_table(table) return () diff --git a/backend/nodes/rotate_field.py b/backend/nodes/rotate_field.py index 2c2a601..6fd4309 100644 --- a/backend/nodes/rotate_field.py +++ b/backend/nodes/rotate_field.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import DataField @@ -84,10 +85,7 @@ class RotateField: return (result,) def _send_warning(self, message: str): - fn = RotateField._broadcast_warning_fn - nid = RotateField._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) @staticmethod def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]: diff --git a/backend/nodes/save.py b/backend/nodes/save.py index 95ef18b..b7fce1c 100644 --- a/backend/nodes/save.py +++ b/backend/nodes/save.py @@ -7,6 +7,7 @@ from pathlib import Path import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8 @@ -34,6 +35,7 @@ class Save: "choices_by_source_type": { "DATA_FIELD": ["TIFF", "PNG", "NPZ"], "IMAGE": ["PNG", "TIFF", "NPZ"], + "ANNOTATION_SOURCE": ["PNG", "TIFF", "NPZ"], "LINE": ["CSV", "NPZ", "JSON"], "MEASURE_TABLE": ["CSV", "JSON"], "RECORD_TABLE": ["CSV", "JSON"], @@ -254,7 +256,4 @@ class Save: path.write_text("\n".join(lines) + "\n", encoding="utf-8") def _send_warning(self, message: str): - fn = Save._broadcast_warning_fn - nid = Save._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) diff --git a/backend/nodes/save_image.py b/backend/nodes/save_image.py index f122c5e..f9365c1 100644 --- a/backend/nodes/save_image.py +++ b/backend/nodes/save_image.py @@ -4,6 +4,7 @@ import numpy as np from pathlib import Path from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import DataField, image_to_uint8 from backend.nodes.helpers import _MAX_SAVE_FIELDS @@ -174,9 +175,6 @@ class SaveImage: raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") def _send_warning(self, message: str): - fn = SaveImage._broadcast_warning_fn - nid = SaveImage._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) return () diff --git a/backend/nodes/stats.py b/backend/nodes/stats.py index 59ea289..bcbf780 100644 --- a/backend/nodes/stats.py +++ b/backend/nodes/stats.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_value from backend.data_types import DataField, LineData, MeasureTable from backend.nodes.helpers import ( LINE_OPS, @@ -71,11 +72,9 @@ class Stats: op_entry = ops[operation] fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry result = fn(values) - if Stats._broadcast_value_fn is not None: - Stats._broadcast_value_fn( - Stats._current_node_id, - _scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)), - ) + emit_value( + _scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)), + ) return (result,) def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str: diff --git a/backend/nodes/threshold_mask.py b/backend/nodes/threshold_mask.py index 56bbac2..108cc97 100644 --- a/backend/nodes/threshold_mask.py +++ b/backend/nodes/threshold_mask.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import DataField, encode_preview from backend.nodes.helpers import _mask_overlay @@ -52,10 +53,7 @@ class ThresholdMask: else: mask = (data < t).astype(np.uint8) * 255 - if ThresholdMask._broadcast_fn is not None: - overlay = _mask_overlay(field, mask) - ThresholdMask._broadcast_fn( - ThresholdMask._current_node_id, encode_preview(overlay), - ) + overlay = _mask_overlay(field, mask) + emit_preview(encode_preview(overlay)) return (mask,) diff --git a/backend/nodes/value_display.py b/backend/nodes/value_display.py index e4a58e8..e21cabf 100644 --- a/backend/nodes/value_display.py +++ b/backend/nodes/value_display.py @@ -1,5 +1,6 @@ from __future__ import annotations from backend.node_registry import register_node +from backend.execution_context import emit_value from backend.data_types import MeasureTable from backend.nodes.helpers import _measurement_entry, _measurement_value, _scalar_payload @@ -38,6 +39,5 @@ class ValueDisplay: unit = row.get("unit", "") if isinstance(row.get("unit"), str) else "" else: numeric = float(value) - if ValueDisplay._broadcast_value_fn is not None: - ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit)) + emit_value(_scalar_payload(numeric, unit)) return (numeric,) diff --git a/backend/nodes/view_3d.py b/backend/nodes/view_3d.py index 8bfcef1..ee40ba1 100644 --- a/backend/nodes/view_3d.py +++ b/backend/nodes/view_3d.py @@ -3,6 +3,7 @@ import base64 import io import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_mesh from backend.data_types import ( COLORMAPS, DataField, @@ -36,19 +37,42 @@ def _grid_triangle_indices(nx: int, ny: int, *, reverse: bool = False) -> list[l return faces -def _build_mesh_model(z: np.ndarray, colors_u8: np.ndarray, z_scale: float, make_solid: bool) -> MeshModel: +def _surface_extent_scale(xreal: float, yreal: float, nx: int, ny: int) -> tuple[float, float]: + def _resolve_span(value: float, fallback_points: int) -> float: + try: + span = abs(float(value)) + except (TypeError, ValueError): + span = 0.0 + if not np.isfinite(span) or span <= 0.0: + span = float(max(fallback_points - 1, 1)) + return span + + x_span = _resolve_span(xreal, nx) + y_span = _resolve_span(yreal, ny) + max_span = max(x_span, y_span, 1.0) + return (x_span / max_span, y_span / max_span) + + +def _build_mesh_model( + z: np.ndarray, + colors_u8: np.ndarray, + z_scale: float, + make_solid: bool, + lateral_extent: tuple[float, float] = (1.0, 1.0), +) -> MeshModel: ny, nx = z.shape zmin = float(z.min()) zmax = float(z.max()) z_range = zmax - zmin if zmax != zmin else 1.0 + x_extent, y_extent = lateral_extent top_vertices = np.empty((nx * ny, 3), dtype=np.float32) top_colors = colors_u8.reshape(-1, 3).astype(np.uint8) for iy in range(ny): - py = iy / max(ny - 1, 1) - 0.5 + py = (iy / max(ny - 1, 1) - 0.5) * y_extent for ix in range(nx): idx = iy * nx + ix - px = ix / max(nx - 1, 1) - 0.5 + px = (ix / max(nx - 1, 1) - 0.5) * x_extent pz = ((float(z[iy, ix]) - zmin) / z_range - 0.5) * z_scale top_vertices[idx] = (px, pz, py) @@ -98,6 +122,9 @@ class View3D: "camera_azimuth": ("FLOAT", {"default": 0.0, "hidden": True}), "camera_polar": ("FLOAT", {"default": 1.1, "hidden": True}), "camera_distance": ("FLOAT", {"default": 1.8, "hidden": True}), + "camera_target_x": ("FLOAT", {"default": 0.0, "hidden": True}), + "camera_target_y": ("FLOAT", {"default": 0.0, "hidden": True}), + "camera_target_z": ("FLOAT", {"default": 0.0, "hidden": True}), "viewport_snapshot": ("STRING", {"default": "", "hidden": True}), }, "optional": { @@ -114,7 +141,7 @@ class View3D: DESCRIPTION = ( "Interactive 3D surface view of a DATA_FIELD. " "Use the mesh input for geometry and optionally a second map input for coloring. " - "Drag to rotate, scroll to zoom. z_scale exaggerates height." + "Drag to rotate, middle-drag to pan, and right-drag or scroll to zoom. z_scale exaggerates height." ) _broadcast_mesh_fn = None @@ -124,6 +151,7 @@ class View3D: self, field: DataField, colormap: str, z_scale: float, resolution: int, make_solid: bool = False, camera_azimuth: float = 0.0, camera_polar: float = 1.1, camera_distance: float = 1.8, + camera_target_x: float = 0.0, camera_target_y: float = 0.0, camera_target_z: float = 0.0, viewport_snapshot: str = "", map_field: DataField | None = None, colormap_map=None, ) -> tuple: @@ -182,7 +210,14 @@ class View3D: default="gray", ) colors_u8 = colormap_to_uint8(z_norm, resolved_colormap) - mesh_model = _build_mesh_model(z, colors_u8, float(z_scale * 0.1), bool(make_solid)) + surface_extent = _surface_extent_scale(field.xreal, field.yreal, nx, ny) + mesh_model = _build_mesh_model( + z, + colors_u8, + float(z_scale * 0.1), + bool(make_solid), + lateral_extent=surface_extent, + ) z_b64 = base64.b64encode(z.tobytes()).decode() colors_b64 = base64.b64encode(colors_u8.tobytes()).decode() @@ -207,12 +242,16 @@ class View3D: "camera_azimuth": float(camera_azimuth), "camera_polar": float(camera_polar), "camera_distance": float(camera_distance), + "camera_target_x": float(camera_target_x), + "camera_target_y": float(camera_target_y), + "camera_target_z": float(camera_target_z), "x_range": [float(field.xoff), float(field.xoff + field.xreal)], "y_range": [float(field.yoff), float(field.yoff + field.yreal)], + "surface_extent_x": float(surface_extent[0]), + "surface_extent_y": float(surface_extent[1]), } - if View3D._broadcast_mesh_fn is not None: - View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data) + emit_mesh(mesh_data) annotation_context = _annotation_context_from_field(color_field, resolved_colormap) annotation_context["xreal"] = float(field.xreal) @@ -225,6 +264,9 @@ class View3D: "azimuth": float(camera_azimuth), "polar": float(camera_polar), "distance": float(camera_distance), + "target_x": float(camera_target_x), + "target_y": float(camera_target_y), + "target_z": float(camera_target_z), }, }, ) diff --git a/backend/server.py b/backend/server.py index 4426c3d..bd702d7 100644 --- a/backend/server.py +++ b/backend/server.py @@ -6,7 +6,11 @@ Routes GET / → serve frontend/index.html GET /static/{path} → serve frontend JS/CSS GET /nodes → JSON dict of all registered node definitions -POST /upload → multipart file upload to input/ +GET /files → list files in the current session upload workspace +GET /folder-files → list compatible files in a picked folder +GET /channels → inspect channels for a picked file +POST /upload → multipart file upload to the current session workspace +POST /upload-folder → create a folder in the current session workspace POST /prompt → submit a workflow; returns {prompt_id} GET /ws → WebSocket upgrade @@ -15,7 +19,7 @@ WebSocket message types sent to clients {"type": "execution_start", "data": {"prompt_id": "..."}} {"type": "executing", "data": {"node": "...", "prompt_id": "..."}} {"type": "preview", "data": {"node_id": "...", "image": "data:..."}} -{"type": "table", "data": {"node_id": "...", "rows": [...]}} +{"type": "table", "data": {"node_id": "...", "rows": [...]} } {"type": "scalar", "data": {"node_id": "...", "value": 1.23, "unit": "nm"}} {"type": "node_timing", "data": {"node_id": "...", "elapsed_ms": 12.34}} {"type": "execution_error", "data": {"node_id": "...", "message": "..."}} @@ -23,39 +27,43 @@ WebSocket message types sent to clients """ from __future__ import annotations + import asyncio import json import logging import sys +from collections import defaultdict +from copy import deepcopy from pathlib import Path from aiohttp import web, WSMsgType + from backend.frontend_build import FrontendBuildError, ensure_frontend_dist_ready -from backend.runtime_paths import ( - ensure_runtime_dirs, - frontend_dir, - frontend_dist_dir, - input_dir, - output_dir, - project_root, +from backend.runtime_paths import ensure_runtime_dirs, frontend_dir, frontend_dist_dir, project_root +from backend.session_runtime import ( + PATH_INPUT_TYPES, + SESSION_HEADER, + SESSION_QUERY, + ensure_session_runtime_dirs, + normalize_relative_upload_path, + resolve_client_path, + server_path_to_client_path, + session_input_dir, + session_upload_uri, + validate_session_id, ) log = logging.getLogger(__name__) FRONTEND_DIR = frontend_dir() DIST_DIR = frontend_dist_dir() -INPUT_DIR = input_dir() -OUTPUT_DIR = output_dir() PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" -# --------------------------------------------------------------------------- -# JSON helper — numpy scalars are not serialisable by default -# --------------------------------------------------------------------------- - class _SafeEncoder(json.JSONEncoder): def default(self, obj): import numpy as np + if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): @@ -81,45 +89,115 @@ def save_png_bytes(target_path: str, payload: bytes) -> Path: return path -# --------------------------------------------------------------------------- -# Application factory -# --------------------------------------------------------------------------- - -def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: - # Import nodes to trigger registration decorators +def create_app( + loop: asyncio.AbstractEventLoop, + *, + allow_local_filesystem: bool = False, +) -> web.Application: import backend.nodes # noqa: F401 - from backend.node_registry import get_all_node_info from backend.execution import ExecutionEngine, new_prompt_id + from backend.node_registry import NODE_CLASS_MAPPINGS, get_all_node_info ensure_runtime_dirs() - engine = ExecutionEngine() - websockets: set[web.WebSocketResponse] = set() + session_engines: dict[str, ExecutionEngine] = {} + session_websockets: dict[str, set[web.WebSocketResponse]] = defaultdict(set) - # ------------------------------------------------------------------ - # WebSocket broadcast helpers - # ------------------------------------------------------------------ + def _is_link(value) -> bool: + return ( + isinstance(value, (list, tuple)) + and len(value) == 2 + and isinstance(value[0], str) + and isinstance(value[1], int) + ) - def broadcast(msg: dict) -> None: - """Schedule a broadcast to all connected WebSocket clients.""" + def require_session_id(request: web.Request) -> str: + raw_session = request.headers.get(SESSION_HEADER) or request.query.get(SESSION_QUERY) + if not raw_session: + if allow_local_filesystem: + raw_session = "desktop-local-session" + else: + raise web.HTTPBadRequest(reason="Missing session id") + + try: + session_id = validate_session_id(raw_session) + except ValueError as exc: + raise web.HTTPBadRequest(reason=str(exc)) from exc + + ensure_session_runtime_dirs(session_id) + return session_id + + def get_session_engine(session_id: str) -> ExecutionEngine: + engine = session_engines.get(session_id) + if engine is None: + engine = ExecutionEngine() + session_engines[session_id] = engine + return engine + + def resolve_request_path(session_id: str, raw_value: str) -> Path: + try: + return resolve_client_path( + raw_value, + session_id=session_id, + allow_local_filesystem=allow_local_filesystem, + ) + except PermissionError as exc: + raise web.HTTPForbidden(reason=str(exc)) from exc + except ValueError as exc: + raise web.HTTPBadRequest(reason=str(exc)) from exc + + def rewrite_prompt_paths(prompt: dict, session_id: str) -> dict: + normalized = deepcopy(prompt) + for node_def in normalized.values(): + class_name = node_def.get("class_type") + cls = NODE_CLASS_MAPPINGS.get(class_name) + if cls is None: + continue + + input_types = cls.INPUT_TYPES() + specs = {} + specs.update(input_types.get("required", {})) + specs.update(input_types.get("optional", {})) + + inputs = node_def.get("inputs", {}) + if not isinstance(inputs, dict): + continue + + for input_name, raw_value in list(inputs.items()): + if _is_link(raw_value) or not isinstance(raw_value, str): + continue + if not raw_value.strip(): + continue + + spec = specs.get(input_name) + input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec + if not isinstance(input_type, str): + continue + if input_type not in PATH_INPUT_TYPES: + continue + + inputs[input_name] = str(resolve_request_path(session_id, raw_value)) + return normalized + + def broadcast(session_id: str, msg: dict) -> None: payload = _dumps(msg) - for ws in list(websockets): + for ws in list(session_websockets.get(session_id, ())): if not ws.closed: asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop) - def on_preview(node_id: str, data_uri: str) -> None: - broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}}) + def on_preview(session_id: str, node_id: str, data_uri: str) -> None: + broadcast(session_id, {"type": "preview", "data": {"node_id": node_id, "image": data_uri}}) - def on_table(node_id: str, rows: list) -> None: - broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}}) + def on_table(session_id: str, node_id: str, rows: list) -> None: + broadcast(session_id, {"type": "table", "data": {"node_id": node_id, "rows": rows}}) - def on_mesh(node_id: str, mesh_data: dict) -> None: - broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}}) + def on_mesh(session_id: str, node_id: str, mesh_data: dict) -> None: + broadcast(session_id, {"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}}) - def on_overlay(node_id: str, overlay_data) -> None: - broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) + def on_overlay(session_id: str, node_id: str, overlay_data) -> None: + broadcast(session_id, {"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) - def on_value(node_id: str, payload) -> None: + def on_value(session_id: str, node_id: str, payload) -> None: if isinstance(payload, dict): value = payload.get("value") unit = payload.get("unit", "") @@ -130,14 +208,10 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: data = {"node_id": node_id, "value": value} if isinstance(unit, str) and unit.strip(): data["unit"] = unit.strip() - broadcast({"type": "scalar", "data": data}) + broadcast(session_id, {"type": "scalar", "data": data}) - def on_warning(node_id: str, message: str) -> None: - broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}}) - - # ------------------------------------------------------------------ - # Route handlers - # ------------------------------------------------------------------ + def on_warning(session_id: str, node_id: str, message: str) -> None: + broadcast(session_id, {"type": "node_warning", "data": {"node_id": node_id, "message": message}}) async def index(request: web.Request) -> web.Response: if not getattr(sys, "frozen", False): @@ -167,88 +241,96 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: ) async def get_nodes(request: web.Request) -> web.Response: - info = get_all_node_info() return web.Response( - text=_dumps(info), + text=_dumps(get_all_node_info()), content_type="application/json", ) async def list_files(request: web.Request) -> web.Response: - """List files in the input/ directory for the file picker widget.""" + session_id = require_session_id(request) + input_path = session_input_dir(session_id) files = sorted( - f.name for f in INPUT_DIR.iterdir() - if f.is_file() and not f.name.startswith(".") - ) if INPUT_DIR.exists() else [] + server_path_to_client_path(entry, session_id) + for entry in input_path.iterdir() + if entry.is_file() and not entry.name.startswith(".") + ) if input_path.exists() else [] return web.Response(text=_dumps(files), content_type="application/json") - async def browse_dir(request: web.Request) -> web.Response: - """ - Server-side directory browser for local file picking. - GET /browse?dir=/some/path → {parent, dirs[], files[]} - """ - dir_path = request.query.get("dir", str(Path.home())) - p = Path(dir_path).expanduser().resolve() - - if not p.is_dir(): - raise web.HTTPBadRequest(reason=f"Not a directory: {p}") - - dirs = [] - files = [] - try: - for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()): - if entry.name.startswith("."): - continue - if entry.is_dir(): - dirs.append(entry.name) - elif entry.is_file(): - files.append(entry.name) - except PermissionError: - pass - + async def create_upload_folder(request: web.Request) -> web.Response: + session_id = require_session_id(request) + body = await request.json() + relative_path = normalize_relative_upload_path(body.get("path", "")) + target = session_input_dir(session_id) / Path(relative_path.as_posix()) + target.mkdir(parents=True, exist_ok=True) return web.Response( - text=_dumps({ - "path": str(p), - "parent": str(p.parent) if p.parent != p else None, - "dirs": dirs, - "files": files, - }), + text=_dumps({"path": session_upload_uri(relative_path)}), content_type="application/json", ) async def get_folder_files(request: web.Request) -> web.Response: - folder_path = request.query.get("folder", "") from backend.nodes.helpers import list_folder_paths - loop = asyncio.get_running_loop() - entries = await loop.run_in_executor(None, list_folder_paths, folder_path) - return web.Response(text=_dumps(entries), content_type="application/json") + + session_id = require_session_id(request) + folder_path = request.query.get("folder", "") + if not folder_path: + return web.Response(text=_dumps([]), content_type="application/json") + + resolved_path = resolve_request_path(session_id, folder_path) + running_loop = asyncio.get_running_loop() + entries = await running_loop.run_in_executor(None, list_folder_paths, str(resolved_path)) + + payload = [] + for entry in entries: + mapped = dict(entry) + if "path" in mapped: + mapped["path"] = server_path_to_client_path(mapped["path"], session_id) + payload.append(mapped) + return web.Response(text=_dumps(payload), content_type="application/json") async def upload_file(request: web.Request) -> web.Response: + session_id = require_session_id(request) reader = await request.multipart() - field = await reader.next() - if field is None or field.name != "file": + relative_path = None + filename = "" + file_bytes = None + + while True: + field = await reader.next() + if field is None: + break + if field.name == "relative_path": + relative_path = await field.text() + continue + if field.name == "file": + filename = Path(field.filename or "upload.bin").name + chunks = [] + while True: + chunk = await field.read_chunk(65536) + if not chunk: + break + chunks.append(chunk) + file_bytes = b"".join(chunks) + + if file_bytes is None: raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body") - filename = Path(field.filename).name # strip any path traversal - dest = INPUT_DIR / filename - with open(dest, "wb") as f: - while True: - chunk = await field.read_chunk(65536) - if not chunk: - break - f.write(chunk) + relative = normalize_relative_upload_path(relative_path or filename) + dest = session_input_dir(session_id) / Path(relative.as_posix()) + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(file_bytes) - return web.Response(text=_dumps({"filename": filename}), content_type="application/json") + return web.Response( + text=_dumps({"filename": filename, "path": session_upload_uri(relative)}), + content_type="application/json", + ) async def download_file(request: web.Request) -> web.Response: - """Accept a blob POST and return it with Content-Disposition: attachment.""" body = await request.read() filename = request.query.get("filename", "workflow.png") return web.Response( body=body, content_type="application/octet-stream", - headers={ - "Content-Disposition": f'attachment; filename="{filename}"', - }, + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) async def save_workflow_png(request: web.Request) -> web.Response: @@ -266,34 +348,39 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: ) async def get_channels(request: web.Request) -> web.Response: - """Return available channels for a given file path.""" from backend.nodes.helpers import list_channels + + session_id = require_session_id(request) filepath = request.query.get("file", "") if not filepath: return web.Response( text=_dumps([{"name": "field", "type": "DATA_FIELD"}]), content_type="application/json", ) - channels = await loop.run_in_executor(None, list_channels, filepath) + + resolved_path = resolve_request_path(session_id, filepath) + channels = await loop.run_in_executor(None, list_channels, str(resolved_path)) return web.Response(text=_dumps(channels), content_type="application/json") async def submit_prompt(request: web.Request) -> web.Response: + session_id = require_session_id(request) body = await request.json() prompt = body.get("prompt") if not isinstance(prompt, dict) or not prompt: raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict") + normalized_prompt = rewrite_prompt_paths(prompt, session_id) prompt_id = new_prompt_id() + engine = get_session_engine(session_id) - # Run execution in a thread pool so scipy doesn't block the event loop async def run(): - broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}}) + broadcast(session_id, {"type": "execution_start", "data": {"prompt_id": prompt_id}}) def on_start(node_id: str) -> None: - broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}}) + broadcast(session_id, {"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}}) def on_done(node_id: str, elapsed_ms: float) -> None: - broadcast({ + broadcast(session_id, { "type": "node_timing", "data": {"node_id": node_id, "elapsed_ms": elapsed_ms}, }) @@ -302,21 +389,21 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: await loop.run_in_executor( None, lambda: engine.execute( - prompt, + normalized_prompt, on_node_start=on_start, on_node_done=on_done, - on_preview=on_preview, - on_table=on_table, - on_mesh=on_mesh, - on_overlay=on_overlay, - on_value=on_value, - on_warning=on_warning, + on_preview=lambda node_id, payload: on_preview(session_id, node_id, payload), + on_table=lambda node_id, rows: on_table(session_id, node_id, rows), + on_mesh=lambda node_id, mesh_data: on_mesh(session_id, node_id, mesh_data), + on_overlay=lambda node_id, overlay_data: on_overlay(session_id, node_id, overlay_data), + on_value=lambda node_id, payload: on_value(session_id, node_id, payload), + on_warning=lambda node_id, message: on_warning(session_id, node_id, message), ), ) - broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}}) + broadcast(session_id, {"type": "execution_complete", "data": {"prompt_id": prompt_id}}) except Exception as exc: log.exception("Execution error") - broadcast({ + broadcast(session_id, { "type": "execution_error", "data": {"node_id": "", "message": str(exc)}, }) @@ -328,32 +415,40 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: ) async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + session_id = require_session_id(request) ws = web.WebSocketResponse() await ws.prepare(request) - websockets.add(ws) - log.info("WebSocket client connected (%d total)", len(websockets)) + session_websockets[session_id].add(ws) + log.info( + "WebSocket client connected for session %s (%d total in session)", + session_id, + len(session_websockets[session_id]), + ) try: async for msg in ws: if msg.type == WSMsgType.TEXT: - pass # clients don't need to send anything currently + pass elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): break finally: - websockets.discard(ws) - log.info("WebSocket client disconnected (%d total)", len(websockets)) + session_websockets[session_id].discard(ws) + if not session_websockets[session_id]: + session_websockets.pop(session_id, None) + log.info( + "WebSocket client disconnected for session %s (%d remaining in session)", + session_id, + len(session_websockets.get(session_id, ())), + ) return ws - # ------------------------------------------------------------------ - # App assembly - # ------------------------------------------------------------------ - app = web.Application() + app["allow_local_filesystem"] = allow_local_filesystem app.router.add_get("/", index) app.router.add_get("/nodes", get_nodes) app.router.add_get("/files", list_files) - app.router.add_get("/browse", browse_dir) app.router.add_get("/folder-files", get_folder_files) + app.router.add_post("/upload-folder", create_upload_folder) app.router.add_post("/upload", upload_file) app.router.add_post("/download", download_file) app.router.add_post("/save-workflow-png", save_workflow_png) @@ -361,26 +456,24 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: app.router.add_post("/prompt", submit_prompt) app.router.add_get("/ws", websocket_handler) - # Serve frontend static files (Vite build or raw) if (DIST_DIR / "assets").exists(): app.router.add_static("/assets", DIST_DIR / "assets") if FRONTEND_DIR.exists(): app.router.add_static("/static", FRONTEND_DIR) - # CORS — allow any origin (local dev only) async def _cors_middleware(app_, handler): async def middleware(request): if request.method == "OPTIONS": return web.Response(headers={ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Allow-Headers": f"Content-Type, {SESSION_HEADER}", }) response = await handler(request) response.headers["Access-Control-Allow-Origin"] = "*" return response + return middleware app.middlewares.append(_cors_middleware) - return app diff --git a/backend/session_runtime.py b/backend/session_runtime.py new file mode 100644 index 0000000..803ce54 --- /dev/null +++ b/backend/session_runtime.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import re +from pathlib import Path, PurePosixPath + +from backend.runtime_paths import app_data_dir, demo_dir + +SESSION_HEADER = "X-Argonode-Session" +SESSION_QUERY = "session" +SESSION_URI_PREFIX = "session://uploads/" + +PATH_INPUT_TYPES = {"FILE_PICKER", "FILE_PATH", "FOLDER_PICKER", "DIRECTORY"} + +_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_-]{7,127}$") + + +def validate_session_id(session_id: str) -> str: + text = str(session_id or "").strip() + if not _SESSION_ID_RE.fullmatch(text): + raise ValueError("Invalid session id") + return text + + +def session_root_dir(session_id: str) -> Path: + validated = validate_session_id(session_id) + return app_data_dir() / "sessions" / validated + + +def session_input_dir(session_id: str) -> Path: + return session_root_dir(session_id) / "input" + + +def session_output_dir(session_id: str) -> Path: + return session_root_dir(session_id) / "output" + + +def ensure_session_runtime_dirs(session_id: str) -> tuple[Path, Path]: + input_path = session_input_dir(session_id) + output_path = session_output_dir(session_id) + input_path.mkdir(parents=True, exist_ok=True) + output_path.mkdir(parents=True, exist_ok=True) + return input_path, output_path + + +def normalize_relative_upload_path(raw_path: str) -> PurePosixPath: + raw_text = str(raw_path or "").replace("\\", "/").strip() + if not raw_text: + raise ValueError("Missing upload path") + + path = PurePosixPath(raw_text) + if path.is_absolute(): + raise ValueError("Upload paths must be relative") + + parts: list[str] = [] + for part in path.parts: + if part in ("", "."): + continue + if part == "..": + raise ValueError("Upload paths cannot escape the session directory") + if "\x00" in part: + raise ValueError("Upload paths cannot contain NUL bytes") + parts.append(part) + + if not parts: + raise ValueError("Upload paths must contain at least one path segment") + + return PurePosixPath(*parts) + + +def session_upload_uri(relative_path: str | PurePosixPath) -> str: + normalized = normalize_relative_upload_path(str(relative_path)) + return f"{SESSION_URI_PREFIX}{normalized.as_posix()}" + + +def session_uri_to_relative_path(value: str) -> PurePosixPath | None: + text = str(value or "").strip() + if not text.startswith(SESSION_URI_PREFIX): + return None + return normalize_relative_upload_path(text[len(SESSION_URI_PREFIX):]) + + +def is_path_within(root: Path, candidate: Path) -> bool: + try: + candidate.resolve(strict=False).relative_to(root.resolve(strict=False)) + return True + except ValueError: + return False + + +def server_path_to_client_path(path_value: str | Path, session_id: str) -> str: + path = Path(path_value).expanduser().resolve(strict=False) + session_input = session_input_dir(session_id).resolve(strict=False) + if is_path_within(session_input, path): + rel = path.relative_to(session_input) + return session_upload_uri(rel.as_posix()) + return str(path) + + +def resolve_client_path( + value: str, + *, + session_id: str, + allow_local_filesystem: bool, +) -> Path: + text = str(value or "").strip() + if not text: + return Path("") + + rel = session_uri_to_relative_path(text) + if rel is not None: + return (session_input_dir(session_id) / Path(rel.as_posix())).resolve(strict=False) + + candidate = Path(text).expanduser() + if not candidate.is_absolute(): + demo_candidate = (demo_dir() / text).expanduser().resolve(strict=False) + if demo_candidate.exists(): + return demo_candidate + + if not candidate.is_absolute(): + if allow_local_filesystem: + return candidate.resolve(strict=False) + raise PermissionError("Browser sessions may only use files uploaded through Browse.") + + resolved = candidate.resolve(strict=False) + if allow_local_filesystem: + return resolved + + session_root = session_root_dir(session_id).resolve(strict=False) + if is_path_within(session_root, resolved): + return resolved + + raise PermissionError("Path is outside the current session workspace.") diff --git a/demo/APL_Figure4.ibw b/demo/APL_Figure4.ibw new file mode 100644 index 0000000..a454c87 Binary files /dev/null and b/demo/APL_Figure4.ibw differ diff --git a/demo/BR_New20012.ibw b/demo/BR_New20012.ibw deleted file mode 100644 index d0988b6..0000000 Binary files a/demo/BR_New20012.ibw and /dev/null differ diff --git a/demo/Calcite0012.ibw b/demo/Calcite0012.ibw new file mode 100644 index 0000000..95db20a Binary files /dev/null and b/demo/Calcite0012.ibw differ diff --git a/demo/DNA1305.ibw b/demo/DNA1305.ibw new file mode 100644 index 0000000..33d5ded Binary files /dev/null and b/demo/DNA1305.ibw differ diff --git a/demo/Grat500nm0d0602.ibw b/demo/Grat500nm0d0602.ibw new file mode 100644 index 0000000..bdb0c2e Binary files /dev/null and b/demo/Grat500nm0d0602.ibw differ diff --git a/demo/Image0002.ibw b/demo/Image0002.ibw new file mode 100644 index 0000000..f70dc70 Binary files /dev/null and b/demo/Image0002.ibw differ diff --git a/demo/LB_media0002.ibw b/demo/LB_media0002.ibw new file mode 100644 index 0000000..f95f3e7 Binary files /dev/null and b/demo/LB_media0002.ibw differ diff --git a/demo/PAbacteria0007.ibw b/demo/PAbacteria0007.ibw new file mode 100644 index 0000000..fbe175d Binary files /dev/null and b/demo/PAbacteria0007.ibw differ diff --git a/demo/PMNJupiter0006.ibw b/demo/PMNJupiter0006.ibw new file mode 100644 index 0000000..a19fd04 Binary files /dev/null and b/demo/PMNJupiter0006.ibw differ diff --git a/demo/PP_PS_b_0000.ibw b/demo/PP_PS_b_0000.ibw new file mode 100644 index 0000000..46ef926 Binary files /dev/null and b/demo/PP_PS_b_0000.ibw differ diff --git a/demo/PZTJupiter0001.ibw b/demo/PZTJupiter0001.ibw new file mode 100644 index 0000000..c85c66a Binary files /dev/null and b/demo/PZTJupiter0001.ibw differ diff --git a/demo/Poly1u0d1101.ibw b/demo/Poly1u0d1101.ibw new file mode 100644 index 0000000..bbf1402 Binary files /dev/null and b/demo/Poly1u0d1101.ibw differ diff --git a/demo/tBLG_057_0032.ibw b/demo/tBLG_057_0032.ibw new file mode 100644 index 0000000..bc770c8 Binary files /dev/null and b/demo/tBLG_057_0032.ibw differ diff --git a/desktop.py b/desktop.py index e77c683..1528826 100644 --- a/desktop.py +++ b/desktop.py @@ -44,6 +44,19 @@ class _Api: return result[0] return None + def open_folder_dialog(self) -> str | None: + """Open a native folder picker and return the selected path (or None).""" + win = self._window_ref[0] + if win is None: + return None + result = win.create_file_dialog( + webview.FOLDER_DIALOG, + allow_multiple=False, + ) + if result and len(result) > 0: + return result[0] + return None + def choose_save_workflow_png_path(self, default_filename: str = "workflow.png") -> str | None: """Open a native save dialog and return the chosen PNG path (or None).""" win = self._window_ref[0] @@ -90,7 +103,7 @@ def _run_server(host: str, port: int, ready: threading.Event, state: dict[str, o state["loop"] = loop async def start() -> None: - app = create_app(loop) + app = create_app(loop, allow_local_filesystem=True) runner = web.AppRunner(app, access_log=None) await runner.setup() site = web.TCPSite(runner, host, port) diff --git a/frontend/package.json b/frontend/package.json index ab03857..61b0016 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -7,8 +7,8 @@ "npm": ">=9.0.0" }, "scripts": { - "dev": "vite", - "build": "vite build", + "dev": "vite --force", + "build": "vite build --emptyOutDir", "preview": "vite preview", "test": "node --test tests/**/*.test.mjs" }, diff --git a/frontend/public/workflow.png b/frontend/public/workflow.png new file mode 100644 index 0000000..0502ed9 Binary files /dev/null and b/frontend/public/workflow.png differ diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 3e37f7b..09dd94a 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -4,18 +4,19 @@ import React, { import { ReactFlow, Background, Controls, MiniMap, useNodesState, useEdgesState, addEdge, useReactFlow, - ReactFlowProvider, getViewportForBounds, + ReactFlowProvider, getViewportForBounds, PanOnScrollMode, SelectionMode, } from '@xyflow/react'; import '@xyflow/react/dist/style.css'; import CustomNode, { NodeContext } from './CustomNode'; -import FileBrowser from './FileBrowser'; import * as api from './api'; +import { pickNativeDirectorySelection, pickNativeFileSelection } from './nativePicker'; import { toBlob } from 'html-to-image'; import { embedWorkflow, extractWorkflow } from './pngMetadata'; import { captureViewportBlob as captureWorkflowViewportBlob } from './workflowCapture'; import { hydrateWorkflowState } from './workflowHydration'; import { serializeWorkflowState } from './workflowSerialization'; +import { sortNodesForParentOrder } from './nodeHierarchy.js'; import { buildNodeClipboardPayload, buildNodeClipboardPayloadForIds, @@ -36,6 +37,17 @@ import { const NODE_TYPES = { custom: CustomNode }; +const GROUP_PADDING_X = 24; +const GROUP_PADDING_Y = 24; +const GROUP_HEADER_HEIGHT = 36; +const GROUP_WORKSPACE_INSET = 12; +const GROUP_MIN_WIDTH = 260; +const GROUP_MIN_HEIGHT = 180; +const CANVAS_MIN_ZOOM = 0.2; +const CANVAS_MAX_ZOOM = 4; +const CANVAS_RIGHT_DRAG_ZOOM_SENSITIVITY = 0.0065; +const CANVAS_RIGHT_DRAG_ZOOM_THRESHOLD = 5; + // ── Handle ID helpers ───────────────────────────────────────────────── function getHandleType(handleId) { @@ -50,6 +62,316 @@ function getOutputSlot(handleId) { return parseInt(handleId.split('::')[1], 10); } +function encodeProxyHandleRef(handleId) { + return encodeURIComponent(String(handleId || '')); +} + +function decodeProxyHandleRef(encoded) { + try { + return decodeURIComponent(String(encoded || '')); + } catch { + return String(encoded || ''); + } +} + +function parseGroupProxyHandle(handleId) { + const text = String(handleId || ''); + if (!text.startsWith('group-proxy::')) return null; + const parts = text.split('::'); + if (parts.length < 5) return null; + return { + direction: parts[1], + nodeId: parts[2], + type: parts[3], + realHandle: decodeProxyHandleRef(parts.slice(4).join('::')), + }; +} + +function getConnectionHandleType(handleId) { + const proxy = parseGroupProxyHandle(handleId); + return proxy?.type || getHandleType(handleId); +} + +function getNodeDimension(node, axis) { + if (axis === 'width') return node.measured?.width || node.style?.width || node.width || 200; + return node.measured?.height || node.style?.height || node.height || 120; +} + +function applyNodeSize(node, width, height) { + const nextWidth = Math.round(Number(width) || 0); + const nextHeight = Math.round(Number(height) || 0); + return { + ...node, + width: nextWidth, + height: nextHeight, + style: { ...(node.style || {}), width: nextWidth, height: nextHeight }, + }; +} + +function getNodeAbsolutePosition(node, nodeMap) { + if (node?.positionAbsolute) { + return { + x: Number(node.positionAbsolute.x) || 0, + y: Number(node.positionAbsolute.y) || 0, + }; + } + const local = { + x: Number(node?.position?.x) || 0, + y: Number(node?.position?.y) || 0, + }; + if (!node?.parentId) return local; + const parent = nodeMap.get(String(node.parentId)); + if (!parent) return local; + const parentPos = getNodeAbsolutePosition(parent, nodeMap); + return { x: parentPos.x + local.x, y: parentPos.y + local.y }; +} + +function collectGroupDescendantIds(nodes, groupId) { + const allNodes = Array.isArray(nodes) ? nodes : []; + const result = new Set(); + let changed = true; + while (changed) { + changed = false; + for (const node of allNodes) { + const parentId = node?.parentId ? String(node.parentId) : null; + const nodeId = String(node?.id); + if (!parentId) continue; + if ((parentId === String(groupId) || result.has(parentId)) && !result.has(nodeId)) { + result.add(nodeId); + changed = true; + } + } + } + return result; +} + +function getGroupMembers(nodes, groupId) { + const descendants = collectGroupDescendantIds(nodes, groupId); + return Array.from(descendants); +} + +function getGroupDisplayBounds(nodes, selectedIds) { + const nodeMap = new Map((nodes || []).map((node) => [String(node.id), node])); + let minX = Infinity; + let minY = Infinity; + let maxX = -Infinity; + let maxY = -Infinity; + + for (const id of selectedIds) { + const node = nodeMap.get(String(id)); + if (!node) continue; + const pos = getNodeAbsolutePosition(node, nodeMap); + const width = Number(getNodeDimension(node, 'width')) || 200; + const height = Number(getNodeDimension(node, 'height')) || 120; + minX = Math.min(minX, pos.x); + minY = Math.min(minY, pos.y); + maxX = Math.max(maxX, pos.x + width); + maxY = Math.max(maxY, pos.y + height); + } + + if (!Number.isFinite(minX) || !Number.isFinite(minY) || !Number.isFinite(maxX) || !Number.isFinite(maxY)) { + return null; + } + + return { minX, minY, maxX, maxY }; +} + +function getGroupWorkspaceBounds(groupNode, nodeMap) { + const pos = getNodeAbsolutePosition(groupNode, nodeMap); + const width = Number(getNodeDimension(groupNode, 'width')) || 200; + const height = Number(getNodeDimension(groupNode, 'height')) || 120; + return { + left: pos.x + GROUP_WORKSPACE_INSET, + top: pos.y + GROUP_HEADER_HEIGHT + GROUP_WORKSPACE_INSET, + right: pos.x + width - GROUP_WORKSPACE_INSET, + bottom: pos.y + height - GROUP_WORKSPACE_INSET, + }; +} + +function getNodeCenter(node, nodeMap) { + const pos = getNodeAbsolutePosition(node, nodeMap); + const width = Number(getNodeDimension(node, 'width')) || 200; + const height = Number(getNodeDimension(node, 'height')) || 120; + return { + x: pos.x + width / 2, + y: pos.y + height / 2, + }; +} + +function getNodeRect(node, nodeMap) { + const pos = getNodeAbsolutePosition(node, nodeMap); + const width = Number(getNodeDimension(node, 'width')) || 200; + const height = Number(getNodeDimension(node, 'height')) || 120; + return { + left: pos.x, + top: pos.y, + right: pos.x + width, + bottom: pos.y + height, + }; +} + +function getAbsoluteRectForNodePosition(node, absolutePosition) { + const width = Number(getNodeDimension(node, 'width')) || 200; + const height = Number(getNodeDimension(node, 'height')) || 120; + return { + left: absolutePosition.x, + top: absolutePosition.y, + right: absolutePosition.x + width, + bottom: absolutePosition.y + height, + }; +} + +function rectContainsPoint(rect, point) { + return point.x >= rect.left + && point.x <= rect.right + && point.y >= rect.top + && point.y <= rect.bottom; +} + +function rectContainsRect(outerRect, innerRect) { + return innerRect.left >= outerRect.left + && innerRect.top >= outerRect.top + && innerRect.right <= outerRect.right + && innerRect.bottom <= outerRect.bottom; +} + +function getEventClientPosition(event) { + if (!event) return null; + const point = 'changedTouches' in event && event.changedTouches?.[0] + ? event.changedTouches[0] + : ('touches' in event && event.touches?.[0] ? event.touches[0] : event); + if (!Number.isFinite(point?.clientX) || !Number.isFinite(point?.clientY)) return null; + return { x: point.clientX, y: point.clientY }; +} + +function getEventFlowPosition(event, reactFlow) { + const clientPosition = getEventClientPosition(event); + if (!clientPosition || typeof reactFlow?.screenToFlowPosition !== 'function') return null; + return reactFlow.screenToFlowPosition(clientPosition); +} + +function getDragIntent(event, reactFlow, dragState) { + if (!dragState?.pointerOffset || !dragState?.anchorStartAbsolute) return null; + const pointerFlowPos = getEventFlowPosition(event, reactFlow); + if (!pointerFlowPos) return null; + + const anchorAbsolute = { + x: pointerFlowPos.x - dragState.pointerOffset.x, + y: pointerFlowPos.y - dragState.pointerOffset.y, + }; + const delta = { + x: anchorAbsolute.x - (Number(dragState.anchorStartAbsolute.x) || 0), + y: anchorAbsolute.y - (Number(dragState.anchorStartAbsolute.y) || 0), + }; + const absolutePositions = new Map( + Object.entries(dragState.absolutePositions || {}).map(([id, pos]) => [ + id, + { + x: (Number(pos?.x) || 0) + delta.x, + y: (Number(pos?.y) || 0) + delta.y, + }, + ]), + ); + + return { + pointerFlowPos, + anchorAbsolute, + absolutePositions, + }; +} + +function findExpandedGroupDropTarget(nodes, draggedNodeIds, anchorNodeId, anchorPoint = null) { + const nodeMap = new Map((nodes || []).map((node) => [String(node.id), node])); + const anchorNode = nodeMap.get(String(anchorNodeId)); + if (!anchorNode) return null; + + const draggedIdSet = new Set((draggedNodeIds || []).map((id) => String(id))); + const anchorCenter = anchorPoint && Number.isFinite(anchorPoint.x) && Number.isFinite(anchorPoint.y) + ? anchorPoint + : getNodeCenter(anchorNode, nodeMap); + + return (nodes || []) + .filter((node) => ( + node?.data?.className === 'Group' + && !node?.data?.collapsed + && !draggedIdSet.has(String(node.id)) + )) + .map((node) => { + const rect = getGroupWorkspaceBounds(node, nodeMap); + return { + node, + rect, + area: Math.max(1, rect.right - rect.left) * Math.max(1, rect.bottom - rect.top), + }; + }) + .filter(({ rect }) => rectContainsPoint(rect, anchorCenter)) + .sort((a, b) => a.area - b.area)[0]?.node || null; +} + +function getInputLabelForNode(node, inputName) { + const inputs = { + ...(node?.data?.definition?.input?.required || {}), + ...(node?.data?.definition?.input?.optional || {}), + }; + const spec = inputs[inputName]; + if (!spec) return inputName; + const [, opts] = Array.isArray(spec) ? spec : [spec, {}]; + return opts?.label || inputName; +} + +function getOutputLabelForNode(node, slot, handleId) { + const outputNames = node?.data?.definition?.output_name || []; + const outputTypes = node?.data?.definition?.output || []; + if (Number.isInteger(slot) && outputNames[slot]) return outputNames[slot]; + const proxy = parseGroupProxyHandle(handleId); + return proxy?.realHandle ? getOutputLabelForNode(node, getOutputSlot(proxy.realHandle), proxy.realHandle) : outputTypes[slot] || 'output'; +} + +function buildGroupProxyData(groupId, nodes, edges) { + const nodeMap = new Map((nodes || []).map((node) => [String(node.id), node])); + const memberIds = new Set(getGroupMembers(nodes, groupId)); + const proxyInputs = []; + const proxyOutputs = []; + const seenInputs = new Set(); + const seenOutputs = new Set(); + + for (const edge of edges || []) { + const original = edge?.data?.groupProxyOriginal || {}; + const sourceId = String(original.source || edge.source); + const targetId = String(original.target || edge.target); + const sourceHandle = original.sourceHandle || edge.sourceHandle; + const targetHandle = original.targetHandle || edge.targetHandle; + const sourceInside = memberIds.has(sourceId); + const targetInside = memberIds.has(targetId); + + if (!sourceInside && targetInside) { + const key = `${targetId}::${targetHandle}`; + if (seenInputs.has(key)) continue; + seenInputs.add(key); + proxyInputs.push({ + key, + type: getHandleType(targetHandle), + label: getInputLabelForNode(nodeMap.get(targetId), getInputName(targetHandle)), + handleId: `group-proxy::in::${targetId}::${getHandleType(targetHandle)}::${encodeProxyHandleRef(targetHandle)}`, + }); + } + + if (sourceInside && !targetInside) { + const key = `${sourceId}::${sourceHandle}`; + if (seenOutputs.has(key)) continue; + seenOutputs.add(key); + proxyOutputs.push({ + key, + type: getHandleType(sourceHandle), + label: getOutputLabelForNode(nodeMap.get(sourceId), getOutputSlot(sourceHandle), sourceHandle), + handleId: `group-proxy::out::${sourceId}::${getHandleType(sourceHandle)}::${encodeProxyHandleRef(sourceHandle)}`, + }); + } + } + + return { proxyInputs, proxyOutputs, childCount: memberIds.size }; +} + function sameStringArray(a = [], b = []) { if (a === b) return true; if (!Array.isArray(a) || !Array.isArray(b) || a.length !== b.length) return false; @@ -62,6 +384,19 @@ function isEditableTarget(target) { return target.closest('[contenteditable="true"]') !== null; } +function clampNumber(value, min, max) { + return Math.min(max, Math.max(min, value)); +} + +function canStartCanvasRightDragZoom(target) { + if (!target || !(target instanceof Element)) return false; + if (isEditableTarget(target)) return false; + if (target.closest('.context-menu, .react-flow__node, .react-flow__edge, .react-flow__controls, .react-flow__minimap, .surface-view-container')) { + return false; + } + return target.closest('.react-flow__pane, .react-flow__background') !== null; +} + function compareMenuNodes(a, b) { const orderA = Number.isFinite(a?.def?.menu_order) ? a.def.menu_order : Number.MAX_SAFE_INTEGER; const orderB = Number.isFinite(b?.def?.menu_order) ? b.def.menu_order : Number.MAX_SAFE_INTEGER; @@ -243,7 +578,7 @@ async function captureViewportBlob(viewportEl, options) { // ── Context menu component ──────────────────────────────────────────── -function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection }) { +function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection, selectedNodeCount = 0, onCreateGroup = null }) { const [openCat, setOpenCat] = useState(null); const [search, setSearch] = useState(''); const menuRef = useRef(null); @@ -396,6 +731,15 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti /> + {!filterType && selectedNodeCount > 1 && typeof onCreateGroup === 'function' && ( +
{ onCreateGroup(); onClose(); }} + > + create group +
+ )} + {searchResults ? (
{searchResults.length === 0 ? ( @@ -464,8 +808,9 @@ function Flow() { const [edges, setEdges, onEdgesChange] = useEdgesState([]); const [status, setStatus] = useState({ text: 'Connecting…', level: 'info' }); const [contextMenu, setContextMenu] = useState(null); - const [fileBrowserState, setFileBrowserState] = useState(null); + const [isCanvasRightZooming, setIsCanvasRightZooming] = useState(false); + const flowContainerRef = useRef(null); const nodeDefsRef = useRef({}); const nextIdRef = useRef(1); const autoRunTimer = useRef(null); @@ -474,6 +819,10 @@ function Flow() { const lastPastedClipboardTextRef = useRef(''); const pasteRepeatCountRef = useRef(0); const duplicateDragRef = useRef(null); + const dragStateRef = useRef(null); + const activeDragNodeIdRef = useRef(null); + const canvasRightZoomRef = useRef(null); + const suppressPaneContextMenuUntilRef = useRef(0); const reactFlow = useReactFlow(); // ── WebSocket ─────────────────────────────────────────────────────── @@ -484,6 +833,290 @@ function Flow() { )); }, [setNodes]); + const refreshGroupNode = useCallback((groupId, explicitNodes = null, explicitEdges = null) => { + const currentNodes = explicitNodes || reactFlow.getNodes(); + const currentEdges = explicitEdges || reactFlow.getEdges(); + const groupNode = currentNodes.find((node) => node.id === groupId && node.data?.className === 'Group'); + if (!groupNode) return; + + const { proxyInputs, proxyOutputs, childCount } = buildGroupProxyData(groupId, currentNodes, currentEdges); + setNodes((prev) => prev.map((node) => ( + node.id !== groupId + ? node + : { + ...node, + className: 'group-shell', + data: { + ...node.data, + proxyInputs, + proxyOutputs, + childCount, + }, + } + ))); + reactFlow.updateNodeInternals(groupId); + }, [reactFlow, setNodes]); + + const toggleGroupCollapse = useCallback((groupId) => { + const currentNodes = reactFlow.getNodes(); + const currentEdges = reactFlow.getEdges(); + const groupNode = currentNodes.find((node) => node.id === groupId && node.data?.className === 'Group'); + if (!groupNode) return; + + const memberIds = new Set(getGroupMembers(currentNodes, groupId)); + const collapsed = !groupNode.data?.collapsed; + const proxyData = buildGroupProxyData(groupId, currentNodes, currentEdges); + + const nextNodes = currentNodes.map((node) => { + if (memberIds.has(String(node.id))) { + return { ...node, hidden: collapsed }; + } + if (node.id !== groupId) return node; + const expandedSize = groupNode.data?.expandedSize || { + width: Number(groupNode.style?.width) || 320, + height: Number(groupNode.style?.height) || 240, + }; + const collapsedHeight = Math.max(74, 38 + Math.max(proxyData.proxyInputs.length, proxyData.proxyOutputs.length, 1) * 24 + 26); + return { + ...applyNodeSize( + node, + collapsed ? 260 : expandedSize.width, + collapsed ? collapsedHeight : expandedSize.height, + ), + data: { + ...node.data, + collapsed, + expandedSize, + proxyInputs: proxyData.proxyInputs, + proxyOutputs: proxyData.proxyOutputs, + childCount: proxyData.childCount, + }, + }; + }); + + const nextEdges = currentEdges.map((edge) => { + if (collapsed) { + if (edge.data?.groupProxyOwner === groupId || edge.data?.groupInternalHiddenBy === groupId) { + return edge; + } + const sourceInside = memberIds.has(String(edge.source)); + const targetInside = memberIds.has(String(edge.target)); + if (sourceInside && targetInside) { + return { + ...edge, + hidden: true, + data: { ...(edge.data || {}), groupInternalHiddenBy: groupId }, + }; + } + if (!sourceInside && targetInside) { + return { + ...edge, + target: groupId, + targetHandle: `group-proxy::in::${edge.target}::${getHandleType(edge.targetHandle)}::${encodeProxyHandleRef(edge.targetHandle)}`, + data: { + ...(edge.data || {}), + groupProxyOwner: groupId, + groupProxyOriginal: { + target: edge.target, + targetHandle: edge.targetHandle, + }, + }, + }; + } + if (sourceInside && !targetInside) { + return { + ...edge, + source: groupId, + sourceHandle: `group-proxy::out::${edge.source}::${getHandleType(edge.sourceHandle)}::${encodeProxyHandleRef(edge.sourceHandle)}`, + data: { + ...(edge.data || {}), + groupProxyOwner: groupId, + groupProxyOriginal: { + source: edge.source, + sourceHandle: edge.sourceHandle, + }, + }, + }; + } + return edge; + } + + if (edge.data?.groupInternalHiddenBy === groupId) { + const nextData = { ...(edge.data || {}) }; + delete nextData.groupInternalHiddenBy; + return { + ...edge, + hidden: false, + data: Object.keys(nextData).length > 0 ? nextData : undefined, + }; + } + if (edge.data?.groupProxyOwner === groupId) { + const nextData = { ...(edge.data || {}) }; + const original = nextData.groupProxyOriginal || {}; + delete nextData.groupProxyOwner; + delete nextData.groupProxyOriginal; + return { + ...edge, + source: original.source || edge.source, + sourceHandle: original.sourceHandle || edge.sourceHandle, + target: original.target || edge.target, + targetHandle: original.targetHandle || edge.targetHandle, + data: Object.keys(nextData).length > 0 ? nextData : undefined, + }; + } + return edge; + }); + + setNodes(nextNodes); + setEdges(nextEdges); + setTimeout(() => refreshGroupNode(groupId, nextNodes, nextEdges), 0); + }, [reactFlow, refreshGroupNode, setEdges, setNodes]); + + const ungroupGroup = useCallback((groupId) => { + const currentNodes = reactFlow.getNodes(); + const currentEdges = reactFlow.getEdges(); + const nodeMap = new Map(currentNodes.map((node) => [String(node.id), node])); + const groupNode = nodeMap.get(String(groupId)); + if (!groupNode || groupNode.data?.className !== 'Group') return; + + const memberIds = new Set(getGroupMembers(currentNodes, groupId)); + const groupSelected = !!groupNode.selected; + + const nextNodes = currentNodes + .filter((node) => String(node.id) !== String(groupId)) + .map((node) => { + if (!memberIds.has(String(node.id))) return node; + const absolute = getNodeAbsolutePosition(node, nodeMap); + return { + ...node, + parentId: undefined, + extent: undefined, + hidden: false, + selected: groupSelected, + position: absolute, + }; + }); + + const nextEdges = currentEdges + .map((edge) => { + if (edge.data?.groupInternalHiddenBy === groupId) { + const nextData = { ...(edge.data || {}) }; + delete nextData.groupInternalHiddenBy; + return { + ...edge, + hidden: false, + data: Object.keys(nextData).length > 0 ? nextData : undefined, + }; + } + if (edge.data?.groupProxyOwner === groupId) { + const nextData = { ...(edge.data || {}) }; + const original = nextData.groupProxyOriginal || {}; + delete nextData.groupProxyOwner; + delete nextData.groupProxyOriginal; + return { + ...edge, + source: original.source || edge.source, + sourceHandle: original.sourceHandle || edge.sourceHandle, + target: original.target || edge.target, + targetHandle: original.targetHandle || edge.targetHandle, + hidden: false, + data: Object.keys(nextData).length > 0 ? nextData : undefined, + }; + } + return edge; + }) + .filter((edge) => String(edge.source) !== String(groupId) && String(edge.target) !== String(groupId)); + + setNodes(nextNodes); + setEdges(nextEdges); + setTimeout(() => { + reactFlow.getNodes() + .filter((node) => node.data?.className === 'Group') + .forEach((node) => refreshGroupNode(node.id, nextNodes, nextEdges)); + }, 0); + }, [reactFlow, refreshGroupNode, setEdges, setNodes]); + + const createGroupFromSelection = useCallback(() => { + const currentNodes = reactFlow.getNodes(); + const selectedNodes = currentNodes.filter((node) => node.selected && node.data?.className !== 'Group'); + if (selectedNodes.length < 2) return; + + const selectedIds = selectedNodes.map((node) => String(node.id)); + const bounds = getGroupDisplayBounds(currentNodes, selectedIds); + if (!bounds) return; + + const groupId = String(nextIdRef.current++); + const groupPosition = { + x: bounds.minX - GROUP_PADDING_X, + y: bounds.minY - (GROUP_HEADER_HEIGHT + GROUP_PADDING_Y), + }; + const groupWidth = Math.max( + GROUP_MIN_WIDTH, + Math.round(bounds.maxX - bounds.minX + GROUP_PADDING_X * 2), + ); + const groupHeight = Math.max( + GROUP_MIN_HEIGHT, + Math.round(bounds.maxY - bounds.minY + GROUP_HEADER_HEIGHT + GROUP_PADDING_Y * 2), + ); + + const groupNode = { + id: groupId, + type: 'custom', + className: 'group-shell', + position: groupPosition, + width: groupWidth, + height: groupHeight, + dragHandle: '.drag-handle', + style: { width: groupWidth, height: groupHeight }, + data: { + label: 'group', + className: 'Group', + definition: null, + widgetValues: {}, + runtimeValues: {}, + collapsed: false, + expandedSize: { width: groupWidth, height: groupHeight }, + proxyInputs: [], + proxyOutputs: [], + childCount: selectedNodes.length, + previewImage: null, + tableRows: null, + meshData: null, + overlay: null, + scalarValue: null, + processingTimeMs: null, + warning: null, + }, + selected: true, + }; + + const nodeMap = new Map(currentNodes.map((node) => [String(node.id), node])); + const nextNodes = [ + ...currentNodes.map((node) => { + if (!selectedIds.includes(String(node.id))) { + return { ...node, selected: false }; + } + const absolute = getNodeAbsolutePosition(node, nodeMap); + return { + ...node, + selected: false, + parentId: groupId, + extent: 'parent', + hidden: false, + position: { + x: absolute.x - groupPosition.x, + y: absolute.y - groupPosition.y, + }, + }; + }), + groupNode, + ]; + + const orderedNodes = sortNodesForParentOrder(nextNodes); + setNodes(orderedNodes); + setTimeout(() => refreshGroupNode(groupId, orderedNodes, reactFlow.getEdges()), 0); + }, [reactFlow, refreshGroupNode, setNodes]); + const setNodeOutputs = useCallback((nodeId, output, outputName, extraDefinitionPatch = {}) => { setNodes((prev) => prev.map((node) => { if (node.id !== nodeId) return node; @@ -516,9 +1149,12 @@ function Flow() { (e) => e.target === nodeId && getInputName(e.targetHandle) === 'path' ); if (!edge) return null; - const sourceNode = reactFlow.getNode(edge.source); + const original = edge.data?.groupProxyOriginal || {}; + const sourceId = original.source || edge.source; + const sourceHandle = original.sourceHandle || edge.sourceHandle; + const sourceNode = reactFlow.getNode(sourceId); const outputPaths = sourceNode?.data?.definition?.output_paths; - const outputSlot = getOutputSlot(edge.sourceHandle); + const outputSlot = getOutputSlot(sourceHandle); if (Array.isArray(outputPaths) && typeof outputPaths[outputSlot] === 'string') { return outputPaths[outputSlot]; } @@ -653,38 +1289,66 @@ function Flow() { // ── Connection handling ───────────────────────────────────────────── const isValidConnection = useCallback((connection) => { - const srcType = getHandleType(connection.sourceHandle); - const tgtType = getHandleType(connection.targetHandle); + const srcType = getConnectionHandleType(connection.sourceHandle); + const tgtType = getConnectionHandleType(connection.targetHandle); return socketTypesCompatible(srcType, tgtType); }, []); const onConnect = useCallback((params) => { - const type = getHandleType(params.sourceHandle); + const sourceProxy = parseGroupProxyHandle(params.sourceHandle); + const targetProxy = parseGroupProxyHandle(params.targetHandle); + const type = getConnectionHandleType(params.sourceHandle); const color = TYPE_COLORS[type] || 'var(--fallback-type)'; + const edgePayload = { + ...params, + style: { stroke: color, strokeWidth: 2 }, + }; + const proxyOriginal = {}; + if (sourceProxy) { + proxyOriginal.source = sourceProxy.nodeId; + proxyOriginal.sourceHandle = sourceProxy.realHandle; + } + if (targetProxy) { + proxyOriginal.target = targetProxy.nodeId; + proxyOriginal.targetHandle = targetProxy.realHandle; + } + if (Object.keys(proxyOriginal).length > 0) { + edgePayload.data = { + ...(edgePayload.data || {}), + groupProxyOwner: sourceProxy?.direction === 'out' ? params.source : params.target, + groupProxyOriginal: proxyOriginal, + }; + } + setEdges((eds) => { // Enforce single connection per input handle const filtered = eds.filter( (e) => !(e.target === params.target && e.targetHandle === params.targetHandle) ); - return addEdge( - { ...params, style: { stroke: color, strokeWidth: 2 } }, - filtered - ); + return addEdge(edgePayload, filtered); }); - if (getInputName(params.targetHandle) === 'path') { + const effectiveTargetHandle = targetProxy?.realHandle || params.targetHandle; + const effectiveTargetNode = targetProxy?.nodeId || params.target; + if (getInputName(effectiveTargetHandle) === 'path') { setTimeout(() => { - refreshLoadNodeOutputs(params.target); + refreshLoadNodeOutputs(effectiveTargetNode); }, 0); } - const targetNode = reactFlow.getNode(params.target); + const targetNode = reactFlow.getNode(effectiveTargetNode); if (targetNode && (targetNode.data.className === 'Annotations' || targetNode.data.className === 'Markup')) { setTimeout(() => { - refreshAnnotationNodeOutputs(params.target); + refreshAnnotationNodeOutputs(effectiveTargetNode); }, 0); } + if (sourceProxy) { + setTimeout(() => refreshGroupNode(params.source), 0); + } + if (targetProxy) { + setTimeout(() => refreshGroupNode(params.target), 0); + } scheduleAutoRun(); - }, [reactFlow, refreshAnnotationNodeOutputs, refreshLoadNodeOutputs, setEdges]); // scheduleAutoRun is stable (no deps) + }, [reactFlow, refreshAnnotationNodeOutputs, refreshGroupNode, refreshLoadNodeOutputs, setEdges]); // scheduleAutoRun is stable (no deps) const handleEdgesChange = useCallback((changes) => { const currentEdges = reactFlow.getEdges(); @@ -721,7 +1385,68 @@ function Flow() { }); }, 0); } - }, [onEdgesChange, reactFlow, refreshAnnotationNodeOutputs, refreshLoadNodeOutputs]); + setTimeout(() => { + reactFlow.getNodes() + .filter((node) => node.data?.className === 'Group') + .forEach((node) => refreshGroupNode(node.id)); + }, 0); + }, [onEdgesChange, reactFlow, refreshAnnotationNodeOutputs, refreshGroupNode, refreshLoadNodeOutputs]); + + const handleNodesChange = useCallback((changes) => { + const currentNodes = reactFlow.getNodes(); + const selectedGroupIds = new Set( + changes + .filter((change) => change.type === 'select' && change.selected) + .map((change) => String(change.id)) + .filter((id) => currentNodes.some((node) => String(node.id) === id && node.data?.className === 'Group')), + ); + const removedIds = new Set( + changes + .filter((change) => change.type === 'remove') + .map((change) => String(change.id)), + ); + + onNodesChange(changes); + + if (selectedGroupIds.size > 0) { + const deselectedDescendantIds = new Set(); + selectedGroupIds.forEach((groupId) => { + collectGroupDescendantIds(currentNodes, groupId).forEach((id) => deselectedDescendantIds.add(id)); + }); + + if (deselectedDescendantIds.size > 0) { + setNodes((existing) => existing.map((node) => ( + deselectedDescendantIds.has(String(node.id)) + ? { ...node, selected: false } + : node + ))); + } + } + + if (removedIds.size === 0) return; + + const groupIds = currentNodes + .filter((node) => removedIds.has(String(node.id)) && node.data?.className === 'Group') + .map((node) => String(node.id)); + const removedWithDescendants = new Set(removedIds); + for (const groupId of groupIds) { + collectGroupDescendantIds(currentNodes, groupId).forEach((id) => removedWithDescendants.add(id)); + } + + if (groupIds.length > 0) { + setNodes((existing) => existing.filter((node) => !removedWithDescendants.has(String(node.id)))); + setEdges((existing) => existing.filter((edge) => ( + !removedWithDescendants.has(String(edge.source)) + && !removedWithDescendants.has(String(edge.target)) + ))); + } + + setTimeout(() => { + reactFlow.getNodes() + .filter((node) => node.data?.className === 'Group') + .forEach((node) => refreshGroupNode(node.id)); + }, 0); + }, [onNodesChange, reactFlow, refreshGroupNode, setEdges, setNodes]); // ── Drop-on-blank: open filtered context menu ────────────────────── @@ -733,7 +1458,7 @@ function Flow() { if (!fromHandle || !fromHandle.id) return; const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event; - const handleType = getHandleType(fromHandle.id); + const handleType = getConnectionHandleType(fromHandle.id); setContextMenu({ x: clientX, @@ -776,22 +1501,68 @@ function Flow() { // ── File browser ──────────────────────────────────────────────────── - const openFileBrowser = useCallback((callback, { selectionMode = 'file' } = {}) => { + const uploadBrowserSelection = useCallback(async (selection, selectionMode) => { + if (!selection) return null; + + if (selectionMode === 'folder') { + const rootName = String(selection.rootName || '').trim(); + if (!rootName) { + throw new Error('Selected folder is empty or could not be read.'); + } + + setStatus({ + text: `Importing folder "${rootName}" into this session…`, + level: 'info', + }); + + const folder = await api.createUploadFolder(rootName); + for (const entry of selection.entries || []) { + await api.uploadFile(entry.file, { relativePath: entry.relativePath }); + } + return folder.path; + } + + const [entry] = selection.entries || []; + if (!entry) return null; + + setStatus({ + text: `Uploading ${entry.file.name}…`, + level: 'info', + }); + + const uploaded = await api.uploadFile(entry.file, { relativePath: entry.relativePath }); + return uploaded.path; + }, []); + + const openFileBrowser = useCallback(async (callback, { selectionMode = 'file' } = {}) => { if (selectionMode === 'folder' && window.pywebview?.api?.open_folder_dialog) { window.pywebview.api.open_folder_dialog().then((path) => { if (path) callback(path); }); return; } - // Use native file picker when running inside pywebview (desktop app) if (selectionMode === 'file' && window.pywebview?.api?.open_file_dialog) { window.pywebview.api.open_file_dialog().then((path) => { if (path) callback(path); }); return; } - setFileBrowserState({ callback, selectionMode }); - }, []); + + try { + const selection = selectionMode === 'folder' + ? await pickNativeDirectorySelection() + : await pickNativeFileSelection(); + if (!selection) return; + + const uploadedPath = await uploadBrowserSelection(selection, selectionMode); + if (uploadedPath) callback(uploadedPath); + } catch (error) { + setStatus({ + text: `Browse failed: ${error.message || String(error)}`, + level: 'error', + }); + } + }, [uploadBrowserSelection]); // ── Node context value (stable) ───────────────────────────────────── @@ -1028,10 +1799,10 @@ function Flow() { nextIdRef.current = pasted.nextNodeId; - setNodes((existing) => [ + setNodes((existing) => sortNodesForParentOrder([ ...existing.map((node) => ({ ...node, selected: false })), ...pasted.nodes, - ]); + ])); setEdges((existing) => [ ...existing.map((edge) => ({ ...edge, selected: false })), ...pasted.edges, @@ -1053,12 +1824,55 @@ function Flow() { setNodes, ]); + const resizeGroup = useCallback((groupId, size) => { + const nextWidth = Math.round(Number(size?.width) || 0); + const nextHeight = Math.round(Number(size?.height) || 0); + if (!nextWidth || !nextHeight) return; + + setNodes((existing) => existing.map((node) => { + if (String(node.id) !== String(groupId) || node.data?.className !== 'Group') return node; + + const sameSize = Math.abs((Number(node.style?.width) || 0) - nextWidth) < 0.5 + && Math.abs((Number(node.style?.height) || 0) - nextHeight) < 0.5; + if (sameSize) return node; + + return { + ...applyNodeSize(node, nextWidth, nextHeight), + data: { + ...node.data, + expandedSize: { width: nextWidth, height: nextHeight }, + }, + }; + })); + + setTimeout(() => reactFlow.updateNodeInternals(String(groupId)), 0); + }, [reactFlow, setNodes]); + + const renameGroup = useCallback((groupId, label) => { + const nextLabel = String(label || '').trim() || 'group'; + setNodes((existing) => existing.map((node) => { + if (String(node.id) !== String(groupId) || node.data?.className !== 'Group') return node; + if (String(node.data?.label || 'group') === nextLabel) return node; + return { + ...node, + data: { + ...node.data, + label: nextLabel, + }, + }; + })); + }, [setNodes]); + const contextValue = useMemo(() => ({ onWidgetChange, onRuntimeValuesChange, openFileBrowser, onManualTrigger, - }), [onRuntimeValuesChange, onWidgetChange, openFileBrowser, onManualTrigger]); + onToggleGroupCollapse: toggleGroupCollapse, + onResizeGroup: resizeGroup, + onRenameGroup: renameGroup, + onUngroup: ungroupGroup, + }), [onRuntimeValuesChange, onWidgetChange, openFileBrowser, onManualTrigger, renameGroup, resizeGroup, toggleGroupCollapse, ungroupGroup]); const clearGraph = useCallback(() => { setNodes([]); @@ -1069,7 +1883,7 @@ function Flow() { const applyWorkflowData = useCallback((data) => { const hydrated = hydrateWorkflowState(data, nodeDefsRef.current); - setNodes(hydrated.nodes); + setNodes(sortNodesForParentOrder(hydrated.nodes)); setEdges(hydrated.edges); nextIdRef.current = hydrated.nextNodeId; initializeDynamicNodes(hydrated.nodes); @@ -1298,8 +2112,51 @@ function Flow() { }, []); const onNodeDragStart = useCallback((event, node) => { + activeDragNodeIdRef.current = String(node.id); + dragStateRef.current = null; + if (!(event.ctrlKey || event.metaKey)) { duplicateDragRef.current = null; + const currentNodes = reactFlow.getNodes(); + const draggedNodes = node.data?.className === 'Group' + ? [] + : ( + node.selected + ? currentNodes.filter((candidate) => candidate.selected && candidate.data?.className !== 'Group') + : currentNodes.filter((candidate) => candidate.id === node.id) + ); + const pointerFlowPos = getEventFlowPosition(event, reactFlow); + if (draggedNodes.length > 0 && pointerFlowPos) { + const nodeMap = new Map(currentNodes.map((candidate) => [String(candidate.id), candidate])); + const absolutePositions = Object.fromEntries( + draggedNodes.map((candidate) => [ + String(candidate.id), + getNodeAbsolutePosition(candidate, nodeMap), + ]), + ); + const anchorAbsolute = absolutePositions[String(node.id)] || getNodeAbsolutePosition(node, nodeMap); + dragStateRef.current = { + anchorId: String(node.id), + anchorStartAbsolute: anchorAbsolute, + absolutePositions, + releasedNodeIds: new Set(), + touchedGroupIds: new Set(), + pointerOffset: { + x: pointerFlowPos.x - anchorAbsolute.x, + y: pointerFlowPos.y - anchorAbsolute.y, + }, + }; + } + if (node.data?.className === 'Group') { + const descendantIds = collectGroupDescendantIds(currentNodes, node.id); + if (descendantIds.size > 0) { + setNodes((existing) => existing.map((candidate) => ( + descendantIds.has(String(candidate.id)) + ? { ...candidate, selected: false } + : candidate + ))); + } + } return; } @@ -1343,15 +2200,16 @@ function Flow() { ); duplicateDragRef.current = { + anchorId: String(node.id), draggedIds, originPositions, duplicateSourceById, }; - setNodes((existing) => [ + setNodes((existing) => sortNodesForParentOrder([ ...existing.map((candidate) => ({ ...candidate, selected: false })), ...duplicated.nodes, - ]); + ])); setEdges((existing) => [ ...existing.map((edge) => ({ ...edge, selected: false })), ...duplicated.edges, @@ -1360,105 +2218,349 @@ function Flow() { initializeDynamicNodes(duplicated.nodes); }, [initializeDynamicNodes, reactFlow, setEdges, setNodes]); - const onNodeDrag = useCallback((_event, node) => { + const onNodeDrag = useCallback((event, node) => { + if (String(node.id) !== activeDragNodeIdRef.current) return; + const duplicateState = duplicateDragRef.current; - if (!duplicateState) return; + if (duplicateState) { + const anchorId = duplicateState.anchorId || duplicateState.draggedIds[0]; + const anchorOrigin = duplicateState.originPositions[anchorId]; + if (!anchorOrigin) return; - const anchorId = duplicateState.draggedIds.includes(String(node.id)) - ? String(node.id) - : duplicateState.draggedIds[0]; - const anchorOrigin = duplicateState.originPositions[anchorId]; - if (!anchorOrigin) return; + const offset = { + x: (Number(node.position?.x) || 0) - anchorOrigin.x, + y: (Number(node.position?.y) || 0) - anchorOrigin.y, + }; + const draggedIdSet = new Set(duplicateState.draggedIds); - const offset = { - x: (Number(node.position?.x) || 0) - anchorOrigin.x, - y: (Number(node.position?.y) || 0) - anchorOrigin.y, - }; - const draggedIdSet = new Set(duplicateState.draggedIds); + setNodes((existing) => existing.map((candidate) => { + const candidateId = String(candidate.id); + const originalPosition = duplicateState.originPositions[candidateId]; + if (draggedIdSet.has(candidateId) && originalPosition) { + return { + ...candidate, + selected: false, + position: originalPosition, + }; + } - setNodes((existing) => existing.map((candidate) => { + const sourceId = duplicateState.duplicateSourceById[candidateId]; + if (sourceId) { + const sourceOrigin = duplicateState.originPositions[sourceId]; + if (!sourceOrigin) return candidate; + return { + ...candidate, + selected: true, + position: { + x: sourceOrigin.x + offset.x, + y: sourceOrigin.y + offset.y, + }, + }; + } + + return candidate; + })); + return; + } + + const dragState = dragStateRef.current; + if (!dragState || node.data?.className === 'Group') return; + + const currentNodes = reactFlow.getNodes(); + const draggedNodes = node.selected + ? currentNodes.filter((candidate) => candidate.selected && candidate.data?.className !== 'Group') + : currentNodes.filter((candidate) => candidate.id === node.id); + if (draggedNodes.length === 0) return; + + const dragIntent = getDragIntent(event, reactFlow, dragState); + if (!dragIntent?.pointerFlowPos) return; + + const draggedIdSet = new Set(draggedNodes.map((candidate) => String(candidate.id))); + const nodeMap = new Map(currentNodes.map((candidate) => [String(candidate.id), candidate])); + const releasedNodeIds = dragState.releasedNodeIds instanceof Set + ? new Set(dragState.releasedNodeIds) + : new Set(); + const touchedGroupIds = dragState.touchedGroupIds instanceof Set + ? new Set(dragState.touchedGroupIds) + : new Set(); + + let nextNodes = currentNodes; + let changed = false; + let structureChanged = false; + + nextNodes = nextNodes.map((candidate) => { const candidateId = String(candidate.id); - const originalPosition = duplicateState.originPositions[candidateId]; - if (draggedIdSet.has(candidateId) && originalPosition) { - return { - ...candidate, - selected: false, - position: originalPosition, - }; - } - - const sourceId = duplicateState.duplicateSourceById[candidateId]; - if (sourceId) { - const sourceOrigin = duplicateState.originPositions[sourceId]; - if (!sourceOrigin) return candidate; - return { - ...candidate, - selected: true, - position: { - x: sourceOrigin.x + offset.x, - y: sourceOrigin.y + offset.y, - }, - }; + if (!draggedIdSet.has(candidateId)) return candidate; + + const absolute = dragIntent.absolutePositions.get(candidateId) + || getNodeAbsolutePosition(candidate, nodeMap); + if (!absolute) return candidate; + + if (candidate.parentId) { + const parentId = String(candidate.parentId); + const parentNode = nodeMap.get(parentId); + if (parentNode?.data?.className === 'Group') { + const parentRect = getGroupWorkspaceBounds(parentNode, nodeMap); + const parentAbsolute = getNodeAbsolutePosition(parentNode, nodeMap); + const nextPosition = { + x: absolute.x - parentAbsolute.x, + y: absolute.y - parentAbsolute.y, + }; + const candidateRect = getAbsoluteRectForNodePosition(candidate, absolute); + const samePosition = Math.abs((Number(candidate.position?.x) || 0) - nextPosition.x) < 0.5 + && Math.abs((Number(candidate.position?.y) || 0) - nextPosition.y) < 0.5; + + if (!releasedNodeIds.has(candidateId) && !rectContainsRect(parentRect, candidateRect)) { + releasedNodeIds.add(candidateId); + changed = true; + return { + ...candidate, + extent: undefined, + hidden: false, + position: nextPosition, + }; + } + + if (releasedNodeIds.has(candidateId)) { + if (!candidate.parentId && !candidate.extent && candidate.hidden !== true && samePosition) { + return candidate; + } + + changed = true; + return { + ...candidate, + extent: undefined, + hidden: false, + position: nextPosition, + }; + } + } } + if (!releasedNodeIds.has(candidateId)) return candidate; return candidate; - })); - }, [setNodes]); + }); - const onNodeDragStop = useCallback((_event, node) => { + if (!changed) return; + + dragStateRef.current = { + ...dragState, + releasedNodeIds, + touchedGroupIds, + }; + + setNodes(structureChanged ? sortNodesForParentOrder(nextNodes) : nextNodes); + + if (structureChanged) { + setTimeout(() => { + touchedGroupIds.forEach((groupId) => { + if (groupId) refreshGroupNode(groupId, nextNodes, reactFlow.getEdges()); + }); + }, 0); + } + }, [reactFlow, refreshGroupNode, setNodes]); + + const onNodeDragStop = useCallback((event, node) => { + if (String(node.id) !== activeDragNodeIdRef.current) return; + activeDragNodeIdRef.current = null; + + const dragState = dragStateRef.current; + dragStateRef.current = null; const duplicateState = duplicateDragRef.current; duplicateDragRef.current = null; - if (!duplicateState) return; + if (duplicateState) { + const anchorId = duplicateState.anchorId || duplicateState.draggedIds[0]; + const anchorOrigin = duplicateState.originPositions[anchorId]; + if (!anchorOrigin) return; - const anchorId = duplicateState.draggedIds.includes(String(node.id)) - ? String(node.id) - : duplicateState.draggedIds[0]; - const anchorOrigin = duplicateState.originPositions[anchorId]; - if (!anchorOrigin) return; + const offset = { + x: (Number(node.position?.x) || 0) - anchorOrigin.x, + y: (Number(node.position?.y) || 0) - anchorOrigin.y, + }; + const draggedIdSet = new Set(duplicateState.draggedIds); - const offset = { - x: (Number(node.position?.x) || 0) - anchorOrigin.x, - y: (Number(node.position?.y) || 0) - anchorOrigin.y, - }; - const draggedIdSet = new Set(duplicateState.draggedIds); + setNodes((existing) => existing.map((candidate) => { + const candidateId = String(candidate.id); + const originalPosition = duplicateState.originPositions[candidateId]; + if (draggedIdSet.has(candidateId) && originalPosition) { + return { + ...candidate, + selected: false, + position: originalPosition, + }; + } + + const sourceId = duplicateState.duplicateSourceById[candidateId]; + if (sourceId) { + const sourceOrigin = duplicateState.originPositions[sourceId]; + if (!sourceOrigin) return candidate; + return { + ...candidate, + selected: true, + position: { + x: sourceOrigin.x + offset.x, + y: sourceOrigin.y + offset.y, + }, + }; + } - setNodes((existing) => existing.map((candidate) => { - const candidateId = String(candidate.id); - const originalPosition = duplicateState.originPositions[candidateId]; - if (draggedIdSet.has(candidateId) && originalPosition) { return { ...candidate, selected: false, - position: originalPosition, }; + })); + + setStatus({ + text: `Duplicated ${Object.keys(duplicateState.duplicateSourceById).length} node${Object.keys(duplicateState.duplicateSourceById).length === 1 ? '' : 's'}.`, + level: 'info', + }); + scheduleAutoRun(); + return; + } + + const currentNodes = reactFlow.getNodes(); + const dragIntent = getDragIntent(event, reactFlow, dragState); + const touchedGroupIds = dragState?.touchedGroupIds instanceof Set + ? new Set(dragState.touchedGroupIds) + : new Set(); + let nextNodes = currentNodes; + let changed = false; + + const draggedNodes = node.data?.className === 'Group' + ? [] + : ( + node.selected + ? nextNodes.filter((candidate) => candidate.selected && candidate.data?.className !== 'Group') + : nextNodes.filter((candidate) => candidate.id === node.id) + ); + + if (draggedNodes.length > 0) { + const draggedIdSet = new Set(draggedNodes.map((candidate) => String(candidate.id))); + const nodeMap = new Map(nextNodes.map((candidate) => [String(candidate.id), candidate])); + const anchorNode = nodeMap.get(String(dragState?.anchorId || node.id)); + const intendedAnchorAbsolute = dragIntent?.absolutePositions.get(String(anchorNode?.id || node.id)) + || (anchorNode ? getNodeAbsolutePosition(anchorNode, nodeMap) : null); + const intendedAnchorCenter = anchorNode && intendedAnchorAbsolute + ? { + x: intendedAnchorAbsolute.x + (Number(getNodeDimension(anchorNode, 'width')) || 200) / 2, + y: intendedAnchorAbsolute.y + (Number(getNodeDimension(anchorNode, 'height')) || 120) / 2, + } + : null; + const targetGroup = findExpandedGroupDropTarget( + nextNodes, + Array.from(draggedIdSet), + node.id, + intendedAnchorCenter, + ); + if (targetGroup) { + const targetRect = getGroupWorkspaceBounds(targetGroup, nodeMap); + const targetAbs = getNodeAbsolutePosition(targetGroup, nodeMap); + let joinedCount = 0; + + nextNodes = nextNodes.map((candidate) => { + if (!draggedIdSet.has(String(candidate.id))) return candidate; + + const intendedAbsolute = dragIntent?.absolutePositions.get(String(candidate.id)); + const width = Number(getNodeDimension(candidate, 'width')) || 200; + const height = Number(getNodeDimension(candidate, 'height')) || 120; + const center = intendedAbsolute + ? { x: intendedAbsolute.x + width / 2, y: intendedAbsolute.y + height / 2 } + : getNodeCenter(candidate, nodeMap); + if (!rectContainsPoint(targetRect, center)) return candidate; + + const absolute = intendedAbsolute || getNodeAbsolutePosition(candidate, nodeMap); + const nextPosition = { + x: absolute.x - targetAbs.x, + y: absolute.y - targetAbs.y, + }; + const alreadyInTarget = String(candidate.parentId || '') === String(targetGroup.id); + const samePosition = Math.abs((Number(candidate.position?.x) || 0) - nextPosition.x) < 0.5 + && Math.abs((Number(candidate.position?.y) || 0) - nextPosition.y) < 0.5; + if (alreadyInTarget && candidate.extent === 'parent' && samePosition) return candidate; + + if (candidate.parentId) { + touchedGroupIds.add(String(candidate.parentId)); + } + touchedGroupIds.add(String(targetGroup.id)); + joinedCount += 1; + changed = true; + return { + ...candidate, + parentId: String(targetGroup.id), + extent: 'parent', + hidden: false, + position: nextPosition, + }; + }); + + if (joinedCount > 0) { + setStatus({ + text: `Added ${joinedCount} node${joinedCount === 1 ? '' : 's'} to group.`, + level: 'info', + }); + } + } else { + let removedCount = 0; + + nextNodes = nextNodes.map((candidate) => { + if (!draggedIdSet.has(String(candidate.id)) || !candidate.parentId) return candidate; + + const parentId = String(candidate.parentId); + const parentNode = nodeMap.get(parentId); + if (!parentNode || parentNode.data?.className !== 'Group') return candidate; + const absolute = dragIntent?.absolutePositions.get(String(candidate.id)) + || getNodeAbsolutePosition(candidate, nodeMap); + const parentWorkspaceRect = getGroupWorkspaceBounds(parentNode, nodeMap); + const candidateRect = getAbsoluteRectForNodePosition(candidate, absolute); + if (rectContainsRect(parentWorkspaceRect, candidateRect)) { + if (candidate.extent === 'parent') return candidate; + changed = true; + return { + ...candidate, + extent: 'parent', + hidden: false, + }; + } + + touchedGroupIds.add(parentId); + removedCount += 1; + changed = true; + return { + ...candidate, + parentId: undefined, + extent: undefined, + hidden: false, + position: absolute, + }; + }); + + if (removedCount > 0) { + setStatus({ + text: `Removed ${removedCount} node${removedCount === 1 ? '' : 's'} from group.`, + level: 'info', + }); + } } + } - const sourceId = duplicateState.duplicateSourceById[candidateId]; - if (sourceId) { - const sourceOrigin = duplicateState.originPositions[sourceId]; - if (!sourceOrigin) return candidate; - return { - ...candidate, - selected: true, - position: { - x: sourceOrigin.x + offset.x, - y: sourceOrigin.y + offset.y, - }, - }; + if (!changed) { + const releasedCount = dragState?.releasedNodeIds instanceof Set ? dragState.releasedNodeIds.size : 0; + if (releasedCount > 0) { + setStatus({ + text: `Removed ${releasedCount} node${releasedCount === 1 ? '' : 's'} from group.`, + level: 'info', + }); } + return; + } - return { - ...candidate, - selected: false, - }; - })); - - setStatus({ - text: `Duplicated ${Object.keys(duplicateState.duplicateSourceById).length} node${Object.keys(duplicateState.duplicateSourceById).length === 1 ? '' : 's'}.`, - level: 'info', - }); - scheduleAutoRun(); - }, [scheduleAutoRun, setNodes]); + setNodes(sortNodesForParentOrder(nextNodes)); + setTimeout(() => { + touchedGroupIds.forEach((groupId) => { + if (groupId) refreshGroupNode(groupId, nextNodes, reactFlow.getEdges()); + }); + }, 0); + }, [reactFlow, refreshGroupNode, scheduleAutoRun, setNodes]); // ── Keyboard shortcut ─────────────────────────────────────────────── @@ -1516,9 +2618,106 @@ function Flow() { const onPaneContextMenu = useCallback((event) => { event.preventDefault(); + if (performance.now() < suppressPaneContextMenuUntilRef.current) { + suppressPaneContextMenuUntilRef.current = 0; + return; + } setContextMenu({ x: event.clientX, y: event.clientY }); }, []); + const onFlowContainerPointerDown = useCallback((event) => { + if (event.button !== 2) return; + if (!canStartCanvasRightDragZoom(event.target)) return; + + event.preventDefault(); + event.stopPropagation(); + setContextMenu(null); + + const viewport = reactFlow.getViewport(); + canvasRightZoomRef.current = { + pointerId: event.pointerId, + startY: event.clientY, + startZoom: Number(viewport.zoom) || 1, + moved: false, + }; + setIsCanvasRightZooming(true); + + try { + event.currentTarget.setPointerCapture?.(event.pointerId); + } catch { + // Ignore capture failures; global listeners still complete the interaction. + } + }, [reactFlow]); + + const onFlowContainerContextMenuCapture = useCallback((event) => { + if (canvasRightZoomRef.current?.moved || performance.now() < suppressPaneContextMenuUntilRef.current) { + event.preventDefault(); + event.stopPropagation(); + } + }, []); + + useEffect(() => { + const handlePointerMove = (event) => { + const zoomState = canvasRightZoomRef.current; + if (!zoomState || event.pointerId !== zoomState.pointerId) return; + + const deltaY = event.clientY - zoomState.startY; + if (Math.abs(deltaY) < CANVAS_RIGHT_DRAG_ZOOM_THRESHOLD) return; + + event.preventDefault(); + zoomState.moved = true; + + const container = flowContainerRef.current; + if (!container) return; + const bounds = container.getBoundingClientRect(); + const localX = event.clientX - bounds.left; + const localY = event.clientY - bounds.top; + const currentViewport = reactFlow.getViewport(); + const flowX = (localX - currentViewport.x) / currentViewport.zoom; + const flowY = (localY - currentViewport.y) / currentViewport.zoom; + const nextZoom = clampNumber( + zoomState.startZoom * Math.exp(-deltaY * CANVAS_RIGHT_DRAG_ZOOM_SENSITIVITY), + CANVAS_MIN_ZOOM, + CANVAS_MAX_ZOOM, + ); + + reactFlow.setViewport({ + x: localX - (flowX * nextZoom), + y: localY - (flowY * nextZoom), + zoom: nextZoom, + }, { duration: 0 }); + }; + + const finishPointerInteraction = (event) => { + const zoomState = canvasRightZoomRef.current; + if (!zoomState || event.pointerId !== zoomState.pointerId) return; + + if (zoomState.moved) { + suppressPaneContextMenuUntilRef.current = performance.now() + 250; + } + canvasRightZoomRef.current = null; + setIsCanvasRightZooming(false); + + const container = flowContainerRef.current; + if (container?.hasPointerCapture?.(event.pointerId)) { + try { + container.releasePointerCapture(event.pointerId); + } catch { + // Ignore capture release errors. + } + } + }; + + window.addEventListener('pointermove', handlePointerMove, true); + window.addEventListener('pointerup', finishPointerInteraction, true); + window.addEventListener('pointercancel', finishPointerInteraction, true); + return () => { + window.removeEventListener('pointermove', handlePointerMove, true); + window.removeEventListener('pointerup', finishPointerInteraction, true); + window.removeEventListener('pointercancel', finishPointerInteraction, true); + }; + }, [reactFlow]); + useEffect(() => { if (!contextMenu) return undefined; @@ -1531,6 +2730,8 @@ function Flow() { return () => window.removeEventListener('pointerdown', handlePointerDown, true); }, [contextMenu]); + const selectedNodeCount = nodes.filter((node) => node.selected).length; + // ── Render ────────────────────────────────────────────────────────── return ( @@ -1565,11 +2766,18 @@ function Flow() {
{/* React Flow canvas */} -
+
setContextMenu(null)} filterType={contextMenu.filterType} filterDirection={contextMenu.filterDirection} + selectedNodeCount={selectedNodeCount} /> )}
- {/* File browser modal */} - {fileBrowserState && ( - { fileBrowserState.callback(path); setFileBrowserState(null); }} - onClose={() => setFileBrowserState(null)} - /> - )}
); diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index d4fb0a0..df8ea01 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -1,5 +1,5 @@ import React, { useContext, useRef, useCallback, useState, useEffect, memo, lazy, Suspense } from 'react'; -import { Handle, Position, useStore } from '@xyflow/react'; +import { Handle, NodeResizeControl, Position, useStore } from '@xyflow/react'; import LinePlotOverlay from './LinePlotOverlay'; const SurfaceView = lazy(() => import('./SurfaceView')); @@ -11,6 +11,7 @@ const MarkupOverlay = lazy(() => import('./MarkupOverlay')); import { DATA_TYPES, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS, } from './constants'; +import { getGroupMinimumSize } from './groupSizing.js'; // ── Context (provided by App) ───────────────────────────────────────── @@ -24,6 +25,198 @@ function formatUiLabel(text) { .toLowerCase(); } +function parseProxyHandle(handleId) { + const text = String(handleId || ''); + if (!text.startsWith('group-proxy::')) return null; + const parts = text.split('::'); + if (parts.length < 5) return null; + return { + direction: parts[1], + nodeId: parts[2], + type: parts[3], + realHandle: decodeURIComponent(parts.slice(4).join('::')), + }; +} + +function GroupNode({ id, data }) { + const ctx = useContext(NodeContext); + const proxyInputs = Array.isArray(data.proxyInputs) ? data.proxyInputs : []; + const proxyOutputs = Array.isArray(data.proxyOutputs) ? data.proxyOutputs : []; + const childCount = Number(data.childCount) || 0; + const collapsed = !!data.collapsed; + const maxRows = Math.max(proxyInputs.length, proxyOutputs.length, collapsed ? 1 : 0); + const [isEditingLabel, setIsEditingLabel] = useState(false); + const [draftLabel, setDraftLabel] = useState(String(data.label || 'group')); + const labelInputRef = useRef(null); + const selected = useStore( + useCallback( + (s) => { + const node = s.nodeLookup?.get(id) || s.nodes?.find((candidate) => candidate.id === id); + return !!node?.selected; + }, + [id], + ), + ); + const groupMinSize = useStore( + useCallback( + (s) => getGroupMinimumSize( + (s.nodes || []).filter((candidate) => String(candidate.parentId || '') === String(id)), + ), + [id], + ), + ); + const displayLabel = String(data.label || 'group'); + const labelFieldSize = Math.max(2, Math.min(40, String(draftLabel || displayLabel || 'group').length)); + + useEffect(() => { + if (!isEditingLabel) { + setDraftLabel(displayLabel); + } + }, [displayLabel, isEditingLabel]); + + useEffect(() => { + if (!isEditingLabel) return; + labelInputRef.current?.focus(); + labelInputRef.current?.select(); + }, [isEditingLabel]); + + const commitLabel = useCallback(() => { + const nextLabel = String(draftLabel || '').trim() || 'group'; + setIsEditingLabel(false); + setDraftLabel(nextLabel); + if (nextLabel !== displayLabel) { + ctx.onRenameGroup?.(id, nextLabel); + } + }, [ctx, displayLabel, draftLabel, id]); + + const cancelLabelEdit = useCallback(() => { + setDraftLabel(displayLabel); + setIsEditingLabel(false); + }, [displayLabel]); + + return ( + <> + {!collapsed && selected && ( + ctx.onResizeGroup?.(id, params)} + /> + )} +
+
+ +
+ {isEditingLabel ? ( + setDraftLabel(event.target.value)} + onBlur={commitLabel} + onClick={(event) => event.stopPropagation()} + onPointerDown={(event) => event.stopPropagation()} + onKeyDown={(event) => { + if (event.key === 'Enter') { + event.preventDefault(); + commitLabel(); + } else if (event.key === 'Escape') { + event.preventDefault(); + cancelLabelEdit(); + } + }} + /> + ) : ( + + )} +
+
+ +
+
+ +
+ {collapsed ? ( + <> + {Array.from({ length: maxRows }, (_, index) => { + const input = proxyInputs[index]; + const output = proxyOutputs[index]; + return ( +
+
+ {input && ( + <> + + {formatUiLabel(input.label || input.name)} + + )} +
+
+ {output && ( + <> + {formatUiLabel(output.label || output.name)} + + + )} +
+
+ ); + })} +
{childCount} nodes
+ + ) : ( +
+
workflow group
+
{childCount} nodes
+
+ )} +
+
+ + ); +} + class PreviewBoundary extends React.Component { constructor(props) { super(props); @@ -390,6 +583,8 @@ function getSourceTypeForInput(store, nodeId, inputName) { const targetHandle = `input::${inputName}::`; const edge = store.edges?.find((e) => e.target === nodeId && e.targetHandle?.startsWith(targetHandle)); if (!edge?.sourceHandle) return null; + const proxy = parseProxyHandle(edge.sourceHandle); + if (proxy) return proxy.type || null; const parts = edge.sourceHandle.split('::'); return parts[2] || null; } @@ -405,8 +600,11 @@ function getConnectedOutputInfo(store, nodeId, inputName) { const targetHandle = `input::${inputName}::`; const edge = store.edges?.find((e) => e.target === nodeId && e.targetHandle?.startsWith(targetHandle)); if (!edge?.sourceHandle) return null; - const sourceNode = store.nodeLookup?.get(edge.source) || store.nodes?.find((n) => n.id === edge.source) || null; - const slot = Number.parseInt(edge.sourceHandle.split('::')[1], 10); + const proxy = parseProxyHandle(edge.sourceHandle); + const sourceNodeId = proxy?.nodeId || edge.source; + const sourceHandle = proxy?.realHandle || edge.sourceHandle; + const sourceNode = store.nodeLookup?.get(sourceNodeId) || store.nodes?.find((n) => n.id === sourceNodeId) || null; + const slot = Number.parseInt(sourceHandle.split('::')[1], 10); if (!sourceNode || !Number.isInteger(slot)) return null; return { path: sourceNode.data?.definition?.output_paths?.[slot] || null, @@ -751,6 +949,9 @@ function NodeTable({ rows }) { function CustomNode({ id, data }) { const ctx = useContext(NodeContext); + if (data.className === 'Group') { + return ; + } const def = data.definition; const scalarDisplay = formatScalarDisplay(data.scalarValue); const processingTimeText = formatProcessingTime(data.processingTimeMs); diff --git a/frontend/src/FileBrowser.jsx b/frontend/src/FileBrowser.jsx deleted file mode 100644 index 4220862..0000000 --- a/frontend/src/FileBrowser.jsx +++ /dev/null @@ -1,103 +0,0 @@ -import React, { useState, useEffect, useCallback } from 'react'; -import * as api from './api'; - -/** - * Server-side file browser modal. - * - * Props: - * onSelect(absolutePath) — called when user picks a file or folder - * onClose() — called when user dismisses the dialog - */ -export default function FileBrowser({ onSelect, onClose, selectionMode = 'file' }) { - const [path, setPath] = useState(''); - const [parent, setParent] = useState(null); - const [dirs, setDirs] = useState([]); - const [files, setFiles] = useState([]); - const [loading, setLoading] = useState(true); - const [error, setError] = useState(null); - - const navigate = useCallback(async (dir) => { - setLoading(true); - setError(null); - try { - const data = await api.browse(dir); - setPath(data.path); - setParent(data.parent); - setDirs(data.dirs); - setFiles(data.files); - } catch (err) { - setError(err.message); - } finally { - setLoading(false); - } - }, []); - - // Start at home directory on mount - useEffect(() => { - navigate(null); - }, [navigate]); - - return ( -
{ if (e.target === e.currentTarget) onClose(); }}> -
- {/* Header */} -
- {path} - {selectionMode === 'folder' && ( - - )} - -
- - {/* File list */} -
- {loading &&
Loading…
} - {error &&
Error: {error}
} - - {!loading && !error && ( - <> - {/* Parent directory */} - {parent && ( -
navigate(parent)}> - ⬆ .. -
- )} - - {/* Directories */} - {dirs.map((d) => ( -
navigate(path + '/' + d)} - > - 📁 {d} -
- ))} - - {/* Files */} - {files.map((f) => ( -
{ - if (selectionMode === 'folder') return; - onSelect(path + '/' + f); - onClose(); - }} - > - {f} -
- ))} - - {dirs.length === 0 && files.length === 0 && ( -
Empty directory
- )} - - )} -
-
-
- ); -} diff --git a/frontend/src/LinePlotOverlay.jsx b/frontend/src/LinePlotOverlay.jsx index bf0f791..4c239b7 100644 --- a/frontend/src/LinePlotOverlay.jsx +++ b/frontend/src/LinePlotOverlay.jsx @@ -249,7 +249,6 @@ export default function LinePlotOverlay({ <> - { @@ -28,19 +97,27 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal const state = threeRef.current; if (!state || !nodeId || !onRuntimeValuesChange) return; const { renderer, controls } = state; - const azimuth = Number(controls.getAzimuthalAngle().toFixed(4)); - const polar = Number(controls.getPolarAngle().toFixed(4)); - const distance = Number(controls.getDistance().toFixed(4)); + const cameraState = { + azimuth: Number(controls.getAzimuthalAngle().toFixed(4)), + polar: Number(controls.getPolarAngle().toFixed(4)), + distance: Number(controls.getDistance().toFixed(4)), + targetX: Number(controls.target.x.toFixed(4)), + targetY: Number(controls.target.y.toFixed(4)), + targetZ: Number(controls.target.z.toFixed(4)), + }; const snapshot = renderer.domElement.toDataURL('image/png'); - const previous = lastAnglesRef.current; + const previous = lastCameraStateRef.current; const patch = {}; - if (previous.azimuth !== azimuth) patch.camera_azimuth = azimuth; - if (previous.polar !== polar) patch.camera_polar = polar; - if (previous.distance !== distance) patch.camera_distance = distance; + if (previous.azimuth !== cameraState.azimuth) patch.camera_azimuth = cameraState.azimuth; + if (previous.polar !== cameraState.polar) patch.camera_polar = cameraState.polar; + if (previous.distance !== cameraState.distance) patch.camera_distance = cameraState.distance; + if (previous.targetX !== cameraState.targetX) patch.camera_target_x = cameraState.targetX; + if (previous.targetY !== cameraState.targetY) patch.camera_target_y = cameraState.targetY; + if (previous.targetZ !== cameraState.targetZ) patch.camera_target_z = cameraState.targetZ; if (snapshot !== lastSnapshotRef.current) patch.viewport_snapshot = snapshot; if (Object.keys(patch).length > 0) { onRuntimeValuesChange(nodeId, patch, { scheduleRun }); - lastAnglesRef.current = { azimuth, polar, distance }; + lastCameraStateRef.current = cameraState; lastSnapshotRef.current = snapshot; } }, [nodeId, onRuntimeValuesChange]); @@ -55,17 +132,26 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal }, delay); }, [syncViewportState]); - const applyCameraState = useCallback((azimuth, polar, distance) => { + const applyCameraState = useCallback((cameraState = {}) => { const state = threeRef.current; if (!state) return; const { camera, controls } = state; - const target = controls.target.clone(); + const target = new THREE.Vector3( + getFiniteNumber(cameraState.targetX, controls.target.x, DEFAULT_CAMERA_STATE.targetX), + getFiniteNumber(cameraState.targetY, controls.target.y, DEFAULT_CAMERA_STATE.targetY), + getFiniteNumber(cameraState.targetZ, controls.target.z, DEFAULT_CAMERA_STATE.targetZ), + ); const spherical = new THREE.Spherical( - Math.max(0.3, Number.isFinite(distance) ? distance : 1.8), - THREE.MathUtils.clamp(Number.isFinite(polar) ? polar : 1.1, 0.01, Math.PI - 0.01), - Number.isFinite(azimuth) ? azimuth : 0.0, + Math.max(0.3, getFiniteNumber(cameraState.distance, DEFAULT_CAMERA_STATE.distance)), + THREE.MathUtils.clamp( + getFiniteNumber(cameraState.polar, DEFAULT_CAMERA_STATE.polar), + 0.01, + Math.PI - 0.01, + ), + getFiniteNumber(cameraState.azimuth, DEFAULT_CAMERA_STATE.azimuth), ); const offset = new THREE.Vector3().setFromSpherical(spherical); + controls.target.copy(target); camera.position.copy(target).add(offset); controls.update(); }, []); @@ -96,8 +182,26 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal const controls = new OrbitControls(camera, renderer.domElement); controls.enableDamping = true; controls.dampingFactor = 0.1; + controls.enablePan = true; + controls.enableZoom = true; + controls.screenSpacePanning = true; + controls.panSpeed = 1.0; + controls.zoomSpeed = 2.2; controls.minDistance = 0.3; controls.maxDistance = 10; + controls.mouseButtons = { + LEFT: THREE.MOUSE.ROTATE, + MIDDLE: THREE.MOUSE.PAN, + RIGHT: THREE.MOUSE.DOLLY, + }; + controls.touches = { + ONE: THREE.TOUCH.ROTATE, + TWO: THREE.TOUCH.DOLLY_PAN, + }; + if ('zoomToCursor' in controls) { + controls.zoomToCursor = true; + } + renderer.domElement.style.touchAction = 'none'; const handleControlsEnd = () => scheduleViewportSync(0, true); controls.addEventListener('end', handleControlsEnd); @@ -121,11 +225,7 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal animate(); threeRef.current = { renderer, scene, camera, controls, mesh: null, animId }; - applyCameraState( - Number(runtimeValues?.camera_azimuth ?? widgetValues?.camera_azimuth), - Number(runtimeValues?.camera_polar ?? widgetValues?.camera_polar), - Number(runtimeValues?.camera_distance ?? widgetValues?.camera_distance), - ); + applyCameraState(getCameraState(meshData, widgetValues, runtimeValues)); // Resize observer to maintain 1:1 aspect when node width changes const ro = new ResizeObserver((entries) => { @@ -152,16 +252,17 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal } threeRef.current = null; }; - }, [applyCameraState, scheduleViewportSync]); + }, [applyCameraState, meshData, runtimeValues, scheduleViewportSync, widgetValues]); // Update mesh when data changes useEffect(() => { if (!threeRef.current || !meshData) return; - const { scene, camera, controls } = threeRef.current; + const { scene, controls } = threeRef.current; const { width: nx, height: ny, z_data, colors, z_min, z_max, z_scale, - positions, indices, vertex_colors, camera_azimuth, camera_polar, camera_distance, + positions, indices, vertex_colors, + surface_extent_x, surface_extent_y, } = meshData; // Decode arrays @@ -182,14 +283,16 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal const geom = new THREE.BufferGeometry(); const positionsArray = posArr ?? new Float32Array(nx * ny * 3); const colorAttr = new Float32Array((vertexColorArr ? vertexColorArr.length : (nx * ny * 3))); + const surfaceExtentX = getFiniteNumber(surface_extent_x, 1.0); + const surfaceExtentY = getFiniteNumber(surface_extent_y, 1.0); if (!posArr) { const zRange = z_max - z_min || 1; for (let iy = 0; iy < ny; iy++) { for (let ix = 0; ix < nx; ix++) { const idx = iy * nx + ix; - const px = ix / (nx - 1) - 0.5; - const py = iy / (ny - 1) - 0.5; + const px = (ix / Math.max(nx - 1, 1) - 0.5) * surfaceExtentX; + const py = (iy / Math.max(ny - 1, 1) - 0.5) * surfaceExtentY; const pz = ((zArr[idx] - z_min) / zRange - 0.5) * z_scale; positionsArray[idx * 3] = px; @@ -238,21 +341,24 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal scene.add(mesh); threeRef.current.mesh = mesh; - // Reset camera target to center of mesh - controls.target.set(0, 0, 0); - if (!hasSyncedInitialSnapshotRef.current) { - applyCameraState( - Number.isFinite(camera_azimuth) ? camera_azimuth : Number(runtimeValues?.camera_azimuth ?? widgetValues?.camera_azimuth), - Number.isFinite(camera_polar) ? camera_polar : Number(runtimeValues?.camera_polar ?? widgetValues?.camera_polar), - Number.isFinite(camera_distance) ? camera_distance : Number(runtimeValues?.camera_distance ?? widgetValues?.camera_distance), - ); - hasSyncedInitialSnapshotRef.current = true; - } + const bounds = new THREE.Box3().setFromObject(mesh); + const center = bounds.isEmpty() ? new THREE.Vector3() : bounds.getCenter(new THREE.Vector3()); + const size = bounds.isEmpty() ? new THREE.Vector3(1, 1, 1) : bounds.getSize(new THREE.Vector3()); + const maxDimension = Math.max(size.x, size.y, size.z, 0.25); + controls.minDistance = Math.max(0.1, maxDimension * 0.35); + controls.maxDistance = Math.max(10, maxDimension * 14); + applyCameraState(getCameraState(meshData, widgetValues, runtimeValues, center)); scheduleViewportSync(0, false); }, [meshData, decode, applyCameraState, runtimeValues, scheduleViewportSync, widgetValues]); // Prevent scroll events from propagating to React Flow const onWheel = useCallback((e) => { + e.preventDefault(); + e.stopPropagation(); + }, []); + + const onContextMenu = useCallback((e) => { + e.preventDefault(); e.stopPropagation(); }, []); @@ -261,6 +367,7 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal ref={containerRef} className="nodrag nowheel surface-view-container" onWheelCapture={onWheel} + onContextMenu={onContextMenu} /> ); } diff --git a/frontend/src/api.js b/frontend/src/api.js index 225240f..7f129ca 100644 --- a/frontend/src/api.js +++ b/frontend/src/api.js @@ -5,49 +5,105 @@ * and production same-origin serving both work transparently. */ -// ── REST helpers ────────────────────────────────────────────────────── +const SESSION_STORAGE_KEY = 'argonode-session-id'; + +let _sessionId = null; +let _ws = null; +let _handler = null; +let _reconnectTimer = null; + +function generateSessionId() { + if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') { + return crypto.randomUUID(); + } + return `session-${Math.random().toString(36).slice(2)}-${Date.now().toString(36)}`; +} + +export function getSessionId() { + if (_sessionId) return _sessionId; + + if (typeof window === 'undefined') { + _sessionId = 'session-test-runner'; + return _sessionId; + } + + try { + const stored = window.sessionStorage?.getItem(SESSION_STORAGE_KEY); + if (stored) { + _sessionId = stored; + return _sessionId; + } + } catch { + // Fall through to in-memory session id generation. + } + + _sessionId = generateSessionId(); + try { + window.sessionStorage?.setItem(SESSION_STORAGE_KEY, _sessionId); + } catch { + // Ignore storage failures and keep the in-memory id. + } + return _sessionId; +} + +function withSessionHeaders(init = {}) { + const headers = new Headers(init.headers || {}); + headers.set('X-Argonode-Session', getSessionId()); + return { ...init, headers }; +} + +async function sessionFetch(input, init) { + return fetch(input, withSessionHeaders(init)); +} export async function getNodes() { - const r = await fetch('/nodes'); + const r = await sessionFetch('/nodes'); if (!r.ok) throw new Error(`GET /nodes failed: ${r.status}`); return r.json(); } export async function getFiles() { - const r = await fetch('/files'); + const r = await sessionFetch('/files'); if (!r.ok) return []; return r.json(); } -export async function browse(dir) { - const url = dir ? `/browse?dir=${encodeURIComponent(dir)}` : '/browse'; - const r = await fetch(url); - if (!r.ok) throw new Error(`Browse failed: ${r.status}`); +export async function createUploadFolder(relativePath) { + const r = await sessionFetch('/upload-folder', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ path: relativePath }), + }); + if (!r.ok) throw new Error(`Create folder failed: ${r.status}`); return r.json(); } -export async function uploadFile(file) { +export async function uploadFile(file, { relativePath = '' } = {}) { const fd = new FormData(); + if (relativePath) fd.append('relative_path', relativePath); fd.append('file', file); - const r = await fetch('/upload', { method: 'POST', body: fd }); - if (!r.ok) throw new Error(`Upload failed: ${r.status}`); + const r = await sessionFetch('/upload', { method: 'POST', body: fd }); + if (!r.ok) { + const text = await r.text(); + throw new Error(`Upload failed (${r.status}): ${text}`); + } return r.json(); } export async function getChannels(filepath) { - const r = await fetch(`/channels?file=${encodeURIComponent(filepath)}`); + const r = await sessionFetch(`/channels?file=${encodeURIComponent(filepath)}`); if (!r.ok) return [{ name: 'field', type: 'DATA_FIELD' }]; return r.json(); } export async function getFolderFiles(folderpath) { - const r = await fetch(`/folder-files?folder=${encodeURIComponent(folderpath)}`); + const r = await sessionFetch(`/folder-files?folder=${encodeURIComponent(folderpath)}`); if (!r.ok) return []; return r.json(); } export async function runPrompt(prompt) { - const r = await fetch('/prompt', { + const r = await sessionFetch('/prompt', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ prompt }), @@ -59,21 +115,16 @@ export async function runPrompt(prompt) { return r.json(); } -// ── WebSocket ───────────────────────────────────────────────────────── - -let _ws = null; -let _handler = null; -let _reconnectTimer = null; - export function setMessageHandler(fn) { _handler = fn; } export function initWS() { - if (_ws && _ws.readyState < 2) return; // already open or connecting + if (_ws && _ws.readyState < 2) return; const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - _ws = new WebSocket(`${protocol}//${window.location.host}/ws`); + const session = encodeURIComponent(getSessionId()); + _ws = new WebSocket(`${protocol}//${window.location.host}/ws?session=${session}`); _ws.onopen = () => { console.log('[argonode] WebSocket connected'); diff --git a/frontend/src/constants.js b/frontend/src/constants.js index 691da4d..baef17a 100644 --- a/frontend/src/constants.js +++ b/frontend/src/constants.js @@ -49,7 +49,7 @@ export const SOCKET_COMPATIBILITY = { VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']), ANNOTATION_SOURCE: new Set(['DATA_FIELD', 'IMAGE']), SAVE_LAYER: new Set(['DATA_FIELD', 'IMAGE']), - SAVE_VALUE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'MESH_MODEL', 'FLOAT']), + SAVE_VALUE: new Set(['DATA_FIELD', 'IMAGE', 'ANNOTATION_SOURCE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'MESH_MODEL', 'FLOAT']), FLOAT: new Set(['INT']), INT: new Set(['FLOAT']), LINE: new Set(['COORDPAIR']), diff --git a/frontend/src/executionGraph.js b/frontend/src/executionGraph.js index 1bc0cf7..759a3f5 100644 --- a/frontend/src/executionGraph.js +++ b/frontend/src/executionGraph.js @@ -1,4 +1,4 @@ -import { DATA_TYPES } from './constants'; +import { DATA_TYPES } from './constants.js'; function getInputName(handleId) { return handleId.split('::')[1]; @@ -8,11 +8,24 @@ function getOutputSlot(handleId) { return parseInt(handleId.split('::')[1], 10); } +function resolveExecutionEdge(edge) { + const original = edge?.data?.groupProxyOriginal; + if (!original) return edge; + return { + ...edge, + source: original.source || edge.source, + sourceHandle: original.sourceHandle || edge.sourceHandle, + target: original.target || edge.target, + targetHandle: original.targetHandle || edge.targetHandle, + }; +} + export function getConnectedNodeIds(edges) { const connectedNodeIds = new Set(); for (const edge of edges) { - connectedNodeIds.add(edge.source); - connectedNodeIds.add(edge.target); + const resolved = resolveExecutionEdge(edge); + connectedNodeIds.add(resolved.source); + connectedNodeIds.add(resolved.target); } return connectedNodeIds; } @@ -53,6 +66,7 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f if (!runnableNodeIds.has(node.id)) continue; const { className, definition, widgetValues, runtimeValues } = node.data; + if (className === 'Group') continue; if (!definition) continue; if (excludeManualTrigger && definition.manual_trigger) continue; @@ -72,7 +86,9 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f } } - const incoming = edges.filter((edge) => edge.target === node.id); + const incoming = edges + .map(resolveExecutionEdge) + .filter((edge) => edge.target === node.id); for (const edge of incoming) { const inputName = getInputName(edge.targetHandle); const outputSlot = getOutputSlot(edge.sourceHandle); @@ -97,12 +113,15 @@ export function hasBlockingAutoRunInput(node, edges) { const required = def.input.required || {}; for (const [name, spec] of Object.entries(required)) { const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; - const hiddenByConnectedInput = (() => { - const raw = opts?.hide_when_input_connected; - if (!raw) return false; - const inputs = Array.isArray(raw) ? raw : [raw]; - return inputs.some((inputName) => edges.some( - (edge) => edge.target === node.id && getInputName(edge.targetHandle) === String(inputName) + const hiddenByConnectedInput = (() => { + const raw = opts?.hide_when_input_connected; + if (!raw) return false; + const inputs = Array.isArray(raw) ? raw : [raw]; + return inputs.some((inputName) => edges.some( + (edge) => { + const resolved = resolveExecutionEdge(edge); + return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName); + } )); })(); @@ -114,7 +133,10 @@ export function hasBlockingAutoRunInput(node, edges) { } if (!DATA_TYPES.has(type)) continue; const hasEdge = edges.some( - (edge) => edge.target === node.id && getInputName(edge.targetHandle) === name + (edge) => { + const resolved = resolveExecutionEdge(edge); + return resolved.target === node.id && getInputName(resolved.targetHandle) === name; + } ); if (!hasEdge) return true; } diff --git a/frontend/src/groupDrag.js b/frontend/src/groupDrag.js new file mode 100644 index 0000000..9851a41 --- /dev/null +++ b/frontend/src/groupDrag.js @@ -0,0 +1,18 @@ +export const GROUP_DRAG_RELEASE_DISTANCE = 18; + +export function getPointDistanceOutsideRect(rect, point) { + if (!rect || !point) return Infinity; + + const dx = point.x < rect.left + ? rect.left - point.x + : (point.x > rect.right ? point.x - rect.right : 0); + const dy = point.y < rect.top + ? rect.top - point.y + : (point.y > rect.bottom ? point.y - rect.bottom : 0); + + return Math.hypot(dx, dy); +} + +export function shouldReleaseFromGroup(rect, point, threshold = GROUP_DRAG_RELEASE_DISTANCE) { + return getPointDistanceOutsideRect(rect, point) >= threshold; +} diff --git a/frontend/src/groupSizing.js b/frontend/src/groupSizing.js new file mode 100644 index 0000000..0b415b9 --- /dev/null +++ b/frontend/src/groupSizing.js @@ -0,0 +1,35 @@ +const DEFAULT_CHILD_WIDTH = 200; +const DEFAULT_CHILD_HEIGHT = 120; + +function getNodeSize(node, axis) { + const fallback = axis === 'width' ? DEFAULT_CHILD_WIDTH : DEFAULT_CHILD_HEIGHT; + const measured = Number(node?.measured?.[axis]); + if (Number.isFinite(measured) && measured > 0) return measured; + const direct = Number(node?.[axis]); + if (Number.isFinite(direct) && direct > 0) return direct; + const styled = Number(node?.style?.[axis]); + if (Number.isFinite(styled) && styled > 0) return styled; + return fallback; +} + +export function getGroupMinimumSize(memberNodes, { + minWidth = 260, + minHeight = 180, + paddingX = 24, + paddingY = 24, +} = {}) { + let maxRight = 0; + let maxBottom = 0; + + for (const node of memberNodes || []) { + const x = Number(node?.position?.x) || 0; + const y = Number(node?.position?.y) || 0; + maxRight = Math.max(maxRight, x + getNodeSize(node, 'width')); + maxBottom = Math.max(maxBottom, y + getNodeSize(node, 'height')); + } + + return { + width: Math.max(minWidth, Math.ceil(maxRight + paddingX)), + height: Math.max(minHeight, Math.ceil(maxBottom + paddingY)), + }; +} diff --git a/frontend/src/nativePicker.js b/frontend/src/nativePicker.js new file mode 100644 index 0000000..d03f688 --- /dev/null +++ b/frontend/src/nativePicker.js @@ -0,0 +1,118 @@ +const FILE_ACCEPT = [ + '.png', '.jpg', '.jpeg', '.tiff', '.tif', '.bmp', + '.npy', '.npz', + '.gwy', '.sxm', '.ibw', + '.ttf', '.otf', '.woff', '.woff2', +].join(','); + +function normalizeRelativePath(path) { + return String(path || '').replace(/\\/g, '/').replace(/^\/+/, ''); +} + +function pickWithInput({ directory = false } = {}) { + return new Promise((resolve) => { + const input = document.createElement('input'); + input.type = 'file'; + input.style.position = 'fixed'; + input.style.left = '-9999px'; + if (directory) { + input.multiple = true; + input.setAttribute('webkitdirectory', ''); + input.setAttribute('directory', ''); + } else { + input.accept = FILE_ACCEPT; + } + + const cleanup = () => { + input.remove(); + }; + + input.addEventListener('change', () => { + const files = Array.from(input.files || []); + cleanup(); + resolve(files); + }, { once: true }); + + document.body.appendChild(input); + input.click(); + }); +} + +async function collectDirectoryEntries(handle, prefix = handle.name) { + const entries = []; + for await (const [name, child] of handle.entries()) { + const relativePath = prefix ? `${prefix}/${name}` : name; + if (child.kind === 'file') { + const file = await child.getFile(); + entries.push({ file, relativePath: normalizeRelativePath(relativePath) }); + continue; + } + if (child.kind === 'directory') { + entries.push(...await collectDirectoryEntries(child, relativePath)); + } + } + return entries; +} + +export async function pickNativeFileSelection() { + try { + if (typeof window.showOpenFilePicker === 'function') { + const [handle] = await window.showOpenFilePicker({ + multiple: false, + types: [{ + description: 'Supported files', + accept: { + 'application/octet-stream': ['.npy', '.npz', '.gwy', '.sxm', '.ibw', '.ttf', '.otf', '.woff', '.woff2'], + 'image/*': ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'], + }, + }], + }); + if (!handle) return null; + const file = await handle.getFile(); + return { + rootName: file.name, + entries: [{ file, relativePath: normalizeRelativePath(file.name) }], + }; + } + } catch (error) { + if (error?.name !== 'AbortError') throw error; + return null; + } + + const files = await pickWithInput({ directory: false }); + if (files.length === 0) return null; + return { + rootName: files[0].name, + entries: [{ file: files[0], relativePath: normalizeRelativePath(files[0].name) }], + }; +} + +export async function pickNativeDirectorySelection() { + try { + if (typeof window.showDirectoryPicker === 'function') { + const handle = await window.showDirectoryPicker(); + if (!handle) return null; + const entries = await collectDirectoryEntries(handle, handle.name); + return { + rootName: handle.name, + entries, + }; + } + } catch (error) { + if (error?.name !== 'AbortError') throw error; + return null; + } + + const files = await pickWithInput({ directory: true }); + if (files.length === 0) return null; + const entries = files.map((file) => ({ + file, + relativePath: normalizeRelativePath(file.webkitRelativePath || file.name), + })); + const rootName = entries[0]?.relativePath.split('/')[0] || ''; + if (!rootName) return null; + return { + rootName, + entries, + }; +} diff --git a/frontend/src/nodeClipboard.js b/frontend/src/nodeClipboard.js index b00506f..4216f3d 100644 --- a/frontend/src/nodeClipboard.js +++ b/frontend/src/nodeClipboard.js @@ -1,3 +1,5 @@ +import { sortNodesForParentOrder } from './nodeHierarchy.js'; + export const NODE_CLIPBOARD_KIND = 'argonode/node-selection'; export const NODE_CLIPBOARD_MIME = 'application/x-argonode-node-selection'; @@ -18,13 +20,52 @@ function clonePlainObject(value) { return cloneValue(value) || {}; } +function collectSelectedNodeIds(nodes, nodeIds) { + const selectedIdSet = new Set((Array.isArray(nodeIds) ? nodeIds : []).map((id) => String(id))); + if (selectedIdSet.size === 0) return selectedIdSet; + + let changed = true; + while (changed) { + changed = false; + for (const node of Array.isArray(nodes) ? nodes : []) { + const parentId = node?.parentId ? String(node.parentId) : null; + const nodeId = String(node?.id); + if (parentId && selectedIdSet.has(parentId) && !selectedIdSet.has(nodeId)) { + selectedIdSet.add(nodeId); + changed = true; + } + } + } + return selectedIdSet; +} + +function extractExtraData(data) { + const source = data || {}; + return Object.fromEntries( + Object.entries(source).filter(([key]) => ![ + 'label', + 'className', + 'widgetValues', + 'runtimeValues', + 'definition', + 'previewImage', + 'tableRows', + 'meshData', + 'overlay', + 'scalarValue', + 'processingTimeMs', + 'warning', + ].includes(key)), + ); +} + export function buildNodeClipboardPayloadForIds( nodes, edges, nodeIds, { includeIncomingExternalEdges = false } = {}, ) { - const selectedIdSet = new Set((Array.isArray(nodeIds) ? nodeIds : []).map((id) => String(id))); + const selectedIdSet = collectSelectedNodeIds(nodes, nodeIds); const selectedNodes = Array.isArray(nodes) ? nodes.filter((node) => selectedIdSet.has(String(node.id))) : []; @@ -50,12 +91,18 @@ export function buildNodeClipboardPayloadForIds( x: Number(node.position?.x) || 0, y: Number(node.position?.y) || 0, }, + ...(node.className ? { className: node.className } : {}), + ...(node.parentId ? { parentId: String(node.parentId) } : {}), + ...(node.extent ? { extent: node.extent } : {}), + ...(node.hidden ? { hidden: true } : {}), + ...(node.style ? { style: cloneValue(node.style) } : {}), dragHandle: node.dragHandle || '.drag-handle', data: { label: node.data?.label || node.data?.className || 'Node', className: node.data?.className || '', widgetValues: clonePlainObject(node.data?.widgetValues), runtimeValues: clonePlainObject(node.data?.runtimeValues), + extraData: clonePlainObject(extractExtraData(node.data)), }, })), edges: capturedEdges.map((edge) => ({ @@ -64,15 +111,19 @@ export function buildNodeClipboardPayloadForIds( target: String(edge.target), targetHandle: edge.targetHandle, ...(edge.style ? { style: { ...edge.style } } : {}), + ...(edge.hidden ? { hidden: true } : {}), + ...(edge.data ? { data: cloneValue(edge.data) } : {}), })), }; } export function buildNodeClipboardPayload(nodes, edges) { - const selectedIds = Array.isArray(nodes) - ? nodes.filter((node) => node?.selected).map((node) => String(node.id)) + const selectedNodes = Array.isArray(nodes) + ? nodes.filter((node) => node?.selected) : []; - return buildNodeClipboardPayloadForIds(nodes, edges, selectedIds); + const selectedIds = selectedNodes.map((node) => String(node.id)); + const includeIncomingExternalEdges = selectedNodes.some((node) => node?.data?.className === 'Group'); + return buildNodeClipboardPayloadForIds(nodes, edges, selectedIds, { includeIncomingExternalEdges }); } export function parseNodeClipboardPayload(text) { @@ -102,19 +153,27 @@ export function instantiateNodeClipboardPayload( const idMap = new Map(); let currentId = Number(nextNodeId) || 1; - const nodes = payload.nodes.map((node) => { - const newId = String(currentId++); - idMap.set(String(node.id), newId); + payload.nodes.forEach((node) => { + idMap.set(String(node.id), String(currentId++)); + }); + + const nodes = sortNodesForParentOrder(payload.nodes.map((node) => { + const newId = idMap.get(String(node.id)); const className = node.data?.className || ''; const definition = className ? defs[className] || null : null; return { id: newId, type: node.type || 'custom', + className: node.className, position: { x: (Number(node.position?.x) || 0) + (Number(offset?.x) || 0), y: (Number(node.position?.y) || 0) + (Number(offset?.y) || 0), }, + ...(node.parentId ? { parentId: idMap.get(String(node.parentId)) || String(node.parentId) } : {}), + ...(node.extent ? { extent: node.extent } : {}), + ...(node.hidden ? { hidden: true } : {}), + ...(node.style ? { style: cloneValue(node.style) } : {}), dragHandle: node.dragHandle || '.drag-handle', selected: true, data: { @@ -122,6 +181,7 @@ export function instantiateNodeClipboardPayload( className, widgetValues: clonePlainObject(node.data?.widgetValues), runtimeValues: clonePlainObject(node.data?.runtimeValues), + ...(clonePlainObject(node.data?.extraData)), definition, previewImage: null, tableRows: null, @@ -132,7 +192,7 @@ export function instantiateNodeClipboardPayload( warning: null, }, }; - }); + })); const edges = payload.edges .filter((edge) => ( @@ -147,6 +207,8 @@ export function instantiateNodeClipboardPayload( targetHandle: edge.targetHandle, selected: false, ...(edge.style ? { style: { ...edge.style } } : {}), + ...(edge.hidden ? { hidden: true } : {}), + ...(edge.data ? { data: cloneValue(edge.data) } : {}), })); return { diff --git a/frontend/src/nodeHierarchy.js b/frontend/src/nodeHierarchy.js new file mode 100644 index 0000000..c3742c8 --- /dev/null +++ b/frontend/src/nodeHierarchy.js @@ -0,0 +1,28 @@ +export function sortNodesForParentOrder(nodes) { + const list = Array.isArray(nodes) ? nodes.filter(Boolean) : []; + const entries = list.map((node) => ({ id: String(node.id), node })); + const byId = new Map(entries.map((entry) => [entry.id, entry])); + const visiting = new Set(); + const visited = new Set(); + const ordered = []; + + function visit(entry) { + if (!entry) return; + const { id, node } = entry; + if (visited.has(id) || visiting.has(id)) return; + + visiting.add(id); + + const parentId = node?.parentId ? String(node.parentId) : null; + if (parentId) { + visit(byId.get(parentId)); + } + + visiting.delete(id); + visited.add(id); + ordered.push(node); + } + + entries.forEach((entry) => visit(entry)); + return ordered; +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index a76a579..8b33b15 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -217,6 +217,11 @@ html, body, #root { flex: 1; position: relative; } +.flow-container.canvas-right-zooming, +.flow-container.canvas-right-zooming .react-flow__pane, +.flow-container.canvas-right-zooming .react-flow__background { + cursor: ns-resize !important; +} /* ── React Flow dark overrides ─────────────────────────────────────── */ .react-flow { @@ -236,8 +241,143 @@ html, body, #root { overflow: hidden; } +.group-node { + width: 100%; + height: 100%; + min-width: 220px; + resize: none; + border-style: dashed; + display: flex; + flex-direction: column; + background: + linear-gradient(180deg, rgba(30, 41, 59, 0.82), rgba(15, 23, 42, 0.72)); + box-shadow: + inset 0 0 0 1px rgba(148, 163, 184, 0.08), + inset 0 1px 18px rgba(15, 23, 42, 0.28); +} + +.group-node-title { + background: #334155; +} + +.group-node-title .node-title-main { + flex: 1; +} + +.group-title-slot { + display: flex; + align-items: center; + flex: 0 1 auto; + min-width: 0; + max-width: 100%; +} + +.group-title-button { + flex: 0 1 auto; + min-width: 0; + max-width: 100%; + padding: 0; + border: 0; + background: transparent; + color: var(--text-heading); + font: inherit; + font-weight: inherit; + text-align: left; + cursor: text; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.group-title-input { + flex: 0 1 auto; + min-width: 0; + max-width: min(40ch, 100%); + width: auto; + height: 22px; + padding: 2px 6px; + border: 1px solid rgba(148, 163, 184, 0.45); + border-radius: 4px; + background: rgba(15, 23, 42, 0.72); + color: var(--text-heading); + font: inherit; +} + +.group-node-actions { + display: flex; + align-items: center; + gap: 6px; + margin-left: auto; +} + +.group-toggle { + border: 0; + background: rgba(15, 23, 42, 0.65); + color: var(--text-heading); + border-radius: 4px; + padding: 2px 8px; + cursor: pointer; + font-size: 12px; + line-height: 1; +} + +.group-toggle-collapse { + min-width: 24px; + padding: 2px 6px; +} + +.group-node-summary { + padding: 6px 10px; + color: var(--text-secondary); + font-size: 10px; + border-top: 1px solid var(--border-subtle); +} + +.group-node .node-body { + flex: 1; + min-height: 0; +} + +.group-node-expanded .node-body { + padding: 12px; +} + +.group-node-workspace { + position: relative; + width: 100%; + height: 100%; + min-height: 120px; + border-radius: 8px; + border: 1px solid rgba(148, 163, 184, 0.2); + background: + linear-gradient(180deg, rgba(15, 23, 42, 0.16), rgba(15, 23, 42, 0.34)); + box-shadow: + inset 0 0 0 1px rgba(15, 23, 42, 0.12), + inset 0 12px 28px rgba(15, 23, 42, 0.18); + pointer-events: none; +} + +.group-node-workspace-label { + position: absolute; + top: 10px; + left: 12px; + color: rgba(148, 163, 184, 0.58); + font-size: 10px; + letter-spacing: 0.06em; + text-transform: lowercase; +} + +.group-node-expanded .group-node-summary { + position: absolute; + right: 10px; + bottom: 8px; + border-top: 0; + padding: 0; + background: transparent; +} + /* Let React Flow node wrapper fit to the custom-node's size */ -.react-flow__node-custom { +.react-flow__node-custom:not(.group-shell) { width: auto !important; height: auto !important; } diff --git a/frontend/src/workflowCapture.js b/frontend/src/workflowCapture.js index b75a17f..7b59aa7 100644 --- a/frontend/src/workflowCapture.js +++ b/frontend/src/workflowCapture.js @@ -1,5 +1,5 @@ import { toBlob } from 'html-to-image'; -import { CANVAS_COLORS } from './constants'; +import { CANVAS_COLORS } from './constants.js'; export const OVERLAY_CAPTURE_SELECTORS = [ '.lineplot-overlay', diff --git a/frontend/src/workflowHydration.js b/frontend/src/workflowHydration.js index 3e97032..8d1d070 100644 --- a/frontend/src/workflowHydration.js +++ b/frontend/src/workflowHydration.js @@ -1,3 +1,5 @@ +import { sortNodesForParentOrder } from './nodeHierarchy.js'; + function mergeDefinition(nodeData, defs) { const savedData = nodeData || {}; const registryDefinition = savedData.className ? defs[savedData.className] : null; @@ -34,27 +36,35 @@ export function hydrateWorkflowState(data, defs = {}) { const loadedNodes = Array.isArray(data?.nodes) ? data.nodes : []; const loadedEdges = Array.isArray(data?.edges) ? data.edges : []; - const nodes = loadedNodes.map((node) => { + const nodes = sortNodesForParentOrder(loadedNodes.map((node) => { const definition = mergeDefinition(node.data, defs); return { ...node, type: node.type || 'custom', + className: node.className, + parentId: node.parentId, + extent: node.extent, + hidden: !!node.hidden, + style: node.style, dragHandle: node.dragHandle || '.drag-handle', data: { ...node.data, label: node.data?.label || node.data?.className || 'Node', widgetValues: sanitizeWidgetValues(node.data?.widgetValues, definition), - runtimeValues: {}, + runtimeValues: node.data?.runtimeValues || {}, + ...(node.data?.extraData || {}), definition, previewImage: null, tableRows: null, meshData: null, overlay: null, scalarValue: null, + processingTimeMs: null, + warning: null, }, }; - }); + })); const edges = loadedEdges.map((edge) => ({ ...edge })); diff --git a/frontend/src/workflowSerialization.js b/frontend/src/workflowSerialization.js index cdac583..170914a 100644 --- a/frontend/src/workflowSerialization.js +++ b/frontend/src/workflowSerialization.js @@ -1,15 +1,44 @@ export function serializeWorkflowState(nodes, edges) { + const compactObject = (value) => { + if (!value || typeof value !== 'object') return null; + const entries = Object.entries(value); + return entries.length > 0 ? Object.fromEntries(entries) : null; + }; + const getExtraData = (data) => compactObject(Object.fromEntries( + Object.entries(data || {}).filter(([key]) => ![ + 'label', + 'className', + 'widgetValues', + 'runtimeValues', + 'definition', + 'previewImage', + 'tableRows', + 'meshData', + 'overlay', + 'scalarValue', + 'processingTimeMs', + 'warning', + ].includes(key)) + )); + return { version: 1, nodes: nodes.map((node) => ({ id: node.id, type: node.type || 'custom', position: node.position, + ...(node.className ? { className: node.className } : {}), + ...(node.parentId ? { parentId: node.parentId } : {}), + ...(node.extent ? { extent: node.extent } : {}), + ...(node.hidden ? { hidden: true } : {}), + ...(node.style ? { style: node.style } : {}), dragHandle: node.dragHandle || '.drag-handle', data: { label: node.data?.label || node.data?.className || 'Node', className: node.data?.className || '', widgetValues: node.data?.widgetValues || {}, + ...(compactObject(node.data?.runtimeValues) ? { runtimeValues: compactObject(node.data?.runtimeValues) } : {}), + ...(getExtraData(node.data) ? { extraData: getExtraData(node.data) } : {}), output: node.data?.definition?.output || [], output_name: node.data?.definition?.output_name || [], }, @@ -21,6 +50,8 @@ export function serializeWorkflowState(nodes, edges) { target: edge.target, targetHandle: edge.targetHandle, ...(edge.style ? { style: edge.style } : {}), + ...(edge.hidden ? { hidden: true } : {}), + ...(edge.data ? { data: edge.data } : {}), })), }; } diff --git a/frontend/tests/constants.test.mjs b/frontend/tests/constants.test.mjs new file mode 100644 index 0000000..987b9c6 --- /dev/null +++ b/frontend/tests/constants.test.mjs @@ -0,0 +1,8 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { SOCKET_COMPATIBILITY } from '../src/constants.js'; + +test('SAVE_VALUE accepts ANNOTATION_SOURCE inputs', () => { + assert.equal(SOCKET_COMPATIBILITY.SAVE_VALUE.has('ANNOTATION_SOURCE'), true); +}); diff --git a/frontend/tests/executionGraph.test.mjs b/frontend/tests/executionGraph.test.mjs index 1c4c763..c6f1354 100644 --- a/frontend/tests/executionGraph.test.mjs +++ b/frontend/tests/executionGraph.test.mjs @@ -192,6 +192,73 @@ test('serializeExecutionGraph allows a singleton ImageDemo graph so previews can }); }); +test('serializeExecutionGraph ignores group shells and resolves collapsed proxy edges back to child endpoints', () => { + const nodes = [ + { + id: '1', + data: { + className: 'Image', + definition: { + input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: { filename: 'scan.gwy' }, + }, + }, + { + id: '10', + data: { + className: 'Group', + definition: null, + widgetValues: {}, + }, + }, + { + id: '2', + parentId: '10', + hidden: true, + data: { + className: 'PreviewImage', + definition: { + input: { required: { field: ['DATA_FIELD', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: {}, + }, + }, + ]; + + const edges = [ + { + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '10', + targetHandle: 'group-proxy::in::2::DATA_FIELD::input%3A%3Afield%3A%3ADATA_FIELD', + data: { + groupProxyOwner: '10', + groupProxyOriginal: { + target: '2', + targetHandle: 'input::field::DATA_FIELD', + }, + }, + }, + ]; + + const prompt = serializeExecutionGraph(nodes, edges); + + assert.deepEqual(prompt, { + '1': { + class_type: 'Image', + inputs: { filename: 'scan.gwy' }, + }, + '2': { + class_type: 'PreviewImage', + inputs: { field: ['1', 0] }, + }, + }); + assert.equal('10' in prompt, false); +}); + test('getAutoRunnableNodes ignores disconnected nodes when deciding what can auto-run', () => { const nodes = [ { id: '1', data: { definition: {}, widgetValues: {} } }, diff --git a/frontend/tests/groupDrag.test.mjs b/frontend/tests/groupDrag.test.mjs new file mode 100644 index 0000000..18ad72b --- /dev/null +++ b/frontend/tests/groupDrag.test.mjs @@ -0,0 +1,26 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { + GROUP_DRAG_RELEASE_DISTANCE, + getPointDistanceOutsideRect, + shouldReleaseFromGroup, +} from '../src/groupDrag.js'; + +test('getPointDistanceOutsideRect returns zero inside the rect', () => { + const rect = { left: 10, top: 20, right: 110, bottom: 120 }; + assert.equal(getPointDistanceOutsideRect(rect, { x: 60, y: 70 }), 0); +}); + +test('shouldReleaseFromGroup waits for a small overshoot before releasing', () => { + const rect = { left: 10, top: 20, right: 110, bottom: 120 }; + + assert.equal( + shouldReleaseFromGroup(rect, { x: 110 + GROUP_DRAG_RELEASE_DISTANCE - 1, y: 70 }), + false, + ); + assert.equal( + shouldReleaseFromGroup(rect, { x: 110 + GROUP_DRAG_RELEASE_DISTANCE, y: 70 }), + true, + ); +}); diff --git a/frontend/tests/groupSizing.test.mjs b/frontend/tests/groupSizing.test.mjs new file mode 100644 index 0000000..44082d0 --- /dev/null +++ b/frontend/tests/groupSizing.test.mjs @@ -0,0 +1,26 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { getGroupMinimumSize } from '../src/groupSizing.js'; + +test('getGroupMinimumSize keeps the base minimum for empty groups', () => { + assert.deepEqual(getGroupMinimumSize([]), { width: 260, height: 180 }); +}); + +test('getGroupMinimumSize grows to fit child bounds plus padding', () => { + const nodes = [ + { + position: { x: 24, y: 60 }, + style: { width: 180, height: 100 }, + }, + { + position: { x: 260, y: 150 }, + style: { width: 220, height: 140 }, + }, + ]; + + assert.deepEqual(getGroupMinimumSize(nodes), { + width: 504, + height: 314, + }); +}); diff --git a/frontend/tests/nodeClipboard.test.mjs b/frontend/tests/nodeClipboard.test.mjs index 682ba4f..09fef6a 100644 --- a/frontend/tests/nodeClipboard.test.mjs +++ b/frontend/tests/nodeClipboard.test.mjs @@ -265,3 +265,28 @@ test('clipboard payload deep-copies local widget and runtime fields', () => { assert.equal(payload.nodes[0].data.widgetValues.markup_shapes[0].points[0], 0.1); assert.equal(payload.nodes[0].data.runtimeValues.camera.azimuth, 15); }); + +test('clipboard payload preserves wrapper class names for group shells', () => { + const payload = buildNodeClipboardPayloadForIds( + [ + { + id: '50', + type: 'custom', + className: 'group-shell', + position: { x: 0, y: 0 }, + data: { + label: 'group', + className: 'Group', + widgetValues: {}, + }, + }, + ], + [], + ['50'], + ); + + const instantiated = instantiateNodeClipboardPayload(payload, {}, 80); + + assert.equal(payload.nodes[0].className, 'group-shell'); + assert.equal(instantiated.nodes[0].className, 'group-shell'); +}); diff --git a/frontend/tests/nodeHierarchy.test.mjs b/frontend/tests/nodeHierarchy.test.mjs new file mode 100644 index 0000000..7dffff8 --- /dev/null +++ b/frontend/tests/nodeHierarchy.test.mjs @@ -0,0 +1,96 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { sortNodesForParentOrder } from '../src/nodeHierarchy.js'; +import { hydrateWorkflowState } from '../src/workflowHydration.js'; +import { instantiateNodeClipboardPayload, NODE_CLIPBOARD_KIND } from '../src/nodeClipboard.js'; + +test('sortNodesForParentOrder places parents before descendants', () => { + const nodes = [ + { id: '2', parentId: '1', position: { x: 80, y: 60 }, data: { className: 'Preview' } }, + { id: '3', position: { x: 300, y: 20 }, data: { className: 'Image' } }, + { id: '1', className: 'group-shell', position: { x: 0, y: 0 }, data: { className: 'Group' } }, + { id: '4', parentId: '2', position: { x: 30, y: 24 }, data: { className: 'Save' } }, + ]; + + const ordered = sortNodesForParentOrder(nodes); + + assert.deepEqual(ordered.map((node) => node.id), ['1', '2', '3', '4']); +}); + +test('hydrateWorkflowState reorders group parents ahead of children', () => { + const saved = { + nodes: [ + { + id: '11', + type: 'custom', + position: { x: 48, y: 72 }, + parentId: '10', + extent: 'parent', + data: { + label: 'preview', + className: 'Preview', + widgetValues: {}, + }, + }, + { + id: '10', + type: 'custom', + className: 'group-shell', + position: { x: 12, y: 24 }, + style: { width: 320, height: 220 }, + data: { + label: 'group', + className: 'Group', + widgetValues: {}, + }, + }, + ], + edges: [], + }; + + const hydrated = hydrateWorkflowState(saved, {}); + + assert.deepEqual(hydrated.nodes.map((node) => node.id), ['10', '11']); + assert.equal(hydrated.nodes[1].parentId, '10'); +}); + +test('instantiateNodeClipboardPayload remaps parent ids before sorting grouped nodes', () => { + const payload = { + kind: NODE_CLIPBOARD_KIND, + version: 1, + nodes: [ + { + id: 'child', + type: 'custom', + position: { x: 48, y: 72 }, + parentId: 'group', + extent: 'parent', + data: { + label: 'preview', + className: 'Preview', + widgetValues: {}, + }, + }, + { + id: 'group', + type: 'custom', + className: 'group-shell', + position: { x: 12, y: 24 }, + style: { width: 320, height: 220 }, + data: { + label: 'group', + className: 'Group', + widgetValues: {}, + }, + }, + ], + edges: [], + }; + + const instantiated = instantiateNodeClipboardPayload(payload, {}, 20); + + assert.deepEqual(instantiated.nodes.map((node) => node.id), ['21', '20']); + assert.equal(instantiated.nodes[1].parentId, '21'); + assert.equal(instantiated.nextNodeId, 22); +}); diff --git a/frontend/tests/workflowSerialization.test.mjs b/frontend/tests/workflowSerialization.test.mjs index 0bf80e4..f9ad553 100644 --- a/frontend/tests/workflowSerialization.test.mjs +++ b/frontend/tests/workflowSerialization.test.mjs @@ -226,3 +226,26 @@ test('hydrateWorkflowState clears saved folder selections on shared workflows', assert.deepEqual(hydrated.nodes[0].data.definition.output, ['PATH']); assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['path']); }); + +test('workflow serialization preserves wrapper class names for group shells', () => { + const nodes = [ + { + id: '31', + type: 'custom', + className: 'group-shell', + position: { x: 5, y: 15 }, + style: { width: 420, height: 260 }, + data: { + label: 'group', + className: 'Group', + widgetValues: {}, + }, + }, + ]; + + const serialized = serializeWorkflowState(nodes, []); + const hydrated = hydrateWorkflowState(serialized, {}); + + assert.equal(serialized.nodes[0].className, 'group-shell'); + assert.equal(hydrated.nodes[0].className, 'group-shell'); +}); diff --git a/frontend/vite.config.js b/frontend/vite.config.js index f674bd5..eb12a43 100644 --- a/frontend/vite.config.js +++ b/frontend/vite.config.js @@ -9,9 +9,9 @@ export default defineConfig({ proxy: { '/nodes': 'http://127.0.0.1:8188', '/files': 'http://127.0.0.1:8188', - '/browse': 'http://127.0.0.1:8188', '/folder-files': 'http://127.0.0.1:8188', '/channels': 'http://127.0.0.1:8188', + '/upload-folder': 'http://127.0.0.1:8188', '/upload': 'http://127.0.0.1:8188', '/download': 'http://127.0.0.1:8188', '/prompt': 'http://127.0.0.1:8188', diff --git a/package.json b/package.json index 0618d35..58f1308 100644 --- a/package.json +++ b/package.json @@ -7,12 +7,15 @@ }, "scripts": { "postinstall": "npm --prefix frontend install", - "dev": "npm --prefix frontend run dev", - "build": "npm --prefix frontend run build", + "clean:dev": "node scripts/clean-build-artifacts.mjs", + "clean:build": "node scripts/clean-build-artifacts.mjs", + "clean:native": "node scripts/clean-build-artifacts.mjs --mode=native", + "dev": "npm run clean:dev && npm --prefix frontend run dev", + "build": "npm run clean:build && npm --prefix frontend run build", "preview": "npm --prefix frontend run preview", "test:frontend": "npm --prefix frontend test", "backend": "python -m backend.main", - "desktop": "python desktop.py", + "desktop": "npm run build && python desktop.py", "build:windows": "powershell -ExecutionPolicy Bypass -File scripts\\build-windows.ps1", "build:mac": "bash scripts/build-mac.sh", "build:linux": "bash scripts/build-linux.sh" diff --git a/scripts/build-windows.ps1 b/scripts/build-windows.ps1 index 5da3c78..3ff8951 100644 --- a/scripts/build-windows.ps1 +++ b/scripts/build-windows.ps1 @@ -38,6 +38,10 @@ $pythonExe = if (Test-Path ".\.venv\Scripts\python.exe") { $frontendDist = Join-Path $repoRoot "frontend\dist" $demoDir = Join-Path $repoRoot "demo" +Write-Host "Removing cached frontend and desktop build artifacts..." +node scripts\clean-build-artifacts.mjs --mode=native +Assert-LastExitCode "Artifact cleanup" + Write-Host "Building frontend bundle..." npm run build Assert-LastExitCode "Frontend build" diff --git a/scripts/clean-build-artifacts.mjs b/scripts/clean-build-artifacts.mjs new file mode 100644 index 0000000..cc0ae9f --- /dev/null +++ b/scripts/clean-build-artifacts.mjs @@ -0,0 +1,70 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const scriptDir = path.dirname(fileURLToPath(import.meta.url)); +const repoRoot = path.resolve(scriptDir, '..'); +const args = new Set(process.argv.slice(2)); +const mode = args.has('--mode=native') || args.has('--native') ? 'native' : 'frontend'; + +const removed = []; + +function removePath(targetPath) { + if (!fs.existsSync(targetPath)) return; + fs.rmSync(targetPath, { recursive: true, force: true }); + removed.push(path.relative(repoRoot, targetPath) || '.'); +} + +function removePythonCaches(rootPath) { + const stack = [rootPath]; + const skipDirs = new Set(['.git', '.venv', 'node_modules']); + + while (stack.length > 0) { + const current = stack.pop(); + let entries = []; + + try { + entries = fs.readdirSync(current, { withFileTypes: true }); + } catch { + continue; + } + + for (const entry of entries) { + const fullPath = path.join(current, entry.name); + + if (entry.isDirectory()) { + if (entry.name === '__pycache__') { + removePath(fullPath); + continue; + } + if (skipDirs.has(entry.name)) continue; + stack.push(fullPath); + continue; + } + + if (entry.isFile() && (entry.name.endsWith('.pyc') || entry.name.endsWith('.pyo'))) { + try { + fs.rmSync(fullPath, { force: true }); + removed.push(path.relative(repoRoot, fullPath) || '.'); + } catch { + // Ignore files held open by another process; the rest of the clean can still continue. + } + } + } + } +} + +removePath(path.join(repoRoot, 'frontend', 'dist')); +removePath(path.join(repoRoot, 'frontend', 'node_modules', '.vite')); +removePythonCaches(repoRoot); + +if (mode === 'native') { + removePath(path.join(repoRoot, 'desktop-build')); + removePath(path.join(repoRoot, 'desktop-dist')); +} + +if (removed.length === 0) { + console.log(`[clean] No cached build artifacts found (${mode}).`); +} else { + console.log(`[clean] Removed ${removed.length} artifact${removed.length === 1 ? '' : 's'} (${mode}).`); +} diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 7f4926c..8fb4e22 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -2110,6 +2110,7 @@ def test_view3d(): print("=== Test: View3D ===") from backend.nodes.view_3d import View3D from backend.data_types import ImageData, MeshModel + from backend.execution_context import active_node, execution_callbacks import base64 import io from PIL import Image @@ -2118,28 +2119,34 @@ def test_view3d(): field = make_field() captured = [] - View3D._broadcast_mesh_fn = lambda nid, mesh: captured.append(mesh) - View3D._current_node_id = "test" + mesh_callback = lambda nid, mesh: captured.append(mesh) preview_image = Image.new("RGB", (12, 10), (255, 0, 0)) preview_buffer = io.BytesIO() preview_image.save(preview_buffer, format="PNG") viewport_snapshot = "data:image/png;base64," + base64.b64encode(preview_buffer.getvalue()).decode() - result = node.render( - field, - colormap="viridis", - z_scale=2.0, - resolution=64, - make_solid=False, - viewport_snapshot=viewport_snapshot, - ) + with execution_callbacks(mesh=mesh_callback), active_node("test"): + result = node.render( + field, + colormap="viridis", + z_scale=2.0, + resolution=64, + make_solid=False, + camera_target_x=0.1, + camera_target_y=-0.2, + camera_target_z=0.3, + viewport_snapshot=viewport_snapshot, + ) assert len(result) == 2 assert isinstance(result[0], MeshModel) assert isinstance(result[1], ImageData) assert result[1].shape == (10, 12, 3) assert np.all(result[1][0, 0] == np.array([255, 0, 0], dtype=np.uint8)) assert result[1].metadata["annotation_context"]["si_unit_xy"] == field.si_unit_xy + assert result[1].metadata["viewport_camera"]["target_x"] == 0.1 + assert result[1].metadata["viewport_camera"]["target_y"] == -0.2 + assert result[1].metadata["viewport_camera"]["target_z"] == 0.3 assert len(captured) == 1 mesh = captured[0] @@ -2150,6 +2157,9 @@ def test_view3d(): assert mesh["z_scale"] == 0.2 assert mesh["width"] <= 64 assert mesh["height"] <= 64 + assert mesh["camera_target_x"] == 0.1 + assert mesh["camera_target_y"] == -0.2 + assert mesh["camera_target_z"] == 0.3 # z_min < z_max for non-constant data assert mesh["z_min"] < mesh["z_max"] @@ -2163,7 +2173,8 @@ def test_view3d(): # High-res input should be downsampled big_field = make_field(shape=(256, 256)) captured.clear() - node.render(big_field, colormap="hot", z_scale=1.0, resolution=64, make_solid=False) + with execution_callbacks(mesh=mesh_callback), active_node("test"): + node.render(big_field, colormap="hot", z_scale=1.0, resolution=64, make_solid=False) assert captured[0]["width"] <= 64 assert captured[0]["height"] <= 64 @@ -2171,16 +2182,23 @@ def test_view3d(): mesh_field = make_field(data=np.zeros((64, 64), dtype=np.float64), xreal=2.0, yreal=3.0) map_field = make_field(data=np.tile(np.linspace(0.0, 1.0, 64, dtype=np.float64), (64, 1)), xreal=2.0, yreal=3.0) captured.clear() - mapped_result = node.render(mesh_field, map_field=map_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) + with execution_callbacks(mesh=mesh_callback), active_node("test"): + mapped_result = node.render(mesh_field, map_field=map_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) mapped_mesh = captured[0] assert mapped_mesh["x_range"] == [float(mesh_field.xoff), float(mesh_field.xoff + mesh_field.xreal)] assert mapped_mesh["y_range"] == [float(mesh_field.yoff), float(mesh_field.yoff + mesh_field.yreal)] + assert np.isclose(mapped_mesh["surface_extent_x"] / mapped_mesh["surface_extent_y"], mesh_field.xreal / mesh_field.yreal) mapped_z = np.frombuffer(base64.b64decode(mapped_mesh["z_data"]), dtype=np.float32) assert np.allclose(mapped_z, 0.0) mapped_colors = np.frombuffer(base64.b64decode(mapped_mesh["colors"]), dtype=np.uint8) + top_vertices = np.asarray(mapped_result[0].vertices, dtype=np.float32) + x_span = float(top_vertices[:, 0].max() - top_vertices[:, 0].min()) + y_span = float(top_vertices[:, 2].max() - top_vertices[:, 2].min()) + assert np.isclose(x_span / y_span, mesh_field.xreal / mesh_field.yreal) captured.clear() - node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) + with execution_callbacks(mesh=mesh_callback), active_node("test"): + node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) mesh_only = captured[0] mesh_only_colors = np.frombuffer(base64.b64decode(mesh_only["colors"]), dtype=np.uint8) assert not np.array_equal(mapped_colors, mesh_only_colors) @@ -2189,7 +2207,8 @@ def test_view3d(): solid_mesh = mapped_result[0] assert isinstance(solid_mesh, MeshModel) captured.clear() - solid_result = node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=16, make_solid=True) + with execution_callbacks(mesh=mesh_callback), active_node("test"): + solid_result = node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=16, make_solid=True) assert len(solid_result[0].vertices) > 16 * 16 assert len(solid_result[0].faces) > (15 * 15 * 2) solid_payload = captured[0] @@ -2197,19 +2216,19 @@ def test_view3d(): assert "positions" in solid_payload assert "indices" in solid_payload assert "vertex_colors" in solid_payload - - View3D._broadcast_mesh_fn = None print(" PASS\n") def test_save_generic(): print("=== Test: Save ===") from backend.nodes.save import Save - from backend.data_types import DataField, LineData, MeasureTable, MeshModel, RecordTable + from backend.data_types import DataField, ImageData, LineData, MeasureTable, MeshModel, RecordTable import tifffile from PIL import Image as PILImage node = Save() + format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"] + assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"] with tempfile.TemporaryDirectory() as tmpdir: # Save scalar as TXT and JSON @@ -2282,6 +2301,26 @@ def test_save_generic(): image_npz = np.load(Path(tmpdir, "image_npz.npz")) assert np.array_equal(image_npz["image"], image) + # Save ANNOTATION_SOURCE as PNG, TIFF, and NPZ + annotation_image = ImageData( + image, + metadata={"annotation_context": {"si_unit_xy": "um", "si_unit_z": "nm"}}, + ) + node.save(filename="annotation_png", directory_path=tmpdir, format="PNG", value=annotation_image) + annotation_png = np.asarray(PILImage.open(Path(tmpdir, "annotation_png.png"))) + assert annotation_png.shape == image.shape + assert np.array_equal(annotation_png, image) + + node.save(filename="annotation_tiff", directory_path=tmpdir, format="TIFF", value=annotation_image) + annotation_tiff = tifffile.imread(Path(tmpdir, "annotation_tiff.tiff")) + assert annotation_tiff.shape == image.shape + assert annotation_tiff.dtype == np.uint8 + assert np.array_equal(annotation_tiff, image) + + node.save(filename="annotation_npz", directory_path=tmpdir, format="NPZ", value=annotation_image) + annotation_npz = np.load(Path(tmpdir, "annotation_npz.npz")) + assert np.array_equal(annotation_npz["image"], image) + # Save tables as CSV and JSON measure_table = MeasureTable([ {"quantity": "Rq", "value": 1.23, "unit": "nm"}, diff --git a/tests/test_session_runtime.py b/tests/test_session_runtime.py new file mode 100644 index 0000000..78f1eb8 --- /dev/null +++ b/tests/test_session_runtime.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import threading +from pathlib import Path + +import pytest + +from backend.execution_context import active_node, emit_warning, execution_callbacks +from backend.session_runtime import ( + ensure_session_runtime_dirs, + resolve_client_path, + server_path_to_client_path, + session_upload_uri, +) + + +def test_session_paths_round_trip(monkeypatch, tmp_path): + monkeypatch.setenv("ARGONODE_APPDATA", str(tmp_path / "appdata")) + + session_id = "session-test-1234" + input_dir, _ = ensure_session_runtime_dirs(session_id) + target = input_dir / "picked-folder" / "image.png" + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(b"png") + + client_path = session_upload_uri("picked-folder/image.png") + resolved = resolve_client_path(client_path, session_id=session_id, allow_local_filesystem=False) + + assert resolved == target.resolve() + assert server_path_to_client_path(target, session_id) == client_path + + +def test_browser_sessions_cannot_escape_workspace(monkeypatch, tmp_path): + monkeypatch.setenv("ARGONODE_APPDATA", str(tmp_path / "appdata")) + + session_id = "session-test-5678" + ensure_session_runtime_dirs(session_id) + + outside_path = (tmp_path / "outside" / "secret.dat").resolve() + + with pytest.raises(PermissionError): + resolve_client_path(str(outside_path), session_id=session_id, allow_local_filesystem=False) + + +def test_execution_callbacks_are_thread_local(): + results = [] + lock = threading.Lock() + barrier = threading.Barrier(2) + + def worker(label: str): + def on_warning(node_id: str, message: str): + with lock: + results.append((label, node_id, message)) + + with execution_callbacks(warning=on_warning): + with active_node(f"node-{label}"): + barrier.wait(timeout=5) + emit_warning(f"warning-{label}") + + threads = [ + threading.Thread(target=worker, args=("a",)), + threading.Thread(target=worker, args=("b",)), + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join(timeout=5) + + assert sorted(results) == [ + ("a", "node-a", "warning-a"), + ("b", "node-b", "warning-b"), + ]