rework web server so multiple clients can be server at a time

This commit is contained in:
matei jordache
2026-03-27 16:18:22 -07:00
parent 1eda4030d1
commit 558046e7aa
33 changed files with 1042 additions and 551 deletions

View File

@@ -32,6 +32,7 @@ from time import perf_counter
from typing import Any, Callable from typing import Any, Callable
from backend.node_registry import NODE_CLASS_MAPPINGS from backend.node_registry import NODE_CLASS_MAPPINGS
from backend.execution_context import active_node, execution_callbacks
def _is_link(value: Any) -> bool: def _is_link(value: Any) -> bool:
@@ -85,63 +86,66 @@ class ExecutionEngine:
node_outputs: dict[str, tuple] = {} node_outputs: dict[str, tuple] = {}
node_output_signatures: dict[str, tuple[str, ...]] = {} node_output_signatures: dict[str, tuple[str, ...]] = {}
# Inject display callbacks before execution with execution_callbacks(
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning) preview=on_preview,
table=on_table,
mesh=on_mesh,
overlay=on_overlay,
value=on_value,
warning=on_warning,
):
for node_id in order:
node_def = prompt[node_id]
class_name = node_def["class_type"]
for node_id in order: if class_name not in NODE_CLASS_MAPPINGS:
node_def = prompt[node_id] raise ValueError(f"Unknown node type: '{class_name}'")
class_name = node_def["class_type"]
if class_name not in NODE_CLASS_MAPPINGS: cls = NODE_CLASS_MAPPINGS[class_name]
raise ValueError(f"Unknown node type: '{class_name}'") raw_inputs = node_def.get("inputs", {})
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
cls = NODE_CLASS_MAPPINGS[class_name] cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
raw_inputs = node_def.get("inputs", {}) if cache_entry is not None:
input_types = cls.INPUT_TYPES() result = self._clone_cached_outputs(cache_entry["outputs"])
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types) elapsed_ms = 0.0
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures) else:
if on_node_start:
on_node_start(node_id)
# Let display nodes know their node_id so they can tag WS messages instance = cls()
self._set_node_id_on_display(cls, node_id) func = getattr(instance, cls.FUNCTION)
start_time = perf_counter()
with active_node(node_id):
result = func(**inputs)
elapsed_ms = (perf_counter() - start_time) * 1000.0
cache_entry = self._get_cached_entry(node_id, class_name, input_signature) # Nodes must return a tuple; coerce single values just in case
if cache_entry is not None: if not isinstance(result, tuple):
result = self._clone_cached_outputs(cache_entry["outputs"]) result = (result,)
elapsed_ms = 0.0
else:
if on_node_start:
on_node_start(node_id)
instance = cls() node_outputs[node_id] = result
func = getattr(instance, cls.FUNCTION) output_signatures = tuple(self._fingerprint_value(value) for value in result)
start_time = perf_counter() node_output_signatures[node_id] = output_signatures
result = func(**inputs)
elapsed_ms = (perf_counter() - start_time) * 1000.0
# Nodes must return a tuple; coerce single values just in case if cache_entry is None and self._node_cacheable(cls):
if not isinstance(result, tuple): self._store_cache_entry(
result = (result,) node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
node_outputs[node_id] = result # Auto-preview: broadcast a thumbnail for any DATA_FIELD,
output_signatures = tuple(self._fingerprint_value(value) for value in result) # IMAGE, or table-like output so every node shows its result.
node_output_signatures[node_id] = output_signatures if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
if cache_entry is None and self._node_cacheable(cls): if on_node_done:
self._store_cache_entry( on_node_done(node_id, elapsed_ms)
node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
if on_node_done:
on_node_done(node_id, elapsed_ms)
return node_outputs return node_outputs
@@ -421,88 +425,6 @@ class ExecutionEngine:
return deepcopy(value) return deepcopy(value)
return value return value
def _inject_display_callbacks(
self,
on_preview: Callable | None,
on_table: Callable | None,
on_mesh: Callable | None = None,
on_overlay: Callable | None = None,
on_value: Callable | None = None,
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.preview_image import PreviewImage
from backend.nodes.print_table import PrintTable
from backend.nodes.view_3d import View3D
from backend.nodes.annotations import Annotations
from backend.nodes.value_display import ValueDisplay
from backend.nodes.markup import Markup
from backend.nodes.cross_section import CrossSection
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
from backend.nodes.histogram import Histogram
from backend.nodes.crop_resize_field import CropResizeField
from backend.nodes.rotate_field import RotateField
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.mask_morphology import MaskMorphology
from backend.nodes.mask_invert import MaskInvert
from backend.nodes.mask_combine import MaskCombine
from backend.nodes.draw_mask import DrawMask
from backend.nodes.save import Save
from backend.nodes.save_image import SaveImage
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
MaskMorphology._broadcast_fn = on_preview
MaskInvert._broadcast_fn = on_preview
MaskCombine._broadcast_fn = on_preview
DrawMask._broadcast_overlay_fn = on_overlay
View3D._broadcast_mesh_fn = on_mesh
Annotations._broadcast_warning_fn = on_warning
PrintTable._broadcast_table_fn = on_table
ValueDisplay._broadcast_value_fn = on_value
Stats._broadcast_value_fn = on_value
Histogram._broadcast_overlay_fn = on_overlay
CrossSection._broadcast_overlay_fn = on_overlay
Cursors._broadcast_overlay_fn = on_overlay
CropResizeField._broadcast_overlay_fn = on_overlay
RotateField._broadcast_warning_fn = on_warning
Markup._broadcast_overlay_fn = on_overlay
Image._broadcast_warning_fn = on_warning
ImageDemo._broadcast_warning_fn = on_warning
Save._broadcast_warning_fn = on_warning
SaveImage._broadcast_warning_fn = on_warning
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.preview_image import PreviewImage
from backend.nodes.print_table import PrintTable
from backend.nodes.view_3d import View3D
from backend.nodes.annotations import Annotations
from backend.nodes.value_display import ValueDisplay
from backend.nodes.markup import Markup
from backend.nodes.cross_section import CrossSection
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
from backend.nodes.histogram import Histogram
from backend.nodes.crop_resize_field import CropResizeField
from backend.nodes.rotate_field import RotateField
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.mask_morphology import MaskMorphology
from backend.nodes.mask_invert import MaskInvert
from backend.nodes.mask_combine import MaskCombine
from backend.nodes.draw_mask import DrawMask
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
from backend.nodes.save import Save
from backend.nodes.save_image import SaveImage
if cls in (PreviewImage, PrintTable, View3D, Annotations, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
Image, ImageDemo, Save, SaveImage):
cls._current_node_id = node_id
def _auto_preview( def _auto_preview(
self, self,
cls: type, cls: type,

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Callable
Callback = Callable[[str, Any], None]
_callbacks_var: ContextVar[dict[str, Callback | None]] = ContextVar(
"argonode_execution_callbacks",
default={},
)
_node_id_var: ContextVar[str | None] = ContextVar("argonode_execution_node_id", default=None)
@contextmanager
def execution_callbacks(
*,
preview: Callback | None = None,
table: Callback | None = None,
mesh: Callback | None = None,
overlay: Callback | None = None,
value: Callback | None = None,
warning: Callback | None = None,
):
token = _callbacks_var.set({
"preview": preview,
"table": table,
"mesh": mesh,
"overlay": overlay,
"value": value,
"warning": warning,
})
try:
yield
finally:
_callbacks_var.reset(token)
@contextmanager
def active_node(node_id: str):
token = _node_id_var.set(str(node_id))
try:
yield
finally:
_node_id_var.reset(token)
def current_node_id() -> str | None:
return _node_id_var.get()
def _emit(kind: str, payload: Any) -> None:
callbacks = _callbacks_var.get()
callback = callbacks.get(kind)
node_id = current_node_id()
if callback is not None and node_id:
callback(node_id, payload)
def emit_preview(payload: Any) -> None:
_emit("preview", payload)
def emit_table(rows: list) -> None:
_emit("table", rows)
def emit_mesh(mesh: dict) -> None:
_emit("mesh", mesh)
def emit_overlay(overlay: dict) -> None:
_emit("overlay", overlay)
def emit_value(payload: Any) -> None:
_emit("value", payload)
def emit_warning(message: str) -> None:
_emit("warning", message)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import ( from backend.data_types import (
COLORMAPS, COLORMAPS,
DataField, DataField,
@@ -120,7 +121,4 @@ class Annotations:
return (ImageData(annotated, metadata={"annotation_context": context}),) return (ImageData(annotated, metadata={"annotation_context": context}),)
def _send_warning(self, message: str): def _send_warning(self, message: str):
fn = Annotations._broadcast_warning_fn emit_warning(message)
nid = Annotations._current_node_id
if fn and nid:
fn(nid, message)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, datafield_to_uint8, encode_preview from backend.data_types import DataField, datafield_to_uint8, encode_preview
@@ -61,20 +62,16 @@ class CropResizeField:
x2 = float(np.clip(x2, 0.0, 1.0)) x2 = float(np.clip(x2, 0.0, 1.0))
y2 = float(np.clip(y2, 0.0, 1.0)) y2 = float(np.clip(y2, 0.0, 1.0))
if CropResizeField._broadcast_overlay_fn is not None: emit_overlay({
CropResizeField._broadcast_overlay_fn( "kind": "crop_box",
CropResizeField._current_node_id, "image": encode_preview(datafield_to_uint8(field, field.colormap)),
{ "x1": x1,
"kind": "crop_box", "y1": y1,
"image": encode_preview(datafield_to_uint8(field, field.colormap)), "x2": x2,
"x1": x1, "y2": y2,
"y1": y1, "a_locked": corner_a is not None,
"x2": x2, "b_locked": corner_b is not None,
"y2": y2, })
"a_locked": corner_a is not None,
"b_locked": corner_b is not None,
},
)
left = min(x1, x2) left = min(x1, x2)
right = max(x1, x2) right = max(x1, x2)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _extend_to_edges from backend.nodes.helpers import _extend_to_edges
@@ -73,19 +74,14 @@ class CrossSection:
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest") profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
if CrossSection._broadcast_overlay_fn is not None: image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
image_uri = encode_preview(datafield_to_uint8(field, field.colormap)) emit_overlay({
"image": image_uri,
CrossSection._broadcast_overlay_fn( "x1": marker_x1, "y1": marker_y1,
CrossSection._current_node_id, "x2": marker_x2, "y2": marker_y2,
{ "a_locked": marker_pair is not None,
"image": image_uri, "b_locked": marker_pair is not None,
"x1": marker_x1, "y1": marker_y1, })
"x2": marker_x2, "y2": marker_y2,
"a_locked": marker_pair is not None,
"b_locked": marker_pair is not None,
},
)
dx_real = (x2 - x1) * field.xreal dx_real = (x2 - x1) * field.xreal
dy_real = (y2 - y1) * field.yreal dy_real = (y2 - y1) * field.yreal

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, LineData, MeasureTable, encode_preview, render_datafield_preview from backend.data_types import DataField, LineData, MeasureTable, encode_preview, render_datafield_preview
@@ -87,22 +88,18 @@ class Cursors:
xa, ya = float(x[idx_a]), float(y[idx_a]) xa, ya = float(x[idx_a]), float(y[idx_a])
xb, yb = float(x[idx_b]), float(y[idx_b]) xb, yb = float(x[idx_b]), float(y[idx_b])
if Cursors._broadcast_overlay_fn is not None: emit_overlay({
Cursors._broadcast_overlay_fn( "kind": "line_plot",
Cursors._current_node_id, "section_title": "Cursors",
{ "line": y.tolist(),
"kind": "line_plot", "x_axis": x.tolist(),
"section_title": "Cursors", "x1": x1,
"line": y.tolist(), "x2": x2,
"x_axis": x.tolist(), "y1": float(y1),
"x1": x1, "y2": float(y2),
"x2": x2, "a_locked": locked,
"y1": float(y1), "b_locked": locked,
"y2": float(y2), })
"a_locked": locked,
"b_locked": locked,
},
)
table = MeasureTable([ table = MeasureTable([
{"quantity": "A x", "value": xa, "unit": x_unit}, {"quantity": "A x", "value": xa, "unit": x_unit},
@@ -143,21 +140,17 @@ class Cursors:
bx = float(field.xoff + x2 * field.xreal) bx = float(field.xoff + x2 * field.xreal)
by = float(field.yoff + y2 * field.yreal) by = float(field.yoff + y2 * field.yreal)
if Cursors._broadcast_overlay_fn is not None: emit_overlay({
Cursors._broadcast_overlay_fn( "kind": "cursor_points",
Cursors._current_node_id, "section_title": "Cursors",
{ "image": encode_preview(render_datafield_preview(field, field.colormap)),
"kind": "cursor_points", "x1": x1,
"section_title": "Cursors", "y1": y1,
"image": encode_preview(render_datafield_preview(field, field.colormap)), "x2": x2,
"x1": x1, "y2": y2,
"y1": y1, "a_locked": locked,
"x2": x2, "b_locked": locked,
"y2": y2, })
"a_locked": locked,
"b_locked": locked,
},
)
table = MeasureTable([ table = MeasureTable([
{"quantity": "A x", "value": ax, "unit": field.si_unit_xy}, {"quantity": "A x", "value": ax, "unit": field.si_unit_xy},

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, datafield_to_uint8, encode_preview from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _parse_mask_strokes, _rasterize_mask from backend.nodes.helpers import _parse_mask_strokes, _rasterize_mask
@@ -40,17 +41,13 @@ class DrawMask:
if invert: if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255)) mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None: emit_overlay({
DrawMask._broadcast_overlay_fn( "kind": "mask_paint",
DrawMask._current_node_id, "section_title": "Mask",
{ "image": encode_preview(datafield_to_uint8(field, "gray")),
"kind": "mask_paint", "image_width": field.xres,
"section_title": "Mask", "image_height": field.yres,
"image": encode_preview(datafield_to_uint8(field, "gray")), "invert": bool(invert),
"image_width": field.xres, })
"image_height": field.yres,
"invert": bool(invert),
},
)
return (mask,) return (mask,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, MeasureTable from backend.data_types import DataField, MeasureTable
@@ -72,22 +73,18 @@ class Histogram:
yb = float(counts[idx_b]) if len(counts) else 0.0 yb = float(counts[idx_b]) if len(counts) else 0.0
count_unit = "count" if y_scale == "linear" else "log10(1+count)" count_unit = "count" if y_scale == "linear" else "log10(1+count)"
if Histogram._broadcast_overlay_fn is not None: emit_overlay({
Histogram._broadcast_overlay_fn( "kind": "line_plot",
Histogram._current_node_id, "section_title": "Histogram",
{ "line": counts.tolist(),
"kind": "line_plot", "x_axis": bin_centers.astype(np.float64).tolist(),
"section_title": "Histogram", "x1": float(np.clip(x1, 0.0, 1.0)),
"line": counts.tolist(), "x2": float(np.clip(x2, 0.0, 1.0)),
"x_axis": bin_centers.astype(np.float64).tolist(), "y1": float(y1),
"x1": float(np.clip(x1, 0.0, 1.0)), "y2": float(y2),
"x2": float(np.clip(x2, 0.0, 1.0)), "a_locked": False,
"y1": float(y1), "b_locked": False,
"y2": float(y2), })
"a_locked": False,
"b_locked": False,
},
)
table = MeasureTable([ table = MeasureTable([
{"quantity": "A position", "value": xa, "unit": field.si_unit_z}, {"quantity": "A position", "value": xa, "unit": field.si_unit_z},

View File

@@ -4,6 +4,7 @@ import numpy as np
from pathlib import Path from pathlib import Path
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import COLORMAPS, DataField, resolve_colormap_input from backend.data_types import COLORMAPS, DataField, resolve_colormap_input
from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS, _import_ibw_loader from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS, _import_ibw_loader
@@ -66,10 +67,7 @@ class Image:
return fields return fields
def _send_warning(self, message: str): def _send_warning(self, message: str):
fn = Image._broadcast_warning_fn emit_warning(message)
nid = Image._current_node_id
if fn and nid:
fn(nid, message)
@staticmethod @staticmethod
@lru_cache(maxsize=32) @lru_cache(maxsize=32)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import ( from backend.data_types import (
DataField, DataField,
ImageData, ImageData,
@@ -70,17 +71,13 @@ class Markup:
metadata=image_metadata(input), metadata=image_metadata(input),
) )
if Markup._broadcast_overlay_fn is not None: emit_overlay({
Markup._broadcast_overlay_fn( "kind": "markup",
Markup._current_node_id, "section_title": "Markup",
{ "image": encode_preview(preview_base),
"kind": "markup", "shape": str(shape),
"section_title": "Markup", "stroke_color": _normalize_markup_color(stroke_color),
"image": encode_preview(preview_base), "stroke_width": max(1, int(stroke_width)),
"shape": str(shape), })
"stroke_color": _normalize_markup_color(stroke_color),
"stroke_width": max(1, int(stroke_width)),
},
)
return (out,) return (out,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import DataField, encode_preview from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay from backend.nodes.helpers import _mask_overlay
@@ -53,10 +54,8 @@ class MaskCombine:
out = result.astype(np.uint8) * 255 out = result.astype(np.uint8) * 255
if field is not None and MaskCombine._broadcast_fn is not None: if field is not None:
overlay = _mask_overlay(field, out) overlay = _mask_overlay(field, out)
MaskCombine._broadcast_fn( emit_preview(encode_preview(overlay))
MaskCombine._current_node_id, encode_preview(overlay),
)
return (out,) return (out,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import DataField, encode_preview from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay from backend.nodes.helpers import _mask_overlay
@@ -32,10 +33,8 @@ class MaskInvert:
def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple: def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
out = np.where(mask > 127, np.uint8(0), np.uint8(255)) out = np.where(mask > 127, np.uint8(0), np.uint8(255))
if field is not None and MaskInvert._broadcast_fn is not None: if field is not None:
overlay = _mask_overlay(field, out) overlay = _mask_overlay(field, out)
MaskInvert._broadcast_fn( emit_preview(encode_preview(overlay))
MaskInvert._current_node_id, encode_preview(overlay),
)
return (out,) return (out,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import DataField, encode_preview from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay, _mask_structure from backend.nodes.helpers import _mask_overlay, _mask_structure
@@ -62,10 +63,8 @@ class MaskMorphology:
out = result.astype(np.uint8) * 255 out = result.astype(np.uint8) * 255
if field is not None and MaskMorphology._broadcast_fn is not None: if field is not None:
overlay = _mask_overlay(field, out) overlay = _mask_overlay(field, out)
MaskMorphology._broadcast_fn( emit_preview(encode_preview(overlay))
MaskMorphology._current_node_id, encode_preview(overlay),
)
return (out,) return (out,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import ( from backend.data_types import (
COLORMAPS, COLORMAPS,
colormap_to_uint8, colormap_to_uint8,
@@ -68,7 +69,6 @@ class PreviewImage:
data_uri = encode_preview(arr_u8) data_uri = encode_preview(arr_u8)
if PreviewImage._broadcast_fn is not None: emit_preview(data_uri)
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
return () return ()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_table
@register_node(display_name="Print Table") @register_node(display_name="Print Table")
@@ -22,6 +23,5 @@ class PrintTable:
_current_node_id: str = "" _current_node_id: str = ""
def print_table(self, table: list) -> tuple: def print_table(self, table: list) -> tuple:
if PrintTable._broadcast_table_fn is not None: emit_table(table)
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
return () return ()

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import DataField from backend.data_types import DataField
@@ -84,10 +85,7 @@ class RotateField:
return (result,) return (result,)
def _send_warning(self, message: str): def _send_warning(self, message: str):
fn = RotateField._broadcast_warning_fn emit_warning(message)
nid = RotateField._current_node_id
if fn and nid:
fn(nid, message)
@staticmethod @staticmethod
def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]: def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]:

View File

@@ -7,6 +7,7 @@ from pathlib import Path
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8 from backend.data_types import DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8
@@ -255,7 +256,4 @@ class Save:
path.write_text("\n".join(lines) + "\n", encoding="utf-8") path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def _send_warning(self, message: str): def _send_warning(self, message: str):
fn = Save._broadcast_warning_fn emit_warning(message)
nid = Save._current_node_id
if fn and nid:
fn(nid, message)

View File

@@ -4,6 +4,7 @@ import numpy as np
from pathlib import Path from pathlib import Path
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import DataField, image_to_uint8 from backend.data_types import DataField, image_to_uint8
from backend.nodes.helpers import _MAX_SAVE_FIELDS from backend.nodes.helpers import _MAX_SAVE_FIELDS
@@ -174,9 +175,6 @@ class SaveImage:
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _send_warning(self, message: str): def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn emit_warning(message)
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return () return ()

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_value
from backend.data_types import DataField, LineData, MeasureTable from backend.data_types import DataField, LineData, MeasureTable
from backend.nodes.helpers import ( from backend.nodes.helpers import (
LINE_OPS, LINE_OPS,
@@ -71,11 +72,9 @@ class Stats:
op_entry = ops[operation] op_entry = ops[operation]
fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry
result = fn(values) result = fn(values)
if Stats._broadcast_value_fn is not None: emit_value(
Stats._broadcast_value_fn( _scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)),
Stats._current_node_id, )
_scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)),
)
return (result,) return (result,)
def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str: def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import DataField, encode_preview from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay from backend.nodes.helpers import _mask_overlay
@@ -52,10 +53,7 @@ class ThresholdMask:
else: else:
mask = (data < t).astype(np.uint8) * 255 mask = (data < t).astype(np.uint8) * 255
if ThresholdMask._broadcast_fn is not None: overlay = _mask_overlay(field, mask)
overlay = _mask_overlay(field, mask) emit_preview(encode_preview(overlay))
ThresholdMask._broadcast_fn(
ThresholdMask._current_node_id, encode_preview(overlay),
)
return (mask,) return (mask,)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_value
from backend.data_types import MeasureTable from backend.data_types import MeasureTable
from backend.nodes.helpers import _measurement_entry, _measurement_value, _scalar_payload from backend.nodes.helpers import _measurement_entry, _measurement_value, _scalar_payload
@@ -38,6 +39,5 @@ class ValueDisplay:
unit = row.get("unit", "") if isinstance(row.get("unit"), str) else "" unit = row.get("unit", "") if isinstance(row.get("unit"), str) else ""
else: else:
numeric = float(value) numeric = float(value)
if ValueDisplay._broadcast_value_fn is not None: emit_value(_scalar_payload(numeric, unit))
ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit))
return (numeric,) return (numeric,)

View File

@@ -3,6 +3,7 @@ import base64
import io import io
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.execution_context import emit_mesh
from backend.data_types import ( from backend.data_types import (
COLORMAPS, COLORMAPS,
DataField, DataField,
@@ -211,8 +212,7 @@ class View3D:
"y_range": [float(field.yoff), float(field.yoff + field.yreal)], "y_range": [float(field.yoff), float(field.yoff + field.yreal)],
} }
if View3D._broadcast_mesh_fn is not None: emit_mesh(mesh_data)
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
annotation_context = _annotation_context_from_field(color_field, resolved_colormap) annotation_context = _annotation_context_from_field(color_field, resolved_colormap)
annotation_context["xreal"] = float(field.xreal) annotation_context["xreal"] = float(field.xreal)

View File

@@ -6,7 +6,11 @@ Routes
GET / → serve frontend/index.html GET / → serve frontend/index.html
GET /static/{path} → serve frontend JS/CSS GET /static/{path} → serve frontend JS/CSS
GET /nodes → JSON dict of all registered node definitions GET /nodes → JSON dict of all registered node definitions
POST /uploadmultipart file upload to input/ GET /files list files in the current session upload workspace
GET /folder-files → list compatible files in a picked folder
GET /channels → inspect channels for a picked file
POST /upload → multipart file upload to the current session workspace
POST /upload-folder → create a folder in the current session workspace
POST /prompt → submit a workflow; returns {prompt_id} POST /prompt → submit a workflow; returns {prompt_id}
GET /ws → WebSocket upgrade GET /ws → WebSocket upgrade
@@ -15,7 +19,7 @@ WebSocket message types sent to clients
{"type": "execution_start", "data": {"prompt_id": "..."}} {"type": "execution_start", "data": {"prompt_id": "..."}}
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}} {"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}} {"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
{"type": "table", "data": {"node_id": "...", "rows": [...]}} {"type": "table", "data": {"node_id": "...", "rows": [...]} }
{"type": "scalar", "data": {"node_id": "...", "value": 1.23, "unit": "nm"}} {"type": "scalar", "data": {"node_id": "...", "value": 1.23, "unit": "nm"}}
{"type": "node_timing", "data": {"node_id": "...", "elapsed_ms": 12.34}} {"type": "node_timing", "data": {"node_id": "...", "elapsed_ms": 12.34}}
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}} {"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
@@ -23,39 +27,43 @@ WebSocket message types sent to clients
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import sys import sys
from collections import defaultdict
from copy import deepcopy
from pathlib import Path from pathlib import Path
from aiohttp import web, WSMsgType from aiohttp import web, WSMsgType
from backend.frontend_build import FrontendBuildError, ensure_frontend_dist_ready from backend.frontend_build import FrontendBuildError, ensure_frontend_dist_ready
from backend.runtime_paths import ( from backend.runtime_paths import ensure_runtime_dirs, frontend_dir, frontend_dist_dir, project_root
ensure_runtime_dirs, from backend.session_runtime import (
frontend_dir, PATH_INPUT_TYPES,
frontend_dist_dir, SESSION_HEADER,
input_dir, SESSION_QUERY,
output_dir, ensure_session_runtime_dirs,
project_root, normalize_relative_upload_path,
resolve_client_path,
server_path_to_client_path,
session_input_dir,
session_upload_uri,
validate_session_id,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
FRONTEND_DIR = frontend_dir() FRONTEND_DIR = frontend_dir()
DIST_DIR = frontend_dist_dir() DIST_DIR = frontend_dist_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
# ---------------------------------------------------------------------------
# JSON helper — numpy scalars are not serialisable by default
# ---------------------------------------------------------------------------
class _SafeEncoder(json.JSONEncoder): class _SafeEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
import numpy as np import numpy as np
if isinstance(obj, (np.integer,)): if isinstance(obj, (np.integer,)):
return int(obj) return int(obj)
if isinstance(obj, (np.floating,)): if isinstance(obj, (np.floating,)):
@@ -81,45 +89,115 @@ def save_png_bytes(target_path: str, payload: bytes) -> Path:
return path return path
# --------------------------------------------------------------------------- def create_app(
# Application factory loop: asyncio.AbstractEventLoop,
# --------------------------------------------------------------------------- *,
allow_local_filesystem: bool = False,
def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: ) -> web.Application:
# Import nodes to trigger registration decorators
import backend.nodes # noqa: F401 import backend.nodes # noqa: F401
from backend.node_registry import get_all_node_info
from backend.execution import ExecutionEngine, new_prompt_id from backend.execution import ExecutionEngine, new_prompt_id
from backend.node_registry import NODE_CLASS_MAPPINGS, get_all_node_info
ensure_runtime_dirs() ensure_runtime_dirs()
engine = ExecutionEngine() session_engines: dict[str, ExecutionEngine] = {}
websockets: set[web.WebSocketResponse] = set() session_websockets: dict[str, set[web.WebSocketResponse]] = defaultdict(set)
# ------------------------------------------------------------------ def _is_link(value) -> bool:
# WebSocket broadcast helpers return (
# ------------------------------------------------------------------ isinstance(value, (list, tuple))
and len(value) == 2
and isinstance(value[0], str)
and isinstance(value[1], int)
)
def broadcast(msg: dict) -> None: def require_session_id(request: web.Request) -> str:
"""Schedule a broadcast to all connected WebSocket clients.""" raw_session = request.headers.get(SESSION_HEADER) or request.query.get(SESSION_QUERY)
if not raw_session:
if allow_local_filesystem:
raw_session = "desktop-local-session"
else:
raise web.HTTPBadRequest(reason="Missing session id")
try:
session_id = validate_session_id(raw_session)
except ValueError as exc:
raise web.HTTPBadRequest(reason=str(exc)) from exc
ensure_session_runtime_dirs(session_id)
return session_id
def get_session_engine(session_id: str) -> ExecutionEngine:
engine = session_engines.get(session_id)
if engine is None:
engine = ExecutionEngine()
session_engines[session_id] = engine
return engine
def resolve_request_path(session_id: str, raw_value: str) -> Path:
try:
return resolve_client_path(
raw_value,
session_id=session_id,
allow_local_filesystem=allow_local_filesystem,
)
except PermissionError as exc:
raise web.HTTPForbidden(reason=str(exc)) from exc
except ValueError as exc:
raise web.HTTPBadRequest(reason=str(exc)) from exc
def rewrite_prompt_paths(prompt: dict, session_id: str) -> dict:
normalized = deepcopy(prompt)
for node_def in normalized.values():
class_name = node_def.get("class_type")
cls = NODE_CLASS_MAPPINGS.get(class_name)
if cls is None:
continue
input_types = cls.INPUT_TYPES()
specs = {}
specs.update(input_types.get("required", {}))
specs.update(input_types.get("optional", {}))
inputs = node_def.get("inputs", {})
if not isinstance(inputs, dict):
continue
for input_name, raw_value in list(inputs.items()):
if _is_link(raw_value) or not isinstance(raw_value, str):
continue
if not raw_value.strip():
continue
spec = specs.get(input_name)
input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec
if not isinstance(input_type, str):
continue
if input_type not in PATH_INPUT_TYPES:
continue
inputs[input_name] = str(resolve_request_path(session_id, raw_value))
return normalized
def broadcast(session_id: str, msg: dict) -> None:
payload = _dumps(msg) payload = _dumps(msg)
for ws in list(websockets): for ws in list(session_websockets.get(session_id, ())):
if not ws.closed: if not ws.closed:
asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop) asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop)
def on_preview(node_id: str, data_uri: str) -> None: def on_preview(session_id: str, node_id: str, data_uri: str) -> None:
broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}}) broadcast(session_id, {"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
def on_table(node_id: str, rows: list) -> None: def on_table(session_id: str, node_id: str, rows: list) -> None:
broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}}) broadcast(session_id, {"type": "table", "data": {"node_id": node_id, "rows": rows}})
def on_mesh(node_id: str, mesh_data: dict) -> None: def on_mesh(session_id: str, node_id: str, mesh_data: dict) -> None:
broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}}) broadcast(session_id, {"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
def on_overlay(node_id: str, overlay_data) -> None: def on_overlay(session_id: str, node_id: str, overlay_data) -> None:
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) broadcast(session_id, {"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
def on_value(node_id: str, payload) -> None: def on_value(session_id: str, node_id: str, payload) -> None:
if isinstance(payload, dict): if isinstance(payload, dict):
value = payload.get("value") value = payload.get("value")
unit = payload.get("unit", "") unit = payload.get("unit", "")
@@ -130,14 +208,10 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
data = {"node_id": node_id, "value": value} data = {"node_id": node_id, "value": value}
if isinstance(unit, str) and unit.strip(): if isinstance(unit, str) and unit.strip():
data["unit"] = unit.strip() data["unit"] = unit.strip()
broadcast({"type": "scalar", "data": data}) broadcast(session_id, {"type": "scalar", "data": data})
def on_warning(node_id: str, message: str) -> None: def on_warning(session_id: str, node_id: str, message: str) -> None:
broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}}) broadcast(session_id, {"type": "node_warning", "data": {"node_id": node_id, "message": message}})
# ------------------------------------------------------------------
# Route handlers
# ------------------------------------------------------------------
async def index(request: web.Request) -> web.Response: async def index(request: web.Request) -> web.Response:
if not getattr(sys, "frozen", False): if not getattr(sys, "frozen", False):
@@ -167,88 +241,96 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
) )
async def get_nodes(request: web.Request) -> web.Response: async def get_nodes(request: web.Request) -> web.Response:
info = get_all_node_info()
return web.Response( return web.Response(
text=_dumps(info), text=_dumps(get_all_node_info()),
content_type="application/json", content_type="application/json",
) )
async def list_files(request: web.Request) -> web.Response: async def list_files(request: web.Request) -> web.Response:
"""List files in the input/ directory for the file picker widget.""" session_id = require_session_id(request)
input_path = session_input_dir(session_id)
files = sorted( files = sorted(
f.name for f in INPUT_DIR.iterdir() server_path_to_client_path(entry, session_id)
if f.is_file() and not f.name.startswith(".") for entry in input_path.iterdir()
) if INPUT_DIR.exists() else [] if entry.is_file() and not entry.name.startswith(".")
) if input_path.exists() else []
return web.Response(text=_dumps(files), content_type="application/json") return web.Response(text=_dumps(files), content_type="application/json")
async def browse_dir(request: web.Request) -> web.Response: async def create_upload_folder(request: web.Request) -> web.Response:
""" session_id = require_session_id(request)
Server-side directory browser for local file picking. body = await request.json()
GET /browse?dir=/some/path → {parent, dirs[], files[]} relative_path = normalize_relative_upload_path(body.get("path", ""))
""" target = session_input_dir(session_id) / Path(relative_path.as_posix())
dir_path = request.query.get("dir", str(Path.home())) target.mkdir(parents=True, exist_ok=True)
p = Path(dir_path).expanduser().resolve()
if not p.is_dir():
raise web.HTTPBadRequest(reason=f"Not a directory: {p}")
dirs = []
files = []
try:
for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()):
if entry.name.startswith("."):
continue
if entry.is_dir():
dirs.append(entry.name)
elif entry.is_file():
files.append(entry.name)
except PermissionError:
pass
return web.Response( return web.Response(
text=_dumps({ text=_dumps({"path": session_upload_uri(relative_path)}),
"path": str(p),
"parent": str(p.parent) if p.parent != p else None,
"dirs": dirs,
"files": files,
}),
content_type="application/json", content_type="application/json",
) )
async def get_folder_files(request: web.Request) -> web.Response: async def get_folder_files(request: web.Request) -> web.Response:
folder_path = request.query.get("folder", "")
from backend.nodes.helpers import list_folder_paths from backend.nodes.helpers import list_folder_paths
loop = asyncio.get_running_loop()
entries = await loop.run_in_executor(None, list_folder_paths, folder_path) session_id = require_session_id(request)
return web.Response(text=_dumps(entries), content_type="application/json") folder_path = request.query.get("folder", "")
if not folder_path:
return web.Response(text=_dumps([]), content_type="application/json")
resolved_path = resolve_request_path(session_id, folder_path)
running_loop = asyncio.get_running_loop()
entries = await running_loop.run_in_executor(None, list_folder_paths, str(resolved_path))
payload = []
for entry in entries:
mapped = dict(entry)
if "path" in mapped:
mapped["path"] = server_path_to_client_path(mapped["path"], session_id)
payload.append(mapped)
return web.Response(text=_dumps(payload), content_type="application/json")
async def upload_file(request: web.Request) -> web.Response: async def upload_file(request: web.Request) -> web.Response:
session_id = require_session_id(request)
reader = await request.multipart() reader = await request.multipart()
field = await reader.next() relative_path = None
if field is None or field.name != "file": filename = ""
file_bytes = None
while True:
field = await reader.next()
if field is None:
break
if field.name == "relative_path":
relative_path = await field.text()
continue
if field.name == "file":
filename = Path(field.filename or "upload.bin").name
chunks = []
while True:
chunk = await field.read_chunk(65536)
if not chunk:
break
chunks.append(chunk)
file_bytes = b"".join(chunks)
if file_bytes is None:
raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body") raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body")
filename = Path(field.filename).name # strip any path traversal relative = normalize_relative_upload_path(relative_path or filename)
dest = INPUT_DIR / filename dest = session_input_dir(session_id) / Path(relative.as_posix())
with open(dest, "wb") as f: dest.parent.mkdir(parents=True, exist_ok=True)
while True: dest.write_bytes(file_bytes)
chunk = await field.read_chunk(65536)
if not chunk:
break
f.write(chunk)
return web.Response(text=_dumps({"filename": filename}), content_type="application/json") return web.Response(
text=_dumps({"filename": filename, "path": session_upload_uri(relative)}),
content_type="application/json",
)
async def download_file(request: web.Request) -> web.Response: async def download_file(request: web.Request) -> web.Response:
"""Accept a blob POST and return it with Content-Disposition: attachment."""
body = await request.read() body = await request.read()
filename = request.query.get("filename", "workflow.png") filename = request.query.get("filename", "workflow.png")
return web.Response( return web.Response(
body=body, body=body,
content_type="application/octet-stream", content_type="application/octet-stream",
headers={ headers={"Content-Disposition": f'attachment; filename="{filename}"'},
"Content-Disposition": f'attachment; filename="{filename}"',
},
) )
async def save_workflow_png(request: web.Request) -> web.Response: async def save_workflow_png(request: web.Request) -> web.Response:
@@ -266,34 +348,39 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
) )
async def get_channels(request: web.Request) -> web.Response: async def get_channels(request: web.Request) -> web.Response:
"""Return available channels for a given file path."""
from backend.nodes.helpers import list_channels from backend.nodes.helpers import list_channels
session_id = require_session_id(request)
filepath = request.query.get("file", "") filepath = request.query.get("file", "")
if not filepath: if not filepath:
return web.Response( return web.Response(
text=_dumps([{"name": "field", "type": "DATA_FIELD"}]), text=_dumps([{"name": "field", "type": "DATA_FIELD"}]),
content_type="application/json", content_type="application/json",
) )
channels = await loop.run_in_executor(None, list_channels, filepath)
resolved_path = resolve_request_path(session_id, filepath)
channels = await loop.run_in_executor(None, list_channels, str(resolved_path))
return web.Response(text=_dumps(channels), content_type="application/json") return web.Response(text=_dumps(channels), content_type="application/json")
async def submit_prompt(request: web.Request) -> web.Response: async def submit_prompt(request: web.Request) -> web.Response:
session_id = require_session_id(request)
body = await request.json() body = await request.json()
prompt = body.get("prompt") prompt = body.get("prompt")
if not isinstance(prompt, dict) or not prompt: if not isinstance(prompt, dict) or not prompt:
raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict") raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict")
normalized_prompt = rewrite_prompt_paths(prompt, session_id)
prompt_id = new_prompt_id() prompt_id = new_prompt_id()
engine = get_session_engine(session_id)
# Run execution in a thread pool so scipy doesn't block the event loop
async def run(): async def run():
broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}}) broadcast(session_id, {"type": "execution_start", "data": {"prompt_id": prompt_id}})
def on_start(node_id: str) -> None: def on_start(node_id: str) -> None:
broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}}) broadcast(session_id, {"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
def on_done(node_id: str, elapsed_ms: float) -> None: def on_done(node_id: str, elapsed_ms: float) -> None:
broadcast({ broadcast(session_id, {
"type": "node_timing", "type": "node_timing",
"data": {"node_id": node_id, "elapsed_ms": elapsed_ms}, "data": {"node_id": node_id, "elapsed_ms": elapsed_ms},
}) })
@@ -302,21 +389,21 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
await loop.run_in_executor( await loop.run_in_executor(
None, None,
lambda: engine.execute( lambda: engine.execute(
prompt, normalized_prompt,
on_node_start=on_start, on_node_start=on_start,
on_node_done=on_done, on_node_done=on_done,
on_preview=on_preview, on_preview=lambda node_id, payload: on_preview(session_id, node_id, payload),
on_table=on_table, on_table=lambda node_id, rows: on_table(session_id, node_id, rows),
on_mesh=on_mesh, on_mesh=lambda node_id, mesh_data: on_mesh(session_id, node_id, mesh_data),
on_overlay=on_overlay, on_overlay=lambda node_id, overlay_data: on_overlay(session_id, node_id, overlay_data),
on_value=on_value, on_value=lambda node_id, payload: on_value(session_id, node_id, payload),
on_warning=on_warning, on_warning=lambda node_id, message: on_warning(session_id, node_id, message),
), ),
) )
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}}) broadcast(session_id, {"type": "execution_complete", "data": {"prompt_id": prompt_id}})
except Exception as exc: except Exception as exc:
log.exception("Execution error") log.exception("Execution error")
broadcast({ broadcast(session_id, {
"type": "execution_error", "type": "execution_error",
"data": {"node_id": "", "message": str(exc)}, "data": {"node_id": "", "message": str(exc)},
}) })
@@ -328,32 +415,40 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
) )
async def websocket_handler(request: web.Request) -> web.WebSocketResponse: async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
session_id = require_session_id(request)
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
websockets.add(ws) session_websockets[session_id].add(ws)
log.info("WebSocket client connected (%d total)", len(websockets)) log.info(
"WebSocket client connected for session %s (%d total in session)",
session_id,
len(session_websockets[session_id]),
)
try: try:
async for msg in ws: async for msg in ws:
if msg.type == WSMsgType.TEXT: if msg.type == WSMsgType.TEXT:
pass # clients don't need to send anything currently pass
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
break break
finally: finally:
websockets.discard(ws) session_websockets[session_id].discard(ws)
log.info("WebSocket client disconnected (%d total)", len(websockets)) if not session_websockets[session_id]:
session_websockets.pop(session_id, None)
log.info(
"WebSocket client disconnected for session %s (%d remaining in session)",
session_id,
len(session_websockets.get(session_id, ())),
)
return ws return ws
# ------------------------------------------------------------------
# App assembly
# ------------------------------------------------------------------
app = web.Application() app = web.Application()
app["allow_local_filesystem"] = allow_local_filesystem
app.router.add_get("/", index) app.router.add_get("/", index)
app.router.add_get("/nodes", get_nodes) app.router.add_get("/nodes", get_nodes)
app.router.add_get("/files", list_files) app.router.add_get("/files", list_files)
app.router.add_get("/browse", browse_dir)
app.router.add_get("/folder-files", get_folder_files) app.router.add_get("/folder-files", get_folder_files)
app.router.add_post("/upload-folder", create_upload_folder)
app.router.add_post("/upload", upload_file) app.router.add_post("/upload", upload_file)
app.router.add_post("/download", download_file) app.router.add_post("/download", download_file)
app.router.add_post("/save-workflow-png", save_workflow_png) app.router.add_post("/save-workflow-png", save_workflow_png)
@@ -361,26 +456,24 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
app.router.add_post("/prompt", submit_prompt) app.router.add_post("/prompt", submit_prompt)
app.router.add_get("/ws", websocket_handler) app.router.add_get("/ws", websocket_handler)
# Serve frontend static files (Vite build or raw)
if (DIST_DIR / "assets").exists(): if (DIST_DIR / "assets").exists():
app.router.add_static("/assets", DIST_DIR / "assets") app.router.add_static("/assets", DIST_DIR / "assets")
if FRONTEND_DIR.exists(): if FRONTEND_DIR.exists():
app.router.add_static("/static", FRONTEND_DIR) app.router.add_static("/static", FRONTEND_DIR)
# CORS — allow any origin (local dev only)
async def _cors_middleware(app_, handler): async def _cors_middleware(app_, handler):
async def middleware(request): async def middleware(request):
if request.method == "OPTIONS": if request.method == "OPTIONS":
return web.Response(headers={ return web.Response(headers={
"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS", "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type", "Access-Control-Allow-Headers": f"Content-Type, {SESSION_HEADER}",
}) })
response = await handler(request) response = await handler(request)
response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Origin"] = "*"
return response return response
return middleware return middleware
app.middlewares.append(_cors_middleware) app.middlewares.append(_cors_middleware)
return app return app

132
backend/session_runtime.py Normal file
View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import re
from pathlib import Path, PurePosixPath
from backend.runtime_paths import app_data_dir, demo_dir
SESSION_HEADER = "X-Argonode-Session"
SESSION_QUERY = "session"
SESSION_URI_PREFIX = "session://uploads/"
PATH_INPUT_TYPES = {"FILE_PICKER", "FILE_PATH", "FOLDER_PICKER", "DIRECTORY"}
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_-]{7,127}$")
def validate_session_id(session_id: str) -> str:
text = str(session_id or "").strip()
if not _SESSION_ID_RE.fullmatch(text):
raise ValueError("Invalid session id")
return text
def session_root_dir(session_id: str) -> Path:
validated = validate_session_id(session_id)
return app_data_dir() / "sessions" / validated
def session_input_dir(session_id: str) -> Path:
return session_root_dir(session_id) / "input"
def session_output_dir(session_id: str) -> Path:
return session_root_dir(session_id) / "output"
def ensure_session_runtime_dirs(session_id: str) -> tuple[Path, Path]:
input_path = session_input_dir(session_id)
output_path = session_output_dir(session_id)
input_path.mkdir(parents=True, exist_ok=True)
output_path.mkdir(parents=True, exist_ok=True)
return input_path, output_path
def normalize_relative_upload_path(raw_path: str) -> PurePosixPath:
raw_text = str(raw_path or "").replace("\\", "/").strip()
if not raw_text:
raise ValueError("Missing upload path")
path = PurePosixPath(raw_text)
if path.is_absolute():
raise ValueError("Upload paths must be relative")
parts: list[str] = []
for part in path.parts:
if part in ("", "."):
continue
if part == "..":
raise ValueError("Upload paths cannot escape the session directory")
if "\x00" in part:
raise ValueError("Upload paths cannot contain NUL bytes")
parts.append(part)
if not parts:
raise ValueError("Upload paths must contain at least one path segment")
return PurePosixPath(*parts)
def session_upload_uri(relative_path: str | PurePosixPath) -> str:
normalized = normalize_relative_upload_path(str(relative_path))
return f"{SESSION_URI_PREFIX}{normalized.as_posix()}"
def session_uri_to_relative_path(value: str) -> PurePosixPath | None:
text = str(value or "").strip()
if not text.startswith(SESSION_URI_PREFIX):
return None
return normalize_relative_upload_path(text[len(SESSION_URI_PREFIX):])
def is_path_within(root: Path, candidate: Path) -> bool:
try:
candidate.resolve(strict=False).relative_to(root.resolve(strict=False))
return True
except ValueError:
return False
def server_path_to_client_path(path_value: str | Path, session_id: str) -> str:
path = Path(path_value).expanduser().resolve(strict=False)
session_input = session_input_dir(session_id).resolve(strict=False)
if is_path_within(session_input, path):
rel = path.relative_to(session_input)
return session_upload_uri(rel.as_posix())
return str(path)
def resolve_client_path(
value: str,
*,
session_id: str,
allow_local_filesystem: bool,
) -> Path:
text = str(value or "").strip()
if not text:
return Path("")
rel = session_uri_to_relative_path(text)
if rel is not None:
return (session_input_dir(session_id) / Path(rel.as_posix())).resolve(strict=False)
candidate = Path(text).expanduser()
if not candidate.is_absolute():
demo_candidate = (demo_dir() / text).expanduser().resolve(strict=False)
if demo_candidate.exists():
return demo_candidate
if not candidate.is_absolute():
if allow_local_filesystem:
return candidate.resolve(strict=False)
raise PermissionError("Browser sessions may only use files uploaded through Browse.")
resolved = candidate.resolve(strict=False)
if allow_local_filesystem:
return resolved
session_root = session_root_dir(session_id).resolve(strict=False)
if is_path_within(session_root, resolved):
return resolved
raise PermissionError("Path is outside the current session workspace.")

View File

@@ -44,6 +44,19 @@ class _Api:
return result[0] return result[0]
return None return None
def open_folder_dialog(self) -> str | None:
"""Open a native folder picker and return the selected path (or None)."""
win = self._window_ref[0]
if win is None:
return None
result = win.create_file_dialog(
webview.FOLDER_DIALOG,
allow_multiple=False,
)
if result and len(result) > 0:
return result[0]
return None
def choose_save_workflow_png_path(self, default_filename: str = "workflow.png") -> str | None: def choose_save_workflow_png_path(self, default_filename: str = "workflow.png") -> str | None:
"""Open a native save dialog and return the chosen PNG path (or None).""" """Open a native save dialog and return the chosen PNG path (or None)."""
win = self._window_ref[0] win = self._window_ref[0]
@@ -90,7 +103,7 @@ def _run_server(host: str, port: int, ready: threading.Event, state: dict[str, o
state["loop"] = loop state["loop"] = loop
async def start() -> None: async def start() -> None:
app = create_app(loop) app = create_app(loop, allow_local_filesystem=True)
runner = web.AppRunner(app, access_log=None) runner = web.AppRunner(app, access_log=None)
await runner.setup() await runner.setup()
site = web.TCPSite(runner, host, port) site = web.TCPSite(runner, host, port)

View File

@@ -4,13 +4,13 @@ import React, {
import { import {
ReactFlow, Background, Controls, MiniMap, ReactFlow, Background, Controls, MiniMap,
useNodesState, useEdgesState, addEdge, useReactFlow, useNodesState, useEdgesState, addEdge, useReactFlow,
ReactFlowProvider, getViewportForBounds, ReactFlowProvider, getViewportForBounds, PanOnScrollMode, SelectionMode,
} from '@xyflow/react'; } from '@xyflow/react';
import '@xyflow/react/dist/style.css'; import '@xyflow/react/dist/style.css';
import CustomNode, { NodeContext } from './CustomNode'; import CustomNode, { NodeContext } from './CustomNode';
import FileBrowser from './FileBrowser';
import * as api from './api'; import * as api from './api';
import { pickNativeDirectorySelection, pickNativeFileSelection } from './nativePicker';
import { toBlob } from 'html-to-image'; import { toBlob } from 'html-to-image';
import { embedWorkflow, extractWorkflow } from './pngMetadata'; import { embedWorkflow, extractWorkflow } from './pngMetadata';
import { captureViewportBlob as captureWorkflowViewportBlob } from './workflowCapture'; import { captureViewportBlob as captureWorkflowViewportBlob } from './workflowCapture';
@@ -791,7 +791,6 @@ function Flow() {
const [edges, setEdges, onEdgesChange] = useEdgesState([]); const [edges, setEdges, onEdgesChange] = useEdgesState([]);
const [status, setStatus] = useState({ text: 'Connecting…', level: 'info' }); const [status, setStatus] = useState({ text: 'Connecting…', level: 'info' });
const [contextMenu, setContextMenu] = useState(null); const [contextMenu, setContextMenu] = useState(null);
const [fileBrowserState, setFileBrowserState] = useState(null);
const nodeDefsRef = useRef({}); const nodeDefsRef = useRef({});
const nextIdRef = useRef(1); const nextIdRef = useRef(1);
@@ -1481,22 +1480,68 @@ function Flow() {
// ── File browser ──────────────────────────────────────────────────── // ── File browser ────────────────────────────────────────────────────
const openFileBrowser = useCallback((callback, { selectionMode = 'file' } = {}) => { const uploadBrowserSelection = useCallback(async (selection, selectionMode) => {
if (!selection) return null;
if (selectionMode === 'folder') {
const rootName = String(selection.rootName || '').trim();
if (!rootName) {
throw new Error('Selected folder is empty or could not be read.');
}
setStatus({
text: `Importing folder "${rootName}" into this session…`,
level: 'info',
});
const folder = await api.createUploadFolder(rootName);
for (const entry of selection.entries || []) {
await api.uploadFile(entry.file, { relativePath: entry.relativePath });
}
return folder.path;
}
const [entry] = selection.entries || [];
if (!entry) return null;
setStatus({
text: `Uploading ${entry.file.name}`,
level: 'info',
});
const uploaded = await api.uploadFile(entry.file, { relativePath: entry.relativePath });
return uploaded.path;
}, []);
const openFileBrowser = useCallback(async (callback, { selectionMode = 'file' } = {}) => {
if (selectionMode === 'folder' && window.pywebview?.api?.open_folder_dialog) { if (selectionMode === 'folder' && window.pywebview?.api?.open_folder_dialog) {
window.pywebview.api.open_folder_dialog().then((path) => { window.pywebview.api.open_folder_dialog().then((path) => {
if (path) callback(path); if (path) callback(path);
}); });
return; return;
} }
// Use native file picker when running inside pywebview (desktop app)
if (selectionMode === 'file' && window.pywebview?.api?.open_file_dialog) { if (selectionMode === 'file' && window.pywebview?.api?.open_file_dialog) {
window.pywebview.api.open_file_dialog().then((path) => { window.pywebview.api.open_file_dialog().then((path) => {
if (path) callback(path); if (path) callback(path);
}); });
return; return;
} }
setFileBrowserState({ callback, selectionMode });
}, []); try {
const selection = selectionMode === 'folder'
? await pickNativeDirectorySelection()
: await pickNativeFileSelection();
if (!selection) return;
const uploadedPath = await uploadBrowserSelection(selection, selectionMode);
if (uploadedPath) callback(uploadedPath);
} catch (error) {
setStatus({
text: `Browse failed: ${error.message || String(error)}`,
level: 'error',
});
}
}, [uploadBrowserSelection]);
// ── Node context value (stable) ───────────────────────────────────── // ── Node context value (stable) ─────────────────────────────────────
@@ -1782,6 +1827,21 @@ function Flow() {
setTimeout(() => reactFlow.updateNodeInternals(String(groupId)), 0); setTimeout(() => reactFlow.updateNodeInternals(String(groupId)), 0);
}, [reactFlow, setNodes]); }, [reactFlow, setNodes]);
const renameGroup = useCallback((groupId, label) => {
const nextLabel = String(label || '').trim() || 'group';
setNodes((existing) => existing.map((node) => {
if (String(node.id) !== String(groupId) || node.data?.className !== 'Group') return node;
if (String(node.data?.label || 'group') === nextLabel) return node;
return {
...node,
data: {
...node.data,
label: nextLabel,
},
};
}));
}, [setNodes]);
const contextValue = useMemo(() => ({ const contextValue = useMemo(() => ({
onWidgetChange, onWidgetChange,
onRuntimeValuesChange, onRuntimeValuesChange,
@@ -1789,8 +1849,9 @@ function Flow() {
onManualTrigger, onManualTrigger,
onToggleGroupCollapse: toggleGroupCollapse, onToggleGroupCollapse: toggleGroupCollapse,
onResizeGroup: resizeGroup, onResizeGroup: resizeGroup,
onRenameGroup: renameGroup,
onUngroup: ungroupGroup, onUngroup: ungroupGroup,
}), [onRuntimeValuesChange, onWidgetChange, openFileBrowser, onManualTrigger, resizeGroup, toggleGroupCollapse, ungroupGroup]); }), [onRuntimeValuesChange, onWidgetChange, openFileBrowser, onManualTrigger, renameGroup, resizeGroup, toggleGroupCollapse, ungroupGroup]);
const clearGraph = useCallback(() => { const clearGraph = useCallback(() => {
setNodes([]); setNodes([]);
@@ -2602,6 +2663,12 @@ function Flow() {
nodeTypes={NODE_TYPES} nodeTypes={NODE_TYPES}
onPaneContextMenu={onPaneContextMenu} onPaneContextMenu={onPaneContextMenu}
colorMode="dark" colorMode="dark"
panOnDrag={[1]}
panOnScroll
panOnScrollMode={PanOnScrollMode.Free}
zoomOnScroll={false}
selectionOnDrag
selectionMode={SelectionMode.Partial}
multiSelectionKeyCode={['Shift']} multiSelectionKeyCode={['Shift']}
deleteKeyCode={['Backspace', 'Delete']} deleteKeyCode={['Backspace', 'Delete']}
defaultEdgeOptions={{ type: 'default' }} defaultEdgeOptions={{ type: 'default' }}
@@ -2631,14 +2698,6 @@ function Flow() {
)} )}
</div> </div>
{/* File browser modal */}
{fileBrowserState && (
<FileBrowser
selectionMode={fileBrowserState.selectionMode}
onSelect={(path) => { fileBrowserState.callback(path); setFileBrowserState(null); }}
onClose={() => setFileBrowserState(null)}
/>
)}
</div> </div>
</NodeContext.Provider> </NodeContext.Provider>
); );

View File

@@ -45,6 +45,9 @@ function GroupNode({ id, data }) {
const childCount = Number(data.childCount) || 0; const childCount = Number(data.childCount) || 0;
const collapsed = !!data.collapsed; const collapsed = !!data.collapsed;
const maxRows = Math.max(proxyInputs.length, proxyOutputs.length, collapsed ? 1 : 0); const maxRows = Math.max(proxyInputs.length, proxyOutputs.length, collapsed ? 1 : 0);
const [isEditingLabel, setIsEditingLabel] = useState(false);
const [draftLabel, setDraftLabel] = useState(String(data.label || 'group'));
const labelInputRef = useRef(null);
const selected = useStore( const selected = useStore(
useCallback( useCallback(
(s) => { (s) => {
@@ -62,6 +65,33 @@ function GroupNode({ id, data }) {
[id], [id],
), ),
); );
const displayLabel = String(data.label || 'group');
useEffect(() => {
if (!isEditingLabel) {
setDraftLabel(displayLabel);
}
}, [displayLabel, isEditingLabel]);
useEffect(() => {
if (!isEditingLabel) return;
labelInputRef.current?.focus();
labelInputRef.current?.select();
}, [isEditingLabel]);
const commitLabel = useCallback(() => {
const nextLabel = String(draftLabel || '').trim() || 'group';
setIsEditingLabel(false);
setDraftLabel(nextLabel);
if (nextLabel !== displayLabel) {
ctx.onRenameGroup?.(id, nextLabel);
}
}, [ctx, displayLabel, draftLabel, id]);
const cancelLabelEdit = useCallback(() => {
setDraftLabel(displayLabel);
setIsEditingLabel(false);
}, [displayLabel]);
return ( return (
<> <>
@@ -84,7 +114,40 @@ function GroupNode({ id, data }) {
> >
{collapsed ? '▸' : '▾'} {collapsed ? '▸' : '▾'}
</button> </button>
<span className="node-title-main">{formatUiLabel(data.label || 'group')}</span> {isEditingLabel ? (
<input
ref={labelInputRef}
className="group-title-input nodrag"
type="text"
value={draftLabel}
onChange={(event) => setDraftLabel(event.target.value)}
onBlur={commitLabel}
onClick={(event) => event.stopPropagation()}
onPointerDown={(event) => event.stopPropagation()}
onKeyDown={(event) => {
if (event.key === 'Enter') {
event.preventDefault();
commitLabel();
} else if (event.key === 'Escape') {
event.preventDefault();
cancelLabelEdit();
}
}}
/>
) : (
<button
type="button"
className="group-title-button nodrag"
title="rename group"
onClick={(event) => {
event.stopPropagation();
setDraftLabel(displayLabel);
setIsEditingLabel(true);
}}
>
{displayLabel}
</button>
)}
<div className="group-node-actions"> <div className="group-node-actions">
<button <button
type="button" type="button"

View File

@@ -1,103 +0,0 @@
import React, { useState, useEffect, useCallback } from 'react';
import * as api from './api';
/**
* Server-side file browser modal.
*
* Props:
* onSelect(absolutePath) — called when user picks a file or folder
* onClose() — called when user dismisses the dialog
*/
export default function FileBrowser({ onSelect, onClose, selectionMode = 'file' }) {
const [path, setPath] = useState('');
const [parent, setParent] = useState(null);
const [dirs, setDirs] = useState([]);
const [files, setFiles] = useState([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState(null);
const navigate = useCallback(async (dir) => {
setLoading(true);
setError(null);
try {
const data = await api.browse(dir);
setPath(data.path);
setParent(data.parent);
setDirs(data.dirs);
setFiles(data.files);
} catch (err) {
setError(err.message);
} finally {
setLoading(false);
}
}, []);
// Start at home directory on mount
useEffect(() => {
navigate(null);
}, [navigate]);
return (
<div className="fb-backdrop" onClick={(e) => { if (e.target === e.currentTarget) onClose(); }}>
<div className="fb-dialog">
{/* Header */}
<div className="fb-header">
<span className="fb-path">{path}</span>
{selectionMode === 'folder' && (
<button className="fb-select-btn" onClick={() => { onSelect(path); onClose(); }}>
Select Folder
</button>
)}
<button className="fb-close" onClick={onClose}></button>
</div>
{/* File list */}
<div className="fb-list">
{loading && <div className="fb-loading">Loading</div>}
{error && <div className="fb-loading">Error: {error}</div>}
{!loading && !error && (
<>
{/* Parent directory */}
{parent && (
<div className="fb-entry fb-dir" onClick={() => navigate(parent)}>
..
</div>
)}
{/* Directories */}
{dirs.map((d) => (
<div
key={d}
className="fb-entry fb-dir"
onClick={() => navigate(path + '/' + d)}
>
📁 {d}
</div>
))}
{/* Files */}
{files.map((f) => (
<div
key={f}
className={`fb-entry fb-file${selectionMode === 'folder' ? ' fb-file-disabled' : ''}`}
onClick={() => {
if (selectionMode === 'folder') return;
onSelect(path + '/' + f);
onClose();
}}
>
{f}
</div>
))}
{dirs.length === 0 && files.length === 0 && (
<div className="fb-loading">Empty directory</div>
)}
</>
)}
</div>
</div>
</div>
);
}

View File

@@ -5,49 +5,105 @@
* and production same-origin serving both work transparently. * and production same-origin serving both work transparently.
*/ */
// ── REST helpers ────────────────────────────────────────────────────── const SESSION_STORAGE_KEY = 'argonode-session-id';
let _sessionId = null;
let _ws = null;
let _handler = null;
let _reconnectTimer = null;
function generateSessionId() {
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
return crypto.randomUUID();
}
return `session-${Math.random().toString(36).slice(2)}-${Date.now().toString(36)}`;
}
export function getSessionId() {
if (_sessionId) return _sessionId;
if (typeof window === 'undefined') {
_sessionId = 'session-test-runner';
return _sessionId;
}
try {
const stored = window.sessionStorage?.getItem(SESSION_STORAGE_KEY);
if (stored) {
_sessionId = stored;
return _sessionId;
}
} catch {
// Fall through to in-memory session id generation.
}
_sessionId = generateSessionId();
try {
window.sessionStorage?.setItem(SESSION_STORAGE_KEY, _sessionId);
} catch {
// Ignore storage failures and keep the in-memory id.
}
return _sessionId;
}
function withSessionHeaders(init = {}) {
const headers = new Headers(init.headers || {});
headers.set('X-Argonode-Session', getSessionId());
return { ...init, headers };
}
async function sessionFetch(input, init) {
return fetch(input, withSessionHeaders(init));
}
export async function getNodes() { export async function getNodes() {
const r = await fetch('/nodes'); const r = await sessionFetch('/nodes');
if (!r.ok) throw new Error(`GET /nodes failed: ${r.status}`); if (!r.ok) throw new Error(`GET /nodes failed: ${r.status}`);
return r.json(); return r.json();
} }
export async function getFiles() { export async function getFiles() {
const r = await fetch('/files'); const r = await sessionFetch('/files');
if (!r.ok) return []; if (!r.ok) return [];
return r.json(); return r.json();
} }
export async function browse(dir) { export async function createUploadFolder(relativePath) {
const url = dir ? `/browse?dir=${encodeURIComponent(dir)}` : '/browse'; const r = await sessionFetch('/upload-folder', {
const r = await fetch(url); method: 'POST',
if (!r.ok) throw new Error(`Browse failed: ${r.status}`); headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ path: relativePath }),
});
if (!r.ok) throw new Error(`Create folder failed: ${r.status}`);
return r.json(); return r.json();
} }
export async function uploadFile(file) { export async function uploadFile(file, { relativePath = '' } = {}) {
const fd = new FormData(); const fd = new FormData();
if (relativePath) fd.append('relative_path', relativePath);
fd.append('file', file); fd.append('file', file);
const r = await fetch('/upload', { method: 'POST', body: fd }); const r = await sessionFetch('/upload', { method: 'POST', body: fd });
if (!r.ok) throw new Error(`Upload failed: ${r.status}`); if (!r.ok) {
const text = await r.text();
throw new Error(`Upload failed (${r.status}): ${text}`);
}
return r.json(); return r.json();
} }
export async function getChannels(filepath) { export async function getChannels(filepath) {
const r = await fetch(`/channels?file=${encodeURIComponent(filepath)}`); const r = await sessionFetch(`/channels?file=${encodeURIComponent(filepath)}`);
if (!r.ok) return [{ name: 'field', type: 'DATA_FIELD' }]; if (!r.ok) return [{ name: 'field', type: 'DATA_FIELD' }];
return r.json(); return r.json();
} }
export async function getFolderFiles(folderpath) { export async function getFolderFiles(folderpath) {
const r = await fetch(`/folder-files?folder=${encodeURIComponent(folderpath)}`); const r = await sessionFetch(`/folder-files?folder=${encodeURIComponent(folderpath)}`);
if (!r.ok) return []; if (!r.ok) return [];
return r.json(); return r.json();
} }
export async function runPrompt(prompt) { export async function runPrompt(prompt) {
const r = await fetch('/prompt', { const r = await sessionFetch('/prompt', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ prompt }), body: JSON.stringify({ prompt }),
@@ -59,21 +115,16 @@ export async function runPrompt(prompt) {
return r.json(); return r.json();
} }
// ── WebSocket ─────────────────────────────────────────────────────────
let _ws = null;
let _handler = null;
let _reconnectTimer = null;
export function setMessageHandler(fn) { export function setMessageHandler(fn) {
_handler = fn; _handler = fn;
} }
export function initWS() { export function initWS() {
if (_ws && _ws.readyState < 2) return; // already open or connecting if (_ws && _ws.readyState < 2) return;
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
_ws = new WebSocket(`${protocol}//${window.location.host}/ws`); const session = encodeURIComponent(getSessionId());
_ws = new WebSocket(`${protocol}//${window.location.host}/ws?session=${session}`);
_ws.onopen = () => { _ws.onopen = () => {
console.log('[argonode] WebSocket connected'); console.log('[argonode] WebSocket connected');

View File

@@ -0,0 +1,118 @@
const FILE_ACCEPT = [
'.png', '.jpg', '.jpeg', '.tiff', '.tif', '.bmp',
'.npy', '.npz',
'.gwy', '.sxm', '.ibw',
'.ttf', '.otf', '.woff', '.woff2',
].join(',');
function normalizeRelativePath(path) {
return String(path || '').replace(/\\/g, '/').replace(/^\/+/, '');
}
function pickWithInput({ directory = false } = {}) {
return new Promise((resolve) => {
const input = document.createElement('input');
input.type = 'file';
input.style.position = 'fixed';
input.style.left = '-9999px';
if (directory) {
input.multiple = true;
input.setAttribute('webkitdirectory', '');
input.setAttribute('directory', '');
} else {
input.accept = FILE_ACCEPT;
}
const cleanup = () => {
input.remove();
};
input.addEventListener('change', () => {
const files = Array.from(input.files || []);
cleanup();
resolve(files);
}, { once: true });
document.body.appendChild(input);
input.click();
});
}
async function collectDirectoryEntries(handle, prefix = handle.name) {
const entries = [];
for await (const [name, child] of handle.entries()) {
const relativePath = prefix ? `${prefix}/${name}` : name;
if (child.kind === 'file') {
const file = await child.getFile();
entries.push({ file, relativePath: normalizeRelativePath(relativePath) });
continue;
}
if (child.kind === 'directory') {
entries.push(...await collectDirectoryEntries(child, relativePath));
}
}
return entries;
}
export async function pickNativeFileSelection() {
try {
if (typeof window.showOpenFilePicker === 'function') {
const [handle] = await window.showOpenFilePicker({
multiple: false,
types: [{
description: 'Supported files',
accept: {
'application/octet-stream': ['.npy', '.npz', '.gwy', '.sxm', '.ibw', '.ttf', '.otf', '.woff', '.woff2'],
'image/*': ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'],
},
}],
});
if (!handle) return null;
const file = await handle.getFile();
return {
rootName: file.name,
entries: [{ file, relativePath: normalizeRelativePath(file.name) }],
};
}
} catch (error) {
if (error?.name !== 'AbortError') throw error;
return null;
}
const files = await pickWithInput({ directory: false });
if (files.length === 0) return null;
return {
rootName: files[0].name,
entries: [{ file: files[0], relativePath: normalizeRelativePath(files[0].name) }],
};
}
export async function pickNativeDirectorySelection() {
try {
if (typeof window.showDirectoryPicker === 'function') {
const handle = await window.showDirectoryPicker();
if (!handle) return null;
const entries = await collectDirectoryEntries(handle, handle.name);
return {
rootName: handle.name,
entries,
};
}
} catch (error) {
if (error?.name !== 'AbortError') throw error;
return null;
}
const files = await pickWithInput({ directory: true });
if (files.length === 0) return null;
const entries = files.map((file) => ({
file,
relativePath: normalizeRelativePath(file.webkitRelativePath || file.name),
}));
const rootName = entries[0]?.relativePath.split('/')[0] || '';
if (!rootName) return null;
return {
rootName,
entries,
};
}

View File

@@ -259,6 +259,34 @@ html, body, #root {
flex: 1; flex: 1;
} }
.group-title-button {
flex: 1;
min-width: 0;
padding: 0;
border: 0;
background: transparent;
color: var(--text-heading);
font: inherit;
font-weight: inherit;
text-align: left;
cursor: text;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.group-title-input {
flex: 1;
min-width: 0;
height: 22px;
padding: 2px 6px;
border: 1px solid rgba(148, 163, 184, 0.45);
border-radius: 4px;
background: rgba(15, 23, 42, 0.72);
color: var(--text-heading);
font: inherit;
}
.group-node-actions { .group-node-actions {
display: flex; display: flex;
align-items: center; align-items: center;

View File

@@ -9,9 +9,9 @@ export default defineConfig({
proxy: { proxy: {
'/nodes': 'http://127.0.0.1:8188', '/nodes': 'http://127.0.0.1:8188',
'/files': 'http://127.0.0.1:8188', '/files': 'http://127.0.0.1:8188',
'/browse': 'http://127.0.0.1:8188',
'/folder-files': 'http://127.0.0.1:8188', '/folder-files': 'http://127.0.0.1:8188',
'/channels': 'http://127.0.0.1:8188', '/channels': 'http://127.0.0.1:8188',
'/upload-folder': 'http://127.0.0.1:8188',
'/upload': 'http://127.0.0.1:8188', '/upload': 'http://127.0.0.1:8188',
'/download': 'http://127.0.0.1:8188', '/download': 'http://127.0.0.1:8188',
'/prompt': 'http://127.0.0.1:8188', '/prompt': 'http://127.0.0.1:8188',

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
import threading
from pathlib import Path
import pytest
from backend.execution_context import active_node, emit_warning, execution_callbacks
from backend.session_runtime import (
ensure_session_runtime_dirs,
resolve_client_path,
server_path_to_client_path,
session_upload_uri,
)
def test_session_paths_round_trip(monkeypatch, tmp_path):
monkeypatch.setenv("ARGONODE_APPDATA", str(tmp_path / "appdata"))
session_id = "session-test-1234"
input_dir, _ = ensure_session_runtime_dirs(session_id)
target = input_dir / "picked-folder" / "image.png"
target.parent.mkdir(parents=True, exist_ok=True)
target.write_bytes(b"png")
client_path = session_upload_uri("picked-folder/image.png")
resolved = resolve_client_path(client_path, session_id=session_id, allow_local_filesystem=False)
assert resolved == target.resolve()
assert server_path_to_client_path(target, session_id) == client_path
def test_browser_sessions_cannot_escape_workspace(monkeypatch, tmp_path):
monkeypatch.setenv("ARGONODE_APPDATA", str(tmp_path / "appdata"))
session_id = "session-test-5678"
ensure_session_runtime_dirs(session_id)
outside_path = (tmp_path / "outside" / "secret.dat").resolve()
with pytest.raises(PermissionError):
resolve_client_path(str(outside_path), session_id=session_id, allow_local_filesystem=False)
def test_execution_callbacks_are_thread_local():
results = []
lock = threading.Lock()
barrier = threading.Barrier(2)
def worker(label: str):
def on_warning(node_id: str, message: str):
with lock:
results.append((label, node_id, message))
with execution_callbacks(warning=on_warning):
with active_node(f"node-{label}"):
barrier.wait(timeout=5)
emit_warning(f"warning-{label}")
threads = [
threading.Thread(target=worker, args=("a",)),
threading.Thread(target=worker, args=("b",)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join(timeout=5)
assert sorted(results) == [
("a", "node-a", "warning-a"),
("b", "node-b", "warning-b"),
]