add folder, file nodes and major usability improvements
This commit is contained in:
@@ -12,13 +12,23 @@ DataField mirrors Gwyddion's GwyDataField structure:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
|
||||
COLORMAPS = ("viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain",
|
||||
"cividis", "magma", "copper", "afmhot")
|
||||
DEFAULT_CUSTOM_COLORMAP_STOPS = (
|
||||
{"position": 0.0, "color": "#440154"},
|
||||
{"position": 1.0, "color": "#fde725"},
|
||||
)
|
||||
SYSTEM_DEFAULT_FONT = "System Default"
|
||||
CUSTOM_FILE_FONT = "Custom File"
|
||||
PREVIEW_MARKUP_REFERENCE_DIM = 512
|
||||
|
||||
|
||||
class RecordTable(list):
|
||||
@@ -28,6 +38,147 @@ class RecordTable(list):
|
||||
class MeasureTable(list):
|
||||
"""Named scalar measurements, typically rows of quantity/value/unit."""
|
||||
|
||||
|
||||
def _normalize_hex_color(color: Any, default: str = "#000000") -> str:
|
||||
if isinstance(color, str):
|
||||
text = color.strip()
|
||||
if len(text) == 4 and text.startswith("#"):
|
||||
text = "#" + "".join(ch * 2 for ch in text[1:])
|
||||
if len(text) == 7 and text.startswith("#"):
|
||||
try:
|
||||
int(text[1:], 16)
|
||||
return text.lower()
|
||||
except ValueError:
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _hex_to_rgb01(color: str) -> tuple[float, float, float]:
|
||||
normalized = _normalize_hex_color(color)
|
||||
return (
|
||||
int(normalized[1:3], 16) / 255.0,
|
||||
int(normalized[3:5], 16) / 255.0,
|
||||
int(normalized[5:7], 16) / 255.0,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_custom_colormap_stops(raw_stops: Any) -> list[dict[str, float | str]] | None:
|
||||
if not isinstance(raw_stops, list):
|
||||
return None
|
||||
|
||||
parsed = []
|
||||
for stop in raw_stops:
|
||||
if not isinstance(stop, dict):
|
||||
continue
|
||||
try:
|
||||
position = float(stop.get("position"))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if not np.isfinite(position):
|
||||
continue
|
||||
parsed.append({
|
||||
"position": float(np.clip(position, 0.0, 1.0)),
|
||||
"color": _normalize_hex_color(stop.get("color"), "#000000"),
|
||||
})
|
||||
|
||||
if len(parsed) < 2:
|
||||
return None
|
||||
|
||||
parsed.sort(key=lambda stop: stop["position"])
|
||||
unique: list[dict[str, float | str]] = []
|
||||
for stop in parsed:
|
||||
if unique and abs(float(stop["position"]) - float(unique[-1]["position"])) < 1e-9:
|
||||
unique[-1] = stop
|
||||
else:
|
||||
unique.append(stop)
|
||||
|
||||
if len(unique) < 2:
|
||||
return None
|
||||
|
||||
unique[0] = {"position": 0.0, "color": unique[0]["color"]}
|
||||
unique[-1] = {"position": 1.0, "color": unique[-1]["color"]}
|
||||
return unique
|
||||
|
||||
|
||||
def normalize_colormap_spec(colormap: Any, fallback: Any = "viridis") -> str | dict[str, Any]:
|
||||
if isinstance(colormap, str):
|
||||
if colormap in COLORMAPS:
|
||||
return colormap
|
||||
elif isinstance(colormap, dict):
|
||||
mode = str(colormap.get("mode", "")).strip().lower()
|
||||
preset = colormap.get("preset")
|
||||
if isinstance(preset, str) and preset in COLORMAPS:
|
||||
if mode == "preset" or "stops" not in colormap:
|
||||
return {"mode": "preset", "preset": preset}
|
||||
stops = _normalize_custom_colormap_stops(colormap.get("stops"))
|
||||
if stops is not None:
|
||||
return {"mode": "custom", "stops": stops}
|
||||
|
||||
if fallback is colormap:
|
||||
return "viridis"
|
||||
return normalize_colormap_spec(fallback, fallback="viridis")
|
||||
|
||||
|
||||
def resolve_colormap_input(
|
||||
selection: Any = "auto",
|
||||
*,
|
||||
colormap_input: Any = None,
|
||||
inherited: Any = None,
|
||||
default: Any = "gray",
|
||||
) -> str | dict[str, Any]:
|
||||
if colormap_input is not None:
|
||||
return normalize_colormap_spec(colormap_input, fallback=inherited if inherited is not None else default)
|
||||
|
||||
if isinstance(selection, str) and selection != "auto":
|
||||
return normalize_colormap_spec(selection, fallback=inherited if inherited is not None else default)
|
||||
|
||||
if inherited is not None:
|
||||
return normalize_colormap_spec(inherited, fallback=default)
|
||||
|
||||
return normalize_colormap_spec(default, fallback="gray")
|
||||
|
||||
|
||||
def normalize_font_spec(font: Any) -> dict[str, str] | None:
|
||||
if isinstance(font, str):
|
||||
family = font.strip()
|
||||
return {"family": family, "path": ""} if family else None
|
||||
|
||||
if isinstance(font, dict):
|
||||
family = str(font.get("family", "")).strip()
|
||||
path = str(font.get("path", "")).strip()
|
||||
if family or path:
|
||||
return {"family": family, "path": path}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def colormap_to_uint8(normalized: np.ndarray, colormap: Any = "gray") -> np.ndarray:
|
||||
normalized = np.asarray(normalized, dtype=np.float64)
|
||||
spec = normalize_colormap_spec(colormap, fallback="gray")
|
||||
|
||||
if spec == "gray":
|
||||
grey = np.rint(normalized * 255.0).astype(np.uint8)
|
||||
rgb = np.empty(grey.shape + (3,), dtype=np.uint8)
|
||||
rgb[..., 0] = grey
|
||||
rgb[..., 1] = grey
|
||||
rgb[..., 2] = grey
|
||||
return rgb
|
||||
|
||||
if isinstance(spec, dict) and spec.get("mode") == "custom":
|
||||
stops = spec["stops"]
|
||||
positions = np.array([float(stop["position"]) for stop in stops], dtype=np.float64)
|
||||
colors = np.array([_hex_to_rgb01(str(stop["color"])) for stop in stops], dtype=np.float64)
|
||||
flat = normalized.reshape(-1)
|
||||
rgb = np.empty((flat.shape[0], 3), dtype=np.uint8)
|
||||
for channel in range(3):
|
||||
rgb[:, channel] = np.rint(np.interp(flat, positions, colors[:, channel]) * 255.0).astype(np.uint8)
|
||||
return rgb.reshape(normalized.shape + (3,))
|
||||
|
||||
cmap_name = spec["preset"] if isinstance(spec, dict) else spec
|
||||
cmap = _get_colormap(cmap_name)
|
||||
rgba = cmap(normalized)
|
||||
return (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
|
||||
@dataclass
|
||||
class DataField:
|
||||
data: np.ndarray # shape (yres, xres), dtype float64
|
||||
@@ -40,15 +191,17 @@ class DataField:
|
||||
si_unit_xy: str = "m"
|
||||
si_unit_z: str = "m"
|
||||
domain: str = "spatial" # "spatial" or "frequency"
|
||||
colormap: str = "viridis"
|
||||
colormap: str | dict[str, Any] = "viridis"
|
||||
display_offset: float = 0.0
|
||||
display_scale: float = 1.0
|
||||
overlays: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.data = np.asarray(self.data, dtype=np.float64)
|
||||
if self.data.ndim != 2:
|
||||
raise ValueError(f"DataField.data must be 2-D, got shape {self.data.shape}")
|
||||
self.yres, self.xres = self.data.shape
|
||||
self.overlays = deepcopy(self.overlays) if isinstance(self.overlays, list) else []
|
||||
|
||||
def copy(self) -> "DataField":
|
||||
"""Return a deep copy with independent data array."""
|
||||
@@ -66,6 +219,7 @@ class DataField:
|
||||
colormap=self.colormap,
|
||||
display_offset=self.display_offset,
|
||||
display_scale=self.display_scale,
|
||||
overlays=deepcopy(self.overlays),
|
||||
)
|
||||
|
||||
def replace(self, **kwargs) -> "DataField":
|
||||
@@ -84,6 +238,7 @@ class DataField:
|
||||
"colormap": self.colormap,
|
||||
"display_offset": self.display_offset,
|
||||
"display_scale": self.display_scale,
|
||||
"overlays": deepcopy(self.overlays),
|
||||
}
|
||||
base.update(kwargs)
|
||||
return DataField(**base)
|
||||
@@ -137,7 +292,7 @@ def normalize_for_colormap(
|
||||
return np.clip((base_norm - offset) / scale, 0.0, 1.0)
|
||||
|
||||
|
||||
def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
|
||||
def datafield_to_uint8(df: DataField, colormap: Any = "gray") -> np.ndarray:
|
||||
"""
|
||||
Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap.
|
||||
Returns shape (H, W, 3) uint8.
|
||||
@@ -147,19 +302,447 @@ def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
|
||||
offset=df.display_offset,
|
||||
scale=df.display_scale,
|
||||
)
|
||||
return colormap_to_uint8(normalized, colormap)
|
||||
|
||||
if colormap == "gray":
|
||||
grey = np.rint(normalized * 255.0).astype(np.uint8)
|
||||
rgb = np.empty(grey.shape + (3,), dtype=np.uint8)
|
||||
rgb[..., 0] = grey
|
||||
rgb[..., 1] = grey
|
||||
rgb[..., 2] = grey
|
||||
return rgb
|
||||
|
||||
cmap = _get_colormap(colormap)
|
||||
rgba = cmap(normalized) # (H, W, 4) float [0,1]
|
||||
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
return rgb
|
||||
_SI_PREFIXES = [
|
||||
(1e24, "Y"),
|
||||
(1e21, "Z"),
|
||||
(1e18, "E"),
|
||||
(1e15, "P"),
|
||||
(1e12, "T"),
|
||||
(1e9, "G"),
|
||||
(1e6, "M"),
|
||||
(1e3, "k"),
|
||||
(1.0, ""),
|
||||
(1e-3, "m"),
|
||||
(1e-6, "u"),
|
||||
(1e-9, "n"),
|
||||
(1e-12, "p"),
|
||||
(1e-15, "f"),
|
||||
(1e-18, "a"),
|
||||
(1e-21, "z"),
|
||||
(1e-24, "y"),
|
||||
]
|
||||
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"}
|
||||
|
||||
|
||||
def _format_numeric(value: float) -> str:
|
||||
if not np.isfinite(value):
|
||||
return str(value)
|
||||
abs_value = abs(value)
|
||||
if abs_value == 0:
|
||||
return "0"
|
||||
if abs_value >= 1e4 or abs_value < 1e-3:
|
||||
return f"{value:.3e}"
|
||||
return f"{value:.4g}"
|
||||
|
||||
|
||||
def _format_with_unit(value: float, unit: str) -> str:
|
||||
unit = (unit or "").strip()
|
||||
if not unit:
|
||||
return _format_numeric(value)
|
||||
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
|
||||
abs_value = abs(value)
|
||||
for scale, prefix in _SI_PREFIXES:
|
||||
scaled = abs_value / scale
|
||||
if 1 <= scaled < 1000:
|
||||
signed = value / scale
|
||||
return f"{_format_numeric(signed)} {prefix}{unit}"
|
||||
return f"{_format_numeric(value)} {unit}"
|
||||
|
||||
|
||||
def _nice_length(target: float) -> float:
|
||||
if not np.isfinite(target) or target <= 0:
|
||||
return 0.0
|
||||
exponent = np.floor(np.log10(target))
|
||||
base = 10.0 ** exponent
|
||||
for step in (5.0, 2.0, 1.0):
|
||||
candidate = step * base
|
||||
if candidate <= target:
|
||||
return candidate
|
||||
return base
|
||||
|
||||
|
||||
def _display_value_range(field: DataField) -> tuple[float, float, float]:
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
dmin = float(data.min())
|
||||
dmax = float(data.max())
|
||||
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
|
||||
return dmin, dmin, dmin
|
||||
|
||||
offset = float(field.display_offset)
|
||||
scale = float(field.display_scale)
|
||||
if not np.isfinite(offset):
|
||||
offset = 0.0
|
||||
if not np.isfinite(scale) or scale <= 0.0:
|
||||
scale = 1.0
|
||||
|
||||
low_norm = float(np.clip(offset, 0.0, 1.0))
|
||||
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
|
||||
if high_norm < low_norm:
|
||||
low_norm, high_norm = high_norm, low_norm
|
||||
mid_norm = 0.5 * (low_norm + high_norm)
|
||||
|
||||
span = dmax - dmin
|
||||
return (
|
||||
dmin + low_norm * span,
|
||||
dmin + mid_norm * span,
|
||||
dmin + high_norm * span,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_font_match_key(text: str) -> str:
|
||||
return "".join(ch for ch in text.lower() if ch.isalnum())
|
||||
|
||||
|
||||
def _render_overlay_text(
|
||||
text: str,
|
||||
size_px: int,
|
||||
color: tuple[int, int, int],
|
||||
font_spec: Any = None,
|
||||
):
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
size_px = max(8, int(round(size_px)))
|
||||
font = _load_overlay_font(size_px, ImageFont, font_spec=font_spec)
|
||||
if font is not None:
|
||||
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
|
||||
probe_draw = ImageDraw.Draw(probe)
|
||||
bbox = probe_draw.textbbox((0, 0), text, font=font)
|
||||
width = max(1, bbox[2] - bbox[0])
|
||||
height = max(1, bbox[3] - bbox[1])
|
||||
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
text_draw = ImageDraw.Draw(text_image)
|
||||
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
|
||||
return text_image
|
||||
|
||||
font = ImageFont.load_default()
|
||||
probe = Image.new("L", (1, 1), 0)
|
||||
probe_draw = ImageDraw.Draw(probe)
|
||||
bbox = probe_draw.textbbox((0, 0), text, font=font)
|
||||
width = max(1, bbox[2] - bbox[0])
|
||||
height = max(1, bbox[3] - bbox[1])
|
||||
mask = Image.new("L", (width, height), 0)
|
||||
mask_draw = ImageDraw.Draw(mask)
|
||||
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
|
||||
scale = max(1.0, size_px / max(1, height))
|
||||
scaled_width = max(1, int(round(width * scale)))
|
||||
scaled_height = max(1, int(round(height * scale)))
|
||||
resampling = getattr(Image, "Resampling", Image)
|
||||
# Preserve edge sharpness if we ever have to scale the bitmap fallback font.
|
||||
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.NEAREST)
|
||||
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
|
||||
text_image.putalpha(scaled_mask)
|
||||
return text_image
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _overlay_font_candidates() -> tuple[str, ...]:
|
||||
candidates: list[str] = []
|
||||
|
||||
try:
|
||||
import PIL
|
||||
|
||||
pil_dir = Path(PIL.__file__).resolve().parent
|
||||
candidates.extend([
|
||||
str(pil_dir / "fonts" / "DejaVuSans.ttf"),
|
||||
str(pil_dir / "Tests" / "fonts" / "DejaVuSans.ttf"),
|
||||
])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
candidates.extend([
|
||||
"/System/Library/Fonts/Supplemental/Arial.ttf",
|
||||
"/System/Library/Fonts/Supplemental/Helvetica.ttc",
|
||||
"/System/Library/Fonts/Supplemental/Times New Roman.ttf",
|
||||
"/Library/Fonts/Arial.ttf",
|
||||
"/Library/Fonts/Helvetica.ttc",
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
|
||||
"/usr/share/fonts/truetype/liberation2/LiberationSans-Regular.ttf",
|
||||
])
|
||||
|
||||
unique: list[str] = []
|
||||
for candidate in candidates:
|
||||
if candidate not in unique and Path(candidate).exists():
|
||||
unique.append(candidate)
|
||||
return tuple(unique)
|
||||
|
||||
|
||||
def list_overlay_font_choices() -> tuple[str, ...]:
|
||||
labels: list[str] = []
|
||||
for candidate in _overlay_font_candidates():
|
||||
label = Path(candidate).stem
|
||||
if label and label not in labels:
|
||||
labels.append(label)
|
||||
return tuple(labels)
|
||||
|
||||
|
||||
def _matching_overlay_font_candidates(family: str) -> list[str]:
|
||||
key = _normalize_font_match_key(family)
|
||||
if not key:
|
||||
return []
|
||||
|
||||
matches: list[str] = []
|
||||
for candidate in _overlay_font_candidates():
|
||||
stem_key = _normalize_font_match_key(Path(candidate).stem)
|
||||
if key == stem_key or key in stem_key or stem_key in key:
|
||||
matches.append(candidate)
|
||||
return matches
|
||||
|
||||
|
||||
def _load_overlay_font(size_px: int, image_font_module, font_spec: Any = None) -> Any:
|
||||
normalized = normalize_font_spec(font_spec)
|
||||
|
||||
if normalized is not None:
|
||||
requested_path = normalized.get("path", "")
|
||||
requested_family = normalized.get("family", "")
|
||||
|
||||
if requested_path:
|
||||
try:
|
||||
return image_font_module.truetype(requested_path, size_px)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if requested_family:
|
||||
try:
|
||||
return image_font_module.truetype(requested_family, size_px)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for candidate in _matching_overlay_font_candidates(requested_family):
|
||||
try:
|
||||
return image_font_module.truetype(candidate, size_px)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for candidate in _overlay_font_candidates():
|
||||
try:
|
||||
return image_font_module.truetype(candidate, size_px)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
try:
|
||||
return image_font_module.truetype("DejaVuSans.ttf", size_px)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
|
||||
if isinstance(color, str):
|
||||
text = color.strip()
|
||||
if len(text) == 4 and text.startswith("#"):
|
||||
text = "#" + "".join(ch * 2 for ch in text[1:])
|
||||
if len(text) == 7 and text.startswith("#"):
|
||||
try:
|
||||
int(text[1:], 16)
|
||||
return text.lower()
|
||||
except ValueError:
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int):
|
||||
dx = end[0] - start[0]
|
||||
dy = end[1] - start[1]
|
||||
length = float(np.hypot(dx, dy))
|
||||
if length <= 1e-6:
|
||||
radius = max(1.0, width / 2.0)
|
||||
draw.ellipse(
|
||||
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
|
||||
fill=color,
|
||||
)
|
||||
return
|
||||
|
||||
ux = dx / length
|
||||
uy = dy / length
|
||||
head_length = max(10.0, width * 4.0)
|
||||
head_width = max(8.0, width * 3.0)
|
||||
shaft_end = (
|
||||
end[0] - ux * head_length,
|
||||
end[1] - uy * head_length,
|
||||
)
|
||||
draw.line((start, shaft_end), fill=color, width=width)
|
||||
px = -uy
|
||||
py = ux
|
||||
left = (shaft_end[0] + px * head_width / 2.0, shaft_end[1] + py * head_width / 2.0)
|
||||
right = (shaft_end[0] - px * head_width / 2.0, shaft_end[1] - py * head_width / 2.0)
|
||||
draw.polygon([end, left, right], fill=color)
|
||||
|
||||
|
||||
def _preview_markup_stroke_width(width: int, image_width: int, image_height: int) -> int:
|
||||
width = max(1, int(width))
|
||||
longest_dim = max(1, int(image_width), int(image_height))
|
||||
scale = max(1.0, longest_dim / float(PREVIEW_MARKUP_REFERENCE_DIM))
|
||||
return max(1, int(round(width * scale)))
|
||||
|
||||
|
||||
def _sanitize_markup_shapes(shapes: Any) -> list[dict[str, Any]]:
|
||||
if not isinstance(shapes, list):
|
||||
return []
|
||||
|
||||
parsed = []
|
||||
for shape in shapes:
|
||||
if not isinstance(shape, dict):
|
||||
continue
|
||||
kind = str(shape.get("kind", "")).strip().lower()
|
||||
if kind not in {"line", "rectangle", "circle", "arrow"}:
|
||||
continue
|
||||
try:
|
||||
x1 = float(shape.get("x1"))
|
||||
y1 = float(shape.get("y1"))
|
||||
x2 = float(shape.get("x2"))
|
||||
y2 = float(shape.get("y2"))
|
||||
width = int(round(float(shape.get("width", 3))))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if not all(np.isfinite(value) for value in (x1, y1, x2, y2)):
|
||||
continue
|
||||
parsed.append({
|
||||
"kind": kind,
|
||||
"x1": float(np.clip(x1, 0.0, 1.0)),
|
||||
"y1": float(np.clip(y1, 0.0, 1.0)),
|
||||
"x2": float(np.clip(x2, 0.0, 1.0)),
|
||||
"y2": float(np.clip(y2, 0.0, 1.0)),
|
||||
"width": max(1, min(128, width)),
|
||||
"color": _normalize_markup_color(shape.get("color")),
|
||||
})
|
||||
return parsed
|
||||
|
||||
|
||||
def _apply_annotation_overlay(
|
||||
image: np.ndarray,
|
||||
field: DataField,
|
||||
colormap: Any,
|
||||
spec: dict[str, Any],
|
||||
) -> np.ndarray:
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
show_scale_bar = bool(spec.get("show_scale_bar", True))
|
||||
show_color_map = bool(spec.get("show_color_map", True))
|
||||
text_size = float(spec.get("text_size", 14.0))
|
||||
text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0
|
||||
font_spec = normalize_font_spec(spec.get("font"))
|
||||
|
||||
current = np.asarray(image, dtype=np.uint8)
|
||||
if current.ndim == 2:
|
||||
current = np.repeat(current[:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
height, current_width = current.shape[:2]
|
||||
field_width = max(1, int(field.xres))
|
||||
legend_width = max(72, int(round(field_width * 0.18))) if show_color_map else 0
|
||||
canvas_width = current_width + legend_width
|
||||
canvas = np.full((height, canvas_width, 3), 255, dtype=np.uint8)
|
||||
canvas[:, :current_width] = current
|
||||
|
||||
pil_image = Image.fromarray(canvas)
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
base_font_px = max(6, int(round(text_size)))
|
||||
|
||||
if show_scale_bar and field_width > 0 and np.isfinite(field.xreal) and field.xreal > 0:
|
||||
target_real = field.xreal / 5.0
|
||||
bar_real = _nice_length(target_real)
|
||||
if bar_real > 0 and np.isfinite(field.dx) and field.dx > 0:
|
||||
bar_px = max(1, int(round(bar_real / field.dx)))
|
||||
margin_x = max(8, field_width // 24)
|
||||
margin_y = max(8, height // 24)
|
||||
bar_height = max(3, int(round(height * 0.012)))
|
||||
bar_px = min(bar_px, max(1, field_width - 2 * margin_x))
|
||||
x0 = margin_x
|
||||
x1 = x0 + bar_px
|
||||
y1 = height - margin_y
|
||||
y0 = y1 - bar_height
|
||||
text = _format_with_unit(bar_real, field.si_unit_xy)
|
||||
text_image = _render_overlay_text(text, base_font_px, (255, 255, 255), font_spec=font_spec)
|
||||
text_w, text_h = text_image.size
|
||||
label_pad = 2
|
||||
bg_left = max(0, x0 - 4)
|
||||
bg_top = max(0, y0 - text_h - label_pad * 3)
|
||||
bg_right = min(canvas_width, max(x1 + 4, x0 + text_w + 8))
|
||||
bg_bottom = min(height, y1 + 4)
|
||||
draw.rectangle((bg_left, bg_top, bg_right, bg_bottom), fill=(0, 0, 0))
|
||||
draw.rectangle((x0, y0, x1, y1), fill=(255, 255, 255))
|
||||
pil_image.paste(text_image, (x0, bg_top + label_pad), text_image)
|
||||
|
||||
if show_color_map and legend_width > 0:
|
||||
panel_x0 = current_width
|
||||
draw.rectangle((panel_x0, 0, canvas_width, height), fill=(245, 245, 245))
|
||||
grad_x0 = panel_x0 + max(8, legend_width // 7)
|
||||
grad_w = max(12, legend_width // 5)
|
||||
grad_y0 = max(10, height // 18)
|
||||
grad_y1 = max(grad_y0 + 10, height - grad_y0)
|
||||
grad_h = grad_y1 - grad_y0
|
||||
gradient = np.linspace(1.0, 0.0, grad_h, dtype=np.float64)[:, np.newaxis]
|
||||
gradient = np.repeat(gradient, grad_w, axis=1)
|
||||
gradient_rgb = colormap_to_uint8(gradient, colormap)
|
||||
pil_image.paste(Image.fromarray(gradient_rgb), (grad_x0, grad_y0))
|
||||
draw.rectangle((grad_x0, grad_y0, grad_x0 + grad_w, grad_y1), outline=(40, 40, 40), width=1)
|
||||
|
||||
legend_min, legend_mid, legend_max = _display_value_range(field)
|
||||
labels = [
|
||||
(legend_max, grad_y0),
|
||||
(legend_mid, grad_y0 + grad_h // 2),
|
||||
(legend_min, grad_y1),
|
||||
]
|
||||
text_x = grad_x0 + grad_w + 8
|
||||
for value, y_center in labels:
|
||||
text_image = _render_overlay_text(
|
||||
_format_with_unit(value, field.si_unit_z),
|
||||
base_font_px,
|
||||
(20, 20, 20),
|
||||
font_spec=font_spec,
|
||||
)
|
||||
text_y = int(round(y_center - text_image.size[1] / 2))
|
||||
text_y = max(0, min(height - text_image.size[1], text_y))
|
||||
pil_image.paste(text_image, (text_x, text_y), text_image)
|
||||
|
||||
return np.asarray(pil_image, dtype=np.uint8)
|
||||
|
||||
|
||||
def _apply_markup_overlay(image: np.ndarray, field: DataField, spec: dict[str, Any]) -> np.ndarray:
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
current = np.asarray(image, dtype=np.uint8)
|
||||
if current.ndim == 2:
|
||||
current = np.repeat(current[:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
pil_image = Image.fromarray(current.copy())
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
field_width = max(1, int(field.xres))
|
||||
field_height = max(1, int(field.yres))
|
||||
|
||||
for shape in _sanitize_markup_shapes(spec.get("shapes")):
|
||||
x1 = float(shape["x1"]) * field_width
|
||||
y1 = float(shape["y1"]) * field_height
|
||||
x2 = float(shape["x2"]) * field_width
|
||||
y2 = float(shape["y2"]) * field_height
|
||||
color = str(shape["color"])
|
||||
stroke_width = _preview_markup_stroke_width(int(shape["width"]), field_width, field_height)
|
||||
kind = str(shape["kind"])
|
||||
|
||||
if kind == "line":
|
||||
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
|
||||
elif kind == "rectangle":
|
||||
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
|
||||
elif kind == "circle":
|
||||
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
|
||||
elif kind == "arrow":
|
||||
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
|
||||
|
||||
return np.asarray(pil_image, dtype=np.uint8)
|
||||
|
||||
|
||||
def render_datafield_preview(df: DataField, colormap: Any = "gray") -> np.ndarray:
|
||||
current = datafield_to_uint8(df, colormap)
|
||||
for overlay in df.overlays:
|
||||
if not isinstance(overlay, dict):
|
||||
continue
|
||||
kind = str(overlay.get("kind", "")).strip().lower()
|
||||
if kind == "annotation":
|
||||
current = _apply_annotation_overlay(current, df, colormap, overlay)
|
||||
elif kind == "markup":
|
||||
current = _apply_markup_overlay(current, df, overlay)
|
||||
return current
|
||||
|
||||
|
||||
def image_to_uint8(image: np.ndarray) -> np.ndarray:
|
||||
@@ -195,5 +778,5 @@ def encode_preview(arr: np.ndarray) -> str:
|
||||
|
||||
@lru_cache(maxsize=len(COLORMAPS))
|
||||
def _get_colormap(colormap: str):
|
||||
import matplotlib.cm as cm
|
||||
return cm.get_cmap(colormap)
|
||||
from matplotlib import colormaps
|
||||
return colormaps[colormap]
|
||||
|
||||
@@ -23,6 +23,7 @@ The engine:
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from math import isfinite
|
||||
from time import perf_counter
|
||||
from typing import Any, Callable
|
||||
|
||||
@@ -87,7 +88,8 @@ class ExecutionEngine:
|
||||
|
||||
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||
raw_inputs = node_def.get("inputs", {})
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs)
|
||||
input_types = cls.INPUT_TYPES()
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
|
||||
|
||||
# Let display nodes know their node_id so they can tag WS messages
|
||||
self._set_node_id_on_display(cls, node_id)
|
||||
@@ -110,7 +112,7 @@ class ExecutionEngine:
|
||||
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
|
||||
# IMAGE, or table-like output so every node shows its result.
|
||||
if on_preview or on_table:
|
||||
self._auto_preview(cls, node_id, result, on_preview, on_table)
|
||||
self._auto_preview(cls, node_id, result, on_preview, on_table, inputs)
|
||||
|
||||
if on_node_done:
|
||||
on_node_done(node_id, elapsed_ms)
|
||||
@@ -154,8 +156,14 @@ class ExecutionEngine:
|
||||
self,
|
||||
raw_inputs: dict[str, Any],
|
||||
node_outputs: dict[str, tuple],
|
||||
input_types: dict[str, dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Replace [src_id, slot] links with actual output values."""
|
||||
specs = {}
|
||||
if input_types:
|
||||
specs.update(input_types.get("required", {}))
|
||||
specs.update(input_types.get("optional", {}))
|
||||
|
||||
resolved = {}
|
||||
for key, value in raw_inputs.items():
|
||||
if _is_link(value):
|
||||
@@ -170,11 +178,36 @@ class ExecutionEngine:
|
||||
f"Node '{src_id}' only has {len(outputs)} outputs, "
|
||||
f"but slot {slot} was requested."
|
||||
)
|
||||
resolved[key] = outputs[slot]
|
||||
resolved_value = outputs[slot]
|
||||
else:
|
||||
resolved[key] = value
|
||||
resolved_value = value
|
||||
|
||||
resolved[key] = self._coerce_input_value(resolved_value, specs.get(key))
|
||||
return resolved
|
||||
|
||||
def _coerce_input_value(self, value: Any, spec: Any) -> Any:
|
||||
if spec is None:
|
||||
return value
|
||||
|
||||
input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec
|
||||
if isinstance(input_type, list):
|
||||
return value
|
||||
|
||||
if input_type == "INT":
|
||||
numeric = float(value)
|
||||
if not isfinite(numeric):
|
||||
raise ValueError(f"Expected a finite numeric value for INT input, got {value!r}")
|
||||
rounded = int(abs(numeric) + 0.5)
|
||||
return rounded if numeric >= 0 else -rounded
|
||||
|
||||
if input_type == "FLOAT":
|
||||
numeric = float(value)
|
||||
if not isfinite(numeric):
|
||||
raise ValueError(f"Expected a finite numeric value for FLOAT input, got {value!r}")
|
||||
return numeric
|
||||
|
||||
return value
|
||||
|
||||
def _inject_display_callbacks(
|
||||
self,
|
||||
on_preview: Callable | None,
|
||||
@@ -185,11 +218,11 @@ class ExecutionEngine:
|
||||
on_warning: Callable | None = None,
|
||||
) -> None:
|
||||
"""Wire up broadcast callbacks on display node classes."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
|
||||
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
|
||||
from backend.nodes.modify import CropResizeField
|
||||
from backend.nodes.modify import CropResizeField, RotateField
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
|
||||
from backend.nodes.io import SaveImage, LoadFile
|
||||
from backend.nodes.io import SaveImage, LoadFile, LoadDemo
|
||||
|
||||
PreviewImage._broadcast_fn = on_preview
|
||||
ThresholdMask._broadcast_fn = on_preview
|
||||
@@ -206,19 +239,22 @@ class ExecutionEngine:
|
||||
CrossSection._broadcast_overlay_fn = on_overlay
|
||||
LineCursors._broadcast_overlay_fn = on_overlay
|
||||
CropResizeField._broadcast_overlay_fn = on_overlay
|
||||
RotateField._broadcast_warning_fn = on_warning
|
||||
Markup._broadcast_overlay_fn = on_overlay
|
||||
LoadFile._broadcast_warning_fn = on_warning
|
||||
LoadDemo._broadcast_warning_fn = on_warning
|
||||
SaveImage._broadcast_warning_fn = on_warning
|
||||
|
||||
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||
"""Inform display nodes of their current node_id for WS tagging."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
|
||||
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
|
||||
from backend.nodes.modify import CropResizeField
|
||||
from backend.nodes.modify import CropResizeField, RotateField
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
|
||||
from backend.nodes.io import LoadFile, SaveImage
|
||||
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField,
|
||||
from backend.nodes.io import LoadFile, LoadDemo, SaveImage
|
||||
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup,
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
|
||||
LoadFile, SaveImage):
|
||||
LoadFile, LoadDemo, SaveImage):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
def _auto_preview(
|
||||
@@ -228,6 +264,7 @@ class ExecutionEngine:
|
||||
result: tuple,
|
||||
on_preview: Callable | None,
|
||||
on_table: Callable | None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
After every node executes, inspect its outputs and broadcast
|
||||
@@ -236,12 +273,19 @@ class ExecutionEngine:
|
||||
"""
|
||||
import numpy as np
|
||||
from backend.data_types import (
|
||||
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
|
||||
DataField, image_to_uint8, encode_preview, render_datafield_preview,
|
||||
)
|
||||
from backend.nodes.io import LoadFile, LoadDemo
|
||||
|
||||
if getattr(cls, "_CUSTOM_PREVIEW", False):
|
||||
return
|
||||
|
||||
if cls in (LoadFile, LoadDemo) and on_preview:
|
||||
preview = self._render_load_node_preview(result, inputs or {})
|
||||
if preview:
|
||||
on_preview(node_id, preview)
|
||||
return
|
||||
|
||||
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||
|
||||
for slot, type_name in enumerate(return_types):
|
||||
@@ -250,7 +294,7 @@ class ExecutionEngine:
|
||||
value = result[slot]
|
||||
|
||||
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
|
||||
arr = datafield_to_uint8(value, value.colormap)
|
||||
arr = render_datafield_preview(value, value.colormap)
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return # one preview per node is enough
|
||||
|
||||
@@ -269,6 +313,39 @@ class ExecutionEngine:
|
||||
on_table(node_id, value)
|
||||
return
|
||||
|
||||
def _render_load_node_preview(
|
||||
self,
|
||||
result: tuple,
|
||||
inputs: dict[str, Any],
|
||||
) -> dict | None:
|
||||
from backend.data_types import DataField, encode_preview, render_datafield_preview
|
||||
from backend.nodes.io import list_channels
|
||||
|
||||
fields = [value for value in result if isinstance(value, DataField)]
|
||||
if not fields:
|
||||
return None
|
||||
|
||||
selected_path = str(inputs.get("path") or inputs.get("filename") or inputs.get("name") or "").strip()
|
||||
channel_names: list[str] = []
|
||||
if selected_path:
|
||||
try:
|
||||
channel_names = [str(entry.get("name", "")).strip() or "field" for entry in list_channels(selected_path)]
|
||||
except Exception:
|
||||
channel_names = []
|
||||
|
||||
layers = []
|
||||
for index, field in enumerate(fields):
|
||||
arr = render_datafield_preview(field, field.colormap)
|
||||
layers.append({
|
||||
"name": channel_names[index] if index < len(channel_names) else f"layer {index + 1}",
|
||||
"image": encode_preview(arr),
|
||||
})
|
||||
|
||||
return {
|
||||
"kind": "layer_gallery",
|
||||
"layers": layers,
|
||||
}
|
||||
|
||||
|
||||
def _render_line_preview(
|
||||
self,
|
||||
|
||||
@@ -7,10 +7,26 @@ before execution begins.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import (
|
||||
DataField, MeasureTable, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview, normalize_for_colormap,
|
||||
DataField,
|
||||
MeasureTable,
|
||||
COLORMAPS,
|
||||
CUSTOM_FILE_FONT,
|
||||
DEFAULT_CUSTOM_COLORMAP_STOPS,
|
||||
SYSTEM_DEFAULT_FONT,
|
||||
colormap_to_uint8,
|
||||
datafield_to_uint8,
|
||||
encode_preview,
|
||||
image_to_uint8,
|
||||
list_overlay_font_choices,
|
||||
normalize_colormap_spec,
|
||||
normalize_font_spec,
|
||||
normalize_for_colormap,
|
||||
render_datafield_preview,
|
||||
resolve_colormap_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -59,15 +75,473 @@ def _scalar_payload(value: float, unit: str = "") -> dict:
|
||||
return payload
|
||||
|
||||
|
||||
_SI_PREFIXES = [
|
||||
(1e24, "Y"),
|
||||
(1e21, "Z"),
|
||||
(1e18, "E"),
|
||||
(1e15, "P"),
|
||||
(1e12, "T"),
|
||||
(1e9, "G"),
|
||||
(1e6, "M"),
|
||||
(1e3, "k"),
|
||||
(1.0, ""),
|
||||
(1e-3, "m"),
|
||||
(1e-6, "u"),
|
||||
(1e-9, "n"),
|
||||
(1e-12, "p"),
|
||||
(1e-15, "f"),
|
||||
(1e-18, "a"),
|
||||
(1e-21, "z"),
|
||||
(1e-24, "y"),
|
||||
]
|
||||
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"}
|
||||
|
||||
|
||||
def _format_numeric(value: float) -> str:
|
||||
if not np.isfinite(value):
|
||||
return str(value)
|
||||
abs_value = abs(value)
|
||||
if abs_value == 0:
|
||||
return "0"
|
||||
if abs_value >= 1e4 or abs_value < 1e-3:
|
||||
return f"{value:.3e}"
|
||||
return f"{value:.4g}"
|
||||
|
||||
|
||||
def _format_with_unit(value: float, unit: str) -> str:
|
||||
unit = (unit or "").strip()
|
||||
if not unit:
|
||||
return _format_numeric(value)
|
||||
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
|
||||
abs_value = abs(value)
|
||||
for scale, prefix in _SI_PREFIXES:
|
||||
scaled = abs_value / scale
|
||||
if 1 <= scaled < 1000:
|
||||
signed = value / scale
|
||||
return f"{_format_numeric(signed)} {prefix}{unit}"
|
||||
return f"{_format_numeric(value)} {unit}"
|
||||
|
||||
|
||||
def _nice_length(target: float) -> float:
|
||||
if not np.isfinite(target) or target <= 0:
|
||||
return 0.0
|
||||
exponent = np.floor(np.log10(target))
|
||||
base = 10.0 ** exponent
|
||||
for step in (5.0, 2.0, 1.0):
|
||||
candidate = step * base
|
||||
if candidate <= target:
|
||||
return candidate
|
||||
return base
|
||||
|
||||
|
||||
def _display_value_range(field: DataField) -> tuple[float, float, float]:
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
dmin = float(data.min())
|
||||
dmax = float(data.max())
|
||||
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
|
||||
return dmin, dmin, dmin
|
||||
|
||||
offset = float(field.display_offset)
|
||||
scale = float(field.display_scale)
|
||||
if not np.isfinite(offset):
|
||||
offset = 0.0
|
||||
if not np.isfinite(scale) or scale <= 0.0:
|
||||
scale = 1.0
|
||||
|
||||
low_norm = float(np.clip(offset, 0.0, 1.0))
|
||||
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
|
||||
if high_norm < low_norm:
|
||||
low_norm, high_norm = high_norm, low_norm
|
||||
mid_norm = 0.5 * (low_norm + high_norm)
|
||||
|
||||
span = dmax - dmin
|
||||
return (
|
||||
dmin + low_norm * span,
|
||||
dmin + mid_norm * span,
|
||||
dmin + high_norm * span,
|
||||
)
|
||||
|
||||
|
||||
def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]):
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
size_px = max(8, int(round(size_px)))
|
||||
try:
|
||||
font = ImageFont.truetype("DejaVuSans.ttf", size_px)
|
||||
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
|
||||
probe_draw = ImageDraw.Draw(probe)
|
||||
bbox = probe_draw.textbbox((0, 0), text, font=font)
|
||||
width = max(1, bbox[2] - bbox[0])
|
||||
height = max(1, bbox[3] - bbox[1])
|
||||
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
text_draw = ImageDraw.Draw(text_image)
|
||||
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
|
||||
return text_image
|
||||
except Exception:
|
||||
font = ImageFont.load_default()
|
||||
probe = Image.new("L", (1, 1), 0)
|
||||
probe_draw = ImageDraw.Draw(probe)
|
||||
bbox = probe_draw.textbbox((0, 0), text, font=font)
|
||||
width = max(1, bbox[2] - bbox[0])
|
||||
height = max(1, bbox[3] - bbox[1])
|
||||
mask = Image.new("L", (width, height), 0)
|
||||
mask_draw = ImageDraw.Draw(mask)
|
||||
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
|
||||
|
||||
scale = max(1.0, size_px / max(1, height))
|
||||
scaled_width = max(1, int(round(width * scale)))
|
||||
scaled_height = max(1, int(round(height * scale)))
|
||||
resampling = getattr(Image, "Resampling", Image)
|
||||
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR)
|
||||
|
||||
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
|
||||
text_image.putalpha(scaled_mask)
|
||||
return text_image
|
||||
|
||||
|
||||
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
|
||||
if isinstance(color, str):
|
||||
text = color.strip()
|
||||
if len(text) == 4 and text.startswith("#"):
|
||||
text = "#" + "".join(ch * 2 for ch in text[1:])
|
||||
if len(text) == 7 and text.startswith("#"):
|
||||
try:
|
||||
int(text[1:], 16)
|
||||
return text.lower()
|
||||
except ValueError:
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _parse_markup_shapes(raw_shapes: str | list | None) -> list[dict[str, object]]:
|
||||
if isinstance(raw_shapes, str):
|
||||
try:
|
||||
raw_shapes = json.loads(raw_shapes or "[]")
|
||||
except json.JSONDecodeError:
|
||||
raw_shapes = []
|
||||
|
||||
if not isinstance(raw_shapes, list):
|
||||
return []
|
||||
|
||||
parsed: list[dict[str, object]] = []
|
||||
for shape in raw_shapes:
|
||||
if not isinstance(shape, dict):
|
||||
continue
|
||||
|
||||
kind = str(shape.get("kind", "")).strip().lower()
|
||||
if kind not in {"line", "rectangle", "circle", "arrow"}:
|
||||
continue
|
||||
|
||||
try:
|
||||
x1 = float(shape.get("x1"))
|
||||
y1 = float(shape.get("y1"))
|
||||
x2 = float(shape.get("x2"))
|
||||
y2 = float(shape.get("y2"))
|
||||
width = int(round(float(shape.get("width", 3))))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
coords = [x1, y1, x2, y2]
|
||||
if not all(np.isfinite(value) for value in coords):
|
||||
continue
|
||||
|
||||
parsed.append({
|
||||
"kind": kind,
|
||||
"x1": float(np.clip(x1, 0.0, 1.0)),
|
||||
"y1": float(np.clip(y1, 0.0, 1.0)),
|
||||
"x2": float(np.clip(x2, 0.0, 1.0)),
|
||||
"y2": float(np.clip(y2, 0.0, 1.0)),
|
||||
"width": max(1, min(128, width)),
|
||||
"color": _normalize_markup_color(shape.get("color")),
|
||||
})
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int):
|
||||
dx = end[0] - start[0]
|
||||
dy = end[1] - start[1]
|
||||
length = float(np.hypot(dx, dy))
|
||||
if length <= 1e-6:
|
||||
radius = max(1.0, width / 2.0)
|
||||
draw.ellipse(
|
||||
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
|
||||
fill=color,
|
||||
)
|
||||
return
|
||||
|
||||
ux = dx / length
|
||||
uy = dy / length
|
||||
head_length = max(10.0, width * 4.0)
|
||||
head_width = max(8.0, width * 3.0)
|
||||
shaft_end = (
|
||||
end[0] - ux * head_length,
|
||||
end[1] - uy * head_length,
|
||||
)
|
||||
|
||||
draw.line((start, shaft_end), fill=color, width=width)
|
||||
|
||||
px = -uy
|
||||
py = ux
|
||||
left = (
|
||||
shaft_end[0] + px * head_width / 2.0,
|
||||
shaft_end[1] + py * head_width / 2.0,
|
||||
)
|
||||
right = (
|
||||
shaft_end[0] - px * head_width / 2.0,
|
||||
shaft_end[1] - py * head_width / 2.0,
|
||||
)
|
||||
draw.polygon([end, left, right], fill=color)
|
||||
|
||||
|
||||
def _render_markup_image(image: np.ndarray, shapes: list[dict[str, object]]) -> np.ndarray:
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
base = image_to_uint8(image)
|
||||
if base.ndim == 2:
|
||||
base = np.repeat(base[:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
canvas = Image.fromarray(base.copy())
|
||||
draw = ImageDraw.Draw(canvas)
|
||||
height, width = base.shape[:2]
|
||||
|
||||
for shape in shapes:
|
||||
x1 = float(shape["x1"]) * width
|
||||
y1 = float(shape["y1"]) * height
|
||||
x2 = float(shape["x2"]) * width
|
||||
y2 = float(shape["y2"]) * height
|
||||
color = str(shape["color"])
|
||||
stroke_width = int(shape["width"])
|
||||
kind = str(shape["kind"])
|
||||
|
||||
if kind == "line":
|
||||
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
|
||||
elif kind == "rectangle":
|
||||
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
|
||||
elif kind == "circle":
|
||||
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
|
||||
elif kind == "arrow":
|
||||
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
|
||||
|
||||
return np.asarray(canvas, dtype=np.uint8)
|
||||
|
||||
|
||||
@register_node(display_name="Color Map")
|
||||
class ColorMap:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mode": (["preset", "custom"], {"default": "preset"}),
|
||||
"preset": (list(COLORMAPS), {
|
||||
"default": "viridis",
|
||||
"show_when_widget_value": {"mode": ["preset"]},
|
||||
}),
|
||||
"stops": ("STRING", {
|
||||
"default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)),
|
||||
"colormap_stops": True,
|
||||
"show_when_widget_value": {"mode": ["custom"]},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("COLORMAP",)
|
||||
RETURN_NAMES = ("colormap",)
|
||||
FUNCTION = "build"
|
||||
CATEGORY = "display"
|
||||
DESCRIPTION = (
|
||||
"Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours "
|
||||
"and any number of intermediate stops."
|
||||
)
|
||||
|
||||
def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple:
|
||||
if mode == "preset":
|
||||
return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},)
|
||||
|
||||
try:
|
||||
raw_stops = stops if stops is not None else stops_json
|
||||
stops_data = json.loads(raw_stops or "[]")
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("Custom colormap stops must be valid JSON.") from exc
|
||||
|
||||
spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None)
|
||||
if not (isinstance(spec, dict) and spec.get("mode") == "custom"):
|
||||
raise ValueError("Custom colormap must include at least min and max colours.")
|
||||
return (spec,)
|
||||
|
||||
|
||||
@register_node(display_name="Font")
|
||||
class Font:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], {
|
||||
"default": SYSTEM_DEFAULT_FONT,
|
||||
}),
|
||||
"font_file": ("FILE_PICKER", {
|
||||
"default": "",
|
||||
"show_when_widget_value": {"family": [CUSTOM_FILE_FONT]},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FONT",)
|
||||
RETURN_NAMES = ("font",)
|
||||
FUNCTION = "build"
|
||||
CATEGORY = "display"
|
||||
DESCRIPTION = (
|
||||
"Build a reusable font spec for annotation overlays. Choose a discovered system font, "
|
||||
"use the default fallback stack, or point to a custom font file."
|
||||
)
|
||||
|
||||
def build(self, family: str, font_file: str = "") -> tuple:
|
||||
if family == SYSTEM_DEFAULT_FONT:
|
||||
return (None,)
|
||||
if family == CUSTOM_FILE_FONT:
|
||||
return (normalize_font_spec({"path": font_file}),)
|
||||
return (normalize_font_spec({"family": family}),)
|
||||
|
||||
|
||||
@register_node(display_name="Annotations")
|
||||
class Annotations:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
|
||||
"show_scale_bar": ("BOOLEAN", {"default": True}),
|
||||
"show_color_map": ("BOOLEAN", {"default": True}),
|
||||
"text_size": ("FLOAT", {
|
||||
"default": 14.0,
|
||||
"min": 6.0,
|
||||
"max": 96.0,
|
||||
"step": 1.0,
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"colormap_map": ("COLORMAP", {"label": "colormap"}),
|
||||
"font": ("FONT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("annotated",)
|
||||
FUNCTION = "render"
|
||||
CATEGORY = "display"
|
||||
DESCRIPTION = (
|
||||
"Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. "
|
||||
"The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values."
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
field: DataField,
|
||||
colormap: str,
|
||||
show_scale_bar: bool,
|
||||
show_color_map: bool,
|
||||
text_size: float = 1.0,
|
||||
colormap_map=None,
|
||||
font=None,
|
||||
) -> tuple:
|
||||
resolved_colormap = resolve_colormap_input(
|
||||
colormap,
|
||||
colormap_input=colormap_map,
|
||||
inherited=field.colormap,
|
||||
default="gray",
|
||||
)
|
||||
text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0
|
||||
out = field.replace(
|
||||
colormap=resolved_colormap,
|
||||
overlays=[
|
||||
*field.overlays,
|
||||
{
|
||||
"kind": "annotation",
|
||||
"show_scale_bar": bool(show_scale_bar),
|
||||
"show_color_map": bool(show_color_map),
|
||||
"text_size": text_size,
|
||||
"font": normalize_font_spec(font),
|
||||
},
|
||||
],
|
||||
)
|
||||
return (out,)
|
||||
|
||||
|
||||
@register_node(display_name="Markup")
|
||||
class Markup:
|
||||
_CUSTOM_PREVIEW = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}),
|
||||
"stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}),
|
||||
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),
|
||||
"clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}),
|
||||
"markup_shapes": ("STRING", {"default": "[]", "hidden": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("annotated",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "display"
|
||||
DESCRIPTION = (
|
||||
"Draw simple vector markup over a DATA_FIELD without flattening the underlying data. "
|
||||
"Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows."
|
||||
)
|
||||
|
||||
_broadcast_overlay_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
shape: str,
|
||||
stroke_color: str,
|
||||
stroke_width: int,
|
||||
markup_shapes: str,
|
||||
) -> tuple:
|
||||
shapes = _parse_markup_shapes(markup_shapes)
|
||||
out = field.replace(
|
||||
overlays=[
|
||||
*field.overlays,
|
||||
{
|
||||
"kind": "markup",
|
||||
"shapes": shapes,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
if Markup._broadcast_overlay_fn is not None:
|
||||
Markup._broadcast_overlay_fn(
|
||||
Markup._current_node_id,
|
||||
{
|
||||
"kind": "markup",
|
||||
"section_title": "Markup",
|
||||
"image": encode_preview(datafield_to_uint8(field, field.colormap)),
|
||||
"shape": str(shape),
|
||||
"stroke_color": _normalize_markup_color(stroke_color),
|
||||
"stroke_width": max(1, int(stroke_width)),
|
||||
},
|
||||
)
|
||||
|
||||
return (out,)
|
||||
|
||||
|
||||
@register_node(display_name="Preview")
|
||||
class PreviewImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"colormap": (["auto"] + list(COLORMAPS),),
|
||||
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
|
||||
},
|
||||
"optional": {
|
||||
"colormap_map": ("COLORMAP", {"label": "colormap"}),
|
||||
"image": ("IMAGE",),
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
@@ -82,30 +556,35 @@ class PreviewImage:
|
||||
_broadcast_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
|
||||
# Resolve "auto" — use field's colormap if available, else fall back to gray
|
||||
if colormap == "auto":
|
||||
colormap = field.colormap if field is not None else "gray"
|
||||
def preview(
|
||||
self,
|
||||
colormap: str,
|
||||
image: np.ndarray | None = None,
|
||||
field=None,
|
||||
colormap_map=None,
|
||||
) -> tuple:
|
||||
resolved_colormap = resolve_colormap_input(
|
||||
colormap,
|
||||
colormap_input=colormap_map,
|
||||
inherited=field.colormap if field is not None else None,
|
||||
default="gray",
|
||||
)
|
||||
|
||||
# Prefer field if both are connected; accept whichever is provided
|
||||
if field is not None:
|
||||
arr_u8 = datafield_to_uint8(field, colormap)
|
||||
arr_u8 = render_datafield_preview(field, resolved_colormap)
|
||||
elif image is not None:
|
||||
if image.dtype != np.uint8:
|
||||
imin, imax = image.min(), image.max()
|
||||
if imax > imin:
|
||||
norm = (image - imin) / (imax - imin)
|
||||
arr_u8 = image_to_uint8(image)
|
||||
if arr_u8.ndim == 2:
|
||||
if image.dtype == np.uint8:
|
||||
normalized = arr_u8.astype(np.float64) / 255.0
|
||||
else:
|
||||
norm = np.zeros_like(image)
|
||||
arr_u8 = (norm * 255).astype(np.uint8)
|
||||
else:
|
||||
arr_u8 = image
|
||||
|
||||
if arr_u8.ndim == 2 and colormap != "gray":
|
||||
import matplotlib.cm as cm
|
||||
cmap = cm.get_cmap(colormap)
|
||||
rgba = cmap(arr_u8.astype(np.float32) / 255.0)
|
||||
arr_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
imin, imax = image.min(), image.max()
|
||||
if imax > imin:
|
||||
normalized = (image - imin) / (imax - imin)
|
||||
else:
|
||||
normalized = np.zeros_like(image, dtype=np.float64)
|
||||
arr_u8 = colormap_to_uint8(normalized, resolved_colormap)
|
||||
else:
|
||||
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
|
||||
|
||||
@@ -124,10 +603,13 @@ class View3D:
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"colormap": (["auto"] + list(COLORMAPS),),
|
||||
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
|
||||
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
|
||||
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
|
||||
}
|
||||
},
|
||||
"optional": {
|
||||
"colormap_map": ("COLORMAP", {"label": "colormap"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
@@ -144,9 +626,8 @@ class View3D:
|
||||
|
||||
def render(
|
||||
self, field: DataField,
|
||||
colormap: str, z_scale: float, resolution: int,
|
||||
colormap: str, z_scale: float, resolution: int, colormap_map=None,
|
||||
) -> tuple:
|
||||
import matplotlib.cm as cm
|
||||
import base64
|
||||
|
||||
data = field.data
|
||||
@@ -168,10 +649,13 @@ class View3D:
|
||||
data_max=float(field.data.max()),
|
||||
)
|
||||
|
||||
cmap_name = field.colormap if colormap == "auto" else colormap
|
||||
cmap = cm.get_cmap(cmap_name)
|
||||
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
|
||||
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
resolved_colormap = resolve_colormap_input(
|
||||
colormap,
|
||||
colormap_input=colormap_map,
|
||||
inherited=field.colormap,
|
||||
default="gray",
|
||||
)
|
||||
colors_u8 = colormap_to_uint8(z_norm, resolved_colormap)
|
||||
|
||||
# Base64-encode arrays for efficient WS transport
|
||||
z_b64 = base64.b64encode(z.tobytes()).decode()
|
||||
|
||||
@@ -4,11 +4,12 @@ I/O nodes: load and save images and SPM data.
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, COLORMAPS, encode_preview, image_to_uint8
|
||||
from backend.data_types import COLORMAPS, DataField, encode_preview, image_to_uint8, resolve_colormap_input
|
||||
from backend.runtime_paths import demo_dir, input_dir, output_dir
|
||||
|
||||
# Resolved at server startup so nodes know where to look
|
||||
@@ -22,6 +23,7 @@ _DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
|
||||
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
|
||||
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
|
||||
_ARRAY_EXTENSIONS = {".npy", ".npz"}
|
||||
_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -105,6 +107,23 @@ def list_channels(filepath: str) -> list[dict]:
|
||||
return [{"name": "field", "type": "DATA_FIELD"}]
|
||||
|
||||
|
||||
def list_folder_paths(folderpath: str) -> list[dict]:
|
||||
"""Return a folder DIRECTORY plus compatible image/array/SPM FILE_PATH outputs."""
|
||||
path = _resolve_path(folderpath)
|
||||
if not path.exists() or not path.is_dir():
|
||||
return []
|
||||
|
||||
resolved_dir = str(path.resolve())
|
||||
results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}]
|
||||
for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()):
|
||||
if not entry.is_file() or entry.name.startswith("."):
|
||||
continue
|
||||
if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS:
|
||||
continue
|
||||
results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())})
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadFile (unified loader — replaces LoadImage + LoadSPM)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -115,9 +134,13 @@ class LoadFile:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("FILE_PICKER", {"default": ""}),
|
||||
"colormap": (list(COLORMAPS),),
|
||||
}
|
||||
"filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}),
|
||||
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
|
||||
},
|
||||
"optional": {
|
||||
"colormap_map": ("COLORMAP", {"label": "colormap"}),
|
||||
"path": ("FILE_PATH", {"label": "path"}),
|
||||
},
|
||||
}
|
||||
|
||||
# Default outputs — overridden dynamically by the frontend for multi-channel files
|
||||
@@ -136,26 +159,28 @@ class LoadFile:
|
||||
_broadcast_warning_fn = None
|
||||
_current_node_id = None
|
||||
|
||||
def load(self, filename: str, colormap: str = "viridis"):
|
||||
if not filename or not filename.strip():
|
||||
def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None):
|
||||
selected_path = str(path).strip() if path is not None else str(filename).strip()
|
||||
if not selected_path:
|
||||
raise ValueError("No file selected — use Browse to pick a file.")
|
||||
path = _resolve_path(filename)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
if path.is_dir():
|
||||
raise IsADirectoryError(f"Expected a file, got a directory: {path}")
|
||||
path_obj = _resolve_path(selected_path)
|
||||
if not path_obj.exists():
|
||||
raise FileNotFoundError(f"File not found: {path_obj}")
|
||||
if path_obj.is_dir():
|
||||
raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
ext = path_obj.suffix.lower()
|
||||
resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis")
|
||||
|
||||
if ext in _SPM_EXTENSIONS:
|
||||
fields = self._load_spm_all(path, ext)
|
||||
fields = self._load_spm_all(path_obj, ext)
|
||||
for f in fields:
|
||||
f.colormap = colormap
|
||||
f.colormap = resolved_colormap
|
||||
return tuple(fields)
|
||||
|
||||
# Image or array — uncalibrated, single output
|
||||
field = self._load_image_or_array(path, ext)
|
||||
field.colormap = colormap
|
||||
field = self._load_image_or_array(path_obj, ext)
|
||||
field.colormap = resolved_colormap
|
||||
self._send_warning("Uncalibrated data — no physical dimensions.")
|
||||
return (field,)
|
||||
|
||||
@@ -349,8 +374,11 @@ class LoadDemo:
|
||||
return {
|
||||
"required": {
|
||||
"name": (choices,),
|
||||
"colormap": (list(COLORMAPS),),
|
||||
}
|
||||
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
|
||||
},
|
||||
"optional": {
|
||||
"colormap_map": ("COLORMAP", {"label": "colormap"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
@@ -359,13 +387,38 @@ class LoadDemo:
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
|
||||
|
||||
def load(self, name: str, colormap: str = "viridis"):
|
||||
path = DEMO_DIR / name
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Demo file not found: {name}")
|
||||
|
||||
def load(self, name: str = "", colormap: str = "viridis", colormap_map=None):
|
||||
loader = LoadFile()
|
||||
return loader.load(filename=str(path), colormap=colormap)
|
||||
demo_path = DEMO_DIR / name
|
||||
if not demo_path.exists():
|
||||
raise FileNotFoundError(f"Demo file not found: {name}")
|
||||
return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map)
|
||||
|
||||
|
||||
@register_node(display_name="Folder")
|
||||
class Folder:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DIRECTORY",)
|
||||
RETURN_NAMES = ("directory",)
|
||||
FUNCTION = "list_files"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = (
|
||||
"Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. "
|
||||
"Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans."
|
||||
)
|
||||
|
||||
def list_files(self, folder: str) -> tuple:
|
||||
entries = list_folder_paths(folder)
|
||||
if not entries:
|
||||
return tuple()
|
||||
return tuple(item["path"] for item in entries)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -395,6 +448,36 @@ class Coordinate:
|
||||
return ((float(x), float(y)),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Number
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Number")
|
||||
class Number:
|
||||
"""Provide a fixed scalar value that can feed FLOAT or INT widget sockets."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = (
|
||||
"Output a fixed numeric value. "
|
||||
"When connected to FLOAT inputs the exact value is used; "
|
||||
"INT inputs round to the nearest integer at execution time."
|
||||
)
|
||||
|
||||
def process(self, value: float) -> tuple:
|
||||
return (float(value),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RangeSlider
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -445,12 +528,32 @@ _MAX_SAVE_FIELDS = 8
|
||||
class SaveImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
optional = {}
|
||||
optional = {
|
||||
"directory": ("DIRECTORY", {"label": "directory"}),
|
||||
}
|
||||
for i in range(_MAX_SAVE_FIELDS):
|
||||
optional[f"field_{i}"] = ("DATA_FIELD",)
|
||||
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"})
|
||||
optional[f"layer_name_{i}"] = ("STRING", {
|
||||
"default": "",
|
||||
"placeholder": "name",
|
||||
"show_when_input_visible": f"field_{i}",
|
||||
"inline_with_input": f"field_{i}",
|
||||
"hide_label": True,
|
||||
})
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("FILE_PICKER", {"default": ""}),
|
||||
"filename": ("STRING", {
|
||||
"default": "",
|
||||
"placeholder": "filename",
|
||||
"placement": "top",
|
||||
}),
|
||||
"directory_path": ("FOLDER_PICKER", {
|
||||
"default": "",
|
||||
"label": "directory",
|
||||
"placement": "top",
|
||||
"hide_when_input_connected": "directory",
|
||||
"top_socket_input": "directory",
|
||||
}),
|
||||
"format": (["TIFF", "NPZ"],),
|
||||
},
|
||||
"optional": optional,
|
||||
@@ -462,59 +565,130 @@ class SaveImage:
|
||||
OUTPUT_NODE = True
|
||||
MANUAL_TRIGGER = True
|
||||
DESCRIPTION = (
|
||||
"Save one or more DATA_FIELD layers to a single file. "
|
||||
"Connect fields to the inputs — a new slot appears as each is filled. "
|
||||
"TIFF writes float32 multi-page; NPZ writes float64 named arrays. "
|
||||
"Save one or more layers to a single file. "
|
||||
"Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. "
|
||||
"Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. "
|
||||
"A new slot appears as each one is filled, with a matching per-layer name field. "
|
||||
"TIFF writes multi-page data and stores layer names as page descriptions; "
|
||||
"NPZ writes named arrays using those layer names as keys. "
|
||||
"Click Save to write (does not auto-run)."
|
||||
)
|
||||
|
||||
_broadcast_warning_fn = None
|
||||
_current_node_id = None
|
||||
|
||||
def save(self, filename: str, format: str = "TIFF", **kwargs):
|
||||
# Collect connected fields in order
|
||||
fields = []
|
||||
def save(
|
||||
self,
|
||||
filename: str,
|
||||
directory_path: str = "",
|
||||
format: str = "TIFF",
|
||||
directory: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
layers = []
|
||||
layer_names = []
|
||||
for i in range(_MAX_SAVE_FIELDS):
|
||||
f = kwargs.get(f"field_{i}")
|
||||
if f is not None:
|
||||
fields.append(f)
|
||||
layer = kwargs.get(f"field_{i}")
|
||||
if layer is not None:
|
||||
layers.append(layer)
|
||||
layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i))
|
||||
|
||||
if not fields:
|
||||
raise ValueError("No fields connected — connect at least one DATA_FIELD input.")
|
||||
if not layers:
|
||||
raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.")
|
||||
|
||||
if not filename or not filename.strip():
|
||||
raise ValueError("No output path selected — use Browse to pick a location.")
|
||||
|
||||
path = Path(filename)
|
||||
# Ensure parent directory exists
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Force correct extension
|
||||
ext = ".tiff" if format == "TIFF" else ".npz"
|
||||
if path.suffix.lower() != ext:
|
||||
path = path.with_suffix(ext)
|
||||
path = self._resolve_save_path(filename, format, directory, directory_path)
|
||||
|
||||
if format == "TIFF":
|
||||
self._save_tiff(path, fields)
|
||||
self._save_tiff(path, layers, layer_names)
|
||||
else:
|
||||
self._save_npz(path, fields)
|
||||
self._save_npz(path, layers, layer_names)
|
||||
|
||||
self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}")
|
||||
self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}")
|
||||
return ()
|
||||
|
||||
def _save_tiff(self, path: Path, fields: list[DataField]):
|
||||
from PIL import Image
|
||||
images = []
|
||||
for f in fields:
|
||||
images.append(Image.fromarray(f.data.astype(np.float32)))
|
||||
images[0].save(str(path), save_all=True, append_images=images[1:])
|
||||
def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
|
||||
import tifffile
|
||||
|
||||
def _save_npz(self, path: Path, fields: list[DataField]):
|
||||
with tifffile.TiffWriter(str(path)) as tif:
|
||||
for layer, layer_name in zip(layers, layer_names):
|
||||
tif.write(self._layer_array_for_tiff(layer), description=layer_name)
|
||||
|
||||
def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
|
||||
arrays = {}
|
||||
for i, f in enumerate(fields):
|
||||
arrays[f"layer_{i}"] = f.data
|
||||
used_keys = set()
|
||||
for i, (layer, layer_name) in enumerate(zip(layers, layer_names)):
|
||||
arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer)
|
||||
np.savez(str(path), **arrays)
|
||||
|
||||
def _resolve_layer_name(self, raw_name: object, index: int) -> str:
|
||||
text = str(raw_name).strip() if raw_name is not None else ""
|
||||
return text or f"layer_{index}"
|
||||
|
||||
def _resolve_save_path(
|
||||
self,
|
||||
filename: str,
|
||||
format: str,
|
||||
directory: str | None,
|
||||
directory_path: str = "",
|
||||
) -> Path:
|
||||
ext = ".tiff" if format == "TIFF" else ".npz"
|
||||
raw_filename = str(filename).strip() if filename is not None else ""
|
||||
raw_directory = str(directory).strip() if directory is not None else ""
|
||||
if not raw_directory:
|
||||
raw_directory = str(directory_path).strip() if directory_path is not None else ""
|
||||
|
||||
if raw_directory:
|
||||
dir_path = Path(raw_directory).expanduser()
|
||||
if dir_path.exists() and not dir_path.is_dir():
|
||||
raise ValueError("Directory input expects a folder path, not a file path.")
|
||||
if not dir_path.exists():
|
||||
if dir_path.suffix:
|
||||
raise ValueError("Directory input expects a folder path, not a file path.")
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
filename_part = Path(raw_filename).name if raw_filename else ""
|
||||
if not filename_part:
|
||||
raise ValueError("No output filename selected — enter a file name when using a directory input.")
|
||||
path = dir_path / filename_part
|
||||
else:
|
||||
if not raw_filename:
|
||||
raise ValueError("No output path selected — use Browse to pick a location.")
|
||||
path = Path(raw_filename).expanduser()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if path.suffix.lower() != ext:
|
||||
path = path.with_suffix(ext)
|
||||
return path
|
||||
|
||||
def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str:
|
||||
key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_")
|
||||
if not key:
|
||||
key = f"layer_{index}"
|
||||
if key[0].isdigit():
|
||||
key = f"layer_{key}"
|
||||
|
||||
candidate = key
|
||||
suffix = 2
|
||||
while candidate in used_keys:
|
||||
candidate = f"{key}_{suffix}"
|
||||
suffix += 1
|
||||
used_keys.add(candidate)
|
||||
return candidate
|
||||
|
||||
def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray:
|
||||
if isinstance(layer, DataField):
|
||||
return np.asarray(layer.data, dtype=np.float32)
|
||||
if isinstance(layer, np.ndarray):
|
||||
return image_to_uint8(layer)
|
||||
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
|
||||
|
||||
def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray:
|
||||
if isinstance(layer, DataField):
|
||||
return np.asarray(layer.data)
|
||||
if isinstance(layer, np.ndarray):
|
||||
return np.asarray(layer)
|
||||
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
|
||||
|
||||
def _send_warning(self, message: str):
|
||||
fn = SaveImage._broadcast_warning_fn
|
||||
nid = SaveImage._current_node_id
|
||||
|
||||
@@ -143,6 +143,7 @@ class CropResizeField:
|
||||
yreal=(py1 - py0) * field.dy,
|
||||
xoff=field.xoff + px0 * field.dx,
|
||||
yoff=field.yoff + py0 * field.dy,
|
||||
overlays=[],
|
||||
)
|
||||
|
||||
target_width, target_height = self._resolve_target_shape(
|
||||
@@ -217,6 +218,9 @@ class RotateField:
|
||||
"Optionally expand the canvas to keep the full rotated field while preserving the field center."
|
||||
)
|
||||
|
||||
_broadcast_warning_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
@@ -224,6 +228,9 @@ class RotateField:
|
||||
interpolation: str,
|
||||
expand_canvas: bool,
|
||||
) -> tuple:
|
||||
if field.overlays:
|
||||
self._send_warning("Rotate clears annotation/markup overlays!")
|
||||
|
||||
angle = float(angle)
|
||||
order_map = {
|
||||
"nearest": 0,
|
||||
@@ -264,9 +271,16 @@ class RotateField:
|
||||
yreal=new_yreal,
|
||||
xoff=center_x - new_xreal / 2.0,
|
||||
yoff=center_y - new_yreal / 2.0,
|
||||
overlays=[],
|
||||
)
|
||||
return (result,)
|
||||
|
||||
def _send_warning(self, message: str):
|
||||
fn = RotateField._broadcast_warning_fn
|
||||
nid = RotateField._current_node_id
|
||||
if fn and nid:
|
||||
fn(nid, message)
|
||||
|
||||
@staticmethod
|
||||
def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]:
|
||||
if not expand_canvas:
|
||||
|
||||
@@ -215,6 +215,13 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
async def get_folder_files(request: web.Request) -> web.Response:
|
||||
folder_path = request.query.get("folder", "")
|
||||
from backend.nodes.io import list_folder_paths
|
||||
loop = asyncio.get_running_loop()
|
||||
entries = await loop.run_in_executor(None, list_folder_paths, folder_path)
|
||||
return web.Response(text=_dumps(entries), content_type="application/json")
|
||||
|
||||
async def upload_file(request: web.Request) -> web.Response:
|
||||
reader = await request.multipart()
|
||||
field = await reader.next()
|
||||
@@ -346,6 +353,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
app.router.add_get("/nodes", get_nodes)
|
||||
app.router.add_get("/files", list_files)
|
||||
app.router.add_get("/browse", browse_dir)
|
||||
app.router.add_get("/folder-files", get_folder_files)
|
||||
app.router.add_post("/upload", upload_file)
|
||||
app.router.add_post("/download", download_file)
|
||||
app.router.add_post("/save-workflow-png", save_workflow_png)
|
||||
|
||||
Reference in New Issue
Block a user