From 558046e7aa949a79fb0e4664fc6f4b47167e2070 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Fri, 27 Mar 2026 16:18:22 -0700 Subject: [PATCH] rework web server so multiple clients can be server at a time --- backend/execution.py | 182 +++++---------- backend/execution_context.py | 82 +++++++ backend/nodes/annotations.py | 6 +- backend/nodes/crop_resize_field.py | 25 +- backend/nodes/cross_section.py | 22 +- backend/nodes/cursors.py | 55 ++--- backend/nodes/draw_mask.py | 21 +- backend/nodes/histogram.py | 29 ++- backend/nodes/image.py | 6 +- backend/nodes/markup.py | 21 +- backend/nodes/mask_combine.py | 7 +- backend/nodes/mask_invert.py | 7 +- backend/nodes/mask_morphology.py | 7 +- backend/nodes/preview_image.py | 4 +- backend/nodes/print_table.py | 4 +- backend/nodes/rotate_field.py | 6 +- backend/nodes/save.py | 6 +- backend/nodes/save_image.py | 6 +- backend/nodes/stats.py | 9 +- backend/nodes/threshold_mask.py | 8 +- backend/nodes/value_display.py | 4 +- backend/nodes/view_3d.py | 4 +- backend/server.py | 353 ++++++++++++++++++----------- backend/session_runtime.py | 132 +++++++++++ desktop.py | 15 +- frontend/src/App.jsx | 91 ++++++-- frontend/src/CustomNode.jsx | 65 +++++- frontend/src/FileBrowser.jsx | 103 --------- frontend/src/api.js | 93 ++++++-- frontend/src/nativePicker.js | 118 ++++++++++ frontend/src/styles.css | 28 +++ frontend/vite.config.js | 2 +- tests/test_session_runtime.py | 72 ++++++ 33 files changed, 1042 insertions(+), 551 deletions(-) create mode 100644 backend/execution_context.py create mode 100644 backend/session_runtime.py delete mode 100644 frontend/src/FileBrowser.jsx create mode 100644 frontend/src/nativePicker.js create mode 100644 tests/test_session_runtime.py diff --git a/backend/execution.py b/backend/execution.py index 0d14ed0..c7f74c0 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -32,6 +32,7 @@ from time import perf_counter from typing import Any, Callable from backend.node_registry import NODE_CLASS_MAPPINGS +from backend.execution_context import active_node, execution_callbacks def _is_link(value: Any) -> bool: @@ -85,63 +86,66 @@ class ExecutionEngine: node_outputs: dict[str, tuple] = {} node_output_signatures: dict[str, tuple[str, ...]] = {} - # Inject display callbacks before execution - self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning) + with execution_callbacks( + preview=on_preview, + table=on_table, + mesh=on_mesh, + overlay=on_overlay, + value=on_value, + warning=on_warning, + ): + for node_id in order: + node_def = prompt[node_id] + class_name = node_def["class_type"] - for node_id in order: - node_def = prompt[node_id] - class_name = node_def["class_type"] + if class_name not in NODE_CLASS_MAPPINGS: + raise ValueError(f"Unknown node type: '{class_name}'") - if class_name not in NODE_CLASS_MAPPINGS: - raise ValueError(f"Unknown node type: '{class_name}'") + cls = NODE_CLASS_MAPPINGS[class_name] + raw_inputs = node_def.get("inputs", {}) + input_types = cls.INPUT_TYPES() + inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types) + input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures) - cls = NODE_CLASS_MAPPINGS[class_name] - raw_inputs = node_def.get("inputs", {}) - input_types = cls.INPUT_TYPES() - inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types) - input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures) + cache_entry = self._get_cached_entry(node_id, class_name, input_signature) + if cache_entry is not None: + result = self._clone_cached_outputs(cache_entry["outputs"]) + elapsed_ms = 0.0 + else: + if on_node_start: + on_node_start(node_id) - # Let display nodes know their node_id so they can tag WS messages - self._set_node_id_on_display(cls, node_id) + instance = cls() + func = getattr(instance, cls.FUNCTION) + start_time = perf_counter() + with active_node(node_id): + result = func(**inputs) + elapsed_ms = (perf_counter() - start_time) * 1000.0 - cache_entry = self._get_cached_entry(node_id, class_name, input_signature) - if cache_entry is not None: - result = self._clone_cached_outputs(cache_entry["outputs"]) - elapsed_ms = 0.0 - else: - if on_node_start: - on_node_start(node_id) + # Nodes must return a tuple; coerce single values just in case + if not isinstance(result, tuple): + result = (result,) - instance = cls() - func = getattr(instance, cls.FUNCTION) - start_time = perf_counter() - result = func(**inputs) - elapsed_ms = (perf_counter() - start_time) * 1000.0 + node_outputs[node_id] = result + output_signatures = tuple(self._fingerprint_value(value) for value in result) + node_output_signatures[node_id] = output_signatures - # Nodes must return a tuple; coerce single values just in case - if not isinstance(result, tuple): - result = (result,) + if cache_entry is None and self._node_cacheable(cls): + self._store_cache_entry( + node_id=node_id, + class_name=class_name, + input_signature=input_signature, + output_signatures=output_signatures, + outputs=self._clone_cached_outputs(result), + ) - node_outputs[node_id] = result - output_signatures = tuple(self._fingerprint_value(value) for value in result) - node_output_signatures[node_id] = output_signatures + # Auto-preview: broadcast a thumbnail for any DATA_FIELD, + # IMAGE, or table-like output so every node shows its result. + if on_preview or on_table: + self._auto_preview(cls, node_id, result, on_preview, on_table, inputs) - if cache_entry is None and self._node_cacheable(cls): - self._store_cache_entry( - node_id=node_id, - class_name=class_name, - input_signature=input_signature, - output_signatures=output_signatures, - outputs=self._clone_cached_outputs(result), - ) - - # Auto-preview: broadcast a thumbnail for any DATA_FIELD, - # IMAGE, or table-like output so every node shows its result. - if on_preview or on_table: - self._auto_preview(cls, node_id, result, on_preview, on_table, inputs) - - if on_node_done: - on_node_done(node_id, elapsed_ms) + if on_node_done: + on_node_done(node_id, elapsed_ms) return node_outputs @@ -421,88 +425,6 @@ class ExecutionEngine: return deepcopy(value) return value - def _inject_display_callbacks( - self, - on_preview: Callable | None, - on_table: Callable | None, - on_mesh: Callable | None = None, - on_overlay: Callable | None = None, - on_value: Callable | None = None, - on_warning: Callable | None = None, - ) -> None: - """Wire up broadcast callbacks on display node classes.""" - from backend.nodes.preview_image import PreviewImage - from backend.nodes.print_table import PrintTable - from backend.nodes.view_3d import View3D - from backend.nodes.annotations import Annotations - from backend.nodes.value_display import ValueDisplay - from backend.nodes.markup import Markup - from backend.nodes.cross_section import CrossSection - from backend.nodes.cursors import Cursors - from backend.nodes.stats import Stats - from backend.nodes.histogram import Histogram - from backend.nodes.crop_resize_field import CropResizeField - from backend.nodes.rotate_field import RotateField - from backend.nodes.threshold_mask import ThresholdMask - from backend.nodes.mask_morphology import MaskMorphology - from backend.nodes.mask_invert import MaskInvert - from backend.nodes.mask_combine import MaskCombine - from backend.nodes.draw_mask import DrawMask - from backend.nodes.save import Save - from backend.nodes.save_image import SaveImage - from backend.nodes.image import Image - from backend.nodes.image_demo import ImageDemo - - PreviewImage._broadcast_fn = on_preview - ThresholdMask._broadcast_fn = on_preview - MaskMorphology._broadcast_fn = on_preview - MaskInvert._broadcast_fn = on_preview - MaskCombine._broadcast_fn = on_preview - DrawMask._broadcast_overlay_fn = on_overlay - View3D._broadcast_mesh_fn = on_mesh - Annotations._broadcast_warning_fn = on_warning - PrintTable._broadcast_table_fn = on_table - ValueDisplay._broadcast_value_fn = on_value - Stats._broadcast_value_fn = on_value - Histogram._broadcast_overlay_fn = on_overlay - CrossSection._broadcast_overlay_fn = on_overlay - Cursors._broadcast_overlay_fn = on_overlay - CropResizeField._broadcast_overlay_fn = on_overlay - RotateField._broadcast_warning_fn = on_warning - Markup._broadcast_overlay_fn = on_overlay - Image._broadcast_warning_fn = on_warning - ImageDemo._broadcast_warning_fn = on_warning - Save._broadcast_warning_fn = on_warning - SaveImage._broadcast_warning_fn = on_warning - - def _set_node_id_on_display(self, cls: type, node_id: str) -> None: - """Inform display nodes of their current node_id for WS tagging.""" - from backend.nodes.preview_image import PreviewImage - from backend.nodes.print_table import PrintTable - from backend.nodes.view_3d import View3D - from backend.nodes.annotations import Annotations - from backend.nodes.value_display import ValueDisplay - from backend.nodes.markup import Markup - from backend.nodes.cross_section import CrossSection - from backend.nodes.cursors import Cursors - from backend.nodes.stats import Stats - from backend.nodes.histogram import Histogram - from backend.nodes.crop_resize_field import CropResizeField - from backend.nodes.rotate_field import RotateField - from backend.nodes.threshold_mask import ThresholdMask - from backend.nodes.mask_morphology import MaskMorphology - from backend.nodes.mask_invert import MaskInvert - from backend.nodes.mask_combine import MaskCombine - from backend.nodes.draw_mask import DrawMask - from backend.nodes.image import Image - from backend.nodes.image_demo import ImageDemo - from backend.nodes.save import Save - from backend.nodes.save_image import SaveImage - if cls in (PreviewImage, PrintTable, View3D, Annotations, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup, - ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask, - Image, ImageDemo, Save, SaveImage): - cls._current_node_id = node_id - def _auto_preview( self, cls: type, diff --git a/backend/execution_context.py b/backend/execution_context.py new file mode 100644 index 0000000..ca19cfb --- /dev/null +++ b/backend/execution_context.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Callable + +Callback = Callable[[str, Any], None] + +_callbacks_var: ContextVar[dict[str, Callback | None]] = ContextVar( + "argonode_execution_callbacks", + default={}, +) +_node_id_var: ContextVar[str | None] = ContextVar("argonode_execution_node_id", default=None) + + +@contextmanager +def execution_callbacks( + *, + preview: Callback | None = None, + table: Callback | None = None, + mesh: Callback | None = None, + overlay: Callback | None = None, + value: Callback | None = None, + warning: Callback | None = None, +): + token = _callbacks_var.set({ + "preview": preview, + "table": table, + "mesh": mesh, + "overlay": overlay, + "value": value, + "warning": warning, + }) + try: + yield + finally: + _callbacks_var.reset(token) + + +@contextmanager +def active_node(node_id: str): + token = _node_id_var.set(str(node_id)) + try: + yield + finally: + _node_id_var.reset(token) + + +def current_node_id() -> str | None: + return _node_id_var.get() + + +def _emit(kind: str, payload: Any) -> None: + callbacks = _callbacks_var.get() + callback = callbacks.get(kind) + node_id = current_node_id() + if callback is not None and node_id: + callback(node_id, payload) + + +def emit_preview(payload: Any) -> None: + _emit("preview", payload) + + +def emit_table(rows: list) -> None: + _emit("table", rows) + + +def emit_mesh(mesh: dict) -> None: + _emit("mesh", mesh) + + +def emit_overlay(overlay: dict) -> None: + _emit("overlay", overlay) + + +def emit_value(payload: Any) -> None: + _emit("value", payload) + + +def emit_warning(message: str) -> None: + _emit("warning", message) diff --git a/backend/nodes/annotations.py b/backend/nodes/annotations.py index 8e74c3e..4deb199 100644 --- a/backend/nodes/annotations.py +++ b/backend/nodes/annotations.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import ( COLORMAPS, DataField, @@ -120,7 +121,4 @@ class Annotations: return (ImageData(annotated, metadata={"annotation_context": context}),) def _send_warning(self, message: str): - fn = Annotations._broadcast_warning_fn - nid = Annotations._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) diff --git a/backend/nodes/crop_resize_field.py b/backend/nodes/crop_resize_field.py index 3c69a21..d9d85b8 100644 --- a/backend/nodes/crop_resize_field.py +++ b/backend/nodes/crop_resize_field.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, datafield_to_uint8, encode_preview @@ -61,20 +62,16 @@ class CropResizeField: x2 = float(np.clip(x2, 0.0, 1.0)) y2 = float(np.clip(y2, 0.0, 1.0)) - if CropResizeField._broadcast_overlay_fn is not None: - CropResizeField._broadcast_overlay_fn( - CropResizeField._current_node_id, - { - "kind": "crop_box", - "image": encode_preview(datafield_to_uint8(field, field.colormap)), - "x1": x1, - "y1": y1, - "x2": x2, - "y2": y2, - "a_locked": corner_a is not None, - "b_locked": corner_b is not None, - }, - ) + emit_overlay({ + "kind": "crop_box", + "image": encode_preview(datafield_to_uint8(field, field.colormap)), + "x1": x1, + "y1": y1, + "x2": x2, + "y2": y2, + "a_locked": corner_a is not None, + "b_locked": corner_b is not None, + }) left = min(x1, x2) right = max(x1, x2) diff --git a/backend/nodes/cross_section.py b/backend/nodes/cross_section.py index e63a74c..b3f28cd 100644 --- a/backend/nodes/cross_section.py +++ b/backend/nodes/cross_section.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview from backend.nodes.helpers import _extend_to_edges @@ -73,19 +74,14 @@ class CrossSection: profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest") - if CrossSection._broadcast_overlay_fn is not None: - image_uri = encode_preview(datafield_to_uint8(field, field.colormap)) - - CrossSection._broadcast_overlay_fn( - CrossSection._current_node_id, - { - "image": image_uri, - "x1": marker_x1, "y1": marker_y1, - "x2": marker_x2, "y2": marker_y2, - "a_locked": marker_pair is not None, - "b_locked": marker_pair is not None, - }, - ) + image_uri = encode_preview(datafield_to_uint8(field, field.colormap)) + emit_overlay({ + "image": image_uri, + "x1": marker_x1, "y1": marker_y1, + "x2": marker_x2, "y2": marker_y2, + "a_locked": marker_pair is not None, + "b_locked": marker_pair is not None, + }) dx_real = (x2 - x1) * field.xreal dy_real = (y2 - y1) * field.yreal diff --git a/backend/nodes/cursors.py b/backend/nodes/cursors.py index c4655a1..5488d70 100644 --- a/backend/nodes/cursors.py +++ b/backend/nodes/cursors.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, LineData, MeasureTable, encode_preview, render_datafield_preview @@ -87,22 +88,18 @@ class Cursors: xa, ya = float(x[idx_a]), float(y[idx_a]) xb, yb = float(x[idx_b]), float(y[idx_b]) - if Cursors._broadcast_overlay_fn is not None: - Cursors._broadcast_overlay_fn( - Cursors._current_node_id, - { - "kind": "line_plot", - "section_title": "Cursors", - "line": y.tolist(), - "x_axis": x.tolist(), - "x1": x1, - "x2": x2, - "y1": float(y1), - "y2": float(y2), - "a_locked": locked, - "b_locked": locked, - }, - ) + emit_overlay({ + "kind": "line_plot", + "section_title": "Cursors", + "line": y.tolist(), + "x_axis": x.tolist(), + "x1": x1, + "x2": x2, + "y1": float(y1), + "y2": float(y2), + "a_locked": locked, + "b_locked": locked, + }) table = MeasureTable([ {"quantity": "A x", "value": xa, "unit": x_unit}, @@ -143,21 +140,17 @@ class Cursors: bx = float(field.xoff + x2 * field.xreal) by = float(field.yoff + y2 * field.yreal) - if Cursors._broadcast_overlay_fn is not None: - Cursors._broadcast_overlay_fn( - Cursors._current_node_id, - { - "kind": "cursor_points", - "section_title": "Cursors", - "image": encode_preview(render_datafield_preview(field, field.colormap)), - "x1": x1, - "y1": y1, - "x2": x2, - "y2": y2, - "a_locked": locked, - "b_locked": locked, - }, - ) + emit_overlay({ + "kind": "cursor_points", + "section_title": "Cursors", + "image": encode_preview(render_datafield_preview(field, field.colormap)), + "x1": x1, + "y1": y1, + "x2": x2, + "y2": y2, + "a_locked": locked, + "b_locked": locked, + }) table = MeasureTable([ {"quantity": "A x", "value": ax, "unit": field.si_unit_xy}, diff --git a/backend/nodes/draw_mask.py b/backend/nodes/draw_mask.py index e39a888..0a8a9d9 100644 --- a/backend/nodes/draw_mask.py +++ b/backend/nodes/draw_mask.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, datafield_to_uint8, encode_preview from backend.nodes.helpers import _parse_mask_strokes, _rasterize_mask @@ -40,17 +41,13 @@ class DrawMask: if invert: mask = np.where(mask > 127, np.uint8(0), np.uint8(255)) - if DrawMask._broadcast_overlay_fn is not None: - DrawMask._broadcast_overlay_fn( - DrawMask._current_node_id, - { - "kind": "mask_paint", - "section_title": "Mask", - "image": encode_preview(datafield_to_uint8(field, "gray")), - "image_width": field.xres, - "image_height": field.yres, - "invert": bool(invert), - }, - ) + emit_overlay({ + "kind": "mask_paint", + "section_title": "Mask", + "image": encode_preview(datafield_to_uint8(field, "gray")), + "image_width": field.xres, + "image_height": field.yres, + "invert": bool(invert), + }) return (mask,) diff --git a/backend/nodes/histogram.py b/backend/nodes/histogram.py index 4888aa4..9199f33 100644 --- a/backend/nodes/histogram.py +++ b/backend/nodes/histogram.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import DataField, MeasureTable @@ -72,22 +73,18 @@ class Histogram: yb = float(counts[idx_b]) if len(counts) else 0.0 count_unit = "count" if y_scale == "linear" else "log10(1+count)" - if Histogram._broadcast_overlay_fn is not None: - Histogram._broadcast_overlay_fn( - Histogram._current_node_id, - { - "kind": "line_plot", - "section_title": "Histogram", - "line": counts.tolist(), - "x_axis": bin_centers.astype(np.float64).tolist(), - "x1": float(np.clip(x1, 0.0, 1.0)), - "x2": float(np.clip(x2, 0.0, 1.0)), - "y1": float(y1), - "y2": float(y2), - "a_locked": False, - "b_locked": False, - }, - ) + emit_overlay({ + "kind": "line_plot", + "section_title": "Histogram", + "line": counts.tolist(), + "x_axis": bin_centers.astype(np.float64).tolist(), + "x1": float(np.clip(x1, 0.0, 1.0)), + "x2": float(np.clip(x2, 0.0, 1.0)), + "y1": float(y1), + "y2": float(y2), + "a_locked": False, + "b_locked": False, + }) table = MeasureTable([ {"quantity": "A position", "value": xa, "unit": field.si_unit_z}, diff --git a/backend/nodes/image.py b/backend/nodes/image.py index 2d46d34..8aa068e 100644 --- a/backend/nodes/image.py +++ b/backend/nodes/image.py @@ -4,6 +4,7 @@ import numpy as np from pathlib import Path from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import COLORMAPS, DataField, resolve_colormap_input from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS, _import_ibw_loader @@ -66,10 +67,7 @@ class Image: return fields def _send_warning(self, message: str): - fn = Image._broadcast_warning_fn - nid = Image._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) @staticmethod @lru_cache(maxsize=32) diff --git a/backend/nodes/markup.py b/backend/nodes/markup.py index 3190946..f85ee25 100644 --- a/backend/nodes/markup.py +++ b/backend/nodes/markup.py @@ -1,5 +1,6 @@ from __future__ import annotations from backend.node_registry import register_node +from backend.execution_context import emit_overlay from backend.data_types import ( DataField, ImageData, @@ -70,17 +71,13 @@ class Markup: metadata=image_metadata(input), ) - if Markup._broadcast_overlay_fn is not None: - Markup._broadcast_overlay_fn( - Markup._current_node_id, - { - "kind": "markup", - "section_title": "Markup", - "image": encode_preview(preview_base), - "shape": str(shape), - "stroke_color": _normalize_markup_color(stroke_color), - "stroke_width": max(1, int(stroke_width)), - }, - ) + emit_overlay({ + "kind": "markup", + "section_title": "Markup", + "image": encode_preview(preview_base), + "shape": str(shape), + "stroke_color": _normalize_markup_color(stroke_color), + "stroke_width": max(1, int(stroke_width)), + }) return (out,) diff --git a/backend/nodes/mask_combine.py b/backend/nodes/mask_combine.py index 28d7fae..ab25489 100644 --- a/backend/nodes/mask_combine.py +++ b/backend/nodes/mask_combine.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import DataField, encode_preview from backend.nodes.helpers import _mask_overlay @@ -53,10 +54,8 @@ class MaskCombine: out = result.astype(np.uint8) * 255 - if field is not None and MaskCombine._broadcast_fn is not None: + if field is not None: overlay = _mask_overlay(field, out) - MaskCombine._broadcast_fn( - MaskCombine._current_node_id, encode_preview(overlay), - ) + emit_preview(encode_preview(overlay)) return (out,) diff --git a/backend/nodes/mask_invert.py b/backend/nodes/mask_invert.py index 1cf3529..4b90e0b 100644 --- a/backend/nodes/mask_invert.py +++ b/backend/nodes/mask_invert.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import DataField, encode_preview from backend.nodes.helpers import _mask_overlay @@ -32,10 +33,8 @@ class MaskInvert: def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple: out = np.where(mask > 127, np.uint8(0), np.uint8(255)) - if field is not None and MaskInvert._broadcast_fn is not None: + if field is not None: overlay = _mask_overlay(field, out) - MaskInvert._broadcast_fn( - MaskInvert._current_node_id, encode_preview(overlay), - ) + emit_preview(encode_preview(overlay)) return (out,) diff --git a/backend/nodes/mask_morphology.py b/backend/nodes/mask_morphology.py index e9603f0..351cb34 100644 --- a/backend/nodes/mask_morphology.py +++ b/backend/nodes/mask_morphology.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import DataField, encode_preview from backend.nodes.helpers import _mask_overlay, _mask_structure @@ -62,10 +63,8 @@ class MaskMorphology: out = result.astype(np.uint8) * 255 - if field is not None and MaskMorphology._broadcast_fn is not None: + if field is not None: overlay = _mask_overlay(field, out) - MaskMorphology._broadcast_fn( - MaskMorphology._current_node_id, encode_preview(overlay), - ) + emit_preview(encode_preview(overlay)) return (out,) diff --git a/backend/nodes/preview_image.py b/backend/nodes/preview_image.py index a955458..ad4d268 100644 --- a/backend/nodes/preview_image.py +++ b/backend/nodes/preview_image.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import ( COLORMAPS, colormap_to_uint8, @@ -68,7 +69,6 @@ class PreviewImage: data_uri = encode_preview(arr_u8) - if PreviewImage._broadcast_fn is not None: - PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri) + emit_preview(data_uri) return () diff --git a/backend/nodes/print_table.py b/backend/nodes/print_table.py index f53f93b..72f1c44 100644 --- a/backend/nodes/print_table.py +++ b/backend/nodes/print_table.py @@ -1,5 +1,6 @@ from __future__ import annotations from backend.node_registry import register_node +from backend.execution_context import emit_table @register_node(display_name="Print Table") @@ -22,6 +23,5 @@ class PrintTable: _current_node_id: str = "" def print_table(self, table: list) -> tuple: - if PrintTable._broadcast_table_fn is not None: - PrintTable._broadcast_table_fn(PrintTable._current_node_id, table) + emit_table(table) return () diff --git a/backend/nodes/rotate_field.py b/backend/nodes/rotate_field.py index 2c2a601..6fd4309 100644 --- a/backend/nodes/rotate_field.py +++ b/backend/nodes/rotate_field.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import DataField @@ -84,10 +85,7 @@ class RotateField: return (result,) def _send_warning(self, message: str): - fn = RotateField._broadcast_warning_fn - nid = RotateField._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) @staticmethod def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]: diff --git a/backend/nodes/save.py b/backend/nodes/save.py index c0536b8..b7fce1c 100644 --- a/backend/nodes/save.py +++ b/backend/nodes/save.py @@ -7,6 +7,7 @@ from pathlib import Path import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8 @@ -255,7 +256,4 @@ class Save: path.write_text("\n".join(lines) + "\n", encoding="utf-8") def _send_warning(self, message: str): - fn = Save._broadcast_warning_fn - nid = Save._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) diff --git a/backend/nodes/save_image.py b/backend/nodes/save_image.py index f122c5e..f9365c1 100644 --- a/backend/nodes/save_image.py +++ b/backend/nodes/save_image.py @@ -4,6 +4,7 @@ import numpy as np from pathlib import Path from backend.node_registry import register_node +from backend.execution_context import emit_warning from backend.data_types import DataField, image_to_uint8 from backend.nodes.helpers import _MAX_SAVE_FIELDS @@ -174,9 +175,6 @@ class SaveImage: raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") def _send_warning(self, message: str): - fn = SaveImage._broadcast_warning_fn - nid = SaveImage._current_node_id - if fn and nid: - fn(nid, message) + emit_warning(message) return () diff --git a/backend/nodes/stats.py b/backend/nodes/stats.py index 59ea289..bcbf780 100644 --- a/backend/nodes/stats.py +++ b/backend/nodes/stats.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_value from backend.data_types import DataField, LineData, MeasureTable from backend.nodes.helpers import ( LINE_OPS, @@ -71,11 +72,9 @@ class Stats: op_entry = ops[operation] fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry result = fn(values) - if Stats._broadcast_value_fn is not None: - Stats._broadcast_value_fn( - Stats._current_node_id, - _scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)), - ) + emit_value( + _scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)), + ) return (result,) def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str: diff --git a/backend/nodes/threshold_mask.py b/backend/nodes/threshold_mask.py index 56bbac2..108cc97 100644 --- a/backend/nodes/threshold_mask.py +++ b/backend/nodes/threshold_mask.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_preview from backend.data_types import DataField, encode_preview from backend.nodes.helpers import _mask_overlay @@ -52,10 +53,7 @@ class ThresholdMask: else: mask = (data < t).astype(np.uint8) * 255 - if ThresholdMask._broadcast_fn is not None: - overlay = _mask_overlay(field, mask) - ThresholdMask._broadcast_fn( - ThresholdMask._current_node_id, encode_preview(overlay), - ) + overlay = _mask_overlay(field, mask) + emit_preview(encode_preview(overlay)) return (mask,) diff --git a/backend/nodes/value_display.py b/backend/nodes/value_display.py index e4a58e8..e21cabf 100644 --- a/backend/nodes/value_display.py +++ b/backend/nodes/value_display.py @@ -1,5 +1,6 @@ from __future__ import annotations from backend.node_registry import register_node +from backend.execution_context import emit_value from backend.data_types import MeasureTable from backend.nodes.helpers import _measurement_entry, _measurement_value, _scalar_payload @@ -38,6 +39,5 @@ class ValueDisplay: unit = row.get("unit", "") if isinstance(row.get("unit"), str) else "" else: numeric = float(value) - if ValueDisplay._broadcast_value_fn is not None: - ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit)) + emit_value(_scalar_payload(numeric, unit)) return (numeric,) diff --git a/backend/nodes/view_3d.py b/backend/nodes/view_3d.py index 8bfcef1..3b1bec7 100644 --- a/backend/nodes/view_3d.py +++ b/backend/nodes/view_3d.py @@ -3,6 +3,7 @@ import base64 import io import numpy as np from backend.node_registry import register_node +from backend.execution_context import emit_mesh from backend.data_types import ( COLORMAPS, DataField, @@ -211,8 +212,7 @@ class View3D: "y_range": [float(field.yoff), float(field.yoff + field.yreal)], } - if View3D._broadcast_mesh_fn is not None: - View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data) + emit_mesh(mesh_data) annotation_context = _annotation_context_from_field(color_field, resolved_colormap) annotation_context["xreal"] = float(field.xreal) diff --git a/backend/server.py b/backend/server.py index 4426c3d..bd702d7 100644 --- a/backend/server.py +++ b/backend/server.py @@ -6,7 +6,11 @@ Routes GET / → serve frontend/index.html GET /static/{path} → serve frontend JS/CSS GET /nodes → JSON dict of all registered node definitions -POST /upload → multipart file upload to input/ +GET /files → list files in the current session upload workspace +GET /folder-files → list compatible files in a picked folder +GET /channels → inspect channels for a picked file +POST /upload → multipart file upload to the current session workspace +POST /upload-folder → create a folder in the current session workspace POST /prompt → submit a workflow; returns {prompt_id} GET /ws → WebSocket upgrade @@ -15,7 +19,7 @@ WebSocket message types sent to clients {"type": "execution_start", "data": {"prompt_id": "..."}} {"type": "executing", "data": {"node": "...", "prompt_id": "..."}} {"type": "preview", "data": {"node_id": "...", "image": "data:..."}} -{"type": "table", "data": {"node_id": "...", "rows": [...]}} +{"type": "table", "data": {"node_id": "...", "rows": [...]} } {"type": "scalar", "data": {"node_id": "...", "value": 1.23, "unit": "nm"}} {"type": "node_timing", "data": {"node_id": "...", "elapsed_ms": 12.34}} {"type": "execution_error", "data": {"node_id": "...", "message": "..."}} @@ -23,39 +27,43 @@ WebSocket message types sent to clients """ from __future__ import annotations + import asyncio import json import logging import sys +from collections import defaultdict +from copy import deepcopy from pathlib import Path from aiohttp import web, WSMsgType + from backend.frontend_build import FrontendBuildError, ensure_frontend_dist_ready -from backend.runtime_paths import ( - ensure_runtime_dirs, - frontend_dir, - frontend_dist_dir, - input_dir, - output_dir, - project_root, +from backend.runtime_paths import ensure_runtime_dirs, frontend_dir, frontend_dist_dir, project_root +from backend.session_runtime import ( + PATH_INPUT_TYPES, + SESSION_HEADER, + SESSION_QUERY, + ensure_session_runtime_dirs, + normalize_relative_upload_path, + resolve_client_path, + server_path_to_client_path, + session_input_dir, + session_upload_uri, + validate_session_id, ) log = logging.getLogger(__name__) FRONTEND_DIR = frontend_dir() DIST_DIR = frontend_dist_dir() -INPUT_DIR = input_dir() -OUTPUT_DIR = output_dir() PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" -# --------------------------------------------------------------------------- -# JSON helper — numpy scalars are not serialisable by default -# --------------------------------------------------------------------------- - class _SafeEncoder(json.JSONEncoder): def default(self, obj): import numpy as np + if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): @@ -81,45 +89,115 @@ def save_png_bytes(target_path: str, payload: bytes) -> Path: return path -# --------------------------------------------------------------------------- -# Application factory -# --------------------------------------------------------------------------- - -def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: - # Import nodes to trigger registration decorators +def create_app( + loop: asyncio.AbstractEventLoop, + *, + allow_local_filesystem: bool = False, +) -> web.Application: import backend.nodes # noqa: F401 - from backend.node_registry import get_all_node_info from backend.execution import ExecutionEngine, new_prompt_id + from backend.node_registry import NODE_CLASS_MAPPINGS, get_all_node_info ensure_runtime_dirs() - engine = ExecutionEngine() - websockets: set[web.WebSocketResponse] = set() + session_engines: dict[str, ExecutionEngine] = {} + session_websockets: dict[str, set[web.WebSocketResponse]] = defaultdict(set) - # ------------------------------------------------------------------ - # WebSocket broadcast helpers - # ------------------------------------------------------------------ + def _is_link(value) -> bool: + return ( + isinstance(value, (list, tuple)) + and len(value) == 2 + and isinstance(value[0], str) + and isinstance(value[1], int) + ) - def broadcast(msg: dict) -> None: - """Schedule a broadcast to all connected WebSocket clients.""" + def require_session_id(request: web.Request) -> str: + raw_session = request.headers.get(SESSION_HEADER) or request.query.get(SESSION_QUERY) + if not raw_session: + if allow_local_filesystem: + raw_session = "desktop-local-session" + else: + raise web.HTTPBadRequest(reason="Missing session id") + + try: + session_id = validate_session_id(raw_session) + except ValueError as exc: + raise web.HTTPBadRequest(reason=str(exc)) from exc + + ensure_session_runtime_dirs(session_id) + return session_id + + def get_session_engine(session_id: str) -> ExecutionEngine: + engine = session_engines.get(session_id) + if engine is None: + engine = ExecutionEngine() + session_engines[session_id] = engine + return engine + + def resolve_request_path(session_id: str, raw_value: str) -> Path: + try: + return resolve_client_path( + raw_value, + session_id=session_id, + allow_local_filesystem=allow_local_filesystem, + ) + except PermissionError as exc: + raise web.HTTPForbidden(reason=str(exc)) from exc + except ValueError as exc: + raise web.HTTPBadRequest(reason=str(exc)) from exc + + def rewrite_prompt_paths(prompt: dict, session_id: str) -> dict: + normalized = deepcopy(prompt) + for node_def in normalized.values(): + class_name = node_def.get("class_type") + cls = NODE_CLASS_MAPPINGS.get(class_name) + if cls is None: + continue + + input_types = cls.INPUT_TYPES() + specs = {} + specs.update(input_types.get("required", {})) + specs.update(input_types.get("optional", {})) + + inputs = node_def.get("inputs", {}) + if not isinstance(inputs, dict): + continue + + for input_name, raw_value in list(inputs.items()): + if _is_link(raw_value) or not isinstance(raw_value, str): + continue + if not raw_value.strip(): + continue + + spec = specs.get(input_name) + input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec + if not isinstance(input_type, str): + continue + if input_type not in PATH_INPUT_TYPES: + continue + + inputs[input_name] = str(resolve_request_path(session_id, raw_value)) + return normalized + + def broadcast(session_id: str, msg: dict) -> None: payload = _dumps(msg) - for ws in list(websockets): + for ws in list(session_websockets.get(session_id, ())): if not ws.closed: asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop) - def on_preview(node_id: str, data_uri: str) -> None: - broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}}) + def on_preview(session_id: str, node_id: str, data_uri: str) -> None: + broadcast(session_id, {"type": "preview", "data": {"node_id": node_id, "image": data_uri}}) - def on_table(node_id: str, rows: list) -> None: - broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}}) + def on_table(session_id: str, node_id: str, rows: list) -> None: + broadcast(session_id, {"type": "table", "data": {"node_id": node_id, "rows": rows}}) - def on_mesh(node_id: str, mesh_data: dict) -> None: - broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}}) + def on_mesh(session_id: str, node_id: str, mesh_data: dict) -> None: + broadcast(session_id, {"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}}) - def on_overlay(node_id: str, overlay_data) -> None: - broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) + def on_overlay(session_id: str, node_id: str, overlay_data) -> None: + broadcast(session_id, {"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) - def on_value(node_id: str, payload) -> None: + def on_value(session_id: str, node_id: str, payload) -> None: if isinstance(payload, dict): value = payload.get("value") unit = payload.get("unit", "") @@ -130,14 +208,10 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: data = {"node_id": node_id, "value": value} if isinstance(unit, str) and unit.strip(): data["unit"] = unit.strip() - broadcast({"type": "scalar", "data": data}) + broadcast(session_id, {"type": "scalar", "data": data}) - def on_warning(node_id: str, message: str) -> None: - broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}}) - - # ------------------------------------------------------------------ - # Route handlers - # ------------------------------------------------------------------ + def on_warning(session_id: str, node_id: str, message: str) -> None: + broadcast(session_id, {"type": "node_warning", "data": {"node_id": node_id, "message": message}}) async def index(request: web.Request) -> web.Response: if not getattr(sys, "frozen", False): @@ -167,88 +241,96 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: ) async def get_nodes(request: web.Request) -> web.Response: - info = get_all_node_info() return web.Response( - text=_dumps(info), + text=_dumps(get_all_node_info()), content_type="application/json", ) async def list_files(request: web.Request) -> web.Response: - """List files in the input/ directory for the file picker widget.""" + session_id = require_session_id(request) + input_path = session_input_dir(session_id) files = sorted( - f.name for f in INPUT_DIR.iterdir() - if f.is_file() and not f.name.startswith(".") - ) if INPUT_DIR.exists() else [] + server_path_to_client_path(entry, session_id) + for entry in input_path.iterdir() + if entry.is_file() and not entry.name.startswith(".") + ) if input_path.exists() else [] return web.Response(text=_dumps(files), content_type="application/json") - async def browse_dir(request: web.Request) -> web.Response: - """ - Server-side directory browser for local file picking. - GET /browse?dir=/some/path → {parent, dirs[], files[]} - """ - dir_path = request.query.get("dir", str(Path.home())) - p = Path(dir_path).expanduser().resolve() - - if not p.is_dir(): - raise web.HTTPBadRequest(reason=f"Not a directory: {p}") - - dirs = [] - files = [] - try: - for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()): - if entry.name.startswith("."): - continue - if entry.is_dir(): - dirs.append(entry.name) - elif entry.is_file(): - files.append(entry.name) - except PermissionError: - pass - + async def create_upload_folder(request: web.Request) -> web.Response: + session_id = require_session_id(request) + body = await request.json() + relative_path = normalize_relative_upload_path(body.get("path", "")) + target = session_input_dir(session_id) / Path(relative_path.as_posix()) + target.mkdir(parents=True, exist_ok=True) return web.Response( - text=_dumps({ - "path": str(p), - "parent": str(p.parent) if p.parent != p else None, - "dirs": dirs, - "files": files, - }), + text=_dumps({"path": session_upload_uri(relative_path)}), content_type="application/json", ) async def get_folder_files(request: web.Request) -> web.Response: - folder_path = request.query.get("folder", "") from backend.nodes.helpers import list_folder_paths - loop = asyncio.get_running_loop() - entries = await loop.run_in_executor(None, list_folder_paths, folder_path) - return web.Response(text=_dumps(entries), content_type="application/json") + + session_id = require_session_id(request) + folder_path = request.query.get("folder", "") + if not folder_path: + return web.Response(text=_dumps([]), content_type="application/json") + + resolved_path = resolve_request_path(session_id, folder_path) + running_loop = asyncio.get_running_loop() + entries = await running_loop.run_in_executor(None, list_folder_paths, str(resolved_path)) + + payload = [] + for entry in entries: + mapped = dict(entry) + if "path" in mapped: + mapped["path"] = server_path_to_client_path(mapped["path"], session_id) + payload.append(mapped) + return web.Response(text=_dumps(payload), content_type="application/json") async def upload_file(request: web.Request) -> web.Response: + session_id = require_session_id(request) reader = await request.multipart() - field = await reader.next() - if field is None or field.name != "file": + relative_path = None + filename = "" + file_bytes = None + + while True: + field = await reader.next() + if field is None: + break + if field.name == "relative_path": + relative_path = await field.text() + continue + if field.name == "file": + filename = Path(field.filename or "upload.bin").name + chunks = [] + while True: + chunk = await field.read_chunk(65536) + if not chunk: + break + chunks.append(chunk) + file_bytes = b"".join(chunks) + + if file_bytes is None: raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body") - filename = Path(field.filename).name # strip any path traversal - dest = INPUT_DIR / filename - with open(dest, "wb") as f: - while True: - chunk = await field.read_chunk(65536) - if not chunk: - break - f.write(chunk) + relative = normalize_relative_upload_path(relative_path or filename) + dest = session_input_dir(session_id) / Path(relative.as_posix()) + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(file_bytes) - return web.Response(text=_dumps({"filename": filename}), content_type="application/json") + return web.Response( + text=_dumps({"filename": filename, "path": session_upload_uri(relative)}), + content_type="application/json", + ) async def download_file(request: web.Request) -> web.Response: - """Accept a blob POST and return it with Content-Disposition: attachment.""" body = await request.read() filename = request.query.get("filename", "workflow.png") return web.Response( body=body, content_type="application/octet-stream", - headers={ - "Content-Disposition": f'attachment; filename="{filename}"', - }, + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) async def save_workflow_png(request: web.Request) -> web.Response: @@ -266,34 +348,39 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: ) async def get_channels(request: web.Request) -> web.Response: - """Return available channels for a given file path.""" from backend.nodes.helpers import list_channels + + session_id = require_session_id(request) filepath = request.query.get("file", "") if not filepath: return web.Response( text=_dumps([{"name": "field", "type": "DATA_FIELD"}]), content_type="application/json", ) - channels = await loop.run_in_executor(None, list_channels, filepath) + + resolved_path = resolve_request_path(session_id, filepath) + channels = await loop.run_in_executor(None, list_channels, str(resolved_path)) return web.Response(text=_dumps(channels), content_type="application/json") async def submit_prompt(request: web.Request) -> web.Response: + session_id = require_session_id(request) body = await request.json() prompt = body.get("prompt") if not isinstance(prompt, dict) or not prompt: raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict") + normalized_prompt = rewrite_prompt_paths(prompt, session_id) prompt_id = new_prompt_id() + engine = get_session_engine(session_id) - # Run execution in a thread pool so scipy doesn't block the event loop async def run(): - broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}}) + broadcast(session_id, {"type": "execution_start", "data": {"prompt_id": prompt_id}}) def on_start(node_id: str) -> None: - broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}}) + broadcast(session_id, {"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}}) def on_done(node_id: str, elapsed_ms: float) -> None: - broadcast({ + broadcast(session_id, { "type": "node_timing", "data": {"node_id": node_id, "elapsed_ms": elapsed_ms}, }) @@ -302,21 +389,21 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: await loop.run_in_executor( None, lambda: engine.execute( - prompt, + normalized_prompt, on_node_start=on_start, on_node_done=on_done, - on_preview=on_preview, - on_table=on_table, - on_mesh=on_mesh, - on_overlay=on_overlay, - on_value=on_value, - on_warning=on_warning, + on_preview=lambda node_id, payload: on_preview(session_id, node_id, payload), + on_table=lambda node_id, rows: on_table(session_id, node_id, rows), + on_mesh=lambda node_id, mesh_data: on_mesh(session_id, node_id, mesh_data), + on_overlay=lambda node_id, overlay_data: on_overlay(session_id, node_id, overlay_data), + on_value=lambda node_id, payload: on_value(session_id, node_id, payload), + on_warning=lambda node_id, message: on_warning(session_id, node_id, message), ), ) - broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}}) + broadcast(session_id, {"type": "execution_complete", "data": {"prompt_id": prompt_id}}) except Exception as exc: log.exception("Execution error") - broadcast({ + broadcast(session_id, { "type": "execution_error", "data": {"node_id": "", "message": str(exc)}, }) @@ -328,32 +415,40 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: ) async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + session_id = require_session_id(request) ws = web.WebSocketResponse() await ws.prepare(request) - websockets.add(ws) - log.info("WebSocket client connected (%d total)", len(websockets)) + session_websockets[session_id].add(ws) + log.info( + "WebSocket client connected for session %s (%d total in session)", + session_id, + len(session_websockets[session_id]), + ) try: async for msg in ws: if msg.type == WSMsgType.TEXT: - pass # clients don't need to send anything currently + pass elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): break finally: - websockets.discard(ws) - log.info("WebSocket client disconnected (%d total)", len(websockets)) + session_websockets[session_id].discard(ws) + if not session_websockets[session_id]: + session_websockets.pop(session_id, None) + log.info( + "WebSocket client disconnected for session %s (%d remaining in session)", + session_id, + len(session_websockets.get(session_id, ())), + ) return ws - # ------------------------------------------------------------------ - # App assembly - # ------------------------------------------------------------------ - app = web.Application() + app["allow_local_filesystem"] = allow_local_filesystem app.router.add_get("/", index) app.router.add_get("/nodes", get_nodes) app.router.add_get("/files", list_files) - app.router.add_get("/browse", browse_dir) app.router.add_get("/folder-files", get_folder_files) + app.router.add_post("/upload-folder", create_upload_folder) app.router.add_post("/upload", upload_file) app.router.add_post("/download", download_file) app.router.add_post("/save-workflow-png", save_workflow_png) @@ -361,26 +456,24 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: app.router.add_post("/prompt", submit_prompt) app.router.add_get("/ws", websocket_handler) - # Serve frontend static files (Vite build or raw) if (DIST_DIR / "assets").exists(): app.router.add_static("/assets", DIST_DIR / "assets") if FRONTEND_DIR.exists(): app.router.add_static("/static", FRONTEND_DIR) - # CORS — allow any origin (local dev only) async def _cors_middleware(app_, handler): async def middleware(request): if request.method == "OPTIONS": return web.Response(headers={ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Allow-Headers": f"Content-Type, {SESSION_HEADER}", }) response = await handler(request) response.headers["Access-Control-Allow-Origin"] = "*" return response + return middleware app.middlewares.append(_cors_middleware) - return app diff --git a/backend/session_runtime.py b/backend/session_runtime.py new file mode 100644 index 0000000..803ce54 --- /dev/null +++ b/backend/session_runtime.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import re +from pathlib import Path, PurePosixPath + +from backend.runtime_paths import app_data_dir, demo_dir + +SESSION_HEADER = "X-Argonode-Session" +SESSION_QUERY = "session" +SESSION_URI_PREFIX = "session://uploads/" + +PATH_INPUT_TYPES = {"FILE_PICKER", "FILE_PATH", "FOLDER_PICKER", "DIRECTORY"} + +_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_-]{7,127}$") + + +def validate_session_id(session_id: str) -> str: + text = str(session_id or "").strip() + if not _SESSION_ID_RE.fullmatch(text): + raise ValueError("Invalid session id") + return text + + +def session_root_dir(session_id: str) -> Path: + validated = validate_session_id(session_id) + return app_data_dir() / "sessions" / validated + + +def session_input_dir(session_id: str) -> Path: + return session_root_dir(session_id) / "input" + + +def session_output_dir(session_id: str) -> Path: + return session_root_dir(session_id) / "output" + + +def ensure_session_runtime_dirs(session_id: str) -> tuple[Path, Path]: + input_path = session_input_dir(session_id) + output_path = session_output_dir(session_id) + input_path.mkdir(parents=True, exist_ok=True) + output_path.mkdir(parents=True, exist_ok=True) + return input_path, output_path + + +def normalize_relative_upload_path(raw_path: str) -> PurePosixPath: + raw_text = str(raw_path or "").replace("\\", "/").strip() + if not raw_text: + raise ValueError("Missing upload path") + + path = PurePosixPath(raw_text) + if path.is_absolute(): + raise ValueError("Upload paths must be relative") + + parts: list[str] = [] + for part in path.parts: + if part in ("", "."): + continue + if part == "..": + raise ValueError("Upload paths cannot escape the session directory") + if "\x00" in part: + raise ValueError("Upload paths cannot contain NUL bytes") + parts.append(part) + + if not parts: + raise ValueError("Upload paths must contain at least one path segment") + + return PurePosixPath(*parts) + + +def session_upload_uri(relative_path: str | PurePosixPath) -> str: + normalized = normalize_relative_upload_path(str(relative_path)) + return f"{SESSION_URI_PREFIX}{normalized.as_posix()}" + + +def session_uri_to_relative_path(value: str) -> PurePosixPath | None: + text = str(value or "").strip() + if not text.startswith(SESSION_URI_PREFIX): + return None + return normalize_relative_upload_path(text[len(SESSION_URI_PREFIX):]) + + +def is_path_within(root: Path, candidate: Path) -> bool: + try: + candidate.resolve(strict=False).relative_to(root.resolve(strict=False)) + return True + except ValueError: + return False + + +def server_path_to_client_path(path_value: str | Path, session_id: str) -> str: + path = Path(path_value).expanduser().resolve(strict=False) + session_input = session_input_dir(session_id).resolve(strict=False) + if is_path_within(session_input, path): + rel = path.relative_to(session_input) + return session_upload_uri(rel.as_posix()) + return str(path) + + +def resolve_client_path( + value: str, + *, + session_id: str, + allow_local_filesystem: bool, +) -> Path: + text = str(value or "").strip() + if not text: + return Path("") + + rel = session_uri_to_relative_path(text) + if rel is not None: + return (session_input_dir(session_id) / Path(rel.as_posix())).resolve(strict=False) + + candidate = Path(text).expanduser() + if not candidate.is_absolute(): + demo_candidate = (demo_dir() / text).expanduser().resolve(strict=False) + if demo_candidate.exists(): + return demo_candidate + + if not candidate.is_absolute(): + if allow_local_filesystem: + return candidate.resolve(strict=False) + raise PermissionError("Browser sessions may only use files uploaded through Browse.") + + resolved = candidate.resolve(strict=False) + if allow_local_filesystem: + return resolved + + session_root = session_root_dir(session_id).resolve(strict=False) + if is_path_within(session_root, resolved): + return resolved + + raise PermissionError("Path is outside the current session workspace.") diff --git a/desktop.py b/desktop.py index e77c683..1528826 100644 --- a/desktop.py +++ b/desktop.py @@ -44,6 +44,19 @@ class _Api: return result[0] return None + def open_folder_dialog(self) -> str | None: + """Open a native folder picker and return the selected path (or None).""" + win = self._window_ref[0] + if win is None: + return None + result = win.create_file_dialog( + webview.FOLDER_DIALOG, + allow_multiple=False, + ) + if result and len(result) > 0: + return result[0] + return None + def choose_save_workflow_png_path(self, default_filename: str = "workflow.png") -> str | None: """Open a native save dialog and return the chosen PNG path (or None).""" win = self._window_ref[0] @@ -90,7 +103,7 @@ def _run_server(host: str, port: int, ready: threading.Event, state: dict[str, o state["loop"] = loop async def start() -> None: - app = create_app(loop) + app = create_app(loop, allow_local_filesystem=True) runner = web.AppRunner(app, access_log=None) await runner.setup() site = web.TCPSite(runner, host, port) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 4f81f79..e13e24e 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -4,13 +4,13 @@ import React, { import { ReactFlow, Background, Controls, MiniMap, useNodesState, useEdgesState, addEdge, useReactFlow, - ReactFlowProvider, getViewportForBounds, + ReactFlowProvider, getViewportForBounds, PanOnScrollMode, SelectionMode, } from '@xyflow/react'; import '@xyflow/react/dist/style.css'; import CustomNode, { NodeContext } from './CustomNode'; -import FileBrowser from './FileBrowser'; import * as api from './api'; +import { pickNativeDirectorySelection, pickNativeFileSelection } from './nativePicker'; import { toBlob } from 'html-to-image'; import { embedWorkflow, extractWorkflow } from './pngMetadata'; import { captureViewportBlob as captureWorkflowViewportBlob } from './workflowCapture'; @@ -791,7 +791,6 @@ function Flow() { const [edges, setEdges, onEdgesChange] = useEdgesState([]); const [status, setStatus] = useState({ text: 'Connecting…', level: 'info' }); const [contextMenu, setContextMenu] = useState(null); - const [fileBrowserState, setFileBrowserState] = useState(null); const nodeDefsRef = useRef({}); const nextIdRef = useRef(1); @@ -1481,22 +1480,68 @@ function Flow() { // ── File browser ──────────────────────────────────────────────────── - const openFileBrowser = useCallback((callback, { selectionMode = 'file' } = {}) => { + const uploadBrowserSelection = useCallback(async (selection, selectionMode) => { + if (!selection) return null; + + if (selectionMode === 'folder') { + const rootName = String(selection.rootName || '').trim(); + if (!rootName) { + throw new Error('Selected folder is empty or could not be read.'); + } + + setStatus({ + text: `Importing folder "${rootName}" into this session…`, + level: 'info', + }); + + const folder = await api.createUploadFolder(rootName); + for (const entry of selection.entries || []) { + await api.uploadFile(entry.file, { relativePath: entry.relativePath }); + } + return folder.path; + } + + const [entry] = selection.entries || []; + if (!entry) return null; + + setStatus({ + text: `Uploading ${entry.file.name}…`, + level: 'info', + }); + + const uploaded = await api.uploadFile(entry.file, { relativePath: entry.relativePath }); + return uploaded.path; + }, []); + + const openFileBrowser = useCallback(async (callback, { selectionMode = 'file' } = {}) => { if (selectionMode === 'folder' && window.pywebview?.api?.open_folder_dialog) { window.pywebview.api.open_folder_dialog().then((path) => { if (path) callback(path); }); return; } - // Use native file picker when running inside pywebview (desktop app) if (selectionMode === 'file' && window.pywebview?.api?.open_file_dialog) { window.pywebview.api.open_file_dialog().then((path) => { if (path) callback(path); }); return; } - setFileBrowserState({ callback, selectionMode }); - }, []); + + try { + const selection = selectionMode === 'folder' + ? await pickNativeDirectorySelection() + : await pickNativeFileSelection(); + if (!selection) return; + + const uploadedPath = await uploadBrowserSelection(selection, selectionMode); + if (uploadedPath) callback(uploadedPath); + } catch (error) { + setStatus({ + text: `Browse failed: ${error.message || String(error)}`, + level: 'error', + }); + } + }, [uploadBrowserSelection]); // ── Node context value (stable) ───────────────────────────────────── @@ -1782,6 +1827,21 @@ function Flow() { setTimeout(() => reactFlow.updateNodeInternals(String(groupId)), 0); }, [reactFlow, setNodes]); + const renameGroup = useCallback((groupId, label) => { + const nextLabel = String(label || '').trim() || 'group'; + setNodes((existing) => existing.map((node) => { + if (String(node.id) !== String(groupId) || node.data?.className !== 'Group') return node; + if (String(node.data?.label || 'group') === nextLabel) return node; + return { + ...node, + data: { + ...node.data, + label: nextLabel, + }, + }; + })); + }, [setNodes]); + const contextValue = useMemo(() => ({ onWidgetChange, onRuntimeValuesChange, @@ -1789,8 +1849,9 @@ function Flow() { onManualTrigger, onToggleGroupCollapse: toggleGroupCollapse, onResizeGroup: resizeGroup, + onRenameGroup: renameGroup, onUngroup: ungroupGroup, - }), [onRuntimeValuesChange, onWidgetChange, openFileBrowser, onManualTrigger, resizeGroup, toggleGroupCollapse, ungroupGroup]); + }), [onRuntimeValuesChange, onWidgetChange, openFileBrowser, onManualTrigger, renameGroup, resizeGroup, toggleGroupCollapse, ungroupGroup]); const clearGraph = useCallback(() => { setNodes([]); @@ -2602,6 +2663,12 @@ function Flow() { nodeTypes={NODE_TYPES} onPaneContextMenu={onPaneContextMenu} colorMode="dark" + panOnDrag={[1]} + panOnScroll + panOnScrollMode={PanOnScrollMode.Free} + zoomOnScroll={false} + selectionOnDrag + selectionMode={SelectionMode.Partial} multiSelectionKeyCode={['Shift']} deleteKeyCode={['Backspace', 'Delete']} defaultEdgeOptions={{ type: 'default' }} @@ -2631,14 +2698,6 @@ function Flow() { )} - {/* File browser modal */} - {fileBrowserState && ( - { fileBrowserState.callback(path); setFileBrowserState(null); }} - onClose={() => setFileBrowserState(null)} - /> - )} ); diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 7983ad6..2c66f20 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -45,6 +45,9 @@ function GroupNode({ id, data }) { const childCount = Number(data.childCount) || 0; const collapsed = !!data.collapsed; const maxRows = Math.max(proxyInputs.length, proxyOutputs.length, collapsed ? 1 : 0); + const [isEditingLabel, setIsEditingLabel] = useState(false); + const [draftLabel, setDraftLabel] = useState(String(data.label || 'group')); + const labelInputRef = useRef(null); const selected = useStore( useCallback( (s) => { @@ -62,6 +65,33 @@ function GroupNode({ id, data }) { [id], ), ); + const displayLabel = String(data.label || 'group'); + + useEffect(() => { + if (!isEditingLabel) { + setDraftLabel(displayLabel); + } + }, [displayLabel, isEditingLabel]); + + useEffect(() => { + if (!isEditingLabel) return; + labelInputRef.current?.focus(); + labelInputRef.current?.select(); + }, [isEditingLabel]); + + const commitLabel = useCallback(() => { + const nextLabel = String(draftLabel || '').trim() || 'group'; + setIsEditingLabel(false); + setDraftLabel(nextLabel); + if (nextLabel !== displayLabel) { + ctx.onRenameGroup?.(id, nextLabel); + } + }, [ctx, displayLabel, draftLabel, id]); + + const cancelLabelEdit = useCallback(() => { + setDraftLabel(displayLabel); + setIsEditingLabel(false); + }, [displayLabel]); return ( <> @@ -84,7 +114,40 @@ function GroupNode({ id, data }) { > {collapsed ? '▸' : '▾'} - {formatUiLabel(data.label || 'group')} + {isEditingLabel ? ( + setDraftLabel(event.target.value)} + onBlur={commitLabel} + onClick={(event) => event.stopPropagation()} + onPointerDown={(event) => event.stopPropagation()} + onKeyDown={(event) => { + if (event.key === 'Enter') { + event.preventDefault(); + commitLabel(); + } else if (event.key === 'Escape') { + event.preventDefault(); + cancelLabelEdit(); + } + }} + /> + ) : ( + + )}
- )} - -
- - {/* File list */} -
- {loading &&
Loading…
} - {error &&
Error: {error}
} - - {!loading && !error && ( - <> - {/* Parent directory */} - {parent && ( -
navigate(parent)}> - ⬆ .. -
- )} - - {/* Directories */} - {dirs.map((d) => ( -
navigate(path + '/' + d)} - > - 📁 {d} -
- ))} - - {/* Files */} - {files.map((f) => ( -
{ - if (selectionMode === 'folder') return; - onSelect(path + '/' + f); - onClose(); - }} - > - {f} -
- ))} - - {dirs.length === 0 && files.length === 0 && ( -
Empty directory
- )} - - )} -
- - - ); -} diff --git a/frontend/src/api.js b/frontend/src/api.js index 225240f..7f129ca 100644 --- a/frontend/src/api.js +++ b/frontend/src/api.js @@ -5,49 +5,105 @@ * and production same-origin serving both work transparently. */ -// ── REST helpers ────────────────────────────────────────────────────── +const SESSION_STORAGE_KEY = 'argonode-session-id'; + +let _sessionId = null; +let _ws = null; +let _handler = null; +let _reconnectTimer = null; + +function generateSessionId() { + if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') { + return crypto.randomUUID(); + } + return `session-${Math.random().toString(36).slice(2)}-${Date.now().toString(36)}`; +} + +export function getSessionId() { + if (_sessionId) return _sessionId; + + if (typeof window === 'undefined') { + _sessionId = 'session-test-runner'; + return _sessionId; + } + + try { + const stored = window.sessionStorage?.getItem(SESSION_STORAGE_KEY); + if (stored) { + _sessionId = stored; + return _sessionId; + } + } catch { + // Fall through to in-memory session id generation. + } + + _sessionId = generateSessionId(); + try { + window.sessionStorage?.setItem(SESSION_STORAGE_KEY, _sessionId); + } catch { + // Ignore storage failures and keep the in-memory id. + } + return _sessionId; +} + +function withSessionHeaders(init = {}) { + const headers = new Headers(init.headers || {}); + headers.set('X-Argonode-Session', getSessionId()); + return { ...init, headers }; +} + +async function sessionFetch(input, init) { + return fetch(input, withSessionHeaders(init)); +} export async function getNodes() { - const r = await fetch('/nodes'); + const r = await sessionFetch('/nodes'); if (!r.ok) throw new Error(`GET /nodes failed: ${r.status}`); return r.json(); } export async function getFiles() { - const r = await fetch('/files'); + const r = await sessionFetch('/files'); if (!r.ok) return []; return r.json(); } -export async function browse(dir) { - const url = dir ? `/browse?dir=${encodeURIComponent(dir)}` : '/browse'; - const r = await fetch(url); - if (!r.ok) throw new Error(`Browse failed: ${r.status}`); +export async function createUploadFolder(relativePath) { + const r = await sessionFetch('/upload-folder', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ path: relativePath }), + }); + if (!r.ok) throw new Error(`Create folder failed: ${r.status}`); return r.json(); } -export async function uploadFile(file) { +export async function uploadFile(file, { relativePath = '' } = {}) { const fd = new FormData(); + if (relativePath) fd.append('relative_path', relativePath); fd.append('file', file); - const r = await fetch('/upload', { method: 'POST', body: fd }); - if (!r.ok) throw new Error(`Upload failed: ${r.status}`); + const r = await sessionFetch('/upload', { method: 'POST', body: fd }); + if (!r.ok) { + const text = await r.text(); + throw new Error(`Upload failed (${r.status}): ${text}`); + } return r.json(); } export async function getChannels(filepath) { - const r = await fetch(`/channels?file=${encodeURIComponent(filepath)}`); + const r = await sessionFetch(`/channels?file=${encodeURIComponent(filepath)}`); if (!r.ok) return [{ name: 'field', type: 'DATA_FIELD' }]; return r.json(); } export async function getFolderFiles(folderpath) { - const r = await fetch(`/folder-files?folder=${encodeURIComponent(folderpath)}`); + const r = await sessionFetch(`/folder-files?folder=${encodeURIComponent(folderpath)}`); if (!r.ok) return []; return r.json(); } export async function runPrompt(prompt) { - const r = await fetch('/prompt', { + const r = await sessionFetch('/prompt', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ prompt }), @@ -59,21 +115,16 @@ export async function runPrompt(prompt) { return r.json(); } -// ── WebSocket ───────────────────────────────────────────────────────── - -let _ws = null; -let _handler = null; -let _reconnectTimer = null; - export function setMessageHandler(fn) { _handler = fn; } export function initWS() { - if (_ws && _ws.readyState < 2) return; // already open or connecting + if (_ws && _ws.readyState < 2) return; const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - _ws = new WebSocket(`${protocol}//${window.location.host}/ws`); + const session = encodeURIComponent(getSessionId()); + _ws = new WebSocket(`${protocol}//${window.location.host}/ws?session=${session}`); _ws.onopen = () => { console.log('[argonode] WebSocket connected'); diff --git a/frontend/src/nativePicker.js b/frontend/src/nativePicker.js new file mode 100644 index 0000000..d03f688 --- /dev/null +++ b/frontend/src/nativePicker.js @@ -0,0 +1,118 @@ +const FILE_ACCEPT = [ + '.png', '.jpg', '.jpeg', '.tiff', '.tif', '.bmp', + '.npy', '.npz', + '.gwy', '.sxm', '.ibw', + '.ttf', '.otf', '.woff', '.woff2', +].join(','); + +function normalizeRelativePath(path) { + return String(path || '').replace(/\\/g, '/').replace(/^\/+/, ''); +} + +function pickWithInput({ directory = false } = {}) { + return new Promise((resolve) => { + const input = document.createElement('input'); + input.type = 'file'; + input.style.position = 'fixed'; + input.style.left = '-9999px'; + if (directory) { + input.multiple = true; + input.setAttribute('webkitdirectory', ''); + input.setAttribute('directory', ''); + } else { + input.accept = FILE_ACCEPT; + } + + const cleanup = () => { + input.remove(); + }; + + input.addEventListener('change', () => { + const files = Array.from(input.files || []); + cleanup(); + resolve(files); + }, { once: true }); + + document.body.appendChild(input); + input.click(); + }); +} + +async function collectDirectoryEntries(handle, prefix = handle.name) { + const entries = []; + for await (const [name, child] of handle.entries()) { + const relativePath = prefix ? `${prefix}/${name}` : name; + if (child.kind === 'file') { + const file = await child.getFile(); + entries.push({ file, relativePath: normalizeRelativePath(relativePath) }); + continue; + } + if (child.kind === 'directory') { + entries.push(...await collectDirectoryEntries(child, relativePath)); + } + } + return entries; +} + +export async function pickNativeFileSelection() { + try { + if (typeof window.showOpenFilePicker === 'function') { + const [handle] = await window.showOpenFilePicker({ + multiple: false, + types: [{ + description: 'Supported files', + accept: { + 'application/octet-stream': ['.npy', '.npz', '.gwy', '.sxm', '.ibw', '.ttf', '.otf', '.woff', '.woff2'], + 'image/*': ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'], + }, + }], + }); + if (!handle) return null; + const file = await handle.getFile(); + return { + rootName: file.name, + entries: [{ file, relativePath: normalizeRelativePath(file.name) }], + }; + } + } catch (error) { + if (error?.name !== 'AbortError') throw error; + return null; + } + + const files = await pickWithInput({ directory: false }); + if (files.length === 0) return null; + return { + rootName: files[0].name, + entries: [{ file: files[0], relativePath: normalizeRelativePath(files[0].name) }], + }; +} + +export async function pickNativeDirectorySelection() { + try { + if (typeof window.showDirectoryPicker === 'function') { + const handle = await window.showDirectoryPicker(); + if (!handle) return null; + const entries = await collectDirectoryEntries(handle, handle.name); + return { + rootName: handle.name, + entries, + }; + } + } catch (error) { + if (error?.name !== 'AbortError') throw error; + return null; + } + + const files = await pickWithInput({ directory: true }); + if (files.length === 0) return null; + const entries = files.map((file) => ({ + file, + relativePath: normalizeRelativePath(file.webkitRelativePath || file.name), + })); + const rootName = entries[0]?.relativePath.split('/')[0] || ''; + if (!rootName) return null; + return { + rootName, + entries, + }; +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index 22ffc63..ccc6a62 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -259,6 +259,34 @@ html, body, #root { flex: 1; } +.group-title-button { + flex: 1; + min-width: 0; + padding: 0; + border: 0; + background: transparent; + color: var(--text-heading); + font: inherit; + font-weight: inherit; + text-align: left; + cursor: text; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.group-title-input { + flex: 1; + min-width: 0; + height: 22px; + padding: 2px 6px; + border: 1px solid rgba(148, 163, 184, 0.45); + border-radius: 4px; + background: rgba(15, 23, 42, 0.72); + color: var(--text-heading); + font: inherit; +} + .group-node-actions { display: flex; align-items: center; diff --git a/frontend/vite.config.js b/frontend/vite.config.js index f674bd5..eb12a43 100644 --- a/frontend/vite.config.js +++ b/frontend/vite.config.js @@ -9,9 +9,9 @@ export default defineConfig({ proxy: { '/nodes': 'http://127.0.0.1:8188', '/files': 'http://127.0.0.1:8188', - '/browse': 'http://127.0.0.1:8188', '/folder-files': 'http://127.0.0.1:8188', '/channels': 'http://127.0.0.1:8188', + '/upload-folder': 'http://127.0.0.1:8188', '/upload': 'http://127.0.0.1:8188', '/download': 'http://127.0.0.1:8188', '/prompt': 'http://127.0.0.1:8188', diff --git a/tests/test_session_runtime.py b/tests/test_session_runtime.py new file mode 100644 index 0000000..78f1eb8 --- /dev/null +++ b/tests/test_session_runtime.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import threading +from pathlib import Path + +import pytest + +from backend.execution_context import active_node, emit_warning, execution_callbacks +from backend.session_runtime import ( + ensure_session_runtime_dirs, + resolve_client_path, + server_path_to_client_path, + session_upload_uri, +) + + +def test_session_paths_round_trip(monkeypatch, tmp_path): + monkeypatch.setenv("ARGONODE_APPDATA", str(tmp_path / "appdata")) + + session_id = "session-test-1234" + input_dir, _ = ensure_session_runtime_dirs(session_id) + target = input_dir / "picked-folder" / "image.png" + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(b"png") + + client_path = session_upload_uri("picked-folder/image.png") + resolved = resolve_client_path(client_path, session_id=session_id, allow_local_filesystem=False) + + assert resolved == target.resolve() + assert server_path_to_client_path(target, session_id) == client_path + + +def test_browser_sessions_cannot_escape_workspace(monkeypatch, tmp_path): + monkeypatch.setenv("ARGONODE_APPDATA", str(tmp_path / "appdata")) + + session_id = "session-test-5678" + ensure_session_runtime_dirs(session_id) + + outside_path = (tmp_path / "outside" / "secret.dat").resolve() + + with pytest.raises(PermissionError): + resolve_client_path(str(outside_path), session_id=session_id, allow_local_filesystem=False) + + +def test_execution_callbacks_are_thread_local(): + results = [] + lock = threading.Lock() + barrier = threading.Barrier(2) + + def worker(label: str): + def on_warning(node_id: str, message: str): + with lock: + results.append((label, node_id, message)) + + with execution_callbacks(warning=on_warning): + with active_node(f"node-{label}"): + barrier.wait(timeout=5) + emit_warning(f"warning-{label}") + + threads = [ + threading.Thread(target=worker, args=("a",)), + threading.Thread(target=worker, args=("b",)), + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join(timeout=5) + + assert sorted(results) == [ + ("a", "node-a", "warning-a"), + ("b", "node-b", "warning-b"), + ]