rework web server so multiple clients can be server at a time

This commit is contained in:
matei jordache
2026-03-27 16:18:22 -07:00
parent 1eda4030d1
commit 558046e7aa
33 changed files with 1042 additions and 551 deletions

View File

@@ -32,6 +32,7 @@ from time import perf_counter
from typing import Any, Callable
from backend.node_registry import NODE_CLASS_MAPPINGS
from backend.execution_context import active_node, execution_callbacks
def _is_link(value: Any) -> bool:
@@ -85,63 +86,66 @@ class ExecutionEngine:
node_outputs: dict[str, tuple] = {}
node_output_signatures: dict[str, tuple[str, ...]] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
with execution_callbacks(
preview=on_preview,
table=on_table,
mesh=on_mesh,
overlay=on_overlay,
value=on_value,
warning=on_warning,
):
for node_id in order:
node_def = prompt[node_id]
class_name = node_def["class_type"]
for node_id in order:
node_def = prompt[node_id]
class_name = node_def["class_type"]
if class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Unknown node type: '{class_name}'")
if class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Unknown node type: '{class_name}'")
cls = NODE_CLASS_MAPPINGS[class_name]
raw_inputs = node_def.get("inputs", {})
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
cls = NODE_CLASS_MAPPINGS[class_name]
raw_inputs = node_def.get("inputs", {})
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
if cache_entry is not None:
result = self._clone_cached_outputs(cache_entry["outputs"])
elapsed_ms = 0.0
else:
if on_node_start:
on_node_start(node_id)
# Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id)
instance = cls()
func = getattr(instance, cls.FUNCTION)
start_time = perf_counter()
with active_node(node_id):
result = func(**inputs)
elapsed_ms = (perf_counter() - start_time) * 1000.0
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
if cache_entry is not None:
result = self._clone_cached_outputs(cache_entry["outputs"])
elapsed_ms = 0.0
else:
if on_node_start:
on_node_start(node_id)
# Nodes must return a tuple; coerce single values just in case
if not isinstance(result, tuple):
result = (result,)
instance = cls()
func = getattr(instance, cls.FUNCTION)
start_time = perf_counter()
result = func(**inputs)
elapsed_ms = (perf_counter() - start_time) * 1000.0
node_outputs[node_id] = result
output_signatures = tuple(self._fingerprint_value(value) for value in result)
node_output_signatures[node_id] = output_signatures
# Nodes must return a tuple; coerce single values just in case
if not isinstance(result, tuple):
result = (result,)
if cache_entry is None and self._node_cacheable(cls):
self._store_cache_entry(
node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
node_outputs[node_id] = result
output_signatures = tuple(self._fingerprint_value(value) for value in result)
node_output_signatures[node_id] = output_signatures
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
if cache_entry is None and self._node_cacheable(cls):
self._store_cache_entry(
node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
if on_node_done:
on_node_done(node_id, elapsed_ms)
if on_node_done:
on_node_done(node_id, elapsed_ms)
return node_outputs
@@ -421,88 +425,6 @@ class ExecutionEngine:
return deepcopy(value)
return value
def _inject_display_callbacks(
self,
on_preview: Callable | None,
on_table: Callable | None,
on_mesh: Callable | None = None,
on_overlay: Callable | None = None,
on_value: Callable | None = None,
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.preview_image import PreviewImage
from backend.nodes.print_table import PrintTable
from backend.nodes.view_3d import View3D
from backend.nodes.annotations import Annotations
from backend.nodes.value_display import ValueDisplay
from backend.nodes.markup import Markup
from backend.nodes.cross_section import CrossSection
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
from backend.nodes.histogram import Histogram
from backend.nodes.crop_resize_field import CropResizeField
from backend.nodes.rotate_field import RotateField
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.mask_morphology import MaskMorphology
from backend.nodes.mask_invert import MaskInvert
from backend.nodes.mask_combine import MaskCombine
from backend.nodes.draw_mask import DrawMask
from backend.nodes.save import Save
from backend.nodes.save_image import SaveImage
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
MaskMorphology._broadcast_fn = on_preview
MaskInvert._broadcast_fn = on_preview
MaskCombine._broadcast_fn = on_preview
DrawMask._broadcast_overlay_fn = on_overlay
View3D._broadcast_mesh_fn = on_mesh
Annotations._broadcast_warning_fn = on_warning
PrintTable._broadcast_table_fn = on_table
ValueDisplay._broadcast_value_fn = on_value
Stats._broadcast_value_fn = on_value
Histogram._broadcast_overlay_fn = on_overlay
CrossSection._broadcast_overlay_fn = on_overlay
Cursors._broadcast_overlay_fn = on_overlay
CropResizeField._broadcast_overlay_fn = on_overlay
RotateField._broadcast_warning_fn = on_warning
Markup._broadcast_overlay_fn = on_overlay
Image._broadcast_warning_fn = on_warning
ImageDemo._broadcast_warning_fn = on_warning
Save._broadcast_warning_fn = on_warning
SaveImage._broadcast_warning_fn = on_warning
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.preview_image import PreviewImage
from backend.nodes.print_table import PrintTable
from backend.nodes.view_3d import View3D
from backend.nodes.annotations import Annotations
from backend.nodes.value_display import ValueDisplay
from backend.nodes.markup import Markup
from backend.nodes.cross_section import CrossSection
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
from backend.nodes.histogram import Histogram
from backend.nodes.crop_resize_field import CropResizeField
from backend.nodes.rotate_field import RotateField
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.mask_morphology import MaskMorphology
from backend.nodes.mask_invert import MaskInvert
from backend.nodes.mask_combine import MaskCombine
from backend.nodes.draw_mask import DrawMask
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
from backend.nodes.save import Save
from backend.nodes.save_image import SaveImage
if cls in (PreviewImage, PrintTable, View3D, Annotations, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
Image, ImageDemo, Save, SaveImage):
cls._current_node_id = node_id
def _auto_preview(
self,
cls: type,