rework web server so multiple clients can be server at a time
This commit is contained in:
@@ -32,6 +32,7 @@ from time import perf_counter
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.node_registry import NODE_CLASS_MAPPINGS
|
||||
from backend.execution_context import active_node, execution_callbacks
|
||||
|
||||
|
||||
def _is_link(value: Any) -> bool:
|
||||
@@ -85,63 +86,66 @@ class ExecutionEngine:
|
||||
node_outputs: dict[str, tuple] = {}
|
||||
node_output_signatures: dict[str, tuple[str, ...]] = {}
|
||||
|
||||
# Inject display callbacks before execution
|
||||
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
|
||||
with execution_callbacks(
|
||||
preview=on_preview,
|
||||
table=on_table,
|
||||
mesh=on_mesh,
|
||||
overlay=on_overlay,
|
||||
value=on_value,
|
||||
warning=on_warning,
|
||||
):
|
||||
for node_id in order:
|
||||
node_def = prompt[node_id]
|
||||
class_name = node_def["class_type"]
|
||||
|
||||
for node_id in order:
|
||||
node_def = prompt[node_id]
|
||||
class_name = node_def["class_type"]
|
||||
if class_name not in NODE_CLASS_MAPPINGS:
|
||||
raise ValueError(f"Unknown node type: '{class_name}'")
|
||||
|
||||
if class_name not in NODE_CLASS_MAPPINGS:
|
||||
raise ValueError(f"Unknown node type: '{class_name}'")
|
||||
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||
raw_inputs = node_def.get("inputs", {})
|
||||
input_types = cls.INPUT_TYPES()
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
|
||||
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
|
||||
|
||||
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||
raw_inputs = node_def.get("inputs", {})
|
||||
input_types = cls.INPUT_TYPES()
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
|
||||
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
|
||||
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
|
||||
if cache_entry is not None:
|
||||
result = self._clone_cached_outputs(cache_entry["outputs"])
|
||||
elapsed_ms = 0.0
|
||||
else:
|
||||
if on_node_start:
|
||||
on_node_start(node_id)
|
||||
|
||||
# Let display nodes know their node_id so they can tag WS messages
|
||||
self._set_node_id_on_display(cls, node_id)
|
||||
instance = cls()
|
||||
func = getattr(instance, cls.FUNCTION)
|
||||
start_time = perf_counter()
|
||||
with active_node(node_id):
|
||||
result = func(**inputs)
|
||||
elapsed_ms = (perf_counter() - start_time) * 1000.0
|
||||
|
||||
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
|
||||
if cache_entry is not None:
|
||||
result = self._clone_cached_outputs(cache_entry["outputs"])
|
||||
elapsed_ms = 0.0
|
||||
else:
|
||||
if on_node_start:
|
||||
on_node_start(node_id)
|
||||
# Nodes must return a tuple; coerce single values just in case
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
|
||||
instance = cls()
|
||||
func = getattr(instance, cls.FUNCTION)
|
||||
start_time = perf_counter()
|
||||
result = func(**inputs)
|
||||
elapsed_ms = (perf_counter() - start_time) * 1000.0
|
||||
node_outputs[node_id] = result
|
||||
output_signatures = tuple(self._fingerprint_value(value) for value in result)
|
||||
node_output_signatures[node_id] = output_signatures
|
||||
|
||||
# Nodes must return a tuple; coerce single values just in case
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
if cache_entry is None and self._node_cacheable(cls):
|
||||
self._store_cache_entry(
|
||||
node_id=node_id,
|
||||
class_name=class_name,
|
||||
input_signature=input_signature,
|
||||
output_signatures=output_signatures,
|
||||
outputs=self._clone_cached_outputs(result),
|
||||
)
|
||||
|
||||
node_outputs[node_id] = result
|
||||
output_signatures = tuple(self._fingerprint_value(value) for value in result)
|
||||
node_output_signatures[node_id] = output_signatures
|
||||
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
|
||||
# IMAGE, or table-like output so every node shows its result.
|
||||
if on_preview or on_table:
|
||||
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
|
||||
|
||||
if cache_entry is None and self._node_cacheable(cls):
|
||||
self._store_cache_entry(
|
||||
node_id=node_id,
|
||||
class_name=class_name,
|
||||
input_signature=input_signature,
|
||||
output_signatures=output_signatures,
|
||||
outputs=self._clone_cached_outputs(result),
|
||||
)
|
||||
|
||||
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
|
||||
# IMAGE, or table-like output so every node shows its result.
|
||||
if on_preview or on_table:
|
||||
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
|
||||
|
||||
if on_node_done:
|
||||
on_node_done(node_id, elapsed_ms)
|
||||
if on_node_done:
|
||||
on_node_done(node_id, elapsed_ms)
|
||||
|
||||
return node_outputs
|
||||
|
||||
@@ -421,88 +425,6 @@ class ExecutionEngine:
|
||||
return deepcopy(value)
|
||||
return value
|
||||
|
||||
def _inject_display_callbacks(
|
||||
self,
|
||||
on_preview: Callable | None,
|
||||
on_table: Callable | None,
|
||||
on_mesh: Callable | None = None,
|
||||
on_overlay: Callable | None = None,
|
||||
on_value: Callable | None = None,
|
||||
on_warning: Callable | None = None,
|
||||
) -> None:
|
||||
"""Wire up broadcast callbacks on display node classes."""
|
||||
from backend.nodes.preview_image import PreviewImage
|
||||
from backend.nodes.print_table import PrintTable
|
||||
from backend.nodes.view_3d import View3D
|
||||
from backend.nodes.annotations import Annotations
|
||||
from backend.nodes.value_display import ValueDisplay
|
||||
from backend.nodes.markup import Markup
|
||||
from backend.nodes.cross_section import CrossSection
|
||||
from backend.nodes.cursors import Cursors
|
||||
from backend.nodes.stats import Stats
|
||||
from backend.nodes.histogram import Histogram
|
||||
from backend.nodes.crop_resize_field import CropResizeField
|
||||
from backend.nodes.rotate_field import RotateField
|
||||
from backend.nodes.threshold_mask import ThresholdMask
|
||||
from backend.nodes.mask_morphology import MaskMorphology
|
||||
from backend.nodes.mask_invert import MaskInvert
|
||||
from backend.nodes.mask_combine import MaskCombine
|
||||
from backend.nodes.draw_mask import DrawMask
|
||||
from backend.nodes.save import Save
|
||||
from backend.nodes.save_image import SaveImage
|
||||
from backend.nodes.image import Image
|
||||
from backend.nodes.image_demo import ImageDemo
|
||||
|
||||
PreviewImage._broadcast_fn = on_preview
|
||||
ThresholdMask._broadcast_fn = on_preview
|
||||
MaskMorphology._broadcast_fn = on_preview
|
||||
MaskInvert._broadcast_fn = on_preview
|
||||
MaskCombine._broadcast_fn = on_preview
|
||||
DrawMask._broadcast_overlay_fn = on_overlay
|
||||
View3D._broadcast_mesh_fn = on_mesh
|
||||
Annotations._broadcast_warning_fn = on_warning
|
||||
PrintTable._broadcast_table_fn = on_table
|
||||
ValueDisplay._broadcast_value_fn = on_value
|
||||
Stats._broadcast_value_fn = on_value
|
||||
Histogram._broadcast_overlay_fn = on_overlay
|
||||
CrossSection._broadcast_overlay_fn = on_overlay
|
||||
Cursors._broadcast_overlay_fn = on_overlay
|
||||
CropResizeField._broadcast_overlay_fn = on_overlay
|
||||
RotateField._broadcast_warning_fn = on_warning
|
||||
Markup._broadcast_overlay_fn = on_overlay
|
||||
Image._broadcast_warning_fn = on_warning
|
||||
ImageDemo._broadcast_warning_fn = on_warning
|
||||
Save._broadcast_warning_fn = on_warning
|
||||
SaveImage._broadcast_warning_fn = on_warning
|
||||
|
||||
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||
"""Inform display nodes of their current node_id for WS tagging."""
|
||||
from backend.nodes.preview_image import PreviewImage
|
||||
from backend.nodes.print_table import PrintTable
|
||||
from backend.nodes.view_3d import View3D
|
||||
from backend.nodes.annotations import Annotations
|
||||
from backend.nodes.value_display import ValueDisplay
|
||||
from backend.nodes.markup import Markup
|
||||
from backend.nodes.cross_section import CrossSection
|
||||
from backend.nodes.cursors import Cursors
|
||||
from backend.nodes.stats import Stats
|
||||
from backend.nodes.histogram import Histogram
|
||||
from backend.nodes.crop_resize_field import CropResizeField
|
||||
from backend.nodes.rotate_field import RotateField
|
||||
from backend.nodes.threshold_mask import ThresholdMask
|
||||
from backend.nodes.mask_morphology import MaskMorphology
|
||||
from backend.nodes.mask_invert import MaskInvert
|
||||
from backend.nodes.mask_combine import MaskCombine
|
||||
from backend.nodes.draw_mask import DrawMask
|
||||
from backend.nodes.image import Image
|
||||
from backend.nodes.image_demo import ImageDemo
|
||||
from backend.nodes.save import Save
|
||||
from backend.nodes.save_image import SaveImage
|
||||
if cls in (PreviewImage, PrintTable, View3D, Annotations, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup,
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
|
||||
Image, ImageDemo, Save, SaveImage):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
def _auto_preview(
|
||||
self,
|
||||
cls: type,
|
||||
|
||||
Reference in New Issue
Block a user