add folder, file nodes and major usability improvements

This commit is contained in:
2026-03-25 22:18:25 -07:00
parent 61b68c142b
commit 7f3dfa8fdf
22 changed files with 3881 additions and 299 deletions

View File

@@ -12,13 +12,23 @@ DataField mirrors Gwyddion's GwyDataField structure:
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any
import numpy as np
COLORMAPS = ("viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain",
"cividis", "magma", "copper", "afmhot")
DEFAULT_CUSTOM_COLORMAP_STOPS = (
{"position": 0.0, "color": "#440154"},
{"position": 1.0, "color": "#fde725"},
)
SYSTEM_DEFAULT_FONT = "System Default"
CUSTOM_FILE_FONT = "Custom File"
PREVIEW_MARKUP_REFERENCE_DIM = 512
class RecordTable(list):
@@ -28,6 +38,147 @@ class RecordTable(list):
class MeasureTable(list):
"""Named scalar measurements, typically rows of quantity/value/unit."""
def _normalize_hex_color(color: Any, default: str = "#000000") -> str:
if isinstance(color, str):
text = color.strip()
if len(text) == 4 and text.startswith("#"):
text = "#" + "".join(ch * 2 for ch in text[1:])
if len(text) == 7 and text.startswith("#"):
try:
int(text[1:], 16)
return text.lower()
except ValueError:
pass
return default
def _hex_to_rgb01(color: str) -> tuple[float, float, float]:
normalized = _normalize_hex_color(color)
return (
int(normalized[1:3], 16) / 255.0,
int(normalized[3:5], 16) / 255.0,
int(normalized[5:7], 16) / 255.0,
)
def _normalize_custom_colormap_stops(raw_stops: Any) -> list[dict[str, float | str]] | None:
if not isinstance(raw_stops, list):
return None
parsed = []
for stop in raw_stops:
if not isinstance(stop, dict):
continue
try:
position = float(stop.get("position"))
except (TypeError, ValueError):
continue
if not np.isfinite(position):
continue
parsed.append({
"position": float(np.clip(position, 0.0, 1.0)),
"color": _normalize_hex_color(stop.get("color"), "#000000"),
})
if len(parsed) < 2:
return None
parsed.sort(key=lambda stop: stop["position"])
unique: list[dict[str, float | str]] = []
for stop in parsed:
if unique and abs(float(stop["position"]) - float(unique[-1]["position"])) < 1e-9:
unique[-1] = stop
else:
unique.append(stop)
if len(unique) < 2:
return None
unique[0] = {"position": 0.0, "color": unique[0]["color"]}
unique[-1] = {"position": 1.0, "color": unique[-1]["color"]}
return unique
def normalize_colormap_spec(colormap: Any, fallback: Any = "viridis") -> str | dict[str, Any]:
if isinstance(colormap, str):
if colormap in COLORMAPS:
return colormap
elif isinstance(colormap, dict):
mode = str(colormap.get("mode", "")).strip().lower()
preset = colormap.get("preset")
if isinstance(preset, str) and preset in COLORMAPS:
if mode == "preset" or "stops" not in colormap:
return {"mode": "preset", "preset": preset}
stops = _normalize_custom_colormap_stops(colormap.get("stops"))
if stops is not None:
return {"mode": "custom", "stops": stops}
if fallback is colormap:
return "viridis"
return normalize_colormap_spec(fallback, fallback="viridis")
def resolve_colormap_input(
selection: Any = "auto",
*,
colormap_input: Any = None,
inherited: Any = None,
default: Any = "gray",
) -> str | dict[str, Any]:
if colormap_input is not None:
return normalize_colormap_spec(colormap_input, fallback=inherited if inherited is not None else default)
if isinstance(selection, str) and selection != "auto":
return normalize_colormap_spec(selection, fallback=inherited if inherited is not None else default)
if inherited is not None:
return normalize_colormap_spec(inherited, fallback=default)
return normalize_colormap_spec(default, fallback="gray")
def normalize_font_spec(font: Any) -> dict[str, str] | None:
if isinstance(font, str):
family = font.strip()
return {"family": family, "path": ""} if family else None
if isinstance(font, dict):
family = str(font.get("family", "")).strip()
path = str(font.get("path", "")).strip()
if family or path:
return {"family": family, "path": path}
return None
def colormap_to_uint8(normalized: np.ndarray, colormap: Any = "gray") -> np.ndarray:
normalized = np.asarray(normalized, dtype=np.float64)
spec = normalize_colormap_spec(colormap, fallback="gray")
if spec == "gray":
grey = np.rint(normalized * 255.0).astype(np.uint8)
rgb = np.empty(grey.shape + (3,), dtype=np.uint8)
rgb[..., 0] = grey
rgb[..., 1] = grey
rgb[..., 2] = grey
return rgb
if isinstance(spec, dict) and spec.get("mode") == "custom":
stops = spec["stops"]
positions = np.array([float(stop["position"]) for stop in stops], dtype=np.float64)
colors = np.array([_hex_to_rgb01(str(stop["color"])) for stop in stops], dtype=np.float64)
flat = normalized.reshape(-1)
rgb = np.empty((flat.shape[0], 3), dtype=np.uint8)
for channel in range(3):
rgb[:, channel] = np.rint(np.interp(flat, positions, colors[:, channel]) * 255.0).astype(np.uint8)
return rgb.reshape(normalized.shape + (3,))
cmap_name = spec["preset"] if isinstance(spec, dict) else spec
cmap = _get_colormap(cmap_name)
rgba = cmap(normalized)
return (rgba[:, :, :3] * 255).astype(np.uint8)
@dataclass
class DataField:
data: np.ndarray # shape (yres, xres), dtype float64
@@ -40,15 +191,17 @@ class DataField:
si_unit_xy: str = "m"
si_unit_z: str = "m"
domain: str = "spatial" # "spatial" or "frequency"
colormap: str = "viridis"
colormap: str | dict[str, Any] = "viridis"
display_offset: float = 0.0
display_scale: float = 1.0
overlays: list[dict[str, Any]] = field(default_factory=list)
def __post_init__(self) -> None:
self.data = np.asarray(self.data, dtype=np.float64)
if self.data.ndim != 2:
raise ValueError(f"DataField.data must be 2-D, got shape {self.data.shape}")
self.yres, self.xres = self.data.shape
self.overlays = deepcopy(self.overlays) if isinstance(self.overlays, list) else []
def copy(self) -> "DataField":
"""Return a deep copy with independent data array."""
@@ -66,6 +219,7 @@ class DataField:
colormap=self.colormap,
display_offset=self.display_offset,
display_scale=self.display_scale,
overlays=deepcopy(self.overlays),
)
def replace(self, **kwargs) -> "DataField":
@@ -84,6 +238,7 @@ class DataField:
"colormap": self.colormap,
"display_offset": self.display_offset,
"display_scale": self.display_scale,
"overlays": deepcopy(self.overlays),
}
base.update(kwargs)
return DataField(**base)
@@ -137,7 +292,7 @@ def normalize_for_colormap(
return np.clip((base_norm - offset) / scale, 0.0, 1.0)
def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
def datafield_to_uint8(df: DataField, colormap: Any = "gray") -> np.ndarray:
"""
Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap.
Returns shape (H, W, 3) uint8.
@@ -147,19 +302,447 @@ def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
offset=df.display_offset,
scale=df.display_scale,
)
return colormap_to_uint8(normalized, colormap)
if colormap == "gray":
grey = np.rint(normalized * 255.0).astype(np.uint8)
rgb = np.empty(grey.shape + (3,), dtype=np.uint8)
rgb[..., 0] = grey
rgb[..., 1] = grey
rgb[..., 2] = grey
return rgb
cmap = _get_colormap(colormap)
rgba = cmap(normalized) # (H, W, 4) float [0,1]
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
return rgb
_SI_PREFIXES = [
(1e24, "Y"),
(1e21, "Z"),
(1e18, "E"),
(1e15, "P"),
(1e12, "T"),
(1e9, "G"),
(1e6, "M"),
(1e3, "k"),
(1.0, ""),
(1e-3, "m"),
(1e-6, "u"),
(1e-9, "n"),
(1e-12, "p"),
(1e-15, "f"),
(1e-18, "a"),
(1e-21, "z"),
(1e-24, "y"),
]
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"}
def _format_numeric(value: float) -> str:
if not np.isfinite(value):
return str(value)
abs_value = abs(value)
if abs_value == 0:
return "0"
if abs_value >= 1e4 or abs_value < 1e-3:
return f"{value:.3e}"
return f"{value:.4g}"
def _format_with_unit(value: float, unit: str) -> str:
unit = (unit or "").strip()
if not unit:
return _format_numeric(value)
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
abs_value = abs(value)
for scale, prefix in _SI_PREFIXES:
scaled = abs_value / scale
if 1 <= scaled < 1000:
signed = value / scale
return f"{_format_numeric(signed)} {prefix}{unit}"
return f"{_format_numeric(value)} {unit}"
def _nice_length(target: float) -> float:
if not np.isfinite(target) or target <= 0:
return 0.0
exponent = np.floor(np.log10(target))
base = 10.0 ** exponent
for step in (5.0, 2.0, 1.0):
candidate = step * base
if candidate <= target:
return candidate
return base
def _display_value_range(field: DataField) -> tuple[float, float, float]:
data = np.asarray(field.data, dtype=np.float64)
dmin = float(data.min())
dmax = float(data.max())
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
return dmin, dmin, dmin
offset = float(field.display_offset)
scale = float(field.display_scale)
if not np.isfinite(offset):
offset = 0.0
if not np.isfinite(scale) or scale <= 0.0:
scale = 1.0
low_norm = float(np.clip(offset, 0.0, 1.0))
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
if high_norm < low_norm:
low_norm, high_norm = high_norm, low_norm
mid_norm = 0.5 * (low_norm + high_norm)
span = dmax - dmin
return (
dmin + low_norm * span,
dmin + mid_norm * span,
dmin + high_norm * span,
)
def _normalize_font_match_key(text: str) -> str:
return "".join(ch for ch in text.lower() if ch.isalnum())
def _render_overlay_text(
text: str,
size_px: int,
color: tuple[int, int, int],
font_spec: Any = None,
):
from PIL import Image, ImageDraw, ImageFont
size_px = max(8, int(round(size_px)))
font = _load_overlay_font(size_px, ImageFont, font_spec=font_spec)
if font is not None:
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
text_draw = ImageDraw.Draw(text_image)
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
return text_image
font = ImageFont.load_default()
probe = Image.new("L", (1, 1), 0)
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
mask = Image.new("L", (width, height), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
scale = max(1.0, size_px / max(1, height))
scaled_width = max(1, int(round(width * scale)))
scaled_height = max(1, int(round(height * scale)))
resampling = getattr(Image, "Resampling", Image)
# Preserve edge sharpness if we ever have to scale the bitmap fallback font.
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.NEAREST)
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
text_image.putalpha(scaled_mask)
return text_image
@lru_cache(maxsize=1)
def _overlay_font_candidates() -> tuple[str, ...]:
candidates: list[str] = []
try:
import PIL
pil_dir = Path(PIL.__file__).resolve().parent
candidates.extend([
str(pil_dir / "fonts" / "DejaVuSans.ttf"),
str(pil_dir / "Tests" / "fonts" / "DejaVuSans.ttf"),
])
except Exception:
pass
candidates.extend([
"/System/Library/Fonts/Supplemental/Arial.ttf",
"/System/Library/Fonts/Supplemental/Helvetica.ttc",
"/System/Library/Fonts/Supplemental/Times New Roman.ttf",
"/Library/Fonts/Arial.ttf",
"/Library/Fonts/Helvetica.ttc",
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"/usr/share/fonts/truetype/liberation2/LiberationSans-Regular.ttf",
])
unique: list[str] = []
for candidate in candidates:
if candidate not in unique and Path(candidate).exists():
unique.append(candidate)
return tuple(unique)
def list_overlay_font_choices() -> tuple[str, ...]:
labels: list[str] = []
for candidate in _overlay_font_candidates():
label = Path(candidate).stem
if label and label not in labels:
labels.append(label)
return tuple(labels)
def _matching_overlay_font_candidates(family: str) -> list[str]:
key = _normalize_font_match_key(family)
if not key:
return []
matches: list[str] = []
for candidate in _overlay_font_candidates():
stem_key = _normalize_font_match_key(Path(candidate).stem)
if key == stem_key or key in stem_key or stem_key in key:
matches.append(candidate)
return matches
def _load_overlay_font(size_px: int, image_font_module, font_spec: Any = None) -> Any:
normalized = normalize_font_spec(font_spec)
if normalized is not None:
requested_path = normalized.get("path", "")
requested_family = normalized.get("family", "")
if requested_path:
try:
return image_font_module.truetype(requested_path, size_px)
except Exception:
pass
if requested_family:
try:
return image_font_module.truetype(requested_family, size_px)
except Exception:
pass
for candidate in _matching_overlay_font_candidates(requested_family):
try:
return image_font_module.truetype(candidate, size_px)
except Exception:
continue
for candidate in _overlay_font_candidates():
try:
return image_font_module.truetype(candidate, size_px)
except Exception:
continue
try:
return image_font_module.truetype("DejaVuSans.ttf", size_px)
except Exception:
return None
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
if isinstance(color, str):
text = color.strip()
if len(text) == 4 and text.startswith("#"):
text = "#" + "".join(ch * 2 for ch in text[1:])
if len(text) == 7 and text.startswith("#"):
try:
int(text[1:], 16)
return text.lower()
except ValueError:
pass
return default
def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int):
dx = end[0] - start[0]
dy = end[1] - start[1]
length = float(np.hypot(dx, dy))
if length <= 1e-6:
radius = max(1.0, width / 2.0)
draw.ellipse(
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
fill=color,
)
return
ux = dx / length
uy = dy / length
head_length = max(10.0, width * 4.0)
head_width = max(8.0, width * 3.0)
shaft_end = (
end[0] - ux * head_length,
end[1] - uy * head_length,
)
draw.line((start, shaft_end), fill=color, width=width)
px = -uy
py = ux
left = (shaft_end[0] + px * head_width / 2.0, shaft_end[1] + py * head_width / 2.0)
right = (shaft_end[0] - px * head_width / 2.0, shaft_end[1] - py * head_width / 2.0)
draw.polygon([end, left, right], fill=color)
def _preview_markup_stroke_width(width: int, image_width: int, image_height: int) -> int:
width = max(1, int(width))
longest_dim = max(1, int(image_width), int(image_height))
scale = max(1.0, longest_dim / float(PREVIEW_MARKUP_REFERENCE_DIM))
return max(1, int(round(width * scale)))
def _sanitize_markup_shapes(shapes: Any) -> list[dict[str, Any]]:
if not isinstance(shapes, list):
return []
parsed = []
for shape in shapes:
if not isinstance(shape, dict):
continue
kind = str(shape.get("kind", "")).strip().lower()
if kind not in {"line", "rectangle", "circle", "arrow"}:
continue
try:
x1 = float(shape.get("x1"))
y1 = float(shape.get("y1"))
x2 = float(shape.get("x2"))
y2 = float(shape.get("y2"))
width = int(round(float(shape.get("width", 3))))
except (TypeError, ValueError):
continue
if not all(np.isfinite(value) for value in (x1, y1, x2, y2)):
continue
parsed.append({
"kind": kind,
"x1": float(np.clip(x1, 0.0, 1.0)),
"y1": float(np.clip(y1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y2": float(np.clip(y2, 0.0, 1.0)),
"width": max(1, min(128, width)),
"color": _normalize_markup_color(shape.get("color")),
})
return parsed
def _apply_annotation_overlay(
image: np.ndarray,
field: DataField,
colormap: Any,
spec: dict[str, Any],
) -> np.ndarray:
from PIL import Image, ImageDraw
show_scale_bar = bool(spec.get("show_scale_bar", True))
show_color_map = bool(spec.get("show_color_map", True))
text_size = float(spec.get("text_size", 14.0))
text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0
font_spec = normalize_font_spec(spec.get("font"))
current = np.asarray(image, dtype=np.uint8)
if current.ndim == 2:
current = np.repeat(current[:, :, np.newaxis], 3, axis=2)
height, current_width = current.shape[:2]
field_width = max(1, int(field.xres))
legend_width = max(72, int(round(field_width * 0.18))) if show_color_map else 0
canvas_width = current_width + legend_width
canvas = np.full((height, canvas_width, 3), 255, dtype=np.uint8)
canvas[:, :current_width] = current
pil_image = Image.fromarray(canvas)
draw = ImageDraw.Draw(pil_image)
base_font_px = max(6, int(round(text_size)))
if show_scale_bar and field_width > 0 and np.isfinite(field.xreal) and field.xreal > 0:
target_real = field.xreal / 5.0
bar_real = _nice_length(target_real)
if bar_real > 0 and np.isfinite(field.dx) and field.dx > 0:
bar_px = max(1, int(round(bar_real / field.dx)))
margin_x = max(8, field_width // 24)
margin_y = max(8, height // 24)
bar_height = max(3, int(round(height * 0.012)))
bar_px = min(bar_px, max(1, field_width - 2 * margin_x))
x0 = margin_x
x1 = x0 + bar_px
y1 = height - margin_y
y0 = y1 - bar_height
text = _format_with_unit(bar_real, field.si_unit_xy)
text_image = _render_overlay_text(text, base_font_px, (255, 255, 255), font_spec=font_spec)
text_w, text_h = text_image.size
label_pad = 2
bg_left = max(0, x0 - 4)
bg_top = max(0, y0 - text_h - label_pad * 3)
bg_right = min(canvas_width, max(x1 + 4, x0 + text_w + 8))
bg_bottom = min(height, y1 + 4)
draw.rectangle((bg_left, bg_top, bg_right, bg_bottom), fill=(0, 0, 0))
draw.rectangle((x0, y0, x1, y1), fill=(255, 255, 255))
pil_image.paste(text_image, (x0, bg_top + label_pad), text_image)
if show_color_map and legend_width > 0:
panel_x0 = current_width
draw.rectangle((panel_x0, 0, canvas_width, height), fill=(245, 245, 245))
grad_x0 = panel_x0 + max(8, legend_width // 7)
grad_w = max(12, legend_width // 5)
grad_y0 = max(10, height // 18)
grad_y1 = max(grad_y0 + 10, height - grad_y0)
grad_h = grad_y1 - grad_y0
gradient = np.linspace(1.0, 0.0, grad_h, dtype=np.float64)[:, np.newaxis]
gradient = np.repeat(gradient, grad_w, axis=1)
gradient_rgb = colormap_to_uint8(gradient, colormap)
pil_image.paste(Image.fromarray(gradient_rgb), (grad_x0, grad_y0))
draw.rectangle((grad_x0, grad_y0, grad_x0 + grad_w, grad_y1), outline=(40, 40, 40), width=1)
legend_min, legend_mid, legend_max = _display_value_range(field)
labels = [
(legend_max, grad_y0),
(legend_mid, grad_y0 + grad_h // 2),
(legend_min, grad_y1),
]
text_x = grad_x0 + grad_w + 8
for value, y_center in labels:
text_image = _render_overlay_text(
_format_with_unit(value, field.si_unit_z),
base_font_px,
(20, 20, 20),
font_spec=font_spec,
)
text_y = int(round(y_center - text_image.size[1] / 2))
text_y = max(0, min(height - text_image.size[1], text_y))
pil_image.paste(text_image, (text_x, text_y), text_image)
return np.asarray(pil_image, dtype=np.uint8)
def _apply_markup_overlay(image: np.ndarray, field: DataField, spec: dict[str, Any]) -> np.ndarray:
from PIL import Image, ImageDraw
current = np.asarray(image, dtype=np.uint8)
if current.ndim == 2:
current = np.repeat(current[:, :, np.newaxis], 3, axis=2)
pil_image = Image.fromarray(current.copy())
draw = ImageDraw.Draw(pil_image)
field_width = max(1, int(field.xres))
field_height = max(1, int(field.yres))
for shape in _sanitize_markup_shapes(spec.get("shapes")):
x1 = float(shape["x1"]) * field_width
y1 = float(shape["y1"]) * field_height
x2 = float(shape["x2"]) * field_width
y2 = float(shape["y2"]) * field_height
color = str(shape["color"])
stroke_width = _preview_markup_stroke_width(int(shape["width"]), field_width, field_height)
kind = str(shape["kind"])
if kind == "line":
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
elif kind == "rectangle":
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "circle":
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "arrow":
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
return np.asarray(pil_image, dtype=np.uint8)
def render_datafield_preview(df: DataField, colormap: Any = "gray") -> np.ndarray:
current = datafield_to_uint8(df, colormap)
for overlay in df.overlays:
if not isinstance(overlay, dict):
continue
kind = str(overlay.get("kind", "")).strip().lower()
if kind == "annotation":
current = _apply_annotation_overlay(current, df, colormap, overlay)
elif kind == "markup":
current = _apply_markup_overlay(current, df, overlay)
return current
def image_to_uint8(image: np.ndarray) -> np.ndarray:
@@ -195,5 +778,5 @@ def encode_preview(arr: np.ndarray) -> str:
@lru_cache(maxsize=len(COLORMAPS))
def _get_colormap(colormap: str):
import matplotlib.cm as cm
return cm.get_cmap(colormap)
from matplotlib import colormaps
return colormaps[colormap]

View File

@@ -23,6 +23,7 @@ The engine:
from __future__ import annotations
import uuid
from collections import defaultdict, deque
from math import isfinite
from time import perf_counter
from typing import Any, Callable
@@ -87,7 +88,8 @@ class ExecutionEngine:
cls = NODE_CLASS_MAPPINGS[class_name]
raw_inputs = node_def.get("inputs", {})
inputs = self._resolve_inputs(raw_inputs, node_outputs)
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
# Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id)
@@ -110,7 +112,7 @@ class ExecutionEngine:
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table)
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
if on_node_done:
on_node_done(node_id, elapsed_ms)
@@ -154,8 +156,14 @@ class ExecutionEngine:
self,
raw_inputs: dict[str, Any],
node_outputs: dict[str, tuple],
input_types: dict[str, dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Replace [src_id, slot] links with actual output values."""
specs = {}
if input_types:
specs.update(input_types.get("required", {}))
specs.update(input_types.get("optional", {}))
resolved = {}
for key, value in raw_inputs.items():
if _is_link(value):
@@ -170,11 +178,36 @@ class ExecutionEngine:
f"Node '{src_id}' only has {len(outputs)} outputs, "
f"but slot {slot} was requested."
)
resolved[key] = outputs[slot]
resolved_value = outputs[slot]
else:
resolved[key] = value
resolved_value = value
resolved[key] = self._coerce_input_value(resolved_value, specs.get(key))
return resolved
def _coerce_input_value(self, value: Any, spec: Any) -> Any:
if spec is None:
return value
input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec
if isinstance(input_type, list):
return value
if input_type == "INT":
numeric = float(value)
if not isfinite(numeric):
raise ValueError(f"Expected a finite numeric value for INT input, got {value!r}")
rounded = int(abs(numeric) + 0.5)
return rounded if numeric >= 0 else -rounded
if input_type == "FLOAT":
numeric = float(value)
if not isfinite(numeric):
raise ValueError(f"Expected a finite numeric value for FLOAT input, got {value!r}")
return numeric
return value
def _inject_display_callbacks(
self,
on_preview: Callable | None,
@@ -185,11 +218,11 @@ class ExecutionEngine:
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
from backend.nodes.modify import CropResizeField
from backend.nodes.modify import CropResizeField, RotateField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import SaveImage, LoadFile
from backend.nodes.io import SaveImage, LoadFile, LoadDemo
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
@@ -206,19 +239,22 @@ class ExecutionEngine:
CrossSection._broadcast_overlay_fn = on_overlay
LineCursors._broadcast_overlay_fn = on_overlay
CropResizeField._broadcast_overlay_fn = on_overlay
RotateField._broadcast_warning_fn = on_warning
Markup._broadcast_overlay_fn = on_overlay
LoadFile._broadcast_warning_fn = on_warning
LoadDemo._broadcast_warning_fn = on_warning
SaveImage._broadcast_warning_fn = on_warning
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
from backend.nodes.modify import CropResizeField
from backend.nodes.modify import CropResizeField, RotateField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import LoadFile, SaveImage
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField,
from backend.nodes.io import LoadFile, LoadDemo, SaveImage
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
LoadFile, SaveImage):
LoadFile, LoadDemo, SaveImage):
cls._current_node_id = node_id
def _auto_preview(
@@ -228,6 +264,7 @@ class ExecutionEngine:
result: tuple,
on_preview: Callable | None,
on_table: Callable | None,
inputs: dict[str, Any] | None = None,
) -> None:
"""
After every node executes, inspect its outputs and broadcast
@@ -236,12 +273,19 @@ class ExecutionEngine:
"""
import numpy as np
from backend.data_types import (
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
DataField, image_to_uint8, encode_preview, render_datafield_preview,
)
from backend.nodes.io import LoadFile, LoadDemo
if getattr(cls, "_CUSTOM_PREVIEW", False):
return
if cls in (LoadFile, LoadDemo) and on_preview:
preview = self._render_load_node_preview(result, inputs or {})
if preview:
on_preview(node_id, preview)
return
return_types = getattr(cls, "RETURN_TYPES", ())
for slot, type_name in enumerate(return_types):
@@ -250,7 +294,7 @@ class ExecutionEngine:
value = result[slot]
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
arr = datafield_to_uint8(value, value.colormap)
arr = render_datafield_preview(value, value.colormap)
on_preview(node_id, encode_preview(arr))
return # one preview per node is enough
@@ -269,6 +313,39 @@ class ExecutionEngine:
on_table(node_id, value)
return
def _render_load_node_preview(
self,
result: tuple,
inputs: dict[str, Any],
) -> dict | None:
from backend.data_types import DataField, encode_preview, render_datafield_preview
from backend.nodes.io import list_channels
fields = [value for value in result if isinstance(value, DataField)]
if not fields:
return None
selected_path = str(inputs.get("path") or inputs.get("filename") or inputs.get("name") or "").strip()
channel_names: list[str] = []
if selected_path:
try:
channel_names = [str(entry.get("name", "")).strip() or "field" for entry in list_channels(selected_path)]
except Exception:
channel_names = []
layers = []
for index, field in enumerate(fields):
arr = render_datafield_preview(field, field.colormap)
layers.append({
"name": channel_names[index] if index < len(channel_names) else f"layer {index + 1}",
"image": encode_preview(arr),
})
return {
"kind": "layer_gallery",
"layers": layers,
}
def _render_line_preview(
self,

View File

@@ -7,10 +7,26 @@ before execution begins.
"""
from __future__ import annotations
import json
import numpy as np
from backend.node_registry import register_node
from backend.data_types import (
DataField, MeasureTable, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview, normalize_for_colormap,
DataField,
MeasureTable,
COLORMAPS,
CUSTOM_FILE_FONT,
DEFAULT_CUSTOM_COLORMAP_STOPS,
SYSTEM_DEFAULT_FONT,
colormap_to_uint8,
datafield_to_uint8,
encode_preview,
image_to_uint8,
list_overlay_font_choices,
normalize_colormap_spec,
normalize_font_spec,
normalize_for_colormap,
render_datafield_preview,
resolve_colormap_input,
)
@@ -59,15 +75,473 @@ def _scalar_payload(value: float, unit: str = "") -> dict:
return payload
_SI_PREFIXES = [
(1e24, "Y"),
(1e21, "Z"),
(1e18, "E"),
(1e15, "P"),
(1e12, "T"),
(1e9, "G"),
(1e6, "M"),
(1e3, "k"),
(1.0, ""),
(1e-3, "m"),
(1e-6, "u"),
(1e-9, "n"),
(1e-12, "p"),
(1e-15, "f"),
(1e-18, "a"),
(1e-21, "z"),
(1e-24, "y"),
]
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"}
def _format_numeric(value: float) -> str:
if not np.isfinite(value):
return str(value)
abs_value = abs(value)
if abs_value == 0:
return "0"
if abs_value >= 1e4 or abs_value < 1e-3:
return f"{value:.3e}"
return f"{value:.4g}"
def _format_with_unit(value: float, unit: str) -> str:
unit = (unit or "").strip()
if not unit:
return _format_numeric(value)
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
abs_value = abs(value)
for scale, prefix in _SI_PREFIXES:
scaled = abs_value / scale
if 1 <= scaled < 1000:
signed = value / scale
return f"{_format_numeric(signed)} {prefix}{unit}"
return f"{_format_numeric(value)} {unit}"
def _nice_length(target: float) -> float:
if not np.isfinite(target) or target <= 0:
return 0.0
exponent = np.floor(np.log10(target))
base = 10.0 ** exponent
for step in (5.0, 2.0, 1.0):
candidate = step * base
if candidate <= target:
return candidate
return base
def _display_value_range(field: DataField) -> tuple[float, float, float]:
data = np.asarray(field.data, dtype=np.float64)
dmin = float(data.min())
dmax = float(data.max())
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
return dmin, dmin, dmin
offset = float(field.display_offset)
scale = float(field.display_scale)
if not np.isfinite(offset):
offset = 0.0
if not np.isfinite(scale) or scale <= 0.0:
scale = 1.0
low_norm = float(np.clip(offset, 0.0, 1.0))
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
if high_norm < low_norm:
low_norm, high_norm = high_norm, low_norm
mid_norm = 0.5 * (low_norm + high_norm)
span = dmax - dmin
return (
dmin + low_norm * span,
dmin + mid_norm * span,
dmin + high_norm * span,
)
def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]):
from PIL import Image, ImageDraw, ImageFont
size_px = max(8, int(round(size_px)))
try:
font = ImageFont.truetype("DejaVuSans.ttf", size_px)
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
text_draw = ImageDraw.Draw(text_image)
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
return text_image
except Exception:
font = ImageFont.load_default()
probe = Image.new("L", (1, 1), 0)
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
mask = Image.new("L", (width, height), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
scale = max(1.0, size_px / max(1, height))
scaled_width = max(1, int(round(width * scale)))
scaled_height = max(1, int(round(height * scale)))
resampling = getattr(Image, "Resampling", Image)
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR)
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
text_image.putalpha(scaled_mask)
return text_image
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
if isinstance(color, str):
text = color.strip()
if len(text) == 4 and text.startswith("#"):
text = "#" + "".join(ch * 2 for ch in text[1:])
if len(text) == 7 and text.startswith("#"):
try:
int(text[1:], 16)
return text.lower()
except ValueError:
pass
return default
def _parse_markup_shapes(raw_shapes: str | list | None) -> list[dict[str, object]]:
if isinstance(raw_shapes, str):
try:
raw_shapes = json.loads(raw_shapes or "[]")
except json.JSONDecodeError:
raw_shapes = []
if not isinstance(raw_shapes, list):
return []
parsed: list[dict[str, object]] = []
for shape in raw_shapes:
if not isinstance(shape, dict):
continue
kind = str(shape.get("kind", "")).strip().lower()
if kind not in {"line", "rectangle", "circle", "arrow"}:
continue
try:
x1 = float(shape.get("x1"))
y1 = float(shape.get("y1"))
x2 = float(shape.get("x2"))
y2 = float(shape.get("y2"))
width = int(round(float(shape.get("width", 3))))
except (TypeError, ValueError):
continue
coords = [x1, y1, x2, y2]
if not all(np.isfinite(value) for value in coords):
continue
parsed.append({
"kind": kind,
"x1": float(np.clip(x1, 0.0, 1.0)),
"y1": float(np.clip(y1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y2": float(np.clip(y2, 0.0, 1.0)),
"width": max(1, min(128, width)),
"color": _normalize_markup_color(shape.get("color")),
})
return parsed
def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int):
dx = end[0] - start[0]
dy = end[1] - start[1]
length = float(np.hypot(dx, dy))
if length <= 1e-6:
radius = max(1.0, width / 2.0)
draw.ellipse(
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
fill=color,
)
return
ux = dx / length
uy = dy / length
head_length = max(10.0, width * 4.0)
head_width = max(8.0, width * 3.0)
shaft_end = (
end[0] - ux * head_length,
end[1] - uy * head_length,
)
draw.line((start, shaft_end), fill=color, width=width)
px = -uy
py = ux
left = (
shaft_end[0] + px * head_width / 2.0,
shaft_end[1] + py * head_width / 2.0,
)
right = (
shaft_end[0] - px * head_width / 2.0,
shaft_end[1] - py * head_width / 2.0,
)
draw.polygon([end, left, right], fill=color)
def _render_markup_image(image: np.ndarray, shapes: list[dict[str, object]]) -> np.ndarray:
from PIL import Image, ImageDraw
base = image_to_uint8(image)
if base.ndim == 2:
base = np.repeat(base[:, :, np.newaxis], 3, axis=2)
canvas = Image.fromarray(base.copy())
draw = ImageDraw.Draw(canvas)
height, width = base.shape[:2]
for shape in shapes:
x1 = float(shape["x1"]) * width
y1 = float(shape["y1"]) * height
x2 = float(shape["x2"]) * width
y2 = float(shape["y2"]) * height
color = str(shape["color"])
stroke_width = int(shape["width"])
kind = str(shape["kind"])
if kind == "line":
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
elif kind == "rectangle":
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "circle":
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "arrow":
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
return np.asarray(canvas, dtype=np.uint8)
@register_node(display_name="Color Map")
class ColorMap:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["preset", "custom"], {"default": "preset"}),
"preset": (list(COLORMAPS), {
"default": "viridis",
"show_when_widget_value": {"mode": ["preset"]},
}),
"stops": ("STRING", {
"default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)),
"colormap_stops": True,
"show_when_widget_value": {"mode": ["custom"]},
}),
}
}
RETURN_TYPES = ("COLORMAP",)
RETURN_NAMES = ("colormap",)
FUNCTION = "build"
CATEGORY = "display"
DESCRIPTION = (
"Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours "
"and any number of intermediate stops."
)
def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple:
if mode == "preset":
return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},)
try:
raw_stops = stops if stops is not None else stops_json
stops_data = json.loads(raw_stops or "[]")
except json.JSONDecodeError as exc:
raise ValueError("Custom colormap stops must be valid JSON.") from exc
spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None)
if not (isinstance(spec, dict) and spec.get("mode") == "custom"):
raise ValueError("Custom colormap must include at least min and max colours.")
return (spec,)
@register_node(display_name="Font")
class Font:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], {
"default": SYSTEM_DEFAULT_FONT,
}),
"font_file": ("FILE_PICKER", {
"default": "",
"show_when_widget_value": {"family": [CUSTOM_FILE_FONT]},
}),
}
}
RETURN_TYPES = ("FONT",)
RETURN_NAMES = ("font",)
FUNCTION = "build"
CATEGORY = "display"
DESCRIPTION = (
"Build a reusable font spec for annotation overlays. Choose a discovered system font, "
"use the default fallback stack, or point to a custom font file."
)
def build(self, family: str, font_file: str = "") -> tuple:
if family == SYSTEM_DEFAULT_FONT:
return (None,)
if family == CUSTOM_FILE_FONT:
return (normalize_font_spec({"path": font_file}),)
return (normalize_font_spec({"family": family}),)
@register_node(display_name="Annotations")
class Annotations:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"show_scale_bar": ("BOOLEAN", {"default": True}),
"show_color_map": ("BOOLEAN", {"default": True}),
"text_size": ("FLOAT", {
"default": 14.0,
"min": 6.0,
"max": 96.0,
"step": 1.0,
}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"font": ("FONT",),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "render"
CATEGORY = "display"
DESCRIPTION = (
"Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. "
"The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values."
)
def render(
self,
field: DataField,
colormap: str,
show_scale_bar: bool,
show_color_map: bool,
text_size: float = 1.0,
colormap_map=None,
font=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0
out = field.replace(
colormap=resolved_colormap,
overlays=[
*field.overlays,
{
"kind": "annotation",
"show_scale_bar": bool(show_scale_bar),
"show_color_map": bool(show_color_map),
"text_size": text_size,
"font": normalize_font_spec(font),
},
],
)
return (out,)
@register_node(display_name="Markup")
class Markup:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}),
"stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}),
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),
"clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}),
"markup_shapes": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "process"
CATEGORY = "display"
DESCRIPTION = (
"Draw simple vector markup over a DATA_FIELD without flattening the underlying data. "
"Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
shape: str,
stroke_color: str,
stroke_width: int,
markup_shapes: str,
) -> tuple:
shapes = _parse_markup_shapes(markup_shapes)
out = field.replace(
overlays=[
*field.overlays,
{
"kind": "markup",
"shapes": shapes,
},
],
)
if Markup._broadcast_overlay_fn is not None:
Markup._broadcast_overlay_fn(
Markup._current_node_id,
{
"kind": "markup",
"section_title": "Markup",
"image": encode_preview(datafield_to_uint8(field, field.colormap)),
"shape": str(shape),
"stroke_color": _normalize_markup_color(stroke_color),
"stroke_width": max(1, int(stroke_width)),
},
)
return (out,)
@register_node(display_name="Preview")
class PreviewImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"colormap": (["auto"] + list(COLORMAPS),),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"image": ("IMAGE",),
"field": ("DATA_FIELD",),
}
@@ -82,30 +556,35 @@ class PreviewImage:
_broadcast_fn = None
_current_node_id: str = ""
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
# Resolve "auto" — use field's colormap if available, else fall back to gray
if colormap == "auto":
colormap = field.colormap if field is not None else "gray"
def preview(
self,
colormap: str,
image: np.ndarray | None = None,
field=None,
colormap_map=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap if field is not None else None,
default="gray",
)
# Prefer field if both are connected; accept whichever is provided
if field is not None:
arr_u8 = datafield_to_uint8(field, colormap)
arr_u8 = render_datafield_preview(field, resolved_colormap)
elif image is not None:
if image.dtype != np.uint8:
imin, imax = image.min(), image.max()
if imax > imin:
norm = (image - imin) / (imax - imin)
arr_u8 = image_to_uint8(image)
if arr_u8.ndim == 2:
if image.dtype == np.uint8:
normalized = arr_u8.astype(np.float64) / 255.0
else:
norm = np.zeros_like(image)
arr_u8 = (norm * 255).astype(np.uint8)
else:
arr_u8 = image
if arr_u8.ndim == 2 and colormap != "gray":
import matplotlib.cm as cm
cmap = cm.get_cmap(colormap)
rgba = cmap(arr_u8.astype(np.float32) / 255.0)
arr_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
imin, imax = image.min(), image.max()
if imax > imin:
normalized = (image - imin) / (imax - imin)
else:
normalized = np.zeros_like(image, dtype=np.float64)
arr_u8 = colormap_to_uint8(normalized, resolved_colormap)
else:
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
@@ -124,10 +603,13 @@ class View3D:
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS),),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
}
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ()
@@ -144,9 +626,8 @@ class View3D:
def render(
self, field: DataField,
colormap: str, z_scale: float, resolution: int,
colormap: str, z_scale: float, resolution: int, colormap_map=None,
) -> tuple:
import matplotlib.cm as cm
import base64
data = field.data
@@ -168,10 +649,13 @@ class View3D:
data_max=float(field.data.max()),
)
cmap_name = field.colormap if colormap == "auto" else colormap
cmap = cm.get_cmap(cmap_name)
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
colors_u8 = colormap_to_uint8(z_norm, resolved_colormap)
# Base64-encode arrays for efficient WS transport
z_b64 = base64.b64encode(z.tobytes()).decode()

View File

@@ -4,11 +4,12 @@ I/O nodes: load and save images and SPM data.
from __future__ import annotations
import os
import re
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, COLORMAPS, encode_preview, image_to_uint8
from backend.data_types import COLORMAPS, DataField, encode_preview, image_to_uint8, resolve_colormap_input
from backend.runtime_paths import demo_dir, input_dir, output_dir
# Resolved at server startup so nodes know where to look
@@ -22,6 +23,7 @@ _DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS
# ---------------------------------------------------------------------------
@@ -105,6 +107,23 @@ def list_channels(filepath: str) -> list[dict]:
return [{"name": "field", "type": "DATA_FIELD"}]
def list_folder_paths(folderpath: str) -> list[dict]:
"""Return a folder DIRECTORY plus compatible image/array/SPM FILE_PATH outputs."""
path = _resolve_path(folderpath)
if not path.exists() or not path.is_dir():
return []
resolved_dir = str(path.resolve())
results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}]
for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()):
if not entry.is_file() or entry.name.startswith("."):
continue
if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS:
continue
results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())})
return results
# ---------------------------------------------------------------------------
# LoadFile (unified loader — replaces LoadImage + LoadSPM)
# ---------------------------------------------------------------------------
@@ -115,9 +134,13 @@ class LoadFile:
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
"colormap": (list(COLORMAPS),),
}
"filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"path": ("FILE_PATH", {"label": "path"}),
},
}
# Default outputs — overridden dynamically by the frontend for multi-channel files
@@ -136,26 +159,28 @@ class LoadFile:
_broadcast_warning_fn = None
_current_node_id = None
def load(self, filename: str, colormap: str = "viridis"):
if not filename or not filename.strip():
def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None):
selected_path = str(path).strip() if path is not None else str(filename).strip()
if not selected_path:
raise ValueError("No file selected — use Browse to pick a file.")
path = _resolve_path(filename)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
if path.is_dir():
raise IsADirectoryError(f"Expected a file, got a directory: {path}")
path_obj = _resolve_path(selected_path)
if not path_obj.exists():
raise FileNotFoundError(f"File not found: {path_obj}")
if path_obj.is_dir():
raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}")
ext = path.suffix.lower()
ext = path_obj.suffix.lower()
resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis")
if ext in _SPM_EXTENSIONS:
fields = self._load_spm_all(path, ext)
fields = self._load_spm_all(path_obj, ext)
for f in fields:
f.colormap = colormap
f.colormap = resolved_colormap
return tuple(fields)
# Image or array — uncalibrated, single output
field = self._load_image_or_array(path, ext)
field.colormap = colormap
field = self._load_image_or_array(path_obj, ext)
field.colormap = resolved_colormap
self._send_warning("Uncalibrated data — no physical dimensions.")
return (field,)
@@ -349,8 +374,11 @@ class LoadDemo:
return {
"required": {
"name": (choices,),
"colormap": (list(COLORMAPS),),
}
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ("DATA_FIELD",)
@@ -359,13 +387,38 @@ class LoadDemo:
CATEGORY = "io"
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
def load(self, name: str, colormap: str = "viridis"):
path = DEMO_DIR / name
if not path.exists():
raise FileNotFoundError(f"Demo file not found: {name}")
def load(self, name: str = "", colormap: str = "viridis", colormap_map=None):
loader = LoadFile()
return loader.load(filename=str(path), colormap=colormap)
demo_path = DEMO_DIR / name
if not demo_path.exists():
raise FileNotFoundError(f"Demo file not found: {name}")
return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map)
@register_node(display_name="Folder")
class Folder:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}),
}
}
RETURN_TYPES = ("DIRECTORY",)
RETURN_NAMES = ("directory",)
FUNCTION = "list_files"
CATEGORY = "io"
DESCRIPTION = (
"Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. "
"Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans."
)
def list_files(self, folder: str) -> tuple:
entries = list_folder_paths(folder)
if not entries:
return tuple()
return tuple(item["path"] for item in entries)
# ---------------------------------------------------------------------------
@@ -395,6 +448,36 @@ class Coordinate:
return ((float(x), float(y)),)
# ---------------------------------------------------------------------------
# Number
# ---------------------------------------------------------------------------
@register_node(display_name="Number")
class Number:
"""Provide a fixed scalar value that can feed FLOAT or INT widget sockets."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "step": 0.01}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "io"
DESCRIPTION = (
"Output a fixed numeric value. "
"When connected to FLOAT inputs the exact value is used; "
"INT inputs round to the nearest integer at execution time."
)
def process(self, value: float) -> tuple:
return (float(value),)
# ---------------------------------------------------------------------------
# RangeSlider
# ---------------------------------------------------------------------------
@@ -445,12 +528,32 @@ _MAX_SAVE_FIELDS = 8
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
optional = {}
optional = {
"directory": ("DIRECTORY", {"label": "directory"}),
}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("DATA_FIELD",)
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"})
optional[f"layer_name_{i}"] = ("STRING", {
"default": "",
"placeholder": "name",
"show_when_input_visible": f"field_{i}",
"inline_with_input": f"field_{i}",
"hide_label": True,
})
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
"filename": ("STRING", {
"default": "",
"placeholder": "filename",
"placement": "top",
}),
"directory_path": ("FOLDER_PICKER", {
"default": "",
"label": "directory",
"placement": "top",
"hide_when_input_connected": "directory",
"top_socket_input": "directory",
}),
"format": (["TIFF", "NPZ"],),
},
"optional": optional,
@@ -462,59 +565,130 @@ class SaveImage:
OUTPUT_NODE = True
MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more DATA_FIELD layers to a single file. "
"Connect fields to the inputs — a new slot appears as each is filled. "
"TIFF writes float32 multi-page; NPZ writes float64 named arrays. "
"Save one or more layers to a single file. "
"Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. "
"Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. "
"A new slot appears as each one is filled, with a matching per-layer name field. "
"TIFF writes multi-page data and stores layer names as page descriptions; "
"NPZ writes named arrays using those layer names as keys. "
"Click Save to write (does not auto-run)."
)
_broadcast_warning_fn = None
_current_node_id = None
def save(self, filename: str, format: str = "TIFF", **kwargs):
# Collect connected fields in order
fields = []
def save(
self,
filename: str,
directory_path: str = "",
format: str = "TIFF",
directory: str | None = None,
**kwargs,
):
layers = []
layer_names = []
for i in range(_MAX_SAVE_FIELDS):
f = kwargs.get(f"field_{i}")
if f is not None:
fields.append(f)
layer = kwargs.get(f"field_{i}")
if layer is not None:
layers.append(layer)
layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i))
if not fields:
raise ValueError("No fields connected — connect at least one DATA_FIELD input.")
if not layers:
raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.")
if not filename or not filename.strip():
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(filename)
# Ensure parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)
# Force correct extension
ext = ".tiff" if format == "TIFF" else ".npz"
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
path = self._resolve_save_path(filename, format, directory, directory_path)
if format == "TIFF":
self._save_tiff(path, fields)
self._save_tiff(path, layers, layer_names)
else:
self._save_npz(path, fields)
self._save_npz(path, layers, layer_names)
self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}")
self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}")
return ()
def _save_tiff(self, path: Path, fields: list[DataField]):
from PIL import Image
images = []
for f in fields:
images.append(Image.fromarray(f.data.astype(np.float32)))
images[0].save(str(path), save_all=True, append_images=images[1:])
def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
import tifffile
def _save_npz(self, path: Path, fields: list[DataField]):
with tifffile.TiffWriter(str(path)) as tif:
for layer, layer_name in zip(layers, layer_names):
tif.write(self._layer_array_for_tiff(layer), description=layer_name)
def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
arrays = {}
for i, f in enumerate(fields):
arrays[f"layer_{i}"] = f.data
used_keys = set()
for i, (layer, layer_name) in enumerate(zip(layers, layer_names)):
arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer)
np.savez(str(path), **arrays)
def _resolve_layer_name(self, raw_name: object, index: int) -> str:
text = str(raw_name).strip() if raw_name is not None else ""
return text or f"layer_{index}"
def _resolve_save_path(
self,
filename: str,
format: str,
directory: str | None,
directory_path: str = "",
) -> Path:
ext = ".tiff" if format == "TIFF" else ".npz"
raw_filename = str(filename).strip() if filename is not None else ""
raw_directory = str(directory).strip() if directory is not None else ""
if not raw_directory:
raw_directory = str(directory_path).strip() if directory_path is not None else ""
if raw_directory:
dir_path = Path(raw_directory).expanduser()
if dir_path.exists() and not dir_path.is_dir():
raise ValueError("Directory input expects a folder path, not a file path.")
if not dir_path.exists():
if dir_path.suffix:
raise ValueError("Directory input expects a folder path, not a file path.")
dir_path.mkdir(parents=True, exist_ok=True)
filename_part = Path(raw_filename).name if raw_filename else ""
if not filename_part:
raise ValueError("No output filename selected — enter a file name when using a directory input.")
path = dir_path / filename_part
else:
if not raw_filename:
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(raw_filename).expanduser()
path.parent.mkdir(parents=True, exist_ok=True)
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
return path
def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str:
key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_")
if not key:
key = f"layer_{index}"
if key[0].isdigit():
key = f"layer_{key}"
candidate = key
suffix = 2
while candidate in used_keys:
candidate = f"{key}_{suffix}"
suffix += 1
used_keys.add(candidate)
return candidate
def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data, dtype=np.float32)
if isinstance(layer, np.ndarray):
return image_to_uint8(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data)
if isinstance(layer, np.ndarray):
return np.asarray(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id

View File

@@ -143,6 +143,7 @@ class CropResizeField:
yreal=(py1 - py0) * field.dy,
xoff=field.xoff + px0 * field.dx,
yoff=field.yoff + py0 * field.dy,
overlays=[],
)
target_width, target_height = self._resolve_target_shape(
@@ -217,6 +218,9 @@ class RotateField:
"Optionally expand the canvas to keep the full rotated field while preserving the field center."
)
_broadcast_warning_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
@@ -224,6 +228,9 @@ class RotateField:
interpolation: str,
expand_canvas: bool,
) -> tuple:
if field.overlays:
self._send_warning("Rotate clears annotation/markup overlays!")
angle = float(angle)
order_map = {
"nearest": 0,
@@ -264,9 +271,16 @@ class RotateField:
yreal=new_yreal,
xoff=center_x - new_xreal / 2.0,
yoff=center_y - new_yreal / 2.0,
overlays=[],
)
return (result,)
def _send_warning(self, message: str):
fn = RotateField._broadcast_warning_fn
nid = RotateField._current_node_id
if fn and nid:
fn(nid, message)
@staticmethod
def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]:
if not expand_canvas:

View File

@@ -215,6 +215,13 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
content_type="application/json",
)
async def get_folder_files(request: web.Request) -> web.Response:
folder_path = request.query.get("folder", "")
from backend.nodes.io import list_folder_paths
loop = asyncio.get_running_loop()
entries = await loop.run_in_executor(None, list_folder_paths, folder_path)
return web.Response(text=_dumps(entries), content_type="application/json")
async def upload_file(request: web.Request) -> web.Response:
reader = await request.multipart()
field = await reader.next()
@@ -346,6 +353,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
app.router.add_get("/nodes", get_nodes)
app.router.add_get("/files", list_files)
app.router.add_get("/browse", browse_dir)
app.router.add_get("/folder-files", get_folder_files)
app.router.add_post("/upload", upload_file)
app.router.add_post("/download", download_file)
app.router.add_post("/save-workflow-png", save_workflow_png)