add folder, file nodes and major usability improvements
This commit is contained in:
@@ -23,6 +23,7 @@ The engine:
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from math import isfinite
|
||||
from time import perf_counter
|
||||
from typing import Any, Callable
|
||||
|
||||
@@ -87,7 +88,8 @@ class ExecutionEngine:
|
||||
|
||||
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||
raw_inputs = node_def.get("inputs", {})
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs)
|
||||
input_types = cls.INPUT_TYPES()
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
|
||||
|
||||
# Let display nodes know their node_id so they can tag WS messages
|
||||
self._set_node_id_on_display(cls, node_id)
|
||||
@@ -110,7 +112,7 @@ class ExecutionEngine:
|
||||
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
|
||||
# IMAGE, or table-like output so every node shows its result.
|
||||
if on_preview or on_table:
|
||||
self._auto_preview(cls, node_id, result, on_preview, on_table)
|
||||
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
|
||||
|
||||
if on_node_done:
|
||||
on_node_done(node_id, elapsed_ms)
|
||||
@@ -154,8 +156,14 @@ class ExecutionEngine:
|
||||
self,
|
||||
raw_inputs: dict[str, Any],
|
||||
node_outputs: dict[str, tuple],
|
||||
input_types: dict[str, dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Replace [src_id, slot] links with actual output values."""
|
||||
specs = {}
|
||||
if input_types:
|
||||
specs.update(input_types.get("required", {}))
|
||||
specs.update(input_types.get("optional", {}))
|
||||
|
||||
resolved = {}
|
||||
for key, value in raw_inputs.items():
|
||||
if _is_link(value):
|
||||
@@ -170,11 +178,36 @@ class ExecutionEngine:
|
||||
f"Node '{src_id}' only has {len(outputs)} outputs, "
|
||||
f"but slot {slot} was requested."
|
||||
)
|
||||
resolved[key] = outputs[slot]
|
||||
resolved_value = outputs[slot]
|
||||
else:
|
||||
resolved[key] = value
|
||||
resolved_value = value
|
||||
|
||||
resolved[key] = self._coerce_input_value(resolved_value, specs.get(key))
|
||||
return resolved
|
||||
|
||||
def _coerce_input_value(self, value: Any, spec: Any) -> Any:
|
||||
if spec is None:
|
||||
return value
|
||||
|
||||
input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec
|
||||
if isinstance(input_type, list):
|
||||
return value
|
||||
|
||||
if input_type == "INT":
|
||||
numeric = float(value)
|
||||
if not isfinite(numeric):
|
||||
raise ValueError(f"Expected a finite numeric value for INT input, got {value!r}")
|
||||
rounded = int(abs(numeric) + 0.5)
|
||||
return rounded if numeric >= 0 else -rounded
|
||||
|
||||
if input_type == "FLOAT":
|
||||
numeric = float(value)
|
||||
if not isfinite(numeric):
|
||||
raise ValueError(f"Expected a finite numeric value for FLOAT input, got {value!r}")
|
||||
return numeric
|
||||
|
||||
return value
|
||||
|
||||
def _inject_display_callbacks(
|
||||
self,
|
||||
on_preview: Callable | None,
|
||||
@@ -185,11 +218,11 @@ class ExecutionEngine:
|
||||
on_warning: Callable | None = None,
|
||||
) -> None:
|
||||
"""Wire up broadcast callbacks on display node classes."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
|
||||
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
|
||||
from backend.nodes.modify import CropResizeField
|
||||
from backend.nodes.modify import CropResizeField, RotateField
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
|
||||
from backend.nodes.io import SaveImage, LoadFile
|
||||
from backend.nodes.io import SaveImage, LoadFile, LoadDemo
|
||||
|
||||
PreviewImage._broadcast_fn = on_preview
|
||||
ThresholdMask._broadcast_fn = on_preview
|
||||
@@ -206,19 +239,22 @@ class ExecutionEngine:
|
||||
CrossSection._broadcast_overlay_fn = on_overlay
|
||||
LineCursors._broadcast_overlay_fn = on_overlay
|
||||
CropResizeField._broadcast_overlay_fn = on_overlay
|
||||
RotateField._broadcast_warning_fn = on_warning
|
||||
Markup._broadcast_overlay_fn = on_overlay
|
||||
LoadFile._broadcast_warning_fn = on_warning
|
||||
LoadDemo._broadcast_warning_fn = on_warning
|
||||
SaveImage._broadcast_warning_fn = on_warning
|
||||
|
||||
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||
"""Inform display nodes of their current node_id for WS tagging."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
|
||||
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
|
||||
from backend.nodes.modify import CropResizeField
|
||||
from backend.nodes.modify import CropResizeField, RotateField
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
|
||||
from backend.nodes.io import LoadFile, SaveImage
|
||||
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField,
|
||||
from backend.nodes.io import LoadFile, LoadDemo, SaveImage
|
||||
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup,
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
|
||||
LoadFile, SaveImage):
|
||||
LoadFile, LoadDemo, SaveImage):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
def _auto_preview(
|
||||
@@ -228,6 +264,7 @@ class ExecutionEngine:
|
||||
result: tuple,
|
||||
on_preview: Callable | None,
|
||||
on_table: Callable | None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
After every node executes, inspect its outputs and broadcast
|
||||
@@ -236,12 +273,19 @@ class ExecutionEngine:
|
||||
"""
|
||||
import numpy as np
|
||||
from backend.data_types import (
|
||||
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
|
||||
DataField, image_to_uint8, encode_preview, render_datafield_preview,
|
||||
)
|
||||
from backend.nodes.io import LoadFile, LoadDemo
|
||||
|
||||
if getattr(cls, "_CUSTOM_PREVIEW", False):
|
||||
return
|
||||
|
||||
if cls in (LoadFile, LoadDemo) and on_preview:
|
||||
preview = self._render_load_node_preview(result, inputs or {})
|
||||
if preview:
|
||||
on_preview(node_id, preview)
|
||||
return
|
||||
|
||||
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||
|
||||
for slot, type_name in enumerate(return_types):
|
||||
@@ -250,7 +294,7 @@ class ExecutionEngine:
|
||||
value = result[slot]
|
||||
|
||||
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
|
||||
arr = datafield_to_uint8(value, value.colormap)
|
||||
arr = render_datafield_preview(value, value.colormap)
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return # one preview per node is enough
|
||||
|
||||
@@ -269,6 +313,39 @@ class ExecutionEngine:
|
||||
on_table(node_id, value)
|
||||
return
|
||||
|
||||
def _render_load_node_preview(
|
||||
self,
|
||||
result: tuple,
|
||||
inputs: dict[str, Any],
|
||||
) -> dict | None:
|
||||
from backend.data_types import DataField, encode_preview, render_datafield_preview
|
||||
from backend.nodes.io import list_channels
|
||||
|
||||
fields = [value for value in result if isinstance(value, DataField)]
|
||||
if not fields:
|
||||
return None
|
||||
|
||||
selected_path = str(inputs.get("path") or inputs.get("filename") or inputs.get("name") or "").strip()
|
||||
channel_names: list[str] = []
|
||||
if selected_path:
|
||||
try:
|
||||
channel_names = [str(entry.get("name", "")).strip() or "field" for entry in list_channels(selected_path)]
|
||||
except Exception:
|
||||
channel_names = []
|
||||
|
||||
layers = []
|
||||
for index, field in enumerate(fields):
|
||||
arr = render_datafield_preview(field, field.colormap)
|
||||
layers.append({
|
||||
"name": channel_names[index] if index < len(channel_names) else f"layer {index + 1}",
|
||||
"image": encode_preview(arr),
|
||||
})
|
||||
|
||||
return {
|
||||
"kind": "layer_gallery",
|
||||
"layers": layers,
|
||||
}
|
||||
|
||||
|
||||
def _render_line_preview(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user