rework web server so multiple clients can be server at a time

This commit is contained in:
matei jordache
2026-03-27 16:18:22 -07:00
parent 1eda4030d1
commit 558046e7aa
33 changed files with 1042 additions and 551 deletions

View File

@@ -32,6 +32,7 @@ from time import perf_counter
from typing import Any, Callable
from backend.node_registry import NODE_CLASS_MAPPINGS
from backend.execution_context import active_node, execution_callbacks
def _is_link(value: Any) -> bool:
@@ -85,63 +86,66 @@ class ExecutionEngine:
node_outputs: dict[str, tuple] = {}
node_output_signatures: dict[str, tuple[str, ...]] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
with execution_callbacks(
preview=on_preview,
table=on_table,
mesh=on_mesh,
overlay=on_overlay,
value=on_value,
warning=on_warning,
):
for node_id in order:
node_def = prompt[node_id]
class_name = node_def["class_type"]
for node_id in order:
node_def = prompt[node_id]
class_name = node_def["class_type"]
if class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Unknown node type: '{class_name}'")
if class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Unknown node type: '{class_name}'")
cls = NODE_CLASS_MAPPINGS[class_name]
raw_inputs = node_def.get("inputs", {})
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
cls = NODE_CLASS_MAPPINGS[class_name]
raw_inputs = node_def.get("inputs", {})
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
if cache_entry is not None:
result = self._clone_cached_outputs(cache_entry["outputs"])
elapsed_ms = 0.0
else:
if on_node_start:
on_node_start(node_id)
# Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id)
instance = cls()
func = getattr(instance, cls.FUNCTION)
start_time = perf_counter()
with active_node(node_id):
result = func(**inputs)
elapsed_ms = (perf_counter() - start_time) * 1000.0
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
if cache_entry is not None:
result = self._clone_cached_outputs(cache_entry["outputs"])
elapsed_ms = 0.0
else:
if on_node_start:
on_node_start(node_id)
# Nodes must return a tuple; coerce single values just in case
if not isinstance(result, tuple):
result = (result,)
instance = cls()
func = getattr(instance, cls.FUNCTION)
start_time = perf_counter()
result = func(**inputs)
elapsed_ms = (perf_counter() - start_time) * 1000.0
node_outputs[node_id] = result
output_signatures = tuple(self._fingerprint_value(value) for value in result)
node_output_signatures[node_id] = output_signatures
# Nodes must return a tuple; coerce single values just in case
if not isinstance(result, tuple):
result = (result,)
if cache_entry is None and self._node_cacheable(cls):
self._store_cache_entry(
node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
node_outputs[node_id] = result
output_signatures = tuple(self._fingerprint_value(value) for value in result)
node_output_signatures[node_id] = output_signatures
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
if cache_entry is None and self._node_cacheable(cls):
self._store_cache_entry(
node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
if on_node_done:
on_node_done(node_id, elapsed_ms)
if on_node_done:
on_node_done(node_id, elapsed_ms)
return node_outputs
@@ -421,88 +425,6 @@ class ExecutionEngine:
return deepcopy(value)
return value
def _inject_display_callbacks(
self,
on_preview: Callable | None,
on_table: Callable | None,
on_mesh: Callable | None = None,
on_overlay: Callable | None = None,
on_value: Callable | None = None,
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.preview_image import PreviewImage
from backend.nodes.print_table import PrintTable
from backend.nodes.view_3d import View3D
from backend.nodes.annotations import Annotations
from backend.nodes.value_display import ValueDisplay
from backend.nodes.markup import Markup
from backend.nodes.cross_section import CrossSection
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
from backend.nodes.histogram import Histogram
from backend.nodes.crop_resize_field import CropResizeField
from backend.nodes.rotate_field import RotateField
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.mask_morphology import MaskMorphology
from backend.nodes.mask_invert import MaskInvert
from backend.nodes.mask_combine import MaskCombine
from backend.nodes.draw_mask import DrawMask
from backend.nodes.save import Save
from backend.nodes.save_image import SaveImage
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
MaskMorphology._broadcast_fn = on_preview
MaskInvert._broadcast_fn = on_preview
MaskCombine._broadcast_fn = on_preview
DrawMask._broadcast_overlay_fn = on_overlay
View3D._broadcast_mesh_fn = on_mesh
Annotations._broadcast_warning_fn = on_warning
PrintTable._broadcast_table_fn = on_table
ValueDisplay._broadcast_value_fn = on_value
Stats._broadcast_value_fn = on_value
Histogram._broadcast_overlay_fn = on_overlay
CrossSection._broadcast_overlay_fn = on_overlay
Cursors._broadcast_overlay_fn = on_overlay
CropResizeField._broadcast_overlay_fn = on_overlay
RotateField._broadcast_warning_fn = on_warning
Markup._broadcast_overlay_fn = on_overlay
Image._broadcast_warning_fn = on_warning
ImageDemo._broadcast_warning_fn = on_warning
Save._broadcast_warning_fn = on_warning
SaveImage._broadcast_warning_fn = on_warning
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.preview_image import PreviewImage
from backend.nodes.print_table import PrintTable
from backend.nodes.view_3d import View3D
from backend.nodes.annotations import Annotations
from backend.nodes.value_display import ValueDisplay
from backend.nodes.markup import Markup
from backend.nodes.cross_section import CrossSection
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
from backend.nodes.histogram import Histogram
from backend.nodes.crop_resize_field import CropResizeField
from backend.nodes.rotate_field import RotateField
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.mask_morphology import MaskMorphology
from backend.nodes.mask_invert import MaskInvert
from backend.nodes.mask_combine import MaskCombine
from backend.nodes.draw_mask import DrawMask
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
from backend.nodes.save import Save
from backend.nodes.save_image import SaveImage
if cls in (PreviewImage, PrintTable, View3D, Annotations, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
Image, ImageDemo, Save, SaveImage):
cls._current_node_id = node_id
def _auto_preview(
self,
cls: type,

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Callable
Callback = Callable[[str, Any], None]
_callbacks_var: ContextVar[dict[str, Callback | None]] = ContextVar(
"argonode_execution_callbacks",
default={},
)
_node_id_var: ContextVar[str | None] = ContextVar("argonode_execution_node_id", default=None)
@contextmanager
def execution_callbacks(
*,
preview: Callback | None = None,
table: Callback | None = None,
mesh: Callback | None = None,
overlay: Callback | None = None,
value: Callback | None = None,
warning: Callback | None = None,
):
token = _callbacks_var.set({
"preview": preview,
"table": table,
"mesh": mesh,
"overlay": overlay,
"value": value,
"warning": warning,
})
try:
yield
finally:
_callbacks_var.reset(token)
@contextmanager
def active_node(node_id: str):
token = _node_id_var.set(str(node_id))
try:
yield
finally:
_node_id_var.reset(token)
def current_node_id() -> str | None:
return _node_id_var.get()
def _emit(kind: str, payload: Any) -> None:
callbacks = _callbacks_var.get()
callback = callbacks.get(kind)
node_id = current_node_id()
if callback is not None and node_id:
callback(node_id, payload)
def emit_preview(payload: Any) -> None:
_emit("preview", payload)
def emit_table(rows: list) -> None:
_emit("table", rows)
def emit_mesh(mesh: dict) -> None:
_emit("mesh", mesh)
def emit_overlay(overlay: dict) -> None:
_emit("overlay", overlay)
def emit_value(payload: Any) -> None:
_emit("value", payload)
def emit_warning(message: str) -> None:
_emit("warning", message)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import (
COLORMAPS,
DataField,
@@ -120,7 +121,4 @@ class Annotations:
return (ImageData(annotated, metadata={"annotation_context": context}),)
def _send_warning(self, message: str):
fn = Annotations._broadcast_warning_fn
nid = Annotations._current_node_id
if fn and nid:
fn(nid, message)
emit_warning(message)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, datafield_to_uint8, encode_preview
@@ -61,20 +62,16 @@ class CropResizeField:
x2 = float(np.clip(x2, 0.0, 1.0))
y2 = float(np.clip(y2, 0.0, 1.0))
if CropResizeField._broadcast_overlay_fn is not None:
CropResizeField._broadcast_overlay_fn(
CropResizeField._current_node_id,
{
"kind": "crop_box",
"image": encode_preview(datafield_to_uint8(field, field.colormap)),
"x1": x1,
"y1": y1,
"x2": x2,
"y2": y2,
"a_locked": corner_a is not None,
"b_locked": corner_b is not None,
},
)
emit_overlay({
"kind": "crop_box",
"image": encode_preview(datafield_to_uint8(field, field.colormap)),
"x1": x1,
"y1": y1,
"x2": x2,
"y2": y2,
"a_locked": corner_a is not None,
"b_locked": corner_b is not None,
})
left = min(x1, x2)
right = max(x1, x2)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _extend_to_edges
@@ -73,19 +74,14 @@ class CrossSection:
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
if CrossSection._broadcast_overlay_fn is not None:
image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
CrossSection._broadcast_overlay_fn(
CrossSection._current_node_id,
{
"image": image_uri,
"x1": marker_x1, "y1": marker_y1,
"x2": marker_x2, "y2": marker_y2,
"a_locked": marker_pair is not None,
"b_locked": marker_pair is not None,
},
)
image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
emit_overlay({
"image": image_uri,
"x1": marker_x1, "y1": marker_y1,
"x2": marker_x2, "y2": marker_y2,
"a_locked": marker_pair is not None,
"b_locked": marker_pair is not None,
})
dx_real = (x2 - x1) * field.xreal
dy_real = (y2 - y1) * field.yreal

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, LineData, MeasureTable, encode_preview, render_datafield_preview
@@ -87,22 +88,18 @@ class Cursors:
xa, ya = float(x[idx_a]), float(y[idx_a])
xb, yb = float(x[idx_b]), float(y[idx_b])
if Cursors._broadcast_overlay_fn is not None:
Cursors._broadcast_overlay_fn(
Cursors._current_node_id,
{
"kind": "line_plot",
"section_title": "Cursors",
"line": y.tolist(),
"x_axis": x.tolist(),
"x1": x1,
"x2": x2,
"y1": float(y1),
"y2": float(y2),
"a_locked": locked,
"b_locked": locked,
},
)
emit_overlay({
"kind": "line_plot",
"section_title": "Cursors",
"line": y.tolist(),
"x_axis": x.tolist(),
"x1": x1,
"x2": x2,
"y1": float(y1),
"y2": float(y2),
"a_locked": locked,
"b_locked": locked,
})
table = MeasureTable([
{"quantity": "A x", "value": xa, "unit": x_unit},
@@ -143,21 +140,17 @@ class Cursors:
bx = float(field.xoff + x2 * field.xreal)
by = float(field.yoff + y2 * field.yreal)
if Cursors._broadcast_overlay_fn is not None:
Cursors._broadcast_overlay_fn(
Cursors._current_node_id,
{
"kind": "cursor_points",
"section_title": "Cursors",
"image": encode_preview(render_datafield_preview(field, field.colormap)),
"x1": x1,
"y1": y1,
"x2": x2,
"y2": y2,
"a_locked": locked,
"b_locked": locked,
},
)
emit_overlay({
"kind": "cursor_points",
"section_title": "Cursors",
"image": encode_preview(render_datafield_preview(field, field.colormap)),
"x1": x1,
"y1": y1,
"x2": x2,
"y2": y2,
"a_locked": locked,
"b_locked": locked,
})
table = MeasureTable([
{"quantity": "A x", "value": ax, "unit": field.si_unit_xy},

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _parse_mask_strokes, _rasterize_mask
@@ -40,17 +41,13 @@ class DrawMask:
if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None:
DrawMask._broadcast_overlay_fn(
DrawMask._current_node_id,
{
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
},
)
emit_overlay({
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
})
return (mask,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import DataField, MeasureTable
@@ -72,22 +73,18 @@ class Histogram:
yb = float(counts[idx_b]) if len(counts) else 0.0
count_unit = "count" if y_scale == "linear" else "log10(1+count)"
if Histogram._broadcast_overlay_fn is not None:
Histogram._broadcast_overlay_fn(
Histogram._current_node_id,
{
"kind": "line_plot",
"section_title": "Histogram",
"line": counts.tolist(),
"x_axis": bin_centers.astype(np.float64).tolist(),
"x1": float(np.clip(x1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y1": float(y1),
"y2": float(y2),
"a_locked": False,
"b_locked": False,
},
)
emit_overlay({
"kind": "line_plot",
"section_title": "Histogram",
"line": counts.tolist(),
"x_axis": bin_centers.astype(np.float64).tolist(),
"x1": float(np.clip(x1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y1": float(y1),
"y2": float(y2),
"a_locked": False,
"b_locked": False,
})
table = MeasureTable([
{"quantity": "A position", "value": xa, "unit": field.si_unit_z},

View File

@@ -4,6 +4,7 @@ import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import COLORMAPS, DataField, resolve_colormap_input
from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS, _import_ibw_loader
@@ -66,10 +67,7 @@ class Image:
return fields
def _send_warning(self, message: str):
fn = Image._broadcast_warning_fn
nid = Image._current_node_id
if fn and nid:
fn(nid, message)
emit_warning(message)
@staticmethod
@lru_cache(maxsize=32)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.execution_context import emit_overlay
from backend.data_types import (
DataField,
ImageData,
@@ -70,17 +71,13 @@ class Markup:
metadata=image_metadata(input),
)
if Markup._broadcast_overlay_fn is not None:
Markup._broadcast_overlay_fn(
Markup._current_node_id,
{
"kind": "markup",
"section_title": "Markup",
"image": encode_preview(preview_base),
"shape": str(shape),
"stroke_color": _normalize_markup_color(stroke_color),
"stroke_width": max(1, int(stroke_width)),
},
)
emit_overlay({
"kind": "markup",
"section_title": "Markup",
"image": encode_preview(preview_base),
"shape": str(shape),
"stroke_color": _normalize_markup_color(stroke_color),
"stroke_width": max(1, int(stroke_width)),
})
return (out,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@@ -53,10 +54,8 @@ class MaskCombine:
out = result.astype(np.uint8) * 255
if field is not None and MaskCombine._broadcast_fn is not None:
if field is not None:
overlay = _mask_overlay(field, out)
MaskCombine._broadcast_fn(
MaskCombine._current_node_id, encode_preview(overlay),
)
emit_preview(encode_preview(overlay))
return (out,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@@ -32,10 +33,8 @@ class MaskInvert:
def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
out = np.where(mask > 127, np.uint8(0), np.uint8(255))
if field is not None and MaskInvert._broadcast_fn is not None:
if field is not None:
overlay = _mask_overlay(field, out)
MaskInvert._broadcast_fn(
MaskInvert._current_node_id, encode_preview(overlay),
)
emit_preview(encode_preview(overlay))
return (out,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay, _mask_structure
@@ -62,10 +63,8 @@ class MaskMorphology:
out = result.astype(np.uint8) * 255
if field is not None and MaskMorphology._broadcast_fn is not None:
if field is not None:
overlay = _mask_overlay(field, out)
MaskMorphology._broadcast_fn(
MaskMorphology._current_node_id, encode_preview(overlay),
)
emit_preview(encode_preview(overlay))
return (out,)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import (
COLORMAPS,
colormap_to_uint8,
@@ -68,7 +69,6 @@ class PreviewImage:
data_uri = encode_preview(arr_u8)
if PreviewImage._broadcast_fn is not None:
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
emit_preview(data_uri)
return ()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.execution_context import emit_table
@register_node(display_name="Print Table")
@@ -22,6 +23,5 @@ class PrintTable:
_current_node_id: str = ""
def print_table(self, table: list) -> tuple:
if PrintTable._broadcast_table_fn is not None:
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
emit_table(table)
return ()

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import DataField
@@ -84,10 +85,7 @@ class RotateField:
return (result,)
def _send_warning(self, message: str):
fn = RotateField._broadcast_warning_fn
nid = RotateField._current_node_id
if fn and nid:
fn(nid, message)
emit_warning(message)
@staticmethod
def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]:

View File

@@ -7,6 +7,7 @@ from pathlib import Path
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8
@@ -255,7 +256,4 @@ class Save:
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def _send_warning(self, message: str):
fn = Save._broadcast_warning_fn
nid = Save._current_node_id
if fn and nid:
fn(nid, message)
emit_warning(message)

View File

@@ -4,6 +4,7 @@ import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import DataField, image_to_uint8
from backend.nodes.helpers import _MAX_SAVE_FIELDS
@@ -174,9 +175,6 @@ class SaveImage:
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
emit_warning(message)
return ()

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_value
from backend.data_types import DataField, LineData, MeasureTable
from backend.nodes.helpers import (
LINE_OPS,
@@ -71,11 +72,9 @@ class Stats:
op_entry = ops[operation]
fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry
result = fn(values)
if Stats._broadcast_value_fn is not None:
Stats._broadcast_value_fn(
Stats._current_node_id,
_scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)),
)
emit_value(
_scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)),
)
return (result,)
def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_preview
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@@ -52,10 +53,7 @@ class ThresholdMask:
else:
mask = (data < t).astype(np.uint8) * 255
if ThresholdMask._broadcast_fn is not None:
overlay = _mask_overlay(field, mask)
ThresholdMask._broadcast_fn(
ThresholdMask._current_node_id, encode_preview(overlay),
)
overlay = _mask_overlay(field, mask)
emit_preview(encode_preview(overlay))
return (mask,)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.execution_context import emit_value
from backend.data_types import MeasureTable
from backend.nodes.helpers import _measurement_entry, _measurement_value, _scalar_payload
@@ -38,6 +39,5 @@ class ValueDisplay:
unit = row.get("unit", "") if isinstance(row.get("unit"), str) else ""
else:
numeric = float(value)
if ValueDisplay._broadcast_value_fn is not None:
ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit))
emit_value(_scalar_payload(numeric, unit))
return (numeric,)

View File

@@ -3,6 +3,7 @@ import base64
import io
import numpy as np
from backend.node_registry import register_node
from backend.execution_context import emit_mesh
from backend.data_types import (
COLORMAPS,
DataField,
@@ -211,8 +212,7 @@ class View3D:
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}
if View3D._broadcast_mesh_fn is not None:
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
emit_mesh(mesh_data)
annotation_context = _annotation_context_from_field(color_field, resolved_colormap)
annotation_context["xreal"] = float(field.xreal)

View File

@@ -6,7 +6,11 @@ Routes
GET / → serve frontend/index.html
GET /static/{path} → serve frontend JS/CSS
GET /nodes → JSON dict of all registered node definitions
POST /uploadmultipart file upload to input/
GET /files list files in the current session upload workspace
GET /folder-files → list compatible files in a picked folder
GET /channels → inspect channels for a picked file
POST /upload → multipart file upload to the current session workspace
POST /upload-folder → create a folder in the current session workspace
POST /prompt → submit a workflow; returns {prompt_id}
GET /ws → WebSocket upgrade
@@ -15,7 +19,7 @@ WebSocket message types sent to clients
{"type": "execution_start", "data": {"prompt_id": "..."}}
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
{"type": "table", "data": {"node_id": "...", "rows": [...]}}
{"type": "table", "data": {"node_id": "...", "rows": [...]} }
{"type": "scalar", "data": {"node_id": "...", "value": 1.23, "unit": "nm"}}
{"type": "node_timing", "data": {"node_id": "...", "elapsed_ms": 12.34}}
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
@@ -23,39 +27,43 @@ WebSocket message types sent to clients
"""
from __future__ import annotations
import asyncio
import json
import logging
import sys
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from aiohttp import web, WSMsgType
from backend.frontend_build import FrontendBuildError, ensure_frontend_dist_ready
from backend.runtime_paths import (
ensure_runtime_dirs,
frontend_dir,
frontend_dist_dir,
input_dir,
output_dir,
project_root,
from backend.runtime_paths import ensure_runtime_dirs, frontend_dir, frontend_dist_dir, project_root
from backend.session_runtime import (
PATH_INPUT_TYPES,
SESSION_HEADER,
SESSION_QUERY,
ensure_session_runtime_dirs,
normalize_relative_upload_path,
resolve_client_path,
server_path_to_client_path,
session_input_dir,
session_upload_uri,
validate_session_id,
)
log = logging.getLogger(__name__)
FRONTEND_DIR = frontend_dir()
DIST_DIR = frontend_dist_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
# ---------------------------------------------------------------------------
# JSON helper — numpy scalars are not serialisable by default
# ---------------------------------------------------------------------------
class _SafeEncoder(json.JSONEncoder):
def default(self, obj):
import numpy as np
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
@@ -81,45 +89,115 @@ def save_png_bytes(target_path: str, payload: bytes) -> Path:
return path
# ---------------------------------------------------------------------------
# Application factory
# ---------------------------------------------------------------------------
def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
# Import nodes to trigger registration decorators
def create_app(
loop: asyncio.AbstractEventLoop,
*,
allow_local_filesystem: bool = False,
) -> web.Application:
import backend.nodes # noqa: F401
from backend.node_registry import get_all_node_info
from backend.execution import ExecutionEngine, new_prompt_id
from backend.node_registry import NODE_CLASS_MAPPINGS, get_all_node_info
ensure_runtime_dirs()
engine = ExecutionEngine()
websockets: set[web.WebSocketResponse] = set()
session_engines: dict[str, ExecutionEngine] = {}
session_websockets: dict[str, set[web.WebSocketResponse]] = defaultdict(set)
# ------------------------------------------------------------------
# WebSocket broadcast helpers
# ------------------------------------------------------------------
def _is_link(value) -> bool:
return (
isinstance(value, (list, tuple))
and len(value) == 2
and isinstance(value[0], str)
and isinstance(value[1], int)
)
def broadcast(msg: dict) -> None:
"""Schedule a broadcast to all connected WebSocket clients."""
def require_session_id(request: web.Request) -> str:
raw_session = request.headers.get(SESSION_HEADER) or request.query.get(SESSION_QUERY)
if not raw_session:
if allow_local_filesystem:
raw_session = "desktop-local-session"
else:
raise web.HTTPBadRequest(reason="Missing session id")
try:
session_id = validate_session_id(raw_session)
except ValueError as exc:
raise web.HTTPBadRequest(reason=str(exc)) from exc
ensure_session_runtime_dirs(session_id)
return session_id
def get_session_engine(session_id: str) -> ExecutionEngine:
engine = session_engines.get(session_id)
if engine is None:
engine = ExecutionEngine()
session_engines[session_id] = engine
return engine
def resolve_request_path(session_id: str, raw_value: str) -> Path:
try:
return resolve_client_path(
raw_value,
session_id=session_id,
allow_local_filesystem=allow_local_filesystem,
)
except PermissionError as exc:
raise web.HTTPForbidden(reason=str(exc)) from exc
except ValueError as exc:
raise web.HTTPBadRequest(reason=str(exc)) from exc
def rewrite_prompt_paths(prompt: dict, session_id: str) -> dict:
normalized = deepcopy(prompt)
for node_def in normalized.values():
class_name = node_def.get("class_type")
cls = NODE_CLASS_MAPPINGS.get(class_name)
if cls is None:
continue
input_types = cls.INPUT_TYPES()
specs = {}
specs.update(input_types.get("required", {}))
specs.update(input_types.get("optional", {}))
inputs = node_def.get("inputs", {})
if not isinstance(inputs, dict):
continue
for input_name, raw_value in list(inputs.items()):
if _is_link(raw_value) or not isinstance(raw_value, str):
continue
if not raw_value.strip():
continue
spec = specs.get(input_name)
input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec
if not isinstance(input_type, str):
continue
if input_type not in PATH_INPUT_TYPES:
continue
inputs[input_name] = str(resolve_request_path(session_id, raw_value))
return normalized
def broadcast(session_id: str, msg: dict) -> None:
payload = _dumps(msg)
for ws in list(websockets):
for ws in list(session_websockets.get(session_id, ())):
if not ws.closed:
asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop)
def on_preview(node_id: str, data_uri: str) -> None:
broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
def on_preview(session_id: str, node_id: str, data_uri: str) -> None:
broadcast(session_id, {"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
def on_table(node_id: str, rows: list) -> None:
broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}})
def on_table(session_id: str, node_id: str, rows: list) -> None:
broadcast(session_id, {"type": "table", "data": {"node_id": node_id, "rows": rows}})
def on_mesh(node_id: str, mesh_data: dict) -> None:
broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
def on_mesh(session_id: str, node_id: str, mesh_data: dict) -> None:
broadcast(session_id, {"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
def on_overlay(node_id: str, overlay_data) -> None:
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
def on_overlay(session_id: str, node_id: str, overlay_data) -> None:
broadcast(session_id, {"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
def on_value(node_id: str, payload) -> None:
def on_value(session_id: str, node_id: str, payload) -> None:
if isinstance(payload, dict):
value = payload.get("value")
unit = payload.get("unit", "")
@@ -130,14 +208,10 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
data = {"node_id": node_id, "value": value}
if isinstance(unit, str) and unit.strip():
data["unit"] = unit.strip()
broadcast({"type": "scalar", "data": data})
broadcast(session_id, {"type": "scalar", "data": data})
def on_warning(node_id: str, message: str) -> None:
broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}})
# ------------------------------------------------------------------
# Route handlers
# ------------------------------------------------------------------
def on_warning(session_id: str, node_id: str, message: str) -> None:
broadcast(session_id, {"type": "node_warning", "data": {"node_id": node_id, "message": message}})
async def index(request: web.Request) -> web.Response:
if not getattr(sys, "frozen", False):
@@ -167,88 +241,96 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
)
async def get_nodes(request: web.Request) -> web.Response:
info = get_all_node_info()
return web.Response(
text=_dumps(info),
text=_dumps(get_all_node_info()),
content_type="application/json",
)
async def list_files(request: web.Request) -> web.Response:
"""List files in the input/ directory for the file picker widget."""
session_id = require_session_id(request)
input_path = session_input_dir(session_id)
files = sorted(
f.name for f in INPUT_DIR.iterdir()
if f.is_file() and not f.name.startswith(".")
) if INPUT_DIR.exists() else []
server_path_to_client_path(entry, session_id)
for entry in input_path.iterdir()
if entry.is_file() and not entry.name.startswith(".")
) if input_path.exists() else []
return web.Response(text=_dumps(files), content_type="application/json")
async def browse_dir(request: web.Request) -> web.Response:
"""
Server-side directory browser for local file picking.
GET /browse?dir=/some/path → {parent, dirs[], files[]}
"""
dir_path = request.query.get("dir", str(Path.home()))
p = Path(dir_path).expanduser().resolve()
if not p.is_dir():
raise web.HTTPBadRequest(reason=f"Not a directory: {p}")
dirs = []
files = []
try:
for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()):
if entry.name.startswith("."):
continue
if entry.is_dir():
dirs.append(entry.name)
elif entry.is_file():
files.append(entry.name)
except PermissionError:
pass
async def create_upload_folder(request: web.Request) -> web.Response:
session_id = require_session_id(request)
body = await request.json()
relative_path = normalize_relative_upload_path(body.get("path", ""))
target = session_input_dir(session_id) / Path(relative_path.as_posix())
target.mkdir(parents=True, exist_ok=True)
return web.Response(
text=_dumps({
"path": str(p),
"parent": str(p.parent) if p.parent != p else None,
"dirs": dirs,
"files": files,
}),
text=_dumps({"path": session_upload_uri(relative_path)}),
content_type="application/json",
)
async def get_folder_files(request: web.Request) -> web.Response:
folder_path = request.query.get("folder", "")
from backend.nodes.helpers import list_folder_paths
loop = asyncio.get_running_loop()
entries = await loop.run_in_executor(None, list_folder_paths, folder_path)
return web.Response(text=_dumps(entries), content_type="application/json")
session_id = require_session_id(request)
folder_path = request.query.get("folder", "")
if not folder_path:
return web.Response(text=_dumps([]), content_type="application/json")
resolved_path = resolve_request_path(session_id, folder_path)
running_loop = asyncio.get_running_loop()
entries = await running_loop.run_in_executor(None, list_folder_paths, str(resolved_path))
payload = []
for entry in entries:
mapped = dict(entry)
if "path" in mapped:
mapped["path"] = server_path_to_client_path(mapped["path"], session_id)
payload.append(mapped)
return web.Response(text=_dumps(payload), content_type="application/json")
async def upload_file(request: web.Request) -> web.Response:
session_id = require_session_id(request)
reader = await request.multipart()
field = await reader.next()
if field is None or field.name != "file":
relative_path = None
filename = ""
file_bytes = None
while True:
field = await reader.next()
if field is None:
break
if field.name == "relative_path":
relative_path = await field.text()
continue
if field.name == "file":
filename = Path(field.filename or "upload.bin").name
chunks = []
while True:
chunk = await field.read_chunk(65536)
if not chunk:
break
chunks.append(chunk)
file_bytes = b"".join(chunks)
if file_bytes is None:
raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body")
filename = Path(field.filename).name # strip any path traversal
dest = INPUT_DIR / filename
with open(dest, "wb") as f:
while True:
chunk = await field.read_chunk(65536)
if not chunk:
break
f.write(chunk)
relative = normalize_relative_upload_path(relative_path or filename)
dest = session_input_dir(session_id) / Path(relative.as_posix())
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(file_bytes)
return web.Response(text=_dumps({"filename": filename}), content_type="application/json")
return web.Response(
text=_dumps({"filename": filename, "path": session_upload_uri(relative)}),
content_type="application/json",
)
async def download_file(request: web.Request) -> web.Response:
"""Accept a blob POST and return it with Content-Disposition: attachment."""
body = await request.read()
filename = request.query.get("filename", "workflow.png")
return web.Response(
body=body,
content_type="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
},
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
async def save_workflow_png(request: web.Request) -> web.Response:
@@ -266,34 +348,39 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
)
async def get_channels(request: web.Request) -> web.Response:
"""Return available channels for a given file path."""
from backend.nodes.helpers import list_channels
session_id = require_session_id(request)
filepath = request.query.get("file", "")
if not filepath:
return web.Response(
text=_dumps([{"name": "field", "type": "DATA_FIELD"}]),
content_type="application/json",
)
channels = await loop.run_in_executor(None, list_channels, filepath)
resolved_path = resolve_request_path(session_id, filepath)
channels = await loop.run_in_executor(None, list_channels, str(resolved_path))
return web.Response(text=_dumps(channels), content_type="application/json")
async def submit_prompt(request: web.Request) -> web.Response:
session_id = require_session_id(request)
body = await request.json()
prompt = body.get("prompt")
if not isinstance(prompt, dict) or not prompt:
raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict")
normalized_prompt = rewrite_prompt_paths(prompt, session_id)
prompt_id = new_prompt_id()
engine = get_session_engine(session_id)
# Run execution in a thread pool so scipy doesn't block the event loop
async def run():
broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}})
broadcast(session_id, {"type": "execution_start", "data": {"prompt_id": prompt_id}})
def on_start(node_id: str) -> None:
broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
broadcast(session_id, {"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
def on_done(node_id: str, elapsed_ms: float) -> None:
broadcast({
broadcast(session_id, {
"type": "node_timing",
"data": {"node_id": node_id, "elapsed_ms": elapsed_ms},
})
@@ -302,21 +389,21 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
await loop.run_in_executor(
None,
lambda: engine.execute(
prompt,
normalized_prompt,
on_node_start=on_start,
on_node_done=on_done,
on_preview=on_preview,
on_table=on_table,
on_mesh=on_mesh,
on_overlay=on_overlay,
on_value=on_value,
on_warning=on_warning,
on_preview=lambda node_id, payload: on_preview(session_id, node_id, payload),
on_table=lambda node_id, rows: on_table(session_id, node_id, rows),
on_mesh=lambda node_id, mesh_data: on_mesh(session_id, node_id, mesh_data),
on_overlay=lambda node_id, overlay_data: on_overlay(session_id, node_id, overlay_data),
on_value=lambda node_id, payload: on_value(session_id, node_id, payload),
on_warning=lambda node_id, message: on_warning(session_id, node_id, message),
),
)
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}})
broadcast(session_id, {"type": "execution_complete", "data": {"prompt_id": prompt_id}})
except Exception as exc:
log.exception("Execution error")
broadcast({
broadcast(session_id, {
"type": "execution_error",
"data": {"node_id": "", "message": str(exc)},
})
@@ -328,32 +415,40 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
)
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
session_id = require_session_id(request)
ws = web.WebSocketResponse()
await ws.prepare(request)
websockets.add(ws)
log.info("WebSocket client connected (%d total)", len(websockets))
session_websockets[session_id].add(ws)
log.info(
"WebSocket client connected for session %s (%d total in session)",
session_id,
len(session_websockets[session_id]),
)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
pass # clients don't need to send anything currently
pass
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
break
finally:
websockets.discard(ws)
log.info("WebSocket client disconnected (%d total)", len(websockets))
session_websockets[session_id].discard(ws)
if not session_websockets[session_id]:
session_websockets.pop(session_id, None)
log.info(
"WebSocket client disconnected for session %s (%d remaining in session)",
session_id,
len(session_websockets.get(session_id, ())),
)
return ws
# ------------------------------------------------------------------
# App assembly
# ------------------------------------------------------------------
app = web.Application()
app["allow_local_filesystem"] = allow_local_filesystem
app.router.add_get("/", index)
app.router.add_get("/nodes", get_nodes)
app.router.add_get("/files", list_files)
app.router.add_get("/browse", browse_dir)
app.router.add_get("/folder-files", get_folder_files)
app.router.add_post("/upload-folder", create_upload_folder)
app.router.add_post("/upload", upload_file)
app.router.add_post("/download", download_file)
app.router.add_post("/save-workflow-png", save_workflow_png)
@@ -361,26 +456,24 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
app.router.add_post("/prompt", submit_prompt)
app.router.add_get("/ws", websocket_handler)
# Serve frontend static files (Vite build or raw)
if (DIST_DIR / "assets").exists():
app.router.add_static("/assets", DIST_DIR / "assets")
if FRONTEND_DIR.exists():
app.router.add_static("/static", FRONTEND_DIR)
# CORS — allow any origin (local dev only)
async def _cors_middleware(app_, handler):
async def middleware(request):
if request.method == "OPTIONS":
return web.Response(headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Allow-Headers": f"Content-Type, {SESSION_HEADER}",
})
response = await handler(request)
response.headers["Access-Control-Allow-Origin"] = "*"
return response
return middleware
app.middlewares.append(_cors_middleware)
return app

132
backend/session_runtime.py Normal file
View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import re
from pathlib import Path, PurePosixPath
from backend.runtime_paths import app_data_dir, demo_dir
SESSION_HEADER = "X-Argonode-Session"
SESSION_QUERY = "session"
SESSION_URI_PREFIX = "session://uploads/"
PATH_INPUT_TYPES = {"FILE_PICKER", "FILE_PATH", "FOLDER_PICKER", "DIRECTORY"}
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_-]{7,127}$")
def validate_session_id(session_id: str) -> str:
text = str(session_id or "").strip()
if not _SESSION_ID_RE.fullmatch(text):
raise ValueError("Invalid session id")
return text
def session_root_dir(session_id: str) -> Path:
validated = validate_session_id(session_id)
return app_data_dir() / "sessions" / validated
def session_input_dir(session_id: str) -> Path:
return session_root_dir(session_id) / "input"
def session_output_dir(session_id: str) -> Path:
return session_root_dir(session_id) / "output"
def ensure_session_runtime_dirs(session_id: str) -> tuple[Path, Path]:
input_path = session_input_dir(session_id)
output_path = session_output_dir(session_id)
input_path.mkdir(parents=True, exist_ok=True)
output_path.mkdir(parents=True, exist_ok=True)
return input_path, output_path
def normalize_relative_upload_path(raw_path: str) -> PurePosixPath:
raw_text = str(raw_path or "").replace("\\", "/").strip()
if not raw_text:
raise ValueError("Missing upload path")
path = PurePosixPath(raw_text)
if path.is_absolute():
raise ValueError("Upload paths must be relative")
parts: list[str] = []
for part in path.parts:
if part in ("", "."):
continue
if part == "..":
raise ValueError("Upload paths cannot escape the session directory")
if "\x00" in part:
raise ValueError("Upload paths cannot contain NUL bytes")
parts.append(part)
if not parts:
raise ValueError("Upload paths must contain at least one path segment")
return PurePosixPath(*parts)
def session_upload_uri(relative_path: str | PurePosixPath) -> str:
normalized = normalize_relative_upload_path(str(relative_path))
return f"{SESSION_URI_PREFIX}{normalized.as_posix()}"
def session_uri_to_relative_path(value: str) -> PurePosixPath | None:
text = str(value or "").strip()
if not text.startswith(SESSION_URI_PREFIX):
return None
return normalize_relative_upload_path(text[len(SESSION_URI_PREFIX):])
def is_path_within(root: Path, candidate: Path) -> bool:
try:
candidate.resolve(strict=False).relative_to(root.resolve(strict=False))
return True
except ValueError:
return False
def server_path_to_client_path(path_value: str | Path, session_id: str) -> str:
path = Path(path_value).expanduser().resolve(strict=False)
session_input = session_input_dir(session_id).resolve(strict=False)
if is_path_within(session_input, path):
rel = path.relative_to(session_input)
return session_upload_uri(rel.as_posix())
return str(path)
def resolve_client_path(
value: str,
*,
session_id: str,
allow_local_filesystem: bool,
) -> Path:
text = str(value or "").strip()
if not text:
return Path("")
rel = session_uri_to_relative_path(text)
if rel is not None:
return (session_input_dir(session_id) / Path(rel.as_posix())).resolve(strict=False)
candidate = Path(text).expanduser()
if not candidate.is_absolute():
demo_candidate = (demo_dir() / text).expanduser().resolve(strict=False)
if demo_candidate.exists():
return demo_candidate
if not candidate.is_absolute():
if allow_local_filesystem:
return candidate.resolve(strict=False)
raise PermissionError("Browser sessions may only use files uploaded through Browse.")
resolved = candidate.resolve(strict=False)
if allow_local_filesystem:
return resolved
session_root = session_root_dir(session_id).resolve(strict=False)
if is_path_within(session_root, resolved):
return resolved
raise PermissionError("Path is outside the current session workspace.")