add folder, file nodes and major usability improvements

This commit is contained in:
2026-03-25 22:18:25 -07:00
parent 61b68c142b
commit 7f3dfa8fdf
22 changed files with 3881 additions and 299 deletions

View File

@@ -23,6 +23,7 @@ The engine:
from __future__ import annotations
import uuid
from collections import defaultdict, deque
from math import isfinite
from time import perf_counter
from typing import Any, Callable
@@ -87,7 +88,8 @@ class ExecutionEngine:
cls = NODE_CLASS_MAPPINGS[class_name]
raw_inputs = node_def.get("inputs", {})
inputs = self._resolve_inputs(raw_inputs, node_outputs)
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
# Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id)
@@ -110,7 +112,7 @@ class ExecutionEngine:
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table)
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
if on_node_done:
on_node_done(node_id, elapsed_ms)
@@ -154,8 +156,14 @@ class ExecutionEngine:
self,
raw_inputs: dict[str, Any],
node_outputs: dict[str, tuple],
input_types: dict[str, dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Replace [src_id, slot] links with actual output values."""
specs = {}
if input_types:
specs.update(input_types.get("required", {}))
specs.update(input_types.get("optional", {}))
resolved = {}
for key, value in raw_inputs.items():
if _is_link(value):
@@ -170,11 +178,36 @@ class ExecutionEngine:
f"Node '{src_id}' only has {len(outputs)} outputs, "
f"but slot {slot} was requested."
)
resolved[key] = outputs[slot]
resolved_value = outputs[slot]
else:
resolved[key] = value
resolved_value = value
resolved[key] = self._coerce_input_value(resolved_value, specs.get(key))
return resolved
def _coerce_input_value(self, value: Any, spec: Any) -> Any:
if spec is None:
return value
input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec
if isinstance(input_type, list):
return value
if input_type == "INT":
numeric = float(value)
if not isfinite(numeric):
raise ValueError(f"Expected a finite numeric value for INT input, got {value!r}")
rounded = int(abs(numeric) + 0.5)
return rounded if numeric >= 0 else -rounded
if input_type == "FLOAT":
numeric = float(value)
if not isfinite(numeric):
raise ValueError(f"Expected a finite numeric value for FLOAT input, got {value!r}")
return numeric
return value
def _inject_display_callbacks(
self,
on_preview: Callable | None,
@@ -185,11 +218,11 @@ class ExecutionEngine:
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
from backend.nodes.modify import CropResizeField
from backend.nodes.modify import CropResizeField, RotateField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import SaveImage, LoadFile
from backend.nodes.io import SaveImage, LoadFile, LoadDemo
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
@@ -206,19 +239,22 @@ class ExecutionEngine:
CrossSection._broadcast_overlay_fn = on_overlay
LineCursors._broadcast_overlay_fn = on_overlay
CropResizeField._broadcast_overlay_fn = on_overlay
RotateField._broadcast_warning_fn = on_warning
Markup._broadcast_overlay_fn = on_overlay
LoadFile._broadcast_warning_fn = on_warning
LoadDemo._broadcast_warning_fn = on_warning
SaveImage._broadcast_warning_fn = on_warning
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
from backend.nodes.modify import CropResizeField
from backend.nodes.modify import CropResizeField, RotateField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import LoadFile, SaveImage
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField,
from backend.nodes.io import LoadFile, LoadDemo, SaveImage
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
LoadFile, SaveImage):
LoadFile, LoadDemo, SaveImage):
cls._current_node_id = node_id
def _auto_preview(
@@ -228,6 +264,7 @@ class ExecutionEngine:
result: tuple,
on_preview: Callable | None,
on_table: Callable | None,
inputs: dict[str, Any] | None = None,
) -> None:
"""
After every node executes, inspect its outputs and broadcast
@@ -236,12 +273,19 @@ class ExecutionEngine:
"""
import numpy as np
from backend.data_types import (
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
DataField, image_to_uint8, encode_preview, render_datafield_preview,
)
from backend.nodes.io import LoadFile, LoadDemo
if getattr(cls, "_CUSTOM_PREVIEW", False):
return
if cls in (LoadFile, LoadDemo) and on_preview:
preview = self._render_load_node_preview(result, inputs or {})
if preview:
on_preview(node_id, preview)
return
return_types = getattr(cls, "RETURN_TYPES", ())
for slot, type_name in enumerate(return_types):
@@ -250,7 +294,7 @@ class ExecutionEngine:
value = result[slot]
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
arr = datafield_to_uint8(value, value.colormap)
arr = render_datafield_preview(value, value.colormap)
on_preview(node_id, encode_preview(arr))
return # one preview per node is enough
@@ -269,6 +313,39 @@ class ExecutionEngine:
on_table(node_id, value)
return
def _render_load_node_preview(
self,
result: tuple,
inputs: dict[str, Any],
) -> dict | None:
from backend.data_types import DataField, encode_preview, render_datafield_preview
from backend.nodes.io import list_channels
fields = [value for value in result if isinstance(value, DataField)]
if not fields:
return None
selected_path = str(inputs.get("path") or inputs.get("filename") or inputs.get("name") or "").strip()
channel_names: list[str] = []
if selected_path:
try:
channel_names = [str(entry.get("name", "")).strip() or "field" for entry in list_channels(selected_path)]
except Exception:
channel_names = []
layers = []
for index, field in enumerate(fields):
arr = render_datafield_preview(field, field.colormap)
layers.append({
"name": channel_names[index] if index < len(channel_names) else f"layer {index + 1}",
"image": encode_preview(arr),
})
return {
"kind": "layer_gallery",
"layers": layers,
}
def _render_line_preview(
self,