multichannel support + colormap inherit

This commit is contained in:
2026-03-24 21:01:58 -07:00
parent 53e2fc7746
commit a60b0c15ca
12 changed files with 889 additions and 220 deletions

View File

@@ -50,6 +50,7 @@ class ExecutionEngine:
on_table: Callable[[str, list], None] | None = None,
on_mesh: Callable[[str, dict], None] | None = None,
on_overlay: Callable[[str, str], None] | None = None,
on_warning: Callable[[str, str], None] | None = None,
) -> dict[str, tuple]:
"""
Execute the workflow described by `prompt`.
@@ -62,6 +63,7 @@ class ExecutionEngine:
on_preview : called with (node_id, data_uri) when a display node runs
on_table : called with (node_id, table_list) when PrintTable runs
on_overlay : called with (node_id, data_uri) for interactive overlays
on_warning : called with (node_id, message) for node warnings
Returns
-------
@@ -71,7 +73,7 @@ class ExecutionEngine:
node_outputs: dict[str, tuple] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay)
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_warning)
for node_id in order:
node_def = prompt[node_id]
@@ -174,12 +176,13 @@ class ExecutionEngine:
on_table: Callable | None,
on_mesh: Callable | None = None,
on_overlay: Callable | None = None,
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import SaveImage
from backend.nodes.io import SaveImage, LoadFile
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
@@ -190,6 +193,7 @@ class ExecutionEngine:
PrintTable._broadcast_table_fn = on_table
CrossSection._broadcast_overlay_fn = on_overlay
LineCursors._broadcast_overlay_fn = on_overlay
LoadFile._broadcast_warning_fn = on_warning
SaveImage._broadcast_preview = (
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
)
@@ -199,8 +203,9 @@ class ExecutionEngine:
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import LoadFile
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine):
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile):
cls._current_node_id = node_id
def _auto_preview(
@@ -232,7 +237,7 @@ class ExecutionEngine:
value = result[slot]
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
arr = datafield_to_uint8(value, "viridis")
arr = datafield_to_uint8(value, value.colormap)
on_preview(node_id, encode_preview(arr))
return # one preview per node is enough