diff --git a/asdf.tiff b/asdf.tiff new file mode 100644 index 0000000..351c96d Binary files /dev/null and b/asdf.tiff differ diff --git a/backend/data_types.py b/backend/data_types.py index e65d77a..cbbdee2 100644 --- a/backend/data_types.py +++ b/backend/data_types.py @@ -12,13 +12,23 @@ DataField mirrors Gwyddion's GwyDataField structure: from __future__ import annotations +from copy import deepcopy from dataclasses import dataclass, field from functools import lru_cache +from pathlib import Path +from typing import Any import numpy as np COLORMAPS = ("viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain", "cividis", "magma", "copper", "afmhot") +DEFAULT_CUSTOM_COLORMAP_STOPS = ( + {"position": 0.0, "color": "#440154"}, + {"position": 1.0, "color": "#fde725"}, +) +SYSTEM_DEFAULT_FONT = "System Default" +CUSTOM_FILE_FONT = "Custom File" +PREVIEW_MARKUP_REFERENCE_DIM = 512 class RecordTable(list): @@ -28,6 +38,147 @@ class RecordTable(list): class MeasureTable(list): """Named scalar measurements, typically rows of quantity/value/unit.""" + +def _normalize_hex_color(color: Any, default: str = "#000000") -> str: + if isinstance(color, str): + text = color.strip() + if len(text) == 4 and text.startswith("#"): + text = "#" + "".join(ch * 2 for ch in text[1:]) + if len(text) == 7 and text.startswith("#"): + try: + int(text[1:], 16) + return text.lower() + except ValueError: + pass + return default + + +def _hex_to_rgb01(color: str) -> tuple[float, float, float]: + normalized = _normalize_hex_color(color) + return ( + int(normalized[1:3], 16) / 255.0, + int(normalized[3:5], 16) / 255.0, + int(normalized[5:7], 16) / 255.0, + ) + + +def _normalize_custom_colormap_stops(raw_stops: Any) -> list[dict[str, float | str]] | None: + if not isinstance(raw_stops, list): + return None + + parsed = [] + for stop in raw_stops: + if not isinstance(stop, dict): + continue + try: + position = float(stop.get("position")) + except (TypeError, ValueError): + continue + if not np.isfinite(position): + continue + parsed.append({ + "position": float(np.clip(position, 0.0, 1.0)), + "color": _normalize_hex_color(stop.get("color"), "#000000"), + }) + + if len(parsed) < 2: + return None + + parsed.sort(key=lambda stop: stop["position"]) + unique: list[dict[str, float | str]] = [] + for stop in parsed: + if unique and abs(float(stop["position"]) - float(unique[-1]["position"])) < 1e-9: + unique[-1] = stop + else: + unique.append(stop) + + if len(unique) < 2: + return None + + unique[0] = {"position": 0.0, "color": unique[0]["color"]} + unique[-1] = {"position": 1.0, "color": unique[-1]["color"]} + return unique + + +def normalize_colormap_spec(colormap: Any, fallback: Any = "viridis") -> str | dict[str, Any]: + if isinstance(colormap, str): + if colormap in COLORMAPS: + return colormap + elif isinstance(colormap, dict): + mode = str(colormap.get("mode", "")).strip().lower() + preset = colormap.get("preset") + if isinstance(preset, str) and preset in COLORMAPS: + if mode == "preset" or "stops" not in colormap: + return {"mode": "preset", "preset": preset} + stops = _normalize_custom_colormap_stops(colormap.get("stops")) + if stops is not None: + return {"mode": "custom", "stops": stops} + + if fallback is colormap: + return "viridis" + return normalize_colormap_spec(fallback, fallback="viridis") + + +def resolve_colormap_input( + selection: Any = "auto", + *, + colormap_input: Any = None, + inherited: Any = None, + default: Any = "gray", +) -> str | dict[str, Any]: + if colormap_input is not None: + return normalize_colormap_spec(colormap_input, fallback=inherited if inherited is not None else default) + + if isinstance(selection, str) and selection != "auto": + return normalize_colormap_spec(selection, fallback=inherited if inherited is not None else default) + + if inherited is not None: + return normalize_colormap_spec(inherited, fallback=default) + + return normalize_colormap_spec(default, fallback="gray") + + +def normalize_font_spec(font: Any) -> dict[str, str] | None: + if isinstance(font, str): + family = font.strip() + return {"family": family, "path": ""} if family else None + + if isinstance(font, dict): + family = str(font.get("family", "")).strip() + path = str(font.get("path", "")).strip() + if family or path: + return {"family": family, "path": path} + + return None + + +def colormap_to_uint8(normalized: np.ndarray, colormap: Any = "gray") -> np.ndarray: + normalized = np.asarray(normalized, dtype=np.float64) + spec = normalize_colormap_spec(colormap, fallback="gray") + + if spec == "gray": + grey = np.rint(normalized * 255.0).astype(np.uint8) + rgb = np.empty(grey.shape + (3,), dtype=np.uint8) + rgb[..., 0] = grey + rgb[..., 1] = grey + rgb[..., 2] = grey + return rgb + + if isinstance(spec, dict) and spec.get("mode") == "custom": + stops = spec["stops"] + positions = np.array([float(stop["position"]) for stop in stops], dtype=np.float64) + colors = np.array([_hex_to_rgb01(str(stop["color"])) for stop in stops], dtype=np.float64) + flat = normalized.reshape(-1) + rgb = np.empty((flat.shape[0], 3), dtype=np.uint8) + for channel in range(3): + rgb[:, channel] = np.rint(np.interp(flat, positions, colors[:, channel]) * 255.0).astype(np.uint8) + return rgb.reshape(normalized.shape + (3,)) + + cmap_name = spec["preset"] if isinstance(spec, dict) else spec + cmap = _get_colormap(cmap_name) + rgba = cmap(normalized) + return (rgba[:, :, :3] * 255).astype(np.uint8) + @dataclass class DataField: data: np.ndarray # shape (yres, xres), dtype float64 @@ -40,15 +191,17 @@ class DataField: si_unit_xy: str = "m" si_unit_z: str = "m" domain: str = "spatial" # "spatial" or "frequency" - colormap: str = "viridis" + colormap: str | dict[str, Any] = "viridis" display_offset: float = 0.0 display_scale: float = 1.0 + overlays: list[dict[str, Any]] = field(default_factory=list) def __post_init__(self) -> None: self.data = np.asarray(self.data, dtype=np.float64) if self.data.ndim != 2: raise ValueError(f"DataField.data must be 2-D, got shape {self.data.shape}") self.yres, self.xres = self.data.shape + self.overlays = deepcopy(self.overlays) if isinstance(self.overlays, list) else [] def copy(self) -> "DataField": """Return a deep copy with independent data array.""" @@ -66,6 +219,7 @@ class DataField: colormap=self.colormap, display_offset=self.display_offset, display_scale=self.display_scale, + overlays=deepcopy(self.overlays), ) def replace(self, **kwargs) -> "DataField": @@ -84,6 +238,7 @@ class DataField: "colormap": self.colormap, "display_offset": self.display_offset, "display_scale": self.display_scale, + "overlays": deepcopy(self.overlays), } base.update(kwargs) return DataField(**base) @@ -137,7 +292,7 @@ def normalize_for_colormap( return np.clip((base_norm - offset) / scale, 0.0, 1.0) -def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray: +def datafield_to_uint8(df: DataField, colormap: Any = "gray") -> np.ndarray: """ Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap. Returns shape (H, W, 3) uint8. @@ -147,19 +302,447 @@ def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray: offset=df.display_offset, scale=df.display_scale, ) + return colormap_to_uint8(normalized, colormap) - if colormap == "gray": - grey = np.rint(normalized * 255.0).astype(np.uint8) - rgb = np.empty(grey.shape + (3,), dtype=np.uint8) - rgb[..., 0] = grey - rgb[..., 1] = grey - rgb[..., 2] = grey - return rgb - cmap = _get_colormap(colormap) - rgba = cmap(normalized) # (H, W, 4) float [0,1] - rgb = (rgba[:, :, :3] * 255).astype(np.uint8) - return rgb +_SI_PREFIXES = [ + (1e24, "Y"), + (1e21, "Z"), + (1e18, "E"), + (1e15, "P"), + (1e12, "T"), + (1e9, "G"), + (1e6, "M"), + (1e3, "k"), + (1.0, ""), + (1e-3, "m"), + (1e-6, "u"), + (1e-9, "n"), + (1e-12, "p"), + (1e-15, "f"), + (1e-18, "a"), + (1e-21, "z"), + (1e-24, "y"), +] +_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"} + + +def _format_numeric(value: float) -> str: + if not np.isfinite(value): + return str(value) + abs_value = abs(value) + if abs_value == 0: + return "0" + if abs_value >= 1e4 or abs_value < 1e-3: + return f"{value:.3e}" + return f"{value:.4g}" + + +def _format_with_unit(value: float, unit: str) -> str: + unit = (unit or "").strip() + if not unit: + return _format_numeric(value) + if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0: + abs_value = abs(value) + for scale, prefix in _SI_PREFIXES: + scaled = abs_value / scale + if 1 <= scaled < 1000: + signed = value / scale + return f"{_format_numeric(signed)} {prefix}{unit}" + return f"{_format_numeric(value)} {unit}" + + +def _nice_length(target: float) -> float: + if not np.isfinite(target) or target <= 0: + return 0.0 + exponent = np.floor(np.log10(target)) + base = 10.0 ** exponent + for step in (5.0, 2.0, 1.0): + candidate = step * base + if candidate <= target: + return candidate + return base + + +def _display_value_range(field: DataField) -> tuple[float, float, float]: + data = np.asarray(field.data, dtype=np.float64) + dmin = float(data.min()) + dmax = float(data.max()) + if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin: + return dmin, dmin, dmin + + offset = float(field.display_offset) + scale = float(field.display_scale) + if not np.isfinite(offset): + offset = 0.0 + if not np.isfinite(scale) or scale <= 0.0: + scale = 1.0 + + low_norm = float(np.clip(offset, 0.0, 1.0)) + high_norm = float(np.clip(offset + scale, 0.0, 1.0)) + if high_norm < low_norm: + low_norm, high_norm = high_norm, low_norm + mid_norm = 0.5 * (low_norm + high_norm) + + span = dmax - dmin + return ( + dmin + low_norm * span, + dmin + mid_norm * span, + dmin + high_norm * span, + ) + + +def _normalize_font_match_key(text: str) -> str: + return "".join(ch for ch in text.lower() if ch.isalnum()) + + +def _render_overlay_text( + text: str, + size_px: int, + color: tuple[int, int, int], + font_spec: Any = None, +): + from PIL import Image, ImageDraw, ImageFont + + size_px = max(8, int(round(size_px))) + font = _load_overlay_font(size_px, ImageFont, font_spec=font_spec) + if font is not None: + probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0)) + probe_draw = ImageDraw.Draw(probe) + bbox = probe_draw.textbbox((0, 0), text, font=font) + width = max(1, bbox[2] - bbox[0]) + height = max(1, bbox[3] - bbox[1]) + text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + text_draw = ImageDraw.Draw(text_image) + text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255)) + return text_image + + font = ImageFont.load_default() + probe = Image.new("L", (1, 1), 0) + probe_draw = ImageDraw.Draw(probe) + bbox = probe_draw.textbbox((0, 0), text, font=font) + width = max(1, bbox[2] - bbox[0]) + height = max(1, bbox[3] - bbox[1]) + mask = Image.new("L", (width, height), 0) + mask_draw = ImageDraw.Draw(mask) + mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255) + scale = max(1.0, size_px / max(1, height)) + scaled_width = max(1, int(round(width * scale))) + scaled_height = max(1, int(round(height * scale))) + resampling = getattr(Image, "Resampling", Image) + # Preserve edge sharpness if we ever have to scale the bitmap fallback font. + scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.NEAREST) + text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0)) + text_image.putalpha(scaled_mask) + return text_image + + +@lru_cache(maxsize=1) +def _overlay_font_candidates() -> tuple[str, ...]: + candidates: list[str] = [] + + try: + import PIL + + pil_dir = Path(PIL.__file__).resolve().parent + candidates.extend([ + str(pil_dir / "fonts" / "DejaVuSans.ttf"), + str(pil_dir / "Tests" / "fonts" / "DejaVuSans.ttf"), + ]) + except Exception: + pass + + candidates.extend([ + "/System/Library/Fonts/Supplemental/Arial.ttf", + "/System/Library/Fonts/Supplemental/Helvetica.ttc", + "/System/Library/Fonts/Supplemental/Times New Roman.ttf", + "/Library/Fonts/Arial.ttf", + "/Library/Fonts/Helvetica.ttc", + "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", + "/usr/share/fonts/truetype/liberation2/LiberationSans-Regular.ttf", + ]) + + unique: list[str] = [] + for candidate in candidates: + if candidate not in unique and Path(candidate).exists(): + unique.append(candidate) + return tuple(unique) + + +def list_overlay_font_choices() -> tuple[str, ...]: + labels: list[str] = [] + for candidate in _overlay_font_candidates(): + label = Path(candidate).stem + if label and label not in labels: + labels.append(label) + return tuple(labels) + + +def _matching_overlay_font_candidates(family: str) -> list[str]: + key = _normalize_font_match_key(family) + if not key: + return [] + + matches: list[str] = [] + for candidate in _overlay_font_candidates(): + stem_key = _normalize_font_match_key(Path(candidate).stem) + if key == stem_key or key in stem_key or stem_key in key: + matches.append(candidate) + return matches + + +def _load_overlay_font(size_px: int, image_font_module, font_spec: Any = None) -> Any: + normalized = normalize_font_spec(font_spec) + + if normalized is not None: + requested_path = normalized.get("path", "") + requested_family = normalized.get("family", "") + + if requested_path: + try: + return image_font_module.truetype(requested_path, size_px) + except Exception: + pass + + if requested_family: + try: + return image_font_module.truetype(requested_family, size_px) + except Exception: + pass + + for candidate in _matching_overlay_font_candidates(requested_family): + try: + return image_font_module.truetype(candidate, size_px) + except Exception: + continue + + for candidate in _overlay_font_candidates(): + try: + return image_font_module.truetype(candidate, size_px) + except Exception: + continue + + try: + return image_font_module.truetype("DejaVuSans.ttf", size_px) + except Exception: + return None + + +def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str: + if isinstance(color, str): + text = color.strip() + if len(text) == 4 and text.startswith("#"): + text = "#" + "".join(ch * 2 for ch in text[1:]) + if len(text) == 7 and text.startswith("#"): + try: + int(text[1:], 16) + return text.lower() + except ValueError: + pass + return default + + +def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int): + dx = end[0] - start[0] + dy = end[1] - start[1] + length = float(np.hypot(dx, dy)) + if length <= 1e-6: + radius = max(1.0, width / 2.0) + draw.ellipse( + (start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius), + fill=color, + ) + return + + ux = dx / length + uy = dy / length + head_length = max(10.0, width * 4.0) + head_width = max(8.0, width * 3.0) + shaft_end = ( + end[0] - ux * head_length, + end[1] - uy * head_length, + ) + draw.line((start, shaft_end), fill=color, width=width) + px = -uy + py = ux + left = (shaft_end[0] + px * head_width / 2.0, shaft_end[1] + py * head_width / 2.0) + right = (shaft_end[0] - px * head_width / 2.0, shaft_end[1] - py * head_width / 2.0) + draw.polygon([end, left, right], fill=color) + + +def _preview_markup_stroke_width(width: int, image_width: int, image_height: int) -> int: + width = max(1, int(width)) + longest_dim = max(1, int(image_width), int(image_height)) + scale = max(1.0, longest_dim / float(PREVIEW_MARKUP_REFERENCE_DIM)) + return max(1, int(round(width * scale))) + + +def _sanitize_markup_shapes(shapes: Any) -> list[dict[str, Any]]: + if not isinstance(shapes, list): + return [] + + parsed = [] + for shape in shapes: + if not isinstance(shape, dict): + continue + kind = str(shape.get("kind", "")).strip().lower() + if kind not in {"line", "rectangle", "circle", "arrow"}: + continue + try: + x1 = float(shape.get("x1")) + y1 = float(shape.get("y1")) + x2 = float(shape.get("x2")) + y2 = float(shape.get("y2")) + width = int(round(float(shape.get("width", 3)))) + except (TypeError, ValueError): + continue + if not all(np.isfinite(value) for value in (x1, y1, x2, y2)): + continue + parsed.append({ + "kind": kind, + "x1": float(np.clip(x1, 0.0, 1.0)), + "y1": float(np.clip(y1, 0.0, 1.0)), + "x2": float(np.clip(x2, 0.0, 1.0)), + "y2": float(np.clip(y2, 0.0, 1.0)), + "width": max(1, min(128, width)), + "color": _normalize_markup_color(shape.get("color")), + }) + return parsed + + +def _apply_annotation_overlay( + image: np.ndarray, + field: DataField, + colormap: Any, + spec: dict[str, Any], +) -> np.ndarray: + from PIL import Image, ImageDraw + + show_scale_bar = bool(spec.get("show_scale_bar", True)) + show_color_map = bool(spec.get("show_color_map", True)) + text_size = float(spec.get("text_size", 14.0)) + text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0 + font_spec = normalize_font_spec(spec.get("font")) + + current = np.asarray(image, dtype=np.uint8) + if current.ndim == 2: + current = np.repeat(current[:, :, np.newaxis], 3, axis=2) + + height, current_width = current.shape[:2] + field_width = max(1, int(field.xres)) + legend_width = max(72, int(round(field_width * 0.18))) if show_color_map else 0 + canvas_width = current_width + legend_width + canvas = np.full((height, canvas_width, 3), 255, dtype=np.uint8) + canvas[:, :current_width] = current + + pil_image = Image.fromarray(canvas) + draw = ImageDraw.Draw(pil_image) + base_font_px = max(6, int(round(text_size))) + + if show_scale_bar and field_width > 0 and np.isfinite(field.xreal) and field.xreal > 0: + target_real = field.xreal / 5.0 + bar_real = _nice_length(target_real) + if bar_real > 0 and np.isfinite(field.dx) and field.dx > 0: + bar_px = max(1, int(round(bar_real / field.dx))) + margin_x = max(8, field_width // 24) + margin_y = max(8, height // 24) + bar_height = max(3, int(round(height * 0.012))) + bar_px = min(bar_px, max(1, field_width - 2 * margin_x)) + x0 = margin_x + x1 = x0 + bar_px + y1 = height - margin_y + y0 = y1 - bar_height + text = _format_with_unit(bar_real, field.si_unit_xy) + text_image = _render_overlay_text(text, base_font_px, (255, 255, 255), font_spec=font_spec) + text_w, text_h = text_image.size + label_pad = 2 + bg_left = max(0, x0 - 4) + bg_top = max(0, y0 - text_h - label_pad * 3) + bg_right = min(canvas_width, max(x1 + 4, x0 + text_w + 8)) + bg_bottom = min(height, y1 + 4) + draw.rectangle((bg_left, bg_top, bg_right, bg_bottom), fill=(0, 0, 0)) + draw.rectangle((x0, y0, x1, y1), fill=(255, 255, 255)) + pil_image.paste(text_image, (x0, bg_top + label_pad), text_image) + + if show_color_map and legend_width > 0: + panel_x0 = current_width + draw.rectangle((panel_x0, 0, canvas_width, height), fill=(245, 245, 245)) + grad_x0 = panel_x0 + max(8, legend_width // 7) + grad_w = max(12, legend_width // 5) + grad_y0 = max(10, height // 18) + grad_y1 = max(grad_y0 + 10, height - grad_y0) + grad_h = grad_y1 - grad_y0 + gradient = np.linspace(1.0, 0.0, grad_h, dtype=np.float64)[:, np.newaxis] + gradient = np.repeat(gradient, grad_w, axis=1) + gradient_rgb = colormap_to_uint8(gradient, colormap) + pil_image.paste(Image.fromarray(gradient_rgb), (grad_x0, grad_y0)) + draw.rectangle((grad_x0, grad_y0, grad_x0 + grad_w, grad_y1), outline=(40, 40, 40), width=1) + + legend_min, legend_mid, legend_max = _display_value_range(field) + labels = [ + (legend_max, grad_y0), + (legend_mid, grad_y0 + grad_h // 2), + (legend_min, grad_y1), + ] + text_x = grad_x0 + grad_w + 8 + for value, y_center in labels: + text_image = _render_overlay_text( + _format_with_unit(value, field.si_unit_z), + base_font_px, + (20, 20, 20), + font_spec=font_spec, + ) + text_y = int(round(y_center - text_image.size[1] / 2)) + text_y = max(0, min(height - text_image.size[1], text_y)) + pil_image.paste(text_image, (text_x, text_y), text_image) + + return np.asarray(pil_image, dtype=np.uint8) + + +def _apply_markup_overlay(image: np.ndarray, field: DataField, spec: dict[str, Any]) -> np.ndarray: + from PIL import Image, ImageDraw + + current = np.asarray(image, dtype=np.uint8) + if current.ndim == 2: + current = np.repeat(current[:, :, np.newaxis], 3, axis=2) + + pil_image = Image.fromarray(current.copy()) + draw = ImageDraw.Draw(pil_image) + field_width = max(1, int(field.xres)) + field_height = max(1, int(field.yres)) + + for shape in _sanitize_markup_shapes(spec.get("shapes")): + x1 = float(shape["x1"]) * field_width + y1 = float(shape["y1"]) * field_height + x2 = float(shape["x2"]) * field_width + y2 = float(shape["y2"]) * field_height + color = str(shape["color"]) + stroke_width = _preview_markup_stroke_width(int(shape["width"]), field_width, field_height) + kind = str(shape["kind"]) + + if kind == "line": + draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width) + elif kind == "rectangle": + draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width) + elif kind == "circle": + draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width) + elif kind == "arrow": + _draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width) + + return np.asarray(pil_image, dtype=np.uint8) + + +def render_datafield_preview(df: DataField, colormap: Any = "gray") -> np.ndarray: + current = datafield_to_uint8(df, colormap) + for overlay in df.overlays: + if not isinstance(overlay, dict): + continue + kind = str(overlay.get("kind", "")).strip().lower() + if kind == "annotation": + current = _apply_annotation_overlay(current, df, colormap, overlay) + elif kind == "markup": + current = _apply_markup_overlay(current, df, overlay) + return current def image_to_uint8(image: np.ndarray) -> np.ndarray: @@ -195,5 +778,5 @@ def encode_preview(arr: np.ndarray) -> str: @lru_cache(maxsize=len(COLORMAPS)) def _get_colormap(colormap: str): - import matplotlib.cm as cm - return cm.get_cmap(colormap) + from matplotlib import colormaps + return colormaps[colormap] diff --git a/backend/execution.py b/backend/execution.py index 7f6a44f..918f4d9 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -23,6 +23,7 @@ The engine: from __future__ import annotations import uuid from collections import defaultdict, deque +from math import isfinite from time import perf_counter from typing import Any, Callable @@ -87,7 +88,8 @@ class ExecutionEngine: cls = NODE_CLASS_MAPPINGS[class_name] raw_inputs = node_def.get("inputs", {}) - inputs = self._resolve_inputs(raw_inputs, node_outputs) + input_types = cls.INPUT_TYPES() + inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types) # Let display nodes know their node_id so they can tag WS messages self._set_node_id_on_display(cls, node_id) @@ -110,7 +112,7 @@ class ExecutionEngine: # Auto-preview: broadcast a thumbnail for any DATA_FIELD, # IMAGE, or table-like output so every node shows its result. if on_preview or on_table: - self._auto_preview(cls, node_id, result, on_preview, on_table) + self._auto_preview(cls, node_id, result, on_preview, on_table, inputs) if on_node_done: on_node_done(node_id, elapsed_ms) @@ -154,8 +156,14 @@ class ExecutionEngine: self, raw_inputs: dict[str, Any], node_outputs: dict[str, tuple], + input_types: dict[str, dict[str, Any]] | None = None, ) -> dict[str, Any]: """Replace [src_id, slot] links with actual output values.""" + specs = {} + if input_types: + specs.update(input_types.get("required", {})) + specs.update(input_types.get("optional", {})) + resolved = {} for key, value in raw_inputs.items(): if _is_link(value): @@ -170,11 +178,36 @@ class ExecutionEngine: f"Node '{src_id}' only has {len(outputs)} outputs, " f"but slot {slot} was requested." ) - resolved[key] = outputs[slot] + resolved_value = outputs[slot] else: - resolved[key] = value + resolved_value = value + + resolved[key] = self._coerce_input_value(resolved_value, specs.get(key)) return resolved + def _coerce_input_value(self, value: Any, spec: Any) -> Any: + if spec is None: + return value + + input_type = spec[0] if isinstance(spec, (list, tuple)) and spec else spec + if isinstance(input_type, list): + return value + + if input_type == "INT": + numeric = float(value) + if not isfinite(numeric): + raise ValueError(f"Expected a finite numeric value for INT input, got {value!r}") + rounded = int(abs(numeric) + 0.5) + return rounded if numeric >= 0 else -rounded + + if input_type == "FLOAT": + numeric = float(value) + if not isfinite(numeric): + raise ValueError(f"Expected a finite numeric value for FLOAT input, got {value!r}") + return numeric + + return value + def _inject_display_callbacks( self, on_preview: Callable | None, @@ -185,11 +218,11 @@ class ExecutionEngine: on_warning: Callable | None = None, ) -> None: """Wire up broadcast callbacks on display node classes.""" - from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay + from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram - from backend.nodes.modify import CropResizeField + from backend.nodes.modify import CropResizeField, RotateField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask - from backend.nodes.io import SaveImage, LoadFile + from backend.nodes.io import SaveImage, LoadFile, LoadDemo PreviewImage._broadcast_fn = on_preview ThresholdMask._broadcast_fn = on_preview @@ -206,19 +239,22 @@ class ExecutionEngine: CrossSection._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay CropResizeField._broadcast_overlay_fn = on_overlay + RotateField._broadcast_warning_fn = on_warning + Markup._broadcast_overlay_fn = on_overlay LoadFile._broadcast_warning_fn = on_warning + LoadDemo._broadcast_warning_fn = on_warning SaveImage._broadcast_warning_fn = on_warning def _set_node_id_on_display(self, cls: type, node_id: str) -> None: """Inform display nodes of their current node_id for WS tagging.""" - from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay + from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram - from backend.nodes.modify import CropResizeField + from backend.nodes.modify import CropResizeField, RotateField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask - from backend.nodes.io import LoadFile, SaveImage - if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, + from backend.nodes.io import LoadFile, LoadDemo, SaveImage + if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup, ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask, - LoadFile, SaveImage): + LoadFile, LoadDemo, SaveImage): cls._current_node_id = node_id def _auto_preview( @@ -228,6 +264,7 @@ class ExecutionEngine: result: tuple, on_preview: Callable | None, on_table: Callable | None, + inputs: dict[str, Any] | None = None, ) -> None: """ After every node executes, inspect its outputs and broadcast @@ -236,12 +273,19 @@ class ExecutionEngine: """ import numpy as np from backend.data_types import ( - DataField, datafield_to_uint8, image_to_uint8, encode_preview, + DataField, image_to_uint8, encode_preview, render_datafield_preview, ) + from backend.nodes.io import LoadFile, LoadDemo if getattr(cls, "_CUSTOM_PREVIEW", False): return + if cls in (LoadFile, LoadDemo) and on_preview: + preview = self._render_load_node_preview(result, inputs or {}) + if preview: + on_preview(node_id, preview) + return + return_types = getattr(cls, "RETURN_TYPES", ()) for slot, type_name in enumerate(return_types): @@ -250,7 +294,7 @@ class ExecutionEngine: value = result[slot] if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview: - arr = datafield_to_uint8(value, value.colormap) + arr = render_datafield_preview(value, value.colormap) on_preview(node_id, encode_preview(arr)) return # one preview per node is enough @@ -269,6 +313,39 @@ class ExecutionEngine: on_table(node_id, value) return + def _render_load_node_preview( + self, + result: tuple, + inputs: dict[str, Any], + ) -> dict | None: + from backend.data_types import DataField, encode_preview, render_datafield_preview + from backend.nodes.io import list_channels + + fields = [value for value in result if isinstance(value, DataField)] + if not fields: + return None + + selected_path = str(inputs.get("path") or inputs.get("filename") or inputs.get("name") or "").strip() + channel_names: list[str] = [] + if selected_path: + try: + channel_names = [str(entry.get("name", "")).strip() or "field" for entry in list_channels(selected_path)] + except Exception: + channel_names = [] + + layers = [] + for index, field in enumerate(fields): + arr = render_datafield_preview(field, field.colormap) + layers.append({ + "name": channel_names[index] if index < len(channel_names) else f"layer {index + 1}", + "image": encode_preview(arr), + }) + + return { + "kind": "layer_gallery", + "layers": layers, + } + def _render_line_preview( self, diff --git a/backend/nodes/display.py b/backend/nodes/display.py index 8d7f7bf..0c263ba 100644 --- a/backend/nodes/display.py +++ b/backend/nodes/display.py @@ -7,10 +7,26 @@ before execution begins. """ from __future__ import annotations +import json import numpy as np from backend.node_registry import register_node from backend.data_types import ( - DataField, MeasureTable, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview, normalize_for_colormap, + DataField, + MeasureTable, + COLORMAPS, + CUSTOM_FILE_FONT, + DEFAULT_CUSTOM_COLORMAP_STOPS, + SYSTEM_DEFAULT_FONT, + colormap_to_uint8, + datafield_to_uint8, + encode_preview, + image_to_uint8, + list_overlay_font_choices, + normalize_colormap_spec, + normalize_font_spec, + normalize_for_colormap, + render_datafield_preview, + resolve_colormap_input, ) @@ -59,15 +75,473 @@ def _scalar_payload(value: float, unit: str = "") -> dict: return payload +_SI_PREFIXES = [ + (1e24, "Y"), + (1e21, "Z"), + (1e18, "E"), + (1e15, "P"), + (1e12, "T"), + (1e9, "G"), + (1e6, "M"), + (1e3, "k"), + (1.0, ""), + (1e-3, "m"), + (1e-6, "u"), + (1e-9, "n"), + (1e-12, "p"), + (1e-15, "f"), + (1e-18, "a"), + (1e-21, "z"), + (1e-24, "y"), +] +_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"} + + +def _format_numeric(value: float) -> str: + if not np.isfinite(value): + return str(value) + abs_value = abs(value) + if abs_value == 0: + return "0" + if abs_value >= 1e4 or abs_value < 1e-3: + return f"{value:.3e}" + return f"{value:.4g}" + + +def _format_with_unit(value: float, unit: str) -> str: + unit = (unit or "").strip() + if not unit: + return _format_numeric(value) + if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0: + abs_value = abs(value) + for scale, prefix in _SI_PREFIXES: + scaled = abs_value / scale + if 1 <= scaled < 1000: + signed = value / scale + return f"{_format_numeric(signed)} {prefix}{unit}" + return f"{_format_numeric(value)} {unit}" + + +def _nice_length(target: float) -> float: + if not np.isfinite(target) or target <= 0: + return 0.0 + exponent = np.floor(np.log10(target)) + base = 10.0 ** exponent + for step in (5.0, 2.0, 1.0): + candidate = step * base + if candidate <= target: + return candidate + return base + + +def _display_value_range(field: DataField) -> tuple[float, float, float]: + data = np.asarray(field.data, dtype=np.float64) + dmin = float(data.min()) + dmax = float(data.max()) + if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin: + return dmin, dmin, dmin + + offset = float(field.display_offset) + scale = float(field.display_scale) + if not np.isfinite(offset): + offset = 0.0 + if not np.isfinite(scale) or scale <= 0.0: + scale = 1.0 + + low_norm = float(np.clip(offset, 0.0, 1.0)) + high_norm = float(np.clip(offset + scale, 0.0, 1.0)) + if high_norm < low_norm: + low_norm, high_norm = high_norm, low_norm + mid_norm = 0.5 * (low_norm + high_norm) + + span = dmax - dmin + return ( + dmin + low_norm * span, + dmin + mid_norm * span, + dmin + high_norm * span, + ) + + +def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]): + from PIL import Image, ImageDraw, ImageFont + + size_px = max(8, int(round(size_px))) + try: + font = ImageFont.truetype("DejaVuSans.ttf", size_px) + probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0)) + probe_draw = ImageDraw.Draw(probe) + bbox = probe_draw.textbbox((0, 0), text, font=font) + width = max(1, bbox[2] - bbox[0]) + height = max(1, bbox[3] - bbox[1]) + text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + text_draw = ImageDraw.Draw(text_image) + text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255)) + return text_image + except Exception: + font = ImageFont.load_default() + probe = Image.new("L", (1, 1), 0) + probe_draw = ImageDraw.Draw(probe) + bbox = probe_draw.textbbox((0, 0), text, font=font) + width = max(1, bbox[2] - bbox[0]) + height = max(1, bbox[3] - bbox[1]) + mask = Image.new("L", (width, height), 0) + mask_draw = ImageDraw.Draw(mask) + mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255) + + scale = max(1.0, size_px / max(1, height)) + scaled_width = max(1, int(round(width * scale))) + scaled_height = max(1, int(round(height * scale))) + resampling = getattr(Image, "Resampling", Image) + scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR) + + text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0)) + text_image.putalpha(scaled_mask) + return text_image + + +def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str: + if isinstance(color, str): + text = color.strip() + if len(text) == 4 and text.startswith("#"): + text = "#" + "".join(ch * 2 for ch in text[1:]) + if len(text) == 7 and text.startswith("#"): + try: + int(text[1:], 16) + return text.lower() + except ValueError: + pass + return default + + +def _parse_markup_shapes(raw_shapes: str | list | None) -> list[dict[str, object]]: + if isinstance(raw_shapes, str): + try: + raw_shapes = json.loads(raw_shapes or "[]") + except json.JSONDecodeError: + raw_shapes = [] + + if not isinstance(raw_shapes, list): + return [] + + parsed: list[dict[str, object]] = [] + for shape in raw_shapes: + if not isinstance(shape, dict): + continue + + kind = str(shape.get("kind", "")).strip().lower() + if kind not in {"line", "rectangle", "circle", "arrow"}: + continue + + try: + x1 = float(shape.get("x1")) + y1 = float(shape.get("y1")) + x2 = float(shape.get("x2")) + y2 = float(shape.get("y2")) + width = int(round(float(shape.get("width", 3)))) + except (TypeError, ValueError): + continue + + coords = [x1, y1, x2, y2] + if not all(np.isfinite(value) for value in coords): + continue + + parsed.append({ + "kind": kind, + "x1": float(np.clip(x1, 0.0, 1.0)), + "y1": float(np.clip(y1, 0.0, 1.0)), + "x2": float(np.clip(x2, 0.0, 1.0)), + "y2": float(np.clip(y2, 0.0, 1.0)), + "width": max(1, min(128, width)), + "color": _normalize_markup_color(shape.get("color")), + }) + + return parsed + + +def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int): + dx = end[0] - start[0] + dy = end[1] - start[1] + length = float(np.hypot(dx, dy)) + if length <= 1e-6: + radius = max(1.0, width / 2.0) + draw.ellipse( + (start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius), + fill=color, + ) + return + + ux = dx / length + uy = dy / length + head_length = max(10.0, width * 4.0) + head_width = max(8.0, width * 3.0) + shaft_end = ( + end[0] - ux * head_length, + end[1] - uy * head_length, + ) + + draw.line((start, shaft_end), fill=color, width=width) + + px = -uy + py = ux + left = ( + shaft_end[0] + px * head_width / 2.0, + shaft_end[1] + py * head_width / 2.0, + ) + right = ( + shaft_end[0] - px * head_width / 2.0, + shaft_end[1] - py * head_width / 2.0, + ) + draw.polygon([end, left, right], fill=color) + + +def _render_markup_image(image: np.ndarray, shapes: list[dict[str, object]]) -> np.ndarray: + from PIL import Image, ImageDraw + + base = image_to_uint8(image) + if base.ndim == 2: + base = np.repeat(base[:, :, np.newaxis], 3, axis=2) + + canvas = Image.fromarray(base.copy()) + draw = ImageDraw.Draw(canvas) + height, width = base.shape[:2] + + for shape in shapes: + x1 = float(shape["x1"]) * width + y1 = float(shape["y1"]) * height + x2 = float(shape["x2"]) * width + y2 = float(shape["y2"]) * height + color = str(shape["color"]) + stroke_width = int(shape["width"]) + kind = str(shape["kind"]) + + if kind == "line": + draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width) + elif kind == "rectangle": + draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width) + elif kind == "circle": + draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width) + elif kind == "arrow": + _draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width) + + return np.asarray(canvas, dtype=np.uint8) + + +@register_node(display_name="Color Map") +class ColorMap: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mode": (["preset", "custom"], {"default": "preset"}), + "preset": (list(COLORMAPS), { + "default": "viridis", + "show_when_widget_value": {"mode": ["preset"]}, + }), + "stops": ("STRING", { + "default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)), + "colormap_stops": True, + "show_when_widget_value": {"mode": ["custom"]}, + }), + } + } + + RETURN_TYPES = ("COLORMAP",) + RETURN_NAMES = ("colormap",) + FUNCTION = "build" + CATEGORY = "display" + DESCRIPTION = ( + "Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours " + "and any number of intermediate stops." + ) + + def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple: + if mode == "preset": + return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},) + + try: + raw_stops = stops if stops is not None else stops_json + stops_data = json.loads(raw_stops or "[]") + except json.JSONDecodeError as exc: + raise ValueError("Custom colormap stops must be valid JSON.") from exc + + spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None) + if not (isinstance(spec, dict) and spec.get("mode") == "custom"): + raise ValueError("Custom colormap must include at least min and max colours.") + return (spec,) + + +@register_node(display_name="Font") +class Font: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], { + "default": SYSTEM_DEFAULT_FONT, + }), + "font_file": ("FILE_PICKER", { + "default": "", + "show_when_widget_value": {"family": [CUSTOM_FILE_FONT]}, + }), + } + } + + RETURN_TYPES = ("FONT",) + RETURN_NAMES = ("font",) + FUNCTION = "build" + CATEGORY = "display" + DESCRIPTION = ( + "Build a reusable font spec for annotation overlays. Choose a discovered system font, " + "use the default fallback stack, or point to a custom font file." + ) + + def build(self, family: str, font_file: str = "") -> tuple: + if family == SYSTEM_DEFAULT_FONT: + return (None,) + if family == CUSTOM_FILE_FONT: + return (normalize_font_spec({"path": font_file}),) + return (normalize_font_spec({"family": family}),) + + +@register_node(display_name="Annotations") +class Annotations: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), + "show_scale_bar": ("BOOLEAN", {"default": True}), + "show_color_map": ("BOOLEAN", {"default": True}), + "text_size": ("FLOAT", { + "default": 14.0, + "min": 6.0, + "max": 96.0, + "step": 1.0, + }), + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + "font": ("FONT",), + }, + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("annotated",) + FUNCTION = "render" + CATEGORY = "display" + DESCRIPTION = ( + "Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. " + "The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values." + ) + + def render( + self, + field: DataField, + colormap: str, + show_scale_bar: bool, + show_color_map: bool, + text_size: float = 1.0, + colormap_map=None, + font=None, + ) -> tuple: + resolved_colormap = resolve_colormap_input( + colormap, + colormap_input=colormap_map, + inherited=field.colormap, + default="gray", + ) + text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0 + out = field.replace( + colormap=resolved_colormap, + overlays=[ + *field.overlays, + { + "kind": "annotation", + "show_scale_bar": bool(show_scale_bar), + "show_color_map": bool(show_color_map), + "text_size": text_size, + "font": normalize_font_spec(font), + }, + ], + ) + return (out,) + + +@register_node(display_name="Markup") +class Markup: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}), + "stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}), + "stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}), + "clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}), + "markup_shapes": ("STRING", {"default": "[]", "hidden": True}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("annotated",) + FUNCTION = "process" + CATEGORY = "display" + DESCRIPTION = ( + "Draw simple vector markup over a DATA_FIELD without flattening the underlying data. " + "Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows." + ) + + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process( + self, + field: DataField, + shape: str, + stroke_color: str, + stroke_width: int, + markup_shapes: str, + ) -> tuple: + shapes = _parse_markup_shapes(markup_shapes) + out = field.replace( + overlays=[ + *field.overlays, + { + "kind": "markup", + "shapes": shapes, + }, + ], + ) + + if Markup._broadcast_overlay_fn is not None: + Markup._broadcast_overlay_fn( + Markup._current_node_id, + { + "kind": "markup", + "section_title": "Markup", + "image": encode_preview(datafield_to_uint8(field, field.colormap)), + "shape": str(shape), + "stroke_color": _normalize_markup_color(stroke_color), + "stroke_width": max(1, int(stroke_width)), + }, + ) + + return (out,) + + @register_node(display_name="Preview") class PreviewImage: @classmethod def INPUT_TYPES(cls): return { "required": { - "colormap": (["auto"] + list(COLORMAPS),), + "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), }, "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), "image": ("IMAGE",), "field": ("DATA_FIELD",), } @@ -82,30 +556,35 @@ class PreviewImage: _broadcast_fn = None _current_node_id: str = "" - def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple: - # Resolve "auto" — use field's colormap if available, else fall back to gray - if colormap == "auto": - colormap = field.colormap if field is not None else "gray" + def preview( + self, + colormap: str, + image: np.ndarray | None = None, + field=None, + colormap_map=None, + ) -> tuple: + resolved_colormap = resolve_colormap_input( + colormap, + colormap_input=colormap_map, + inherited=field.colormap if field is not None else None, + default="gray", + ) # Prefer field if both are connected; accept whichever is provided if field is not None: - arr_u8 = datafield_to_uint8(field, colormap) + arr_u8 = render_datafield_preview(field, resolved_colormap) elif image is not None: - if image.dtype != np.uint8: - imin, imax = image.min(), image.max() - if imax > imin: - norm = (image - imin) / (imax - imin) + arr_u8 = image_to_uint8(image) + if arr_u8.ndim == 2: + if image.dtype == np.uint8: + normalized = arr_u8.astype(np.float64) / 255.0 else: - norm = np.zeros_like(image) - arr_u8 = (norm * 255).astype(np.uint8) - else: - arr_u8 = image - - if arr_u8.ndim == 2 and colormap != "gray": - import matplotlib.cm as cm - cmap = cm.get_cmap(colormap) - rgba = cmap(arr_u8.astype(np.float32) / 255.0) - arr_u8 = (rgba[:, :, :3] * 255).astype(np.uint8) + imin, imax = image.min(), image.max() + if imax > imin: + normalized = (image - imin) / (imax - imin) + else: + normalized = np.zeros_like(image, dtype=np.float64) + arr_u8 = colormap_to_uint8(normalized, resolved_colormap) else: raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.") @@ -124,10 +603,13 @@ class View3D: return { "required": { "field": ("DATA_FIELD",), - "colormap": (["auto"] + list(COLORMAPS),), + "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), "z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}), "resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}), - } + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + }, } RETURN_TYPES = () @@ -144,9 +626,8 @@ class View3D: def render( self, field: DataField, - colormap: str, z_scale: float, resolution: int, + colormap: str, z_scale: float, resolution: int, colormap_map=None, ) -> tuple: - import matplotlib.cm as cm import base64 data = field.data @@ -168,10 +649,13 @@ class View3D: data_max=float(field.data.max()), ) - cmap_name = field.colormap if colormap == "auto" else colormap - cmap = cm.get_cmap(cmap_name) - rgba = cmap(z_norm) # (ny, nx, 4) float [0,1] - colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8) + resolved_colormap = resolve_colormap_input( + colormap, + colormap_input=colormap_map, + inherited=field.colormap, + default="gray", + ) + colors_u8 = colormap_to_uint8(z_norm, resolved_colormap) # Base64-encode arrays for efficient WS transport z_b64 = base64.b64encode(z.tobytes()).decode() diff --git a/backend/nodes/io.py b/backend/nodes/io.py index f34fe52..ca9b7d4 100644 --- a/backend/nodes/io.py +++ b/backend/nodes/io.py @@ -4,11 +4,12 @@ I/O nodes: load and save images and SPM data. from __future__ import annotations import os +import re import numpy as np from pathlib import Path from backend.node_registry import register_node -from backend.data_types import DataField, COLORMAPS, encode_preview, image_to_uint8 +from backend.data_types import COLORMAPS, DataField, encode_preview, image_to_uint8, resolve_colormap_input from backend.runtime_paths import demo_dir, input_dir, output_dir # Resolved at server startup so nodes know where to look @@ -22,6 +23,7 @@ _DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz", _SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"} _IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"} _ARRAY_EXTENSIONS = {".npy", ".npz"} +_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS # --------------------------------------------------------------------------- @@ -105,6 +107,23 @@ def list_channels(filepath: str) -> list[dict]: return [{"name": "field", "type": "DATA_FIELD"}] +def list_folder_paths(folderpath: str) -> list[dict]: + """Return a folder DIRECTORY plus compatible image/array/SPM FILE_PATH outputs.""" + path = _resolve_path(folderpath) + if not path.exists() or not path.is_dir(): + return [] + + resolved_dir = str(path.resolve()) + results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}] + for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()): + if not entry.is_file() or entry.name.startswith("."): + continue + if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS: + continue + results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())}) + return results + + # --------------------------------------------------------------------------- # LoadFile (unified loader — replaces LoadImage + LoadSPM) # --------------------------------------------------------------------------- @@ -115,9 +134,13 @@ class LoadFile: def INPUT_TYPES(cls): return { "required": { - "filename": ("FILE_PICKER", {"default": ""}), - "colormap": (list(COLORMAPS),), - } + "filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}), + "colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + "path": ("FILE_PATH", {"label": "path"}), + }, } # Default outputs — overridden dynamically by the frontend for multi-channel files @@ -136,26 +159,28 @@ class LoadFile: _broadcast_warning_fn = None _current_node_id = None - def load(self, filename: str, colormap: str = "viridis"): - if not filename or not filename.strip(): + def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None): + selected_path = str(path).strip() if path is not None else str(filename).strip() + if not selected_path: raise ValueError("No file selected — use Browse to pick a file.") - path = _resolve_path(filename) - if not path.exists(): - raise FileNotFoundError(f"File not found: {path}") - if path.is_dir(): - raise IsADirectoryError(f"Expected a file, got a directory: {path}") + path_obj = _resolve_path(selected_path) + if not path_obj.exists(): + raise FileNotFoundError(f"File not found: {path_obj}") + if path_obj.is_dir(): + raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}") - ext = path.suffix.lower() + ext = path_obj.suffix.lower() + resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis") if ext in _SPM_EXTENSIONS: - fields = self._load_spm_all(path, ext) + fields = self._load_spm_all(path_obj, ext) for f in fields: - f.colormap = colormap + f.colormap = resolved_colormap return tuple(fields) # Image or array — uncalibrated, single output - field = self._load_image_or_array(path, ext) - field.colormap = colormap + field = self._load_image_or_array(path_obj, ext) + field.colormap = resolved_colormap self._send_warning("Uncalibrated data — no physical dimensions.") return (field,) @@ -349,8 +374,11 @@ class LoadDemo: return { "required": { "name": (choices,), - "colormap": (list(COLORMAPS),), - } + "colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + }, } RETURN_TYPES = ("DATA_FIELD",) @@ -359,13 +387,38 @@ class LoadDemo: CATEGORY = "io" DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data." - def load(self, name: str, colormap: str = "viridis"): - path = DEMO_DIR / name - if not path.exists(): - raise FileNotFoundError(f"Demo file not found: {name}") - + def load(self, name: str = "", colormap: str = "viridis", colormap_map=None): loader = LoadFile() - return loader.load(filename=str(path), colormap=colormap) + demo_path = DEMO_DIR / name + if not demo_path.exists(): + raise FileNotFoundError(f"Demo file not found: {name}") + return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map) + + +@register_node(display_name="Folder") +class Folder: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}), + } + } + + RETURN_TYPES = ("DIRECTORY",) + RETURN_NAMES = ("directory",) + FUNCTION = "list_files" + CATEGORY = "io" + DESCRIPTION = ( + "Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. " + "Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans." + ) + + def list_files(self, folder: str) -> tuple: + entries = list_folder_paths(folder) + if not entries: + return tuple() + return tuple(item["path"] for item in entries) # --------------------------------------------------------------------------- @@ -395,6 +448,36 @@ class Coordinate: return ((float(x), float(y)),) +# --------------------------------------------------------------------------- +# Number +# --------------------------------------------------------------------------- + +@register_node(display_name="Number") +class Number: + """Provide a fixed scalar value that can feed FLOAT or INT widget sockets.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + CATEGORY = "io" + DESCRIPTION = ( + "Output a fixed numeric value. " + "When connected to FLOAT inputs the exact value is used; " + "INT inputs round to the nearest integer at execution time." + ) + + def process(self, value: float) -> tuple: + return (float(value),) + + # --------------------------------------------------------------------------- # RangeSlider # --------------------------------------------------------------------------- @@ -445,12 +528,32 @@ _MAX_SAVE_FIELDS = 8 class SaveImage: @classmethod def INPUT_TYPES(cls): - optional = {} + optional = { + "directory": ("DIRECTORY", {"label": "directory"}), + } for i in range(_MAX_SAVE_FIELDS): - optional[f"field_{i}"] = ("DATA_FIELD",) + optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"}) + optional[f"layer_name_{i}"] = ("STRING", { + "default": "", + "placeholder": "name", + "show_when_input_visible": f"field_{i}", + "inline_with_input": f"field_{i}", + "hide_label": True, + }) return { "required": { - "filename": ("FILE_PICKER", {"default": ""}), + "filename": ("STRING", { + "default": "", + "placeholder": "filename", + "placement": "top", + }), + "directory_path": ("FOLDER_PICKER", { + "default": "", + "label": "directory", + "placement": "top", + "hide_when_input_connected": "directory", + "top_socket_input": "directory", + }), "format": (["TIFF", "NPZ"],), }, "optional": optional, @@ -462,59 +565,130 @@ class SaveImage: OUTPUT_NODE = True MANUAL_TRIGGER = True DESCRIPTION = ( - "Save one or more DATA_FIELD layers to a single file. " - "Connect fields to the inputs — a new slot appears as each is filled. " - "TIFF writes float32 multi-page; NPZ writes float64 named arrays. " + "Save one or more layers to a single file. " + "Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. " + "Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. " + "A new slot appears as each one is filled, with a matching per-layer name field. " + "TIFF writes multi-page data and stores layer names as page descriptions; " + "NPZ writes named arrays using those layer names as keys. " "Click Save to write (does not auto-run)." ) _broadcast_warning_fn = None _current_node_id = None - def save(self, filename: str, format: str = "TIFF", **kwargs): - # Collect connected fields in order - fields = [] + def save( + self, + filename: str, + directory_path: str = "", + format: str = "TIFF", + directory: str | None = None, + **kwargs, + ): + layers = [] + layer_names = [] for i in range(_MAX_SAVE_FIELDS): - f = kwargs.get(f"field_{i}") - if f is not None: - fields.append(f) + layer = kwargs.get(f"field_{i}") + if layer is not None: + layers.append(layer) + layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i)) - if not fields: - raise ValueError("No fields connected — connect at least one DATA_FIELD input.") + if not layers: + raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.") - if not filename or not filename.strip(): - raise ValueError("No output path selected — use Browse to pick a location.") - - path = Path(filename) - # Ensure parent directory exists - path.parent.mkdir(parents=True, exist_ok=True) - - # Force correct extension - ext = ".tiff" if format == "TIFF" else ".npz" - if path.suffix.lower() != ext: - path = path.with_suffix(ext) + path = self._resolve_save_path(filename, format, directory, directory_path) if format == "TIFF": - self._save_tiff(path, fields) + self._save_tiff(path, layers, layer_names) else: - self._save_npz(path, fields) + self._save_npz(path, layers, layer_names) - self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}") + self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}") return () - def _save_tiff(self, path: Path, fields: list[DataField]): - from PIL import Image - images = [] - for f in fields: - images.append(Image.fromarray(f.data.astype(np.float32))) - images[0].save(str(path), save_all=True, append_images=images[1:]) + def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]): + import tifffile - def _save_npz(self, path: Path, fields: list[DataField]): + with tifffile.TiffWriter(str(path)) as tif: + for layer, layer_name in zip(layers, layer_names): + tif.write(self._layer_array_for_tiff(layer), description=layer_name) + + def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]): arrays = {} - for i, f in enumerate(fields): - arrays[f"layer_{i}"] = f.data + used_keys = set() + for i, (layer, layer_name) in enumerate(zip(layers, layer_names)): + arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer) np.savez(str(path), **arrays) + def _resolve_layer_name(self, raw_name: object, index: int) -> str: + text = str(raw_name).strip() if raw_name is not None else "" + return text or f"layer_{index}" + + def _resolve_save_path( + self, + filename: str, + format: str, + directory: str | None, + directory_path: str = "", + ) -> Path: + ext = ".tiff" if format == "TIFF" else ".npz" + raw_filename = str(filename).strip() if filename is not None else "" + raw_directory = str(directory).strip() if directory is not None else "" + if not raw_directory: + raw_directory = str(directory_path).strip() if directory_path is not None else "" + + if raw_directory: + dir_path = Path(raw_directory).expanduser() + if dir_path.exists() and not dir_path.is_dir(): + raise ValueError("Directory input expects a folder path, not a file path.") + if not dir_path.exists(): + if dir_path.suffix: + raise ValueError("Directory input expects a folder path, not a file path.") + dir_path.mkdir(parents=True, exist_ok=True) + + filename_part = Path(raw_filename).name if raw_filename else "" + if not filename_part: + raise ValueError("No output filename selected — enter a file name when using a directory input.") + path = dir_path / filename_part + else: + if not raw_filename: + raise ValueError("No output path selected — use Browse to pick a location.") + path = Path(raw_filename).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + + if path.suffix.lower() != ext: + path = path.with_suffix(ext) + return path + + def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str: + key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_") + if not key: + key = f"layer_{index}" + if key[0].isdigit(): + key = f"layer_{key}" + + candidate = key + suffix = 2 + while candidate in used_keys: + candidate = f"{key}_{suffix}" + suffix += 1 + used_keys.add(candidate) + return candidate + + def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray: + if isinstance(layer, DataField): + return np.asarray(layer.data, dtype=np.float32) + if isinstance(layer, np.ndarray): + return image_to_uint8(layer) + raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") + + def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray: + if isinstance(layer, DataField): + return np.asarray(layer.data) + if isinstance(layer, np.ndarray): + return np.asarray(layer) + raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") + def _send_warning(self, message: str): fn = SaveImage._broadcast_warning_fn nid = SaveImage._current_node_id diff --git a/backend/nodes/modify.py b/backend/nodes/modify.py index 42f0976..39527ed 100644 --- a/backend/nodes/modify.py +++ b/backend/nodes/modify.py @@ -143,6 +143,7 @@ class CropResizeField: yreal=(py1 - py0) * field.dy, xoff=field.xoff + px0 * field.dx, yoff=field.yoff + py0 * field.dy, + overlays=[], ) target_width, target_height = self._resolve_target_shape( @@ -217,6 +218,9 @@ class RotateField: "Optionally expand the canvas to keep the full rotated field while preserving the field center." ) + _broadcast_warning_fn = None + _current_node_id: str = "" + def process( self, field: DataField, @@ -224,6 +228,9 @@ class RotateField: interpolation: str, expand_canvas: bool, ) -> tuple: + if field.overlays: + self._send_warning("Rotate clears annotation/markup overlays!") + angle = float(angle) order_map = { "nearest": 0, @@ -264,9 +271,16 @@ class RotateField: yreal=new_yreal, xoff=center_x - new_xreal / 2.0, yoff=center_y - new_yreal / 2.0, + overlays=[], ) return (result,) + def _send_warning(self, message: str): + fn = RotateField._broadcast_warning_fn + nid = RotateField._current_node_id + if fn and nid: + fn(nid, message) + @staticmethod def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]: if not expand_canvas: diff --git a/backend/server.py b/backend/server.py index dec4caf..6af0284 100644 --- a/backend/server.py +++ b/backend/server.py @@ -215,6 +215,13 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: content_type="application/json", ) + async def get_folder_files(request: web.Request) -> web.Response: + folder_path = request.query.get("folder", "") + from backend.nodes.io import list_folder_paths + loop = asyncio.get_running_loop() + entries = await loop.run_in_executor(None, list_folder_paths, folder_path) + return web.Response(text=_dumps(entries), content_type="application/json") + async def upload_file(request: web.Request) -> web.Response: reader = await request.multipart() field = await reader.next() @@ -346,6 +353,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: app.router.add_get("/nodes", get_nodes) app.router.add_get("/files", list_files) app.router.add_get("/browse", browse_dir) + app.router.add_get("/folder-files", get_folder_files) app.router.add_post("/upload", upload_file) app.router.add_post("/download", download_file) app.router.add_post("/save-workflow-png", save_workflow_png) diff --git a/frontend/public/README.txt b/frontend/public/README.txt new file mode 100644 index 0000000..61f1c18 --- /dev/null +++ b/frontend/public/README.txt @@ -0,0 +1,7 @@ +Put a checked-in default workflow asset here to load it on boot. + +Supported filenames: +- default-workflow.png +- default-workflow.json + +If both are present, Argonode loads default-workflow.json first. diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index b803089..bdb0f27 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -16,18 +16,27 @@ import { embedWorkflow, extractWorkflow } from './pngMetadata'; import { captureViewportBlob as captureWorkflowViewportBlob } from './workflowCapture'; import { hydrateWorkflowState } from './workflowHydration'; import { serializeWorkflowState } from './workflowSerialization'; +import { loadDefaultWorkflowAsset } from './defaultWorkflow'; +import { + serializeExecutionGraph, + getAutoRunnableNodes, + hasBlockingAutoRunInput, +} from './executionGraph'; // ── Constants ───────────────────────────────────────────────────────── const DATA_TYPES = new Set([ 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE', - 'COORD', 'STATS_SOURCE', 'VALUE_SOURCE', + 'COORD', 'STATS_SOURCE', 'VALUE_SOURCE', 'COLORMAP', 'SAVE_LAYER', 'FONT', 'FILE_PATH', 'DIRECTORY', ]); const SOCKET_COMPATIBILITY = { STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'RECORD_TABLE']), ANY_TABLE: new Set(['MEASURE_TABLE', 'RECORD_TABLE']), VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']), + SAVE_LAYER: new Set(['DATA_FIELD', 'IMAGE']), + FLOAT: new Set(['INT']), + INT: new Set(['FLOAT']), }; const TYPE_COLORS = { @@ -39,8 +48,14 @@ const TYPE_COLORS = { ANY_TABLE: '#67e8f9', COORD: '#e91ed1', FLOAT: '#7dd3fc', + INT: '#38bdf8', STATS_SOURCE:'#c084fc', VALUE_SOURCE:'#60a5fa', + COLORMAP: '#f472b6', + SAVE_LAYER: '#22c55e', + FONT: '#fb7185', + FILE_PATH: '#f59e0b', + DIRECTORY: '#f97316', }; const NODE_TYPES = { custom: CustomNode }; @@ -59,6 +74,12 @@ function getOutputSlot(handleId) { return parseInt(handleId.split('::')[1], 10); } +function sameStringArray(a = [], b = []) { + if (a === b) return true; + if (!Array.isArray(a) || !Array.isArray(b) || a.length !== b.length) return false; + return a.every((item, index) => item === b[index]); +} + function socketTypesCompatible(sourceType, targetType) { if (sourceType === targetType) return true; const accepted = SOCKET_COMPATIBILITY[targetType]; @@ -221,43 +242,6 @@ async function captureViewportBlob(viewportEl, options) { } } -// ── Graph serialisation → backend prompt format ─────────────────────── - -function serializeGraph(nodes, edges, { excludeManualTrigger = false } = {}) { - const prompt = {}; - - for (const node of nodes) { - const { className, definition, widgetValues } = node.data; - if (!definition) continue; - if (excludeManualTrigger && definition.manual_trigger) continue; - - const inputs = {}; - - // Widget (scalar) values - const required = definition.input.required || {}; - for (const [name, spec] of Object.entries(required)) { - const [type] = Array.isArray(spec) ? spec : [spec]; - if (DATA_TYPES.has(type)) continue; // socket, handled via edges - if (type === 'BUTTON') continue; // UI-only widget, not a backend input - if (widgetValues[name] !== undefined) { - inputs[name] = widgetValues[name]; - } - } - - // Connected (socket) inputs from edges - const incoming = edges.filter((e) => e.target === node.id); - for (const edge of incoming) { - const inputName = getInputName(edge.targetHandle); - const outputSlot = getOutputSlot(edge.sourceHandle); - inputs[inputName] = [edge.source, outputSlot]; - } - - prompt[node.id] = { class_type: className, inputs }; - } - - return prompt; -} - // ── Context menu component ──────────────────────────────────────────── function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection }) { @@ -461,25 +445,15 @@ function Flow() { const [edges, setEdges, onEdgesChange] = useEdgesState([]); const [status, setStatus] = useState({ text: 'Connecting…', level: 'info' }); const [contextMenu, setContextMenu] = useState(null); - const [fileBrowserCb, setFileBrowserCb] = useState(null); + const [fileBrowserState, setFileBrowserState] = useState(null); const nodeDefsRef = useRef({}); const nextIdRef = useRef(1); const autoRunTimer = useRef(null); const autoRunRef = useRef(null); + const defaultWorkflowLoadAttemptedRef = useRef(false); const reactFlow = useReactFlow(); - // ── Load node definitions ─────────────────────────────────────────── - - useEffect(() => { - api.getNodes().then((defs) => { - nodeDefsRef.current = defs; - setStatus({ text: `Loaded ${Object.keys(defs).length} nodes.`, level: 'info' }); - }).catch((err) => { - setStatus({ text: 'Failed to load nodes: ' + err.message, level: 'error' }); - }); - }, []); - // ── WebSocket ─────────────────────────────────────────────────────── const updateNodeData = useCallback((nodeId, patch) => { @@ -488,6 +462,96 @@ function Flow() { )); }, [setNodes]); + const setNodeOutputs = useCallback((nodeId, output, outputName, extraDefinitionPatch = {}) => { + setNodes((prev) => prev.map((node) => { + if (node.id !== nodeId) return node; + const currentDefinition = node.data.definition || {}; + const nextDefinition = { + ...currentDefinition, + ...extraDefinitionPatch, + output, + output_name: outputName, + }; + const sameOutputs = sameStringArray(currentDefinition.output, output); + const sameNames = sameStringArray(currentDefinition.output_name, outputName); + const sameOutputPaths = sameStringArray(currentDefinition.output_paths, nextDefinition.output_paths); + if (sameOutputs && sameNames && sameOutputPaths) { + return node; + } + return { + ...node, + data: { + ...node.data, + definition: nextDefinition, + }, + }; + })); + reactFlow.updateNodeInternals(nodeId); + }, [reactFlow, setNodes]); + + const getResolvedPathInput = useCallback((nodeId) => { + const edge = reactFlow.getEdges().find( + (e) => e.target === nodeId && getInputName(e.targetHandle) === 'path' + ); + if (!edge) return null; + const sourceNode = reactFlow.getNode(edge.source); + const outputPaths = sourceNode?.data?.definition?.output_paths; + const outputSlot = getOutputSlot(edge.sourceHandle); + if (Array.isArray(outputPaths) && typeof outputPaths[outputSlot] === 'string') { + return outputPaths[outputSlot]; + } + return null; + }, [reactFlow]); + + const refreshLoadNodeOutputs = useCallback(async (nodeId, explicitPath = null) => { + const node = reactFlow.getNode(nodeId); + if (!node) return; + + let resolvedPath = typeof explicitPath === 'string' && explicitPath ? explicitPath : null; + if (!resolvedPath) { + resolvedPath = getResolvedPathInput(nodeId); + } + if (!resolvedPath) { + if (node.data.className === 'LoadFile') { + resolvedPath = node.data.widgetValues?.filename || ''; + } else if (node.data.className === 'LoadDemo') { + resolvedPath = node.data.widgetValues?.name || ''; + } + } + + if (!resolvedPath) { + setNodeOutputs(nodeId, ['DATA_FIELD'], ['field'], { output_paths: [] }); + return; + } + + const channels = await api.getChannels(resolvedPath); + setNodeOutputs( + nodeId, + channels.map((channel) => channel.type), + channels.map((channel) => channel.name), + { output_paths: [] }, + ); + }, [getResolvedPathInput, reactFlow, setNodeOutputs]); + + const refreshFolderNodeOutputs = useCallback(async (nodeId, folderPath) => { + const entries = folderPath ? await api.getFolderFiles(folderPath) : []; + setNodeOutputs( + nodeId, + entries.map((entry) => entry.type), + entries.map((entry) => entry.name), + { output_paths: entries.map((entry) => entry.path) }, + ); + + const downstreamPathEdges = reactFlow.getEdges().filter( + (edge) => edge.source === nodeId && getInputName(edge.targetHandle) === 'path' + ); + for (const edge of downstreamPathEdges) { + const outputSlot = getOutputSlot(edge.sourceHandle); + const resolvedPath = entries[outputSlot]?.path || null; + await refreshLoadNodeOutputs(edge.target, resolvedPath); + } + }, [reactFlow, refreshLoadNodeOutputs, setNodeOutputs]); + useEffect(() => { api.setMessageHandler((msg) => { console.log('[argonode] WS:', msg.type, msg.data?.node_id || msg.data?.node || ''); @@ -532,7 +596,7 @@ function Flow() { case 'overlay': updateNodeData( msg.data.node_id, - msg.data.overlay?.kind === 'mask_paint' + msg.data.overlay?.kind === 'mask_paint' || msg.data.overlay?.kind === 'markup' ? { overlay: msg.data.overlay, previewImage: null } : { overlay: msg.data.overlay }, ); @@ -568,8 +632,36 @@ function Flow() { filtered ); }); + if (getInputName(params.targetHandle) === 'path') { + setTimeout(() => { + refreshLoadNodeOutputs(params.target); + }, 0); + } scheduleAutoRun(); - }, [setEdges]); + }, [refreshLoadNodeOutputs, setEdges]); // scheduleAutoRun is stable (no deps) + + const handleEdgesChange = useCallback((changes) => { + const currentEdges = reactFlow.getEdges(); + onEdgesChange(changes); + + const affectedPathTargets = new Set(); + for (const change of changes) { + if (change.type !== 'remove') continue; + const removedEdge = currentEdges.find((edge) => edge.id === change.id); + if (!removedEdge) continue; + if (getInputName(removedEdge.targetHandle) === 'path') { + affectedPathTargets.add(removedEdge.target); + } + } + + if (affectedPathTargets.size > 0) { + setTimeout(() => { + affectedPathTargets.forEach((nodeId) => { + refreshLoadNodeOutputs(nodeId); + }); + }, 0); + } + }, [onEdgesChange, reactFlow, refreshLoadNodeOutputs]); // ── Drop-on-blank: open filtered context menu ────────────────────── @@ -610,44 +702,35 @@ function Flow() { }; })); - // If this is a filename/name change on a LoadFile/LoadDemo node, fetch channels - if ((name === 'filename' || name === 'name') && value) { - const node = reactFlow.getNode(nodeId); - if (node && (node.data.className === 'LoadFile' || node.data.className === 'LoadDemo')) { - api.getChannels(value).then((channels) => { - setNodes((prev) => prev.map((n) => { - if (n.id !== nodeId) return n; - return { - ...n, - data: { - ...n.data, - definition: { - ...n.data.definition, - output: channels.map((c) => c.type), - output_name: channels.map((c) => c.name), - }, - }, - }; - })); - reactFlow.updateNodeInternals(nodeId); - }); - } + const node = reactFlow.getNode(nodeId); + if (node && node.data.className === 'Folder' && name === 'folder') { + refreshFolderNodeOutputs(nodeId, value); + } + + if (node && (node.data.className === 'LoadFile' || node.data.className === 'LoadDemo') && (name === 'filename' || name === 'name')) { + refreshLoadNodeOutputs(nodeId, value); } scheduleAutoRun(); - }, [setNodes]); // scheduleAutoRun is stable (no deps) + }, [reactFlow, refreshFolderNodeOutputs, refreshLoadNodeOutputs, setNodes]); // scheduleAutoRun is stable (no deps) // ── File browser ──────────────────────────────────────────────────── - const openFileBrowser = useCallback((callback) => { + const openFileBrowser = useCallback((callback, { selectionMode = 'file' } = {}) => { + if (selectionMode === 'folder' && window.pywebview?.api?.open_folder_dialog) { + window.pywebview.api.open_folder_dialog().then((path) => { + if (path) callback(path); + }); + return; + } // Use native file picker when running inside pywebview (desktop app) - if (window.pywebview?.api?.open_file_dialog) { + if (selectionMode === 'file' && window.pywebview?.api?.open_file_dialog) { window.pywebview.api.open_file_dialog().then((path) => { if (path) callback(path); }); return; } - setFileBrowserCb(() => callback); + setFileBrowserState({ callback, selectionMode }); }, []); // ── Node context value (stable) ───────────────────────────────────── @@ -656,7 +739,7 @@ function Flow() { const currentNodes = reactFlow.getNodes(); const currentEdges = reactFlow.getEdges(); // Include ALL nodes (no excludeManualTrigger) so the save node is in the prompt - const prompt = serializeGraph(currentNodes, currentEdges); + const prompt = serializeExecutionGraph(currentNodes, currentEdges); if (!prompt || Object.keys(prompt).length === 0) return; setStatus({ text: 'Saving…', level: 'info' }); api.runPrompt(prompt).catch((err) => { @@ -715,25 +798,17 @@ function Flow() { setNodes((ns) => [...ns, newNode]); + // Initialize dynamic outputs for nodes that depend on the selected path/folder. + if (className === 'Folder' && widgetValues.folder) { + refreshFolderNodeOutputs(newNodeId, widgetValues.folder); + } + // For LoadFile/LoadDemo, auto-fetch channels for the default value if (className === 'LoadDemo' && widgetValues.name) { - api.getChannels(widgetValues.name).then((channels) => { - setNodes((prev) => prev.map((n) => { - if (n.id !== newNodeId) return n; - return { - ...n, - data: { - ...n.data, - definition: { - ...n.data.definition, - output: channels.map((c) => c.type), - output_name: channels.map((c) => c.name), - }, - }, - }; - })); - reactFlow.updateNodeInternals(newNodeId); - }); + refreshLoadNodeOutputs(newNodeId, widgetValues.name); + } + if (className === 'LoadFile' && widgetValues.filename) { + refreshLoadNodeOutputs(newNodeId, widgetValues.filename); } // Auto-connect if this was triggered by dropping a connection on blank space @@ -783,7 +858,7 @@ function Flow() { setContextMenu(null); scheduleAutoRun(); - }, [contextMenu, reactFlow, setNodes, setEdges]); + }, [contextMenu, reactFlow, refreshFolderNodeOutputs, refreshLoadNodeOutputs, setNodes, setEdges]); // scheduleAutoRun is stable (no deps) // ── Toolbar actions ───────────────────────────────────────────────── @@ -791,7 +866,7 @@ function Flow() { // Read current state via functional ref to avoid stale closure const currentNodes = reactFlow.getNodes(); const currentEdges = reactFlow.getEdges(); - const prompt = serializeGraph(currentNodes, currentEdges); + const prompt = serializeExecutionGraph(currentNodes, currentEdges); if (!prompt || Object.keys(prompt).length === 0) { setStatus({ text: 'Graph is empty — add some nodes first.', level: 'error' }); @@ -809,28 +884,15 @@ function Flow() { autoRunRef.current = () => { const currentNodes = reactFlow.getNodes(); const currentEdges = reactFlow.getEdges(); + const runnableNodes = getAutoRunnableNodes(currentNodes, currentEdges); // Don't run if any non-manual node has unconnected required data inputs // or any FILE_PICKER widget is empty - for (const node of currentNodes) { - const def = node.data?.definition; - if (!def || def.manual_trigger) continue; // skip manual-trigger nodes - const required = def.input.required || {}; - for (const [name, spec] of Object.entries(required)) { - const [type] = Array.isArray(spec) ? spec : [spec]; - if (type === 'FILE_PICKER') { - if (!node.data.widgetValues?.[name]) return; // no file selected, skip - continue; - } - if (!DATA_TYPES.has(type)) continue; - const hasEdge = currentEdges.some( - (e) => e.target === node.id && getInputName(e.targetHandle) === name - ); - if (!hasEdge) return; // incomplete graph, skip auto-run - } + for (const node of runnableNodes) { + if (hasBlockingAutoRunInput(node, currentEdges)) return; } - const prompt = serializeGraph(currentNodes, currentEdges, { excludeManualTrigger: true }); + const prompt = serializeExecutionGraph(currentNodes, currentEdges, { excludeManualTrigger: true }); if (!prompt || Object.keys(prompt).length === 0) return; setStatus({ text: 'Running…', level: 'info' }); api.runPrompt(prompt).catch((err) => { @@ -855,7 +917,57 @@ function Flow() { setNodes(hydrated.nodes); setEdges(hydrated.edges); nextIdRef.current = hydrated.nextNodeId; - }, [setNodes, setEdges]); + setTimeout(() => { + hydrated.nodes.forEach((node) => { + if (node.data.className === 'Folder' && node.data.widgetValues?.folder) { + refreshFolderNodeOutputs(node.id, node.data.widgetValues.folder); + } + }); + hydrated.nodes.forEach((node) => { + if (node.data.className === 'LoadFile' || node.data.className === 'LoadDemo') { + refreshLoadNodeOutputs(node.id); + } + }); + }, 0); + }, [refreshFolderNodeOutputs, refreshLoadNodeOutputs, setNodes, setEdges]); + + const loadDefaultWorkflow = useCallback(async () => { + if (defaultWorkflowLoadAttemptedRef.current) return; + defaultWorkflowLoadAttemptedRef.current = true; + + const graphHasContent = () => { + const currentNodes = reactFlow.getNodes(); + const currentEdges = reactFlow.getEdges(); + return currentNodes.length > 0 || currentEdges.length > 0; + }; + + if (graphHasContent()) return; + + try { + const loaded = await loadDefaultWorkflowAsset(); + if (!loaded || graphHasContent()) return; + + applyWorkflowData(loaded.workflow); + setStatus({ text: `Loaded default workflow from ${loaded.source}.`, level: 'info' }); + requestAnimationFrame(() => { + requestAnimationFrame(() => scheduleAutoRun()); + }); + } catch (err) { + setStatus({ text: 'Default workflow failed to load: ' + err.message, level: 'error' }); + } + }, [applyWorkflowData, reactFlow, scheduleAutoRun]); + + // ── Load node definitions ─────────────────────────────────────────── + + useEffect(() => { + api.getNodes().then((defs) => { + nodeDefsRef.current = defs; + setStatus({ text: `Loaded ${Object.keys(defs).length} nodes.`, level: 'info' }); + loadDefaultWorkflow(); + }).catch((err) => { + setStatus({ text: 'Failed to load nodes: ' + err.message, level: 'error' }); + }); + }, [loadDefaultWorkflow]); const getWorkflowBlob = useCallback(async () => { const viewportEl = document.querySelector('.react-flow__viewport'); @@ -1112,7 +1224,7 @@ function Flow() { nodes={nodes} edges={edges} onNodesChange={onNodesChange} - onEdgesChange={onEdgesChange} + onEdgesChange={handleEdgesChange} onConnect={onConnect} onConnectEnd={onConnectEnd} isValidConnection={isValidConnection} @@ -1150,10 +1262,11 @@ function Flow() { {/* File browser modal */} - {fileBrowserCb && ( + {fileBrowserState && ( { fileBrowserCb(path); setFileBrowserCb(null); }} - onClose={() => setFileBrowserCb(null)} + selectionMode={fileBrowserState.selectionMode} + onSelect={(path) => { fileBrowserState.callback(path); setFileBrowserState(null); }} + onClose={() => setFileBrowserState(null)} /> )} diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 1752dc0..5c3336b 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -6,14 +6,15 @@ const SurfaceView = lazy(() => import('./SurfaceView')); const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay')); const CropBoxOverlay = lazy(() => import('./CropBoxOverlay')); const MaskPaintOverlay = lazy(() => import('./MaskPaintOverlay')); +const MarkupOverlay = lazy(() => import('./MarkupOverlay')); // ── Constants ───────────────────────────────────────────────────────── const DATA_TYPES = new Set([ 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE', - 'COORD', 'STATS_SOURCE', 'VALUE_SOURCE', + 'COORD', 'STATS_SOURCE', 'VALUE_SOURCE', 'COLORMAP', 'SAVE_LAYER', 'FONT', 'FILE_PATH', 'DIRECTORY', ]); -const SOCKET_WIDGET_TYPES = new Set(['FLOAT']); +const SOCKET_WIDGET_TYPES = new Set(['FLOAT', 'INT']); const TYPE_COLORS = { DATA_FIELD: '#3a7abf', @@ -24,8 +25,14 @@ const TYPE_COLORS = { ANY_TABLE: '#67e8f9', COORD: '#e91e63', FLOAT: '#7dd3fc', + INT: '#38bdf8', STATS_SOURCE:'#c084fc', VALUE_SOURCE:'#60a5fa', + COLORMAP: '#f472b6', + SAVE_LAYER: '#22c55e', + FONT: '#fb7185', + FILE_PATH: '#f59e0b', + DIRECTORY: '#f97316', }; const CAT_COLORS = { @@ -128,6 +135,21 @@ function DraggableNumber({ value, step, min, max, precision, onChange }) { } }, [display]); + const onWheel = useCallback((e) => { + if (editing) return; + e.preventDefault(); + + const baseStep = Number(step) || 1; + const multiplier = e.shiftKey ? 10 : 1; + const delta = (e.deltaY < 0 ? 1 : -1) * baseStep * multiplier; + const startVal = Number(value); + const raw = (Number.isFinite(startVal) ? startVal : 0) + delta; + const rounded = precision != null + ? parseFloat(raw.toFixed(precision)) + : Math.round(raw); + onChange(clamp(rounded)); + }, [editing, step, value, precision, onChange, clamp]); + const commitEdit = useCallback(() => { setEditing(false); const parsed = parseFloat(editText); @@ -155,6 +177,7 @@ function DraggableNumber({ value, step, min, max, precision, onChange }) { onPointerDown={onPointerDown} onPointerMove={onPointerMove} onPointerUp={onPointerUp} + onWheel={onWheel} > {display} @@ -179,6 +202,57 @@ function CollapsibleSection({ title, defaultOpen, children }) { ); } +function LayerGalleryPreview({ overlay }) { + const layers = Array.isArray(overlay?.layers) ? overlay.layers : []; + const [index, setIndex] = useState(0); + + useEffect(() => { + setIndex(0); + }, [overlay]); + + useEffect(() => { + if (layers.length === 0) { + setIndex(0); + return; + } + if (index >= layers.length) { + setIndex(layers.length - 1); + } + }, [index, layers.length]); + + if (layers.length === 0) return null; + + const active = layers[index] || layers[0]; + + return ( +
+
+ +
+ {active.name || `Layer ${index + 1}`} +
+ +
+
+ {index + 1} / {layers.length} +
+
+ {active.name +
+
+ ); +} + function getTableColumns(rows) { const columns = []; for (const row of rows) { @@ -352,6 +426,28 @@ function getSourceNodeForInput(store, nodeId, inputName) { return store.nodeLookup?.get(edge.source) || store.nodes?.find((n) => n.id === edge.source) || null; } +function getConnectedOutputInfo(store, nodeId, inputName) { + const targetHandle = `input::${inputName}::`; + const edge = store.edges?.find((e) => e.target === nodeId && e.targetHandle?.startsWith(targetHandle)); + if (!edge?.sourceHandle) return null; + const sourceNode = store.nodeLookup?.get(edge.source) || store.nodes?.find((n) => n.id === edge.source) || null; + const slot = Number.parseInt(edge.sourceHandle.split('::')[1], 10); + if (!sourceNode || !Number.isInteger(slot)) return null; + return { + path: sourceNode.data?.definition?.output_paths?.[slot] || null, + name: sourceNode.data?.definition?.output_name?.[slot] || null, + }; +} + +function getBasename(value) { + if (typeof value !== 'string') return ''; + const trimmed = value.trim(); + if (!trimmed) return ''; + const normalized = trimmed.replace(/\\/g, '/').replace(/\/+$/, ''); + const parts = normalized.split('/'); + return parts[parts.length - 1] || ''; +} + function getWidgetSourceInputName(opts) { return opts?.source_type_input || opts?.choices_from_table_input @@ -368,6 +464,197 @@ function widgetVisibleForSourceType(widget, sourceType) { return allowed.includes(sourceType); } +function widgetVisibleForWidgetValues(widget, widgetValues) { + const rules = widget?.opts?.show_when_widget_value; + if (!rules || typeof rules !== 'object') return true; + + for (const [widgetName, allowedValues] of Object.entries(rules)) { + const allowed = Array.isArray(allowedValues) ? allowedValues.map(String) : []; + if (allowed.length === 0) continue; + if (!allowed.includes(String(widgetValues?.[widgetName] ?? ''))) { + return false; + } + } + + return true; +} + +function widgetHiddenByConnectedInput(widget, connectedInputs) { + const raw = widget?.opts?.hide_when_input_connected; + if (!raw || !connectedInputs) return false; + const inputs = Array.isArray(raw) ? raw : [raw]; + return inputs.some((inputName) => connectedInputs.has(String(inputName))); +} + +function widgetVisibleForInputVisibility(widget, visibleInputs) { + const raw = widget?.opts?.show_when_input_visible; + if (!raw) return true; + const inputs = Array.isArray(raw) ? raw : [raw]; + return inputs.some((inputName) => visibleInputs?.has(String(inputName))); +} + +function getWidgetInlineInputName(widget) { + const raw = widget?.opts?.inline_with_input; + if (!raw) return null; + return String(Array.isArray(raw) ? raw[0] : raw); +} + +const DEFAULT_COLORMAP_STOPS = [ + { position: 0, color: '#440154' }, + { position: 1, color: '#fde725' }, +]; + +function normalizeHexColor(color, fallback = '#000000') { + if (typeof color !== 'string') return fallback; + let text = color.trim(); + if (text.startsWith('#') && text.length === 4) { + text = `#${text.slice(1).split('').map((ch) => `${ch}${ch}`).join('')}`; + } + if (/^#[0-9a-fA-F]{6}$/.test(text)) { + return text.toLowerCase(); + } + return fallback; +} + +function parseColorMapStops(raw) { + let parsed = raw; + if (typeof raw === 'string') { + try { + parsed = JSON.parse(raw); + } catch { + parsed = DEFAULT_COLORMAP_STOPS; + } + } + + if (!Array.isArray(parsed)) { + parsed = DEFAULT_COLORMAP_STOPS; + } + + const stops = parsed + .map((stop) => { + const position = Number(stop?.position); + return { + position: Number.isFinite(position) ? Math.max(0, Math.min(1, position)) : 0, + color: normalizeHexColor(stop?.color, '#000000'), + }; + }) + .sort((a, b) => a.position - b.position); + + if (stops.length < 2) { + return DEFAULT_COLORMAP_STOPS.map((stop) => ({ ...stop })); + } + + stops[0].position = 0; + stops[stops.length - 1].position = 1; + return stops; +} + +function serializeColorMapStops(stops) { + return JSON.stringify(stops.map((stop, index) => ({ + position: index === 0 ? 0 : index === stops.length - 1 ? 1 : Number(stop.position.toFixed(4)), + color: normalizeHexColor(stop.color, '#000000'), + }))); +} + +function colorMapGradient(stops) { + return `linear-gradient(90deg, ${stops.map((stop) => `${stop.color} ${Math.round(stop.position * 1000) / 10}%`).join(', ')})`; +} + +function ColorMapStopsEditor({ nodeId, name, value, onChange }) { + const stops = parseColorMapStops(value); + + const commitStops = useCallback((nextStops) => { + const ordered = [...nextStops].sort((a, b) => a.position - b.position); + if (ordered.length < 2) return; + ordered[0] = { ...ordered[0], position: 0 }; + ordered[ordered.length - 1] = { ...ordered[ordered.length - 1], position: 1 }; + onChange(nodeId, name, serializeColorMapStops(ordered)); + }, [name, nodeId, onChange]); + + const updateStop = useCallback((index, patch) => { + const next = stops.map((stop, stopIndex) => (stopIndex === index ? { ...stop, ...patch } : { ...stop })); + if (index > 0 && index < next.length - 1) { + const prev = next[index - 1].position + 0.001; + const after = next[index + 1].position - 0.001; + next[index].position = Math.max(prev, Math.min(after, next[index].position)); + } + commitStops(next); + }, [commitStops, stops]); + + const removeStop = useCallback((index) => { + if (stops.length <= 2) return; + commitStops(stops.filter((_, stopIndex) => stopIndex !== index)); + }, [commitStops, stops]); + + const addStop = useCallback(() => { + let gapIndex = 0; + let gapSize = -1; + for (let i = 0; i < stops.length - 1; i += 1) { + const gap = stops[i + 1].position - stops[i].position; + if (gap > gapSize) { + gapIndex = i; + gapSize = gap; + } + } + + const left = stops[gapIndex]; + const right = stops[gapIndex + 1]; + const newStop = { + position: Number((((left.position + right.position) / 2)).toFixed(4)), + color: left.color, + }; + const next = [...stops]; + next.splice(gapIndex + 1, 0, newStop); + commitStops(next); + }, [commitStops, stops]); + + return ( +
+
+
+ {stops.map((stop, index) => { + const isEndpoint = index === 0 || index === stops.length - 1; + return ( +
+ {isEndpoint ? (index === 0 ? 'min' : 'max') : `stop ${index}`} + updateStop(index, { color: e.target.value })} + /> + {isEndpoint ? ( + {index === 0 ? '0%' : '100%'} + ) : ( + updateStop(index, { position: Number(e.target.value) })} + /> + )} + +
+ ); + })} +
+ +
+ ); +} + function NodeTable({ rows }) { const columns = getTableColumns(rows); if (columns.length === 0) return null; @@ -440,6 +727,9 @@ function CustomNode({ id, data }) { const def = data.definition; const scalarDisplay = formatScalarDisplay(data.scalarValue); const processingTimeText = formatProcessingTime(data.processingTimeMs); + const connectedPathInfo = useStore( + useCallback((s) => getConnectedOutputInfo(s, id, 'path'), [id]), + ); // Parse inputs into data handles and widgets const required = def.input.required || {}; @@ -447,13 +737,15 @@ function CustomNode({ id, data }) { const dataInputs = []; const widgets = []; + const visibleInputNames = new Set(); const hiddenWidgets = new Set(); for (const [name, spec] of Object.entries(required)) { const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; if (DATA_TYPES.has(type)) { - dataInputs.push({ name, type }); + dataInputs.push({ name, type, label: opts?.label || name }); + visibleInputNames.add(name); } else if (opts?.hidden) { hiddenWidgets.add(name); } else { @@ -467,7 +759,6 @@ function CustomNode({ id, data }) { const connectedInputs = useStore( useCallback( (s) => { - if (!isProgressive) return null; const set = new Set(); for (const e of s.edges) { if (e.target === id) { @@ -477,7 +768,7 @@ function CustomNode({ id, data }) { } return set; }, - [id, isProgressive], + [id], ), ); @@ -503,7 +794,8 @@ function CustomNode({ id, data }) { if (match) { const idx = parseInt(match[1], 10); if (idx === 0 || (connectedInputs && connectedInputs.has(`field_${idx - 1}`))) { - dataInputs.push({ name, type }); + dataInputs.push({ name, type, label: opts?.label || name }); + visibleInputNames.add(name); } continue; } @@ -511,12 +803,42 @@ function CustomNode({ id, data }) { if (opts?.hidden) { hiddenWidgets.add(name); } else if (DATA_TYPES.has(type)) { - dataInputs.push({ name, type }); + dataInputs.push({ name, type, label: opts?.label || name }); + visibleInputNames.add(name); } else { widgets.push({ name, type, opts: opts || {}, socketType: SOCKET_WIDGET_TYPES.has(type) ? type : null }); } } + const visibleWidgets = widgets.filter((w) => ( + widgetVisibleForSourceType(w, connectedSourceTypes?.[getWidgetSourceInputName(w.opts)]) + && widgetVisibleForWidgetValues(w, data.widgetValues) + && widgetVisibleForInputVisibility(w, visibleInputNames) + && !widgetHiddenByConnectedInput(w, connectedInputs) + )); + + const combinedTopInputNames = new Set( + visibleWidgets + .map((widget) => widget?.opts?.top_socket_input) + .filter((name) => typeof name === 'string' && name.length > 0), + ); + const renderedDataInputs = dataInputs.filter((input) => !combinedTopInputNames.has(input.name)); + const dataInputByName = new Map(dataInputs.map((input) => [input.name, input])); + + const inlineWidgetsByInput = new Map(); + const topWidgets = []; + const standaloneWidgets = []; + for (const widget of visibleWidgets) { + const inlineInputName = getWidgetInlineInputName(widget); + if (inlineInputName) { + inlineWidgetsByInput.set(inlineInputName, widget); + } else if (widget.opts?.placement === 'top') { + topWidgets.push(widget); + } else { + standaloneWidgets.push(widget); + } + } + const outputs = def.output.map((type, i) => ({ name: def.output_name[i] || type, type, @@ -524,30 +846,85 @@ function CustomNode({ id, data }) { })); const catColor = CAT_COLORS[def.category] || '#333'; - const maxIORows = Math.max(dataInputs.length, outputs.length); + const maxIORows = Math.max(renderedDataInputs.length, outputs.length); const hasInteractiveLineOverlay = data.overlay?.kind === 'line_plot' && hiddenWidgets.has('x1'); - const hasInteractiveOverlay = !!data.overlay && (hiddenWidgets.has('x1') || data.overlay.kind === 'mask_paint'); - const hidePreviewForInteractiveMask = data.overlay?.kind === 'mask_paint'; + const hasInteractiveOverlay = !!data.overlay && ( + hiddenWidgets.has('x1') + || data.overlay.kind === 'mask_paint' + || data.overlay.kind === 'markup' + ); + const hidePreviewForInteractiveMask = data.overlay?.kind === 'mask_paint' || data.overlay?.kind === 'markup'; const overlayTitle = data.overlay?.section_title || (data.overlay?.kind === 'mask_paint' ? 'Mask' + : data.overlay?.kind === 'markup' + ? 'Markup' : data.overlay?.kind === 'crop_box' ? 'Crop' : data.overlay?.kind === 'line_plot' ? 'Line Plot' - : 'Cross Section'); + : 'Cross Section'); + const headerMeta = (() => { + if (data.className === 'Folder') { + return getBasename(data.widgetValues?.folder); + } + if (data.className === 'LoadFile') { + return getBasename(connectedPathInfo?.path || data.widgetValues?.filename); + } + if (data.className === 'LoadDemo') { + return getBasename(data.widgetValues?.name); + } + return ''; + })(); return (
{/* Title */}
- {data.label} + {data.label} + {headerMeta && {headerMeta}}
+ {topWidgets.length > 0 && ( +
+ {topWidgets.map((w) => ( +
+ {(w.socketType || w.opts?.top_socket_input) && (() => { + const socketInput = w.opts?.top_socket_input ? dataInputByName.get(w.opts.top_socket_input) : null; + const socketType = w.socketType || socketInput?.type; + const socketName = w.socketType ? w.name : socketInput?.name; + if (!socketType || !socketName) return null; + return ( + + ); + })()} + +
+ ))} +
+ )} + {/* I/O rows — pair inputs[i] with outputs[i] */} {Array.from({ length: maxIORows }, (_, i) => { - const inp = dataInputs[i]; + const inp = renderedDataInputs[i]; const out = outputs[i]; return (
@@ -561,7 +938,20 @@ function CustomNode({ id, data }) { className="typed-handle" style={{ background: TYPE_COLORS[inp.type] || '#999' }} /> - {inp.name} + {inp.label || inp.name} + {inlineWidgetsByInput.has(inp.name) && ( +
+ +
+ )} )}
@@ -601,7 +991,7 @@ function CustomNode({ id, data }) { )} {/* Widget rows */} - {widgets.filter((w) => widgetVisibleForSourceType(w, connectedSourceTypes?.[getWidgetSourceInputName(w.opts)])).map((w) => ( + {standaloneWidgets.map((w) => (
{w.socketType && (
))} @@ -654,6 +1045,7 @@ function CustomNode({ id, data }) { resetKey={typeof data.previewImage === 'string' ? data.previewImage : JSON.stringify({ kind: data.previewImage.kind, len: data.previewImage.line?.length, + layers: data.previewImage.layers?.length, })} fallbackImage={typeof data.previewImage === 'object' ? data.previewImage.fallback_image : null} > @@ -661,6 +1053,8 @@ function CustomNode({ id, data }) {
preview
+ ) : data.previewImage.kind === 'layer_gallery' ? ( + ) : data.previewImage.kind === 'line_plot' ? ( ) : null} @@ -704,6 +1098,16 @@ function CustomNode({ id, data }) { nodeId={id} onWidgetChange={ctx.onWidgetChange} /> + ) : data.overlay.kind === 'markup' ? ( + ) : ( { @@ -818,11 +1224,34 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile onChange(nodeId, name, dynamicTypeChoices[0]); }, [dynamicTypeChoices, name, nodeId, onChange, val]); + if (connected) { + return ( + <> + {!hideLabel && } +
Connected
+ + ); + } + + if (opts?.colormap_stops) { + return ( + <> + {!hideLabel && } + + + ); + } + // Combo / enum — type itself is the array of options if (Array.isArray(type)) { return ( <> - + {!hideLabel && } - + {!hideLabel && } - + {!hideLabel && }
onChange(nodeId, name, e.target.value)} - placeholder="Select file…" + placeholder={placeholder || (isFolderPicker ? 'Select folder…' : 'Select file…')} /> @@ -913,6 +1346,23 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile ); } + if (type === 'STRING' && opts?.color_picker) { + const normalized = typeof val === 'string' && /^#[0-9a-fA-F]{6}$/.test(val) + ? val + : '#ffd54f'; + return ( + <> + {!hideLabel && } + onChange(nodeId, name, e.target.value)} + /> + + ); + } + if (type === 'BUTTON') { const updates = opts?.set_widgets && typeof opts.set_widgets === 'object' ? Object.entries(opts.set_widgets) @@ -950,7 +1400,7 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile return ( <> - + {!hideLabel && }
- + {!hideLabel && } - + {!hideLabel && } - + {!hideLabel && } - + {!hideLabel && } onChange(nodeId, name, e.target.value)} /> diff --git a/frontend/src/FileBrowser.jsx b/frontend/src/FileBrowser.jsx index 31477b8..4220862 100644 --- a/frontend/src/FileBrowser.jsx +++ b/frontend/src/FileBrowser.jsx @@ -5,10 +5,10 @@ import * as api from './api'; * Server-side file browser modal. * * Props: - * onSelect(absolutePath) — called when user picks a file + * onSelect(absolutePath) — called when user picks a file or folder * onClose() — called when user dismisses the dialog */ -export default function FileBrowser({ onSelect, onClose }) { +export default function FileBrowser({ onSelect, onClose, selectionMode = 'file' }) { const [path, setPath] = useState(''); const [parent, setParent] = useState(null); const [dirs, setDirs] = useState([]); @@ -43,6 +43,11 @@ export default function FileBrowser({ onSelect, onClose }) { {/* Header */}
{path} + {selectionMode === 'folder' && ( + + )}
@@ -75,8 +80,12 @@ export default function FileBrowser({ onSelect, onClose }) { {files.map((f) => (
{ onSelect(path + '/' + f); onClose(); }} + className={`fb-entry fb-file${selectionMode === 'folder' ? ' fb-file-disabled' : ''}`} + onClick={() => { + if (selectionMode === 'folder') return; + onSelect(path + '/' + f); + onClose(); + }} > {f}
diff --git a/frontend/src/MarkupOverlay.jsx b/frontend/src/MarkupOverlay.jsx new file mode 100644 index 0000000..c2c7fec --- /dev/null +++ b/frontend/src/MarkupOverlay.jsx @@ -0,0 +1,285 @@ +import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; + +function clampFraction(value) { + const numeric = Number(value); + if (!Number.isFinite(numeric)) return 0; + return Math.max(0, Math.min(1, numeric)); +} + +function sanitizeColor(color, fallback = '#ffd54f') { + if (typeof color !== 'string') return fallback; + const value = color.trim(); + return /^#[0-9a-fA-F]{6}$/.test(value) ? value.toLowerCase() : fallback; +} + +function sanitizeShape(shape, fallbackShape, fallbackColor, fallbackWidth) { + if (!shape || typeof shape !== 'object') return null; + const kind = ['line', 'rectangle', 'circle', 'arrow'].includes(shape.kind) ? shape.kind : fallbackShape; + const x1 = clampFraction(shape.x1); + const y1 = clampFraction(shape.y1); + const x2 = clampFraction(shape.x2); + const y2 = clampFraction(shape.y2); + const width = Math.max(1, Math.min(64, Math.round(Number(shape.width) || fallbackWidth || 1))); + return { + kind, + x1: Number(x1.toFixed(4)), + y1: Number(y1.toFixed(4)), + x2: Number(x2.toFixed(4)), + y2: Number(y2.toFixed(4)), + width, + color: sanitizeColor(shape.color, fallbackColor), + }; +} + +function parseMarkupShapes(markupShapes, fallbackShape, fallbackColor, fallbackWidth) { + if (Array.isArray(markupShapes)) { + return markupShapes + .map((shape) => sanitizeShape(shape, fallbackShape, fallbackColor, fallbackWidth)) + .filter(Boolean); + } + + if (typeof markupShapes !== 'string' || !markupShapes.trim()) return []; + + try { + const parsed = JSON.parse(markupShapes); + if (!Array.isArray(parsed)) return []; + return parsed + .map((shape) => sanitizeShape(shape, fallbackShape, fallbackColor, fallbackWidth)) + .filter(Boolean); + } catch { + return []; + } +} + +function arrowPoints(shape, imageWidth, imageHeight) { + const x1 = shape.x1 * imageWidth; + const y1 = shape.y1 * imageHeight; + const x2 = shape.x2 * imageWidth; + const y2 = shape.y2 * imageHeight; + const dx = x2 - x1; + const dy = y2 - y1; + const length = Math.hypot(dx, dy) || 1; + const ux = dx / length; + const uy = dy / length; + const strokeWidth = Math.max(1, shape.width); + const headLength = Math.max(10, strokeWidth * 4); + const headWidth = Math.max(8, strokeWidth * 3); + const overlap = Math.max(1, strokeWidth * 0.75); + const shaftX = x2 - ux * Math.max(0, headLength - overlap); + const shaftY = y2 - uy * Math.max(0, headLength - overlap); + const headBaseX = x2 - ux * headLength; + const headBaseY = y2 - uy * headLength; + const px = -uy; + const py = ux; + const leftX = headBaseX + px * headWidth * 0.5; + const leftY = headBaseY + py * headWidth * 0.5; + const rightX = headBaseX - px * headWidth * 0.5; + const rightY = headBaseY - py * headWidth * 0.5; + return { + line: `${x1},${y1} ${shaftX},${shaftY}`, + head: `${x2},${y2} ${leftX},${leftY} ${rightX},${rightY}`, + }; +} + +function ShapeElement({ shape, imageWidth, imageHeight }) { + const x1 = shape.x1 * imageWidth; + const y1 = shape.y1 * imageHeight; + const x2 = shape.x2 * imageWidth; + const y2 = shape.y2 * imageHeight; + const left = Math.min(x1, x2); + const top = Math.min(y1, y2); + const width = Math.abs(x2 - x1); + const height = Math.abs(y2 - y1); + const strokeWidth = Math.max(1, shape.width); + const common = { + fill: 'none', + stroke: shape.color, + strokeWidth, + strokeLinecap: 'round', + strokeLinejoin: 'round', + vectorEffect: 'non-scaling-stroke', + }; + + if (shape.kind === 'line') { + return ; + } + + if (shape.kind === 'rectangle') { + return ; + } + + if (shape.kind === 'circle') { + return ( + + ); + } + + const arrow = arrowPoints(shape, imageWidth, imageHeight); + return ( + <> + + + + ); +} + +export default function MarkupOverlay({ + image, + shape, + strokeColor, + strokeWidth, + markupShapes, + nodeId, + onWidgetChange, +}) { + const containerRef = useRef(null); + const imageRef = useRef(null); + const shapesRef = useRef([]); + const [draftShape, setDraftShape] = useState(null); + const [drawing, setDrawing] = useState(false); + const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); + + const normalizedShape = useMemo( + () => (['line', 'rectangle', 'circle', 'arrow'].includes(shape) ? shape : 'line'), + [shape], + ); + const normalizedColor = useMemo(() => sanitizeColor(strokeColor, '#ffd54f'), [strokeColor]); + const normalizedWidth = useMemo( + () => Math.max(1, Math.min(64, Math.round(Number(strokeWidth) || 3))), + [strokeWidth], + ); + + const committedShapes = useMemo( + () => parseMarkupShapes(markupShapes, normalizedShape, normalizedColor, normalizedWidth), + [markupShapes, normalizedShape, normalizedColor, normalizedWidth], + ); + + useEffect(() => { + shapesRef.current = committedShapes; + }, [committedShapes]); + + useEffect(() => { + const img = imageRef.current; + if (!img) return; + const updateImageSize = () => { + const width = Math.max(1, img.naturalWidth || img.width || 1); + const height = Math.max(1, img.naturalHeight || img.height || 1); + setImageSize({ width, height }); + }; + + updateImageSize(); + if (!img.complete) { + img.addEventListener('load', updateImageSize); + return () => img.removeEventListener('load', updateImageSize); + } + return undefined; + }, [image]); + + const getPoint = useCallback((event) => { + const rect = containerRef.current?.getBoundingClientRect(); + if (!rect) return null; + return { + x: Number(clampFraction((event.clientX - rect.left) / rect.width).toFixed(4)), + y: Number(clampFraction((event.clientY - rect.top) / rect.height).toFixed(4)), + }; + }, []); + + const commitShapes = useCallback((nextShapes) => { + if (!nodeId || !onWidgetChange) return; + onWidgetChange(nodeId, 'markup_shapes', JSON.stringify(nextShapes)); + }, [nodeId, onWidgetChange]); + + const handlePointerDown = useCallback((event) => { + if (!onWidgetChange || event.target.closest('button')) return; + const point = getPoint(event); + if (!point) return; + event.preventDefault(); + event.stopPropagation(); + event.currentTarget.setPointerCapture(event.pointerId); + setDrawing(true); + setDraftShape({ + kind: normalizedShape, + color: normalizedColor, + width: normalizedWidth, + x1: point.x, + y1: point.y, + x2: point.x, + y2: point.y, + }); + }, [getPoint, normalizedColor, normalizedShape, normalizedWidth, onWidgetChange]); + + const handlePointerMove = useCallback((event) => { + if (!drawing) return; + const point = getPoint(event); + if (!point) return; + setDraftShape((current) => (current ? { ...current, x2: point.x, y2: point.y } : current)); + }, [drawing, getPoint]); + + const finishDrawing = useCallback(() => { + if (!draftShape) { + setDrawing(false); + return; + } + const nextShape = sanitizeShape(draftShape, normalizedShape, normalizedColor, normalizedWidth); + setDraftShape(null); + setDrawing(false); + if (!nextShape) return; + commitShapes([...shapesRef.current, nextShape]); + }, [commitShapes, draftShape, normalizedColor, normalizedShape, normalizedWidth]); + + const undoLast = useCallback(() => { + if (shapesRef.current.length === 0) return; + commitShapes(shapesRef.current.slice(0, -1)); + }, [commitShapes]); + + const clearAll = useCallback(() => { + commitShapes([]); + }, [commitShapes]); + + const renderedShapes = draftShape ? [...committedShapes, draftShape] : committedShapes; + + return ( +
+ markup source + + {renderedShapes.map((item, index) => ( + + ))} + +
+ + +
+
+ ); +} diff --git a/frontend/src/api.js b/frontend/src/api.js index 39ede09..225240f 100644 --- a/frontend/src/api.js +++ b/frontend/src/api.js @@ -40,6 +40,12 @@ export async function getChannels(filepath) { return r.json(); } +export async function getFolderFiles(folderpath) { + const r = await fetch(`/folder-files?folder=${encodeURIComponent(folderpath)}`); + if (!r.ok) return []; + return r.json(); +} + export async function runPrompt(prompt) { const r = await fetch('/prompt', { method: 'POST', diff --git a/frontend/src/defaultWorkflow.js b/frontend/src/defaultWorkflow.js new file mode 100644 index 0000000..d86b940 --- /dev/null +++ b/frontend/src/defaultWorkflow.js @@ -0,0 +1,56 @@ +import { extractWorkflow } from './pngMetadata.js'; + +const DEFAULT_WORKFLOW_CANDIDATES = [ + { path: '/default-workflow.json', type: 'json' }, + { path: '/default-workflow.png', type: 'png' }, +]; + +async function loadCandidate(candidate, fetchImpl, extractWorkflowFn) { + let response; + try { + response = await fetchImpl(candidate.path, { cache: 'no-store' }); + } catch { + return null; + } + + const contentType = response.headers?.get?.('content-type') || ''; + const isHtmlFallback = typeof contentType === 'string' && contentType.toLowerCase().includes('text/html'); + + if (!response.ok) { + if (response.status === 404 || response.status === 0) return null; + throw new Error(`Failed to load ${candidate.path} (${response.status})`); + } + + if (candidate.type === 'json') { + if (isHtmlFallback) return null; + try { + return await response.json(); + } catch { + throw new Error(`${candidate.path} is not valid JSON`); + } + } + + if (isHtmlFallback) return null; + const workflow = await extractWorkflowFn(await response.blob()); + if (!workflow) { + throw new Error(`${candidate.path} does not contain embedded workflow metadata`); + } + return workflow; +} + +export async function loadDefaultWorkflowAsset({ + fetchImpl = fetch, + extractWorkflowFn = extractWorkflow, +} = {}) { + for (const candidate of DEFAULT_WORKFLOW_CANDIDATES) { + const workflow = await loadCandidate(candidate, fetchImpl, extractWorkflowFn); + if (workflow) { + return { + source: candidate.path, + format: candidate.type, + workflow, + }; + } + } + return null; +} diff --git a/frontend/src/executionGraph.js b/frontend/src/executionGraph.js new file mode 100644 index 0000000..693c917 --- /dev/null +++ b/frontend/src/executionGraph.js @@ -0,0 +1,125 @@ +const DATA_TYPES = new Set([ + 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE', + 'COORD', 'STATS_SOURCE', 'VALUE_SOURCE', 'COLORMAP', 'SAVE_LAYER', 'FONT', 'FILE_PATH', 'DIRECTORY', +]); + +function getInputName(handleId) { + return handleId.split('::')[1]; +} + +function getOutputSlot(handleId) { + return parseInt(handleId.split('::')[1], 10); +} + +export function getConnectedNodeIds(edges) { + const connectedNodeIds = new Set(); + for (const edge of edges) { + connectedNodeIds.add(edge.source); + connectedNodeIds.add(edge.target); + } + return connectedNodeIds; +} + +function isPreviewLoadNode(node) { + return ['LoadFile', 'LoadDemo'].includes(node?.data?.className); +} + +function hasPreviewLoadSelection(node) { + if (node?.data?.className === 'LoadFile') { + return !!String(node.data?.widgetValues?.filename || '').trim(); + } + if (node?.data?.className === 'LoadDemo') { + return !!String(node.data?.widgetValues?.name || '').trim(); + } + return false; +} + +function getRunnableNodeIds(nodes, edges) { + const connectedNodeIds = getConnectedNodeIds(edges); + + const runnableNodeIds = new Set(connectedNodeIds); + for (const node of nodes) { + if (connectedNodeIds.has(node.id)) continue; + if (isPreviewLoadNode(node) && hasPreviewLoadSelection(node)) { + runnableNodeIds.add(node.id); + } + } + + return runnableNodeIds; +} + +export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = false } = {}) { + const runnableNodeIds = getRunnableNodeIds(nodes, edges); + const prompt = {}; + + for (const node of nodes) { + if (!runnableNodeIds.has(node.id)) continue; + + const { className, definition, widgetValues } = node.data; + if (!definition) continue; + if (excludeManualTrigger && definition.manual_trigger) continue; + + const inputs = {}; + + const allWidgets = { + ...(definition.input.required || {}), + ...(definition.input.optional || {}), + }; + for (const [name, spec] of Object.entries(allWidgets)) { + const [type] = Array.isArray(spec) ? spec : [spec]; + if (DATA_TYPES.has(type)) continue; + if (type === 'BUTTON') continue; + if (widgetValues[name] !== undefined) { + inputs[name] = widgetValues[name]; + } + } + + const incoming = edges.filter((edge) => edge.target === node.id); + for (const edge of incoming) { + const inputName = getInputName(edge.targetHandle); + const outputSlot = getOutputSlot(edge.sourceHandle); + inputs[inputName] = [edge.source, outputSlot]; + } + + prompt[node.id] = { class_type: className, inputs }; + } + + return prompt; +} + +export function getAutoRunnableNodes(nodes, edges) { + const runnableNodeIds = getRunnableNodeIds(nodes, edges); + return nodes.filter((node) => runnableNodeIds.has(node.id)); +} + +export function hasBlockingAutoRunInput(node, edges) { + const def = node.data?.definition; + if (!def || def.manual_trigger) return false; + + const required = def.input.required || {}; + for (const [name, spec] of Object.entries(required)) { + const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; + const hiddenByConnectedInput = (() => { + const raw = opts?.hide_when_input_connected; + if (!raw) return false; + const inputs = Array.isArray(raw) ? raw : [raw]; + return inputs.some((inputName) => edges.some( + (edge) => edge.target === node.id && getInputName(edge.targetHandle) === String(inputName) + )); + })(); + + if (hiddenByConnectedInput) continue; + + if (type === 'FILE_PICKER' || type === 'FOLDER_PICKER') { + if (!node.data.widgetValues?.[name]) return true; + continue; + } + if (!DATA_TYPES.has(type)) continue; + const hasEdge = edges.some( + (edge) => edge.target === node.id && getInputName(edge.targetHandle) === name + ); + if (!hasEdge) return true; + } + + return false; +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index d68b4ae..04c1bb3 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -141,6 +141,10 @@ html, body, #root { .node-title { padding: 5px 10px; + display: flex; + align-items: center; + justify-content: space-between; + gap: 8px; font-weight: 600; font-size: 12px; color: white; @@ -148,12 +152,36 @@ html, body, #root { border-bottom: 1px solid rgba(0, 0, 0, 0.3); } +.node-title-main { + min-width: 0; +} + +.node-title-meta { + max-width: 48%; + min-width: 0; + padding: 1px 6px; + border-radius: 999px; + background: rgba(15, 23, 42, 0.28); + color: rgba(255, 255, 255, 0.88); + font-size: 10px; + font-weight: 500; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + .node-body { padding: 4px 0; display: flex; flex-direction: column; } +.top-widget-section { + padding-bottom: 2px; + border-bottom: 1px solid rgba(51, 65, 85, 0.35); + margin-bottom: 2px; +} + .node-warning { padding: 3px 10px; font-size: 10px; @@ -226,6 +254,11 @@ html, body, #root { gap: 4px; } +.io-left { + flex: 1; + min-width: 0; +} + .io-label { font-size: 10px; color: #94a3b8; @@ -280,8 +313,36 @@ html, body, #root { flex-shrink: 0; } +.io-inline-widget { + flex: 1; + min-width: 0; + margin-left: 8px; + display: flex; + align-items: center; +} + +.io-inline-widget .widget-row, +.io-inline-widget label { + display: none; +} + +.io-inline-widget input[type="text"], +.io-inline-widget input[type="number"], +.io-inline-widget input[type="color"], +.io-inline-widget select { + background: #0f172a; + color: #e0e0e0; + border: 1px solid #334155; + border-radius: 3px; + padding: 2px 5px; + font-size: 11px; + flex: 1; + min-width: 0; +} + .widget-row input[type="text"], .widget-row input[type="number"], +.widget-row input[type="color"], .widget-row select { background: #0f172a; color: #e0e0e0; @@ -293,6 +354,11 @@ html, body, #root { min-width: 0; } +.widget-row input[type="color"] { + padding: 2px; + height: 24px; +} + .widget-row input[type="checkbox"] { accent-color: #3a7abf; } @@ -314,6 +380,87 @@ html, body, #root { border-color: #3a7abf; } +.widget-linked-state { + flex: 1; + min-width: 0; + padding: 4px 8px; + border: 1px dashed rgba(244, 114, 182, 0.45); + border-radius: 4px; + background: rgba(30, 41, 59, 0.55); + color: #f9a8d4; + font-size: 10px; + text-transform: uppercase; + letter-spacing: 0.08em; + text-align: center; +} + +.colormap-editor { + flex: 1; + min-width: 0; + display: flex; + flex-direction: column; + gap: 6px; +} + +.colormap-preview { + width: 100%; + height: 18px; + border-radius: 999px; + border: 1px solid #334155; + background-color: #0f172a; +} + +.colormap-stop-list { + display: flex; + flex-direction: column; + gap: 4px; +} + +.colormap-stop-row { + display: grid; + grid-template-columns: 34px 34px minmax(0, 1fr) auto; + gap: 6px; + align-items: center; +} + +.colormap-stop-label, +.colormap-stop-boundary { + font-size: 10px; + color: #94a3b8; +} + +.colormap-stop-color { + width: 34px; + height: 24px; + padding: 0; + border: 1px solid #334155; + border-radius: 4px; + background: #0f172a; +} + +.colormap-stop-position { + width: 100%; +} + +.colormap-stop-action { + background: #172554; + color: #dbeafe; + border: 1px solid #334155; + border-radius: 4px; + padding: 4px 8px; + font-size: 10px; + cursor: pointer; +} + +.colormap-stop-action:disabled { + opacity: 0.45; + cursor: default; +} + +.colormap-add-stop { + margin-top: 2px; +} + .slider-control { display: flex; align-items: center; @@ -438,6 +585,54 @@ html, body, #root { display: block; } +.layer-gallery { + display: flex; + flex-direction: column; + gap: 6px; +} + +.layer-gallery-toolbar { + display: grid; + grid-template-columns: 28px minmax(0, 1fr) 28px; + gap: 6px; + align-items: center; +} + +.layer-gallery-btn { + height: 26px; + border: 1px solid #334155; + border-radius: 6px; + background: #0f172a; + color: #e2e8f0; + font-size: 14px; + cursor: pointer; +} + +.layer-gallery-btn:disabled { + opacity: 0.4; + cursor: default; +} + +.layer-gallery-name { + min-width: 0; + padding: 4px 8px; + border: 1px solid rgba(51, 65, 85, 0.9); + border-radius: 6px; + background: rgba(15, 23, 42, 0.8); + color: #cbd5e1; + font-size: 10px; + text-align: center; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.layer-gallery-count { + font-size: 10px; + color: #64748b; + text-align: center; +} + /* ── Cross-section overlay ────────────────────────────────────────── */ .cs-overlay { position: relative; @@ -609,6 +804,60 @@ html, body, #root { z-index: 2; } +.markup-overlay { + position: relative; + overflow: hidden; + user-select: none; + touch-action: none; + background: #0f172a; + border: 1px solid #334155; + border-radius: 6px; + cursor: crosshair; +} + +.markup-overlay-drawing { + cursor: crosshair; +} + +.markup-image { + width: 100%; + display: block; +} + +.markup-svg { + position: absolute; + inset: 0; + width: 100%; + height: 100%; + pointer-events: none; + overflow: visible; +} + +.markup-toolbar { + position: absolute; + top: 8px; + right: 8px; + display: flex; + gap: 6px; + z-index: 2; +} + +.markup-tool-btn { + border: 1px solid rgba(148, 163, 184, 0.35); + background: rgba(15, 23, 42, 0.88); + color: #e2e8f0; + border-radius: 999px; + padding: 4px 9px; + font-size: 10px; + line-height: 1; + cursor: pointer; +} + +.markup-tool-btn:disabled { + opacity: 0.45; + cursor: default; +} + /* ── 3D surface view ──────────────────────────────────────────────── */ .surface-view-container { width: 100%; @@ -830,7 +1079,7 @@ html, body, #root { .fb-header { display: flex; align-items: center; - justify-content: space-between; + gap: 8px; padding: 10px 14px; border-bottom: 1px solid #0f3460; } @@ -852,6 +1101,17 @@ html, body, #root { padding: 2px 6px; } .fb-close:hover { color: #e94560; } +.fb-select-btn { + background: #0f3460; + color: #e0e0e0; + border: 1px solid #334155; + border-radius: 4px; + padding: 4px 8px; + font-size: 11px; + cursor: pointer; + white-space: nowrap; +} +.fb-select-btn:hover { background: #1a4a8a; } .fb-list { overflow-y: auto; padding: 6px 0; @@ -868,6 +1128,13 @@ html, body, #root { .fb-entry:hover { background: #0f3460; } .fb-dir { color: #90caf9; } .fb-file { color: #e0e0e0; } +.fb-file-disabled { + cursor: default; + opacity: 0.5; +} +.fb-file-disabled:hover { + background: transparent; +} .fb-loading { padding: 16px; text-align: center; diff --git a/frontend/src/workflowHydration.js b/frontend/src/workflowHydration.js index 46039de..39691ee 100644 --- a/frontend/src/workflowHydration.js +++ b/frontend/src/workflowHydration.js @@ -34,6 +34,26 @@ function getInputType(definition, inputName) { return getSocketType(required[inputName] ?? optional[inputName]); } +function getInputEntries(definition) { + return [ + ...Object.entries(definition?.input?.required || {}), + ...Object.entries(definition?.input?.optional || {}), + ]; +} + +function sanitizeWidgetValues(widgetValues, definition) { + const nextValues = { ...(widgetValues || {}) }; + + getInputEntries(definition).forEach(([inputName, inputDef]) => { + const type = getSocketType(inputDef); + if (type === 'FILE_PICKER' || type === 'FOLDER_PICKER') { + nextValues[inputName] = ''; + } + }); + + return nextValues; +} + function remapLegacyHandle(handleId, kind, nodeData) { if (typeof handleId !== 'string') return handleId; @@ -63,22 +83,26 @@ export function hydrateWorkflowState(data, defs = {}) { const loadedNodes = Array.isArray(data?.nodes) ? data.nodes : []; const loadedEdges = Array.isArray(data?.edges) ? data.edges : []; - const nodes = loadedNodes.map((node) => ({ - ...node, - type: node.type || 'custom', - dragHandle: node.dragHandle || '.drag-handle', - data: { - ...node.data, - label: node.data?.label || node.data?.className || 'Node', - widgetValues: node.data?.widgetValues || {}, - definition: mergeDefinition(node.data, defs), - previewImage: null, - tableRows: null, - meshData: null, - overlay: null, - scalarValue: null, - }, - })); + const nodes = loadedNodes.map((node) => { + const definition = mergeDefinition(node.data, defs); + + return { + ...node, + type: node.type || 'custom', + dragHandle: node.dragHandle || '.drag-handle', + data: { + ...node.data, + label: node.data?.label || node.data?.className || 'Node', + widgetValues: sanitizeWidgetValues(node.data?.widgetValues, definition), + definition, + previewImage: null, + tableRows: null, + meshData: null, + overlay: null, + scalarValue: null, + }, + }; + }); const nodeById = new Map(nodes.map((node) => [String(node.id), node.data])); diff --git a/frontend/tests/defaultWorkflow.test.mjs b/frontend/tests/defaultWorkflow.test.mjs new file mode 100644 index 0000000..de6dffc --- /dev/null +++ b/frontend/tests/defaultWorkflow.test.mjs @@ -0,0 +1,117 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { embedWorkflow } from '../src/pngMetadata.js'; +import { loadDefaultWorkflowAsset } from '../src/defaultWorkflow.js'; + +const PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+aF9sAAAAASUVORK5CYII='; + +function makePngBlob() { + return new Blob([Buffer.from(PNG_BASE64, 'base64')], { type: 'image/png' }); +} + +test('loadDefaultWorkflowAsset prefers checked-in JSON when present', async () => { + const workflow = { version: 1, nodes: [{ id: '1' }], edges: [] }; + const requests = []; + const fetchImpl = async (url) => { + requests.push(url); + if (url === '/default-workflow.json') { + return { + ok: true, + status: 200, + async json() { + return workflow; + }, + }; + } + throw new Error('PNG fallback should not be requested when JSON exists'); + }; + + const loaded = await loadDefaultWorkflowAsset({ fetchImpl }); + + assert.deepEqual(loaded, { + source: '/default-workflow.json', + format: 'json', + workflow, + }); + assert.deepEqual(requests, ['/default-workflow.json']); +}); + +test('loadDefaultWorkflowAsset falls back to PNG workflow metadata when JSON is missing', async () => { + const workflow = { version: 1, nodes: [{ id: '2' }], edges: [] }; + const pngWithWorkflow = await embedWorkflow(makePngBlob(), workflow); + const requests = []; + const fetchImpl = async (url) => { + requests.push(url); + if (url === '/default-workflow.json') { + return { ok: false, status: 404 }; + } + if (url === '/default-workflow.png') { + return { + ok: true, + status: 200, + async blob() { + return pngWithWorkflow; + }, + }; + } + throw new Error(`Unexpected URL ${url}`); + }; + + const loaded = await loadDefaultWorkflowAsset({ fetchImpl }); + + assert.deepEqual(loaded, { + source: '/default-workflow.png', + format: 'png', + workflow, + }); + assert.deepEqual(requests, ['/default-workflow.json', '/default-workflow.png']); +}); + +test('loadDefaultWorkflowAsset returns null when no default workflow asset is present', async () => { + const fetchImpl = async () => ({ ok: false, status: 404 }); + + const loaded = await loadDefaultWorkflowAsset({ fetchImpl }); + + assert.equal(loaded, null); +}); + +test('loadDefaultWorkflowAsset stays quiet when default assets are simply absent in the host runtime', async () => { + const fetchImpl = async () => { + throw new TypeError('Failed to fetch'); + }; + + const loaded = await loadDefaultWorkflowAsset({ fetchImpl }); + + assert.equal(loaded, null); +}); + +test('loadDefaultWorkflowAsset stays quiet when the host serves app HTML for missing default assets', async () => { + const fetchImpl = async (url) => ({ + ok: true, + status: 200, + headers: { + get(name) { + return name.toLowerCase() === 'content-type' ? 'text/html; charset=utf-8' : null; + }, + }, + async json() { + throw new SyntaxError(`Unexpected token '<' while parsing ${url}`); + }, + async blob() { + return new Blob([''], { type: 'text/html' }); + }, + }); + + const loaded = await loadDefaultWorkflowAsset({ fetchImpl }); + + assert.equal(loaded, null); +}); + +test('loadDefaultWorkflowAsset stays quiet when the host reports missing assets with status 0', async () => { + const fetchImpl = async () => ({ ok: false, status: 0 }); + + const loaded = await loadDefaultWorkflowAsset({ fetchImpl }); + + assert.equal(loaded, null); +}); diff --git a/frontend/tests/executionGraph.test.mjs b/frontend/tests/executionGraph.test.mjs new file mode 100644 index 0000000..d5558d5 --- /dev/null +++ b/frontend/tests/executionGraph.test.mjs @@ -0,0 +1,339 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { + serializeExecutionGraph, + getAutoRunnableNodes, + hasBlockingAutoRunInput, +} from '../src/executionGraph.js'; + +test('serializeExecutionGraph excludes isolated nodes from the backend prompt', () => { + const nodes = [ + { + id: '1', + data: { + className: 'LoadFile', + definition: { + input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: { filename: 'scan.gwy' }, + }, + }, + { + id: '2', + data: { + className: 'PreviewImage', + definition: { + input: { required: { field: ['DATA_FIELD', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: {}, + }, + }, + { + id: '3', + data: { + className: 'LoadFile', + definition: { + input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: {}, + }, + }, + ]; + const edges = [ + { + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '2', + targetHandle: 'input::field::DATA_FIELD', + }, + ]; + + const prompt = serializeExecutionGraph(nodes, edges); + + assert.deepEqual(prompt, { + '1': { + class_type: 'LoadFile', + inputs: { filename: 'scan.gwy' }, + }, + '2': { + class_type: 'PreviewImage', + inputs: { field: ['1', 0] }, + }, + }); + assert.equal('3' in prompt, false); +}); + +test('serializeExecutionGraph includes isolated preview-load nodes alongside connected subgraphs', () => { + const nodes = [ + { + id: '1', + data: { + className: 'LoadFile', + definition: { + input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: { filename: 'first.gwy' }, + }, + }, + { + id: '2', + data: { + className: 'PreviewImage', + definition: { + input: { required: { field: ['DATA_FIELD', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: {}, + }, + }, + { + id: '3', + data: { + className: 'LoadDemo', + definition: { + input: { required: { name: [['demo.npy'], {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: { name: 'demo.npy' }, + }, + }, + { + id: '4', + data: { + className: 'LoadFile', + definition: { + input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: { filename: '' }, + }, + }, + ]; + const edges = [ + { + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '2', + targetHandle: 'input::field::DATA_FIELD', + }, + ]; + + const prompt = serializeExecutionGraph(nodes, edges); + + assert.deepEqual(prompt, { + '1': { + class_type: 'LoadFile', + inputs: { filename: 'first.gwy' }, + }, + '2': { + class_type: 'PreviewImage', + inputs: { field: ['1', 0] }, + }, + '3': { + class_type: 'LoadDemo', + inputs: { name: 'demo.npy' }, + }, + }); + assert.equal('4' in prompt, false); +}); + +test('serializeExecutionGraph allows a singleton LoadFile graph so previews can run', () => { + const nodes = [ + { + id: '1', + data: { + className: 'LoadFile', + definition: { + input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: { filename: 'scan.gwy' }, + }, + }, + ]; + + const prompt = serializeExecutionGraph(nodes, []); + + assert.deepEqual(prompt, { + '1': { + class_type: 'LoadFile', + inputs: { filename: 'scan.gwy' }, + }, + }); +}); + +test('serializeExecutionGraph allows a singleton LoadDemo graph so previews can run', () => { + const nodes = [ + { + id: '1', + data: { + className: 'LoadDemo', + definition: { + input: { required: { name: [['demo.npy'], {}] }, optional: {} }, + manual_trigger: false, + }, + widgetValues: { name: 'demo.npy' }, + }, + }, + ]; + + const prompt = serializeExecutionGraph(nodes, []); + + assert.deepEqual(prompt, { + '1': { + class_type: 'LoadDemo', + inputs: { name: 'demo.npy' }, + }, + }); +}); + +test('getAutoRunnableNodes ignores disconnected nodes when deciding what can auto-run', () => { + const nodes = [ + { id: '1', data: { definition: {}, widgetValues: {} } }, + { id: '2', data: { definition: {}, widgetValues: {} } }, + { id: '3', data: { definition: {}, widgetValues: {} } }, + ]; + const edges = [ + { + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '2', + targetHandle: 'input::field::DATA_FIELD', + }, + ]; + + const runnable = getAutoRunnableNodes(nodes, edges); + + assert.deepEqual(runnable.map((node) => node.id), ['1', '2']); +}); + +test('getAutoRunnableNodes includes isolated preview-load nodes with selections', () => { + const nodes = [ + { id: '1', data: { className: 'LoadFile', definition: {}, widgetValues: { filename: 'first.gwy' } } }, + { id: '2', data: { className: 'PreviewImage', definition: {}, widgetValues: {} } }, + { id: '3', data: { className: 'LoadDemo', definition: {}, widgetValues: { name: 'demo.npy' } } }, + { id: '4', data: { className: 'LoadFile', definition: {}, widgetValues: { filename: '' } } }, + ]; + const edges = [ + { + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '2', + targetHandle: 'input::field::DATA_FIELD', + }, + ]; + + const runnable = getAutoRunnableNodes(nodes, edges); + + assert.deepEqual(runnable.map((node) => node.id), ['1', '2', '3']); +}); + +test('getAutoRunnableNodes allows a singleton LoadFile graph', () => { + const nodes = [ + { + id: '1', + data: { + className: 'LoadFile', + definition: {}, + widgetValues: { filename: 'scan.gwy' }, + }, + }, + ]; + + const runnable = getAutoRunnableNodes(nodes, []); + + assert.deepEqual(runnable.map((node) => node.id), ['1']); +}); + +test('getAutoRunnableNodes allows a singleton LoadDemo graph', () => { + const nodes = [ + { + id: '1', + data: { + className: 'LoadDemo', + definition: {}, + widgetValues: { name: 'demo.npy' }, + }, + }, + ]; + + const runnable = getAutoRunnableNodes(nodes, []); + + assert.deepEqual(runnable.map((node) => node.id), ['1']); +}); + +test('hasBlockingAutoRunInput only blocks connected nodes with incomplete required inputs', () => { + const node = { + id: '2', + data: { + definition: { + manual_trigger: false, + input: { + required: { + field: ['DATA_FIELD', {}], + filename: ['FILE_PICKER', {}], + }, + }, + }, + widgetValues: { filename: '' }, + }, + }; + const completeEdges = [ + { + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '2', + targetHandle: 'input::field::DATA_FIELD', + }, + ]; + + assert.equal(hasBlockingAutoRunInput(node, completeEdges), true); + assert.equal( + hasBlockingAutoRunInput( + { + ...node, + data: { + ...node.data, + widgetValues: { filename: 'scan.gwy' }, + }, + }, + completeEdges, + ), + false, + ); +}); + +test('hasBlockingAutoRunInput skips required file widgets when a connected socket overrides them', () => { + const node = { + id: '2', + data: { + definition: { + manual_trigger: false, + input: { + required: { + filename: ['FILE_PICKER', { hide_when_input_connected: 'path' }], + }, + optional: { + path: ['FILE_PATH', {}], + }, + }, + }, + widgetValues: { filename: '' }, + }, + }; + const edges = [ + { + source: '1', + sourceHandle: 'output::0::FILE_PATH', + target: '2', + targetHandle: 'input::path::FILE_PATH', + }, + ]; + + assert.equal(hasBlockingAutoRunInput(node, edges), false); +}); diff --git a/frontend/tests/workflowSerialization.test.mjs b/frontend/tests/workflowSerialization.test.mjs index ca60ec7..f691437 100644 --- a/frontend/tests/workflowSerialization.test.mjs +++ b/frontend/tests/workflowSerialization.test.mjs @@ -95,7 +95,7 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload assert.equal('selected' in serialized.edges[0], false); }); -test('hydrateWorkflowState restores saved dynamic outputs on top of current node definitions', () => { +test('hydrateWorkflowState clears shared path widgets while restoring saved dynamic outputs', () => { const saved = { version: 1, nodes: [ @@ -140,12 +140,14 @@ test('hydrateWorkflowState restores saved dynamic outputs on top of current node assert.equal(hydrated.nodes[0].dragHandle, '.drag-handle'); assert.equal(hydrated.nodes[0].data.label, 'LoadFile'); assert.equal(hydrated.nodes[0].data.previewImage, null); + assert.equal(hydrated.nodes[0].data.widgetValues.filename, ''); + assert.equal(hydrated.nodes[0].data.widgetValues.colormap, 'viridis'); assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD']); assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Height', 'Phase']); assert.deepEqual(hydrated.nodes[0].data.definition.input, defs.LoadFile.input); }); -test('serializeWorkflowState and hydrateWorkflowState preserve reload-critical metadata for dynamic nodes', () => { +test('serializeWorkflowState and hydrateWorkflowState clear path-like widgets but preserve other metadata', () => { const nodes = [ { id: '7', @@ -185,8 +187,42 @@ test('serializeWorkflowState and hydrateWorkflowState preserve reload-critical m const serialized = serializeWorkflowState(nodes, edges); const hydrated = hydrateWorkflowState(serialized, defs); - assert.deepEqual(hydrated.nodes[0].data.widgetValues, nodes[0].data.widgetValues); + assert.deepEqual(hydrated.nodes[0].data.widgetValues, { filename: '', colormap: 'gray' }); assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD']); assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Topography', 'Error', 'Mask']); assert.deepEqual(hydrated.edges, edges); }); + +test('hydrateWorkflowState clears saved folder selections on shared workflows', () => { + const saved = { + version: 1, + nodes: [ + { + id: '21', + position: { x: 0, y: 0 }, + data: { + className: 'Folder', + widgetValues: { folder: '/Users/alice/Desktop/shared-dataset' }, + output: ['PATH', 'PATH'], + output_name: ['scan1.png', 'scan2.png'], + }, + }, + ], + edges: [], + }; + + const defs = { + Folder: { + category: 'io', + input: { required: { folder: ['FOLDER_PICKER', {}] } }, + output: ['PATH'], + output_name: ['path'], + }, + }; + + const hydrated = hydrateWorkflowState(saved, defs); + + assert.equal(hydrated.nodes[0].data.widgetValues.folder, ''); + assert.deepEqual(hydrated.nodes[0].data.definition.output, ['PATH', 'PATH']); + assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['scan1.png', 'scan2.png']); +}); diff --git a/frontend/vite.config.js b/frontend/vite.config.js index 3072098..f674bd5 100644 --- a/frontend/vite.config.js +++ b/frontend/vite.config.js @@ -10,6 +10,7 @@ export default defineConfig({ '/nodes': 'http://127.0.0.1:8188', '/files': 'http://127.0.0.1:8188', '/browse': 'http://127.0.0.1:8188', + '/folder-files': 'http://127.0.0.1:8188', '/channels': 'http://127.0.0.1:8188', '/upload': 'http://127.0.0.1:8188', '/download': 'http://127.0.0.1:8188', diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 71a13b4..5a7d463 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -8,10 +8,11 @@ import json import sys import os import tempfile +from pathlib import Path import numpy as np sys.path.insert(0, ".") -from backend.data_types import DataField, MeasureTable, RecordTable, datafield_to_uint8 +from backend.data_types import DataField, MeasureTable, RecordTable, datafield_to_uint8, render_datafield_preview def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6): @@ -79,6 +80,7 @@ def test_crop_resize_field(): yoff=20.0, si_unit_xy="nm", si_unit_z="nm", + overlays=[{"kind": "markup", "shapes": [{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 2, "color": "#ffffff"}]}], ) overlays = [] @@ -103,6 +105,7 @@ def test_crop_resize_field(): assert cropped.yoff == 21.0 assert cropped.si_unit_xy == field.si_unit_xy assert cropped.si_unit_z == field.si_unit_z + assert cropped.overlays == [] assert len(overlays) == 1 assert overlays[0]["kind"] == "crop_box" assert overlays[0]["image"].startswith("data:image/png;base64,") @@ -192,6 +195,7 @@ def test_rotate_field(): assert rotated_90.yoff == 19.0 assert rotated_90.si_unit_xy == field.si_unit_xy assert rotated_90.si_unit_z == field.si_unit_z + assert rotated_90.overlays == [] rotated_180, = node.process( field, @@ -224,6 +228,34 @@ def test_rotate_field(): print(" PASS\n") +def test_rotate_field_overlay_warning(): + print("=== Test: RotateField overlay warning ===") + from backend.nodes.modify import RotateField + + node = RotateField() + warnings = [] + RotateField._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) + RotateField._current_node_id = "test" + + field = DataField( + data=np.arange(16, dtype=np.float64).reshape(4, 4), + overlays=[{"kind": "markup", "shapes": [{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 2, "color": "#ffffff"}]}], + ) + + rotated, = node.process( + field, + angle=30.0, + interpolation="bilinear", + expand_canvas=True, + ) + assert rotated.overlays == [] + assert len(warnings) == 1 + assert "clears annotation/markup overlays" in warnings[0] + + RotateField._broadcast_warning_fn = None + print(" PASS\n") + + def test_colormap_adjust(): print("=== Test: ColormapAdjust ===") from backend.nodes.modify import ColormapAdjust @@ -833,16 +865,36 @@ def test_load_file(): result_npy = node.load(filename=path_npy) assert np.allclose(result_npy[0].data, data_npy) + custom_colormap = { + "mode": "custom", + "stops": [ + {"position": 0.0, "color": "#000000"}, + {"position": 0.5, "color": "#ff0000"}, + {"position": 1.0, "color": "#ffffff"}, + ], + } + result_custom = node.load(filename=path, colormap_map=custom_colormap) + assert isinstance(result_custom[0].colormap, dict) + assert result_custom[0].colormap["mode"] == "custom" + assert len(result_custom[0].colormap["stops"]) == 3 + + result_from_path = node.load(filename="", path=path) + assert len(result_from_path) == 1 + assert result_from_path[0].data.shape == (48, 64) + print(" PASS\n") def test_save_image(): print("=== Test: SaveImage (Save Layers) ===") from backend.nodes.io import SaveImage + import tifffile node = SaveImage() field_a = make_field(data=np.random.default_rng(4).random((32, 32))) field_b = make_field(data=np.random.default_rng(5).random((32, 32))) + annotated = np.zeros((24, 24, 3), dtype=np.uint8) + annotated[..., 0] = 255 with tempfile.TemporaryDirectory() as tmpdir: # Save single layer as TIFF @@ -861,20 +913,57 @@ def test_save_image(): im2 = Image.open(tiff_path2) assert im2.n_frames == 2 - # Save as NPZ + # Save annotated image as TIFF with layer name + annotated_tiff = os.path.join(tmpdir, "annotated.tiff") + node.save( + filename=annotated_tiff, + format="TIFF", + field_0=annotated, + layer_name_0="annotated overview", + ) + with tifffile.TiffFile(annotated_tiff) as tif: + assert len(tif.pages) == 1 + assert tif.pages[0].description == "annotated overview" + assert tif.pages[0].asarray().shape == annotated.shape + + # Save as NPZ with layer names npz_path = os.path.join(tmpdir, "out.npz") - node.save(filename=npz_path, format="NPZ", field_0=field_a, field_1=field_b) + node.save( + filename=npz_path, + format="NPZ", + field_0=field_a, + field_1=annotated, + layer_name_0="height map", + layer_name_1="annotated-overview", + ) assert os.path.exists(npz_path) npz = np.load(npz_path) assert len(npz.files) == 2 - assert np.allclose(npz["layer_0"], field_a.data) - assert np.allclose(npz["layer_1"], field_b.data) + assert np.allclose(npz["height_map"], field_a.data) + assert np.array_equal(npz["annotated_overview"], annotated) # Extension is forced to match format wrong_ext = os.path.join(tmpdir, "output.png") node.save(filename=wrong_ext, format="TIFF", field_0=field_a) assert os.path.exists(os.path.join(tmpdir, "output.tiff")) + # Directory input can drive the destination folder while filename supplies the basename + driven_dir = os.path.join(tmpdir, "nested-output") + node.save(filename="driven_name", directory=driven_dir, format="NPZ", field_0=field_a) + assert os.path.exists(os.path.join(driven_dir, "driven_name.npz")) + + # Directory input rejects file paths + try: + node.save( + filename="bad", + directory=os.path.join(tmpdir, "looks_like_file.txt"), + format="TIFF", + field_0=field_a, + ) + assert False, "Should have raised ValueError for file-like directory path" + except ValueError: + pass + # No fields connected → error try: node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF") @@ -896,6 +985,50 @@ def test_save_image(): # Display (limited testing — these are output nodes with WS callbacks) # ========================================================================= +def test_color_map_node(): + print("=== Test: ColorMap ===") + from backend.nodes.display import ColorMap + + node = ColorMap() + + preset, = node.build(mode="preset", preset="magma", stops_json="[]") + assert preset["mode"] == "preset" + assert preset["preset"] == "magma" + + custom, = node.build( + mode="custom", + preset="viridis", + stops_json=json.dumps([ + {"position": 0.0, "color": "#000000"}, + {"position": 0.4, "color": "#00ff00"}, + {"position": 1.0, "color": "#ffffff"}, + ]), + ) + assert custom["mode"] == "custom" + assert custom["stops"][0]["position"] == 0.0 + assert custom["stops"][-1]["position"] == 1.0 + assert len(custom["stops"]) == 3 + print(" PASS\n") + + +def test_font_node(): + print("=== Test: Font ===") + from backend.nodes.display import Font + from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT + + node = Font() + + system_default, = node.build(SYSTEM_DEFAULT_FONT) + assert system_default is None + + named, = node.build("Arial") + assert named == {"family": "Arial", "path": ""} + + custom, = node.build(CUSTOM_FILE_FONT, "/tmp/example-font.ttf") + assert custom == {"family": "", "path": "/tmp/example-font.ttf"} + print(" PASS\n") + + def test_preview_image(): print("=== Test: PreviewImage ===") from backend.nodes.display import PreviewImage @@ -912,6 +1045,27 @@ def test_preview_image(): assert len(captured) == 1 assert captured[0].startswith("data:image/png;base64,") + # Preview with field overlay metadata + captured.clear() + field_with_overlay = field.replace(overlays=[{"kind": "annotation", "show_scale_bar": True, "show_color_map": False, "text_size": 14.0}]) + node.preview(colormap="viridis", field=field_with_overlay) + assert len(captured) == 1 + assert captured[0].startswith("data:image/png;base64,") + + # Preview with a custom colormap input + captured.clear() + custom_colormap = { + "mode": "custom", + "stops": [ + {"position": 0.0, "color": "#000000"}, + {"position": 0.5, "color": "#ff0000"}, + {"position": 1.0, "color": "#ffffff"}, + ], + } + node.preview(colormap="auto", field=field, colormap_map=custom_colormap) + assert len(captured) == 1 + assert captured[0].startswith("data:image/png;base64,") + # Preview with an IMAGE array captured.clear() arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8) @@ -923,6 +1077,128 @@ def test_preview_image(): print(" PASS\n") +def test_annotations(): + print("=== Test: Annotations ===") + from backend.nodes.display import Annotations, Font + + node = Annotations() + font_node = Font() + field = DataField( + data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64), + xreal=1e-6, + yreal=1e-6, + si_unit_xy="m", + si_unit_z="V", + colormap="viridis", + ) + + base = datafield_to_uint8(field, "viridis") + plain_preview = render_datafield_preview(field, "viridis") + assert np.array_equal(plain_preview, base) + + plain_field, = node.render(field, colormap="auto", show_scale_bar=False, show_color_map=False) + assert isinstance(plain_field, DataField) + assert np.array_equal(plain_field.data, field.data) + assert plain_field.colormap == "viridis" + assert plain_field.overlays[-1]["kind"] == "annotation" + plain = render_datafield_preview(plain_field, plain_field.colormap) + assert plain.shape == base.shape + assert np.array_equal(plain, base) + + with_scale_field, = node.render(field, colormap="auto", show_scale_bar=True, show_color_map=False) + with_scale = render_datafield_preview(with_scale_field, with_scale_field.colormap) + assert with_scale.shape == base.shape + assert not np.array_equal(with_scale, base) + + with_legend_field, = node.render(field, colormap="auto", show_scale_bar=False, show_color_map=True) + with_legend = render_datafield_preview(with_legend_field, with_legend_field.colormap) + assert with_legend.shape[0] == base.shape[0] + assert with_legend.shape[1] > base.shape[1] + assert with_legend.shape[2] == 3 + + larger_legend_field, = node.render( + field, + colormap="auto", + show_scale_bar=False, + show_color_map=True, + text_size=28.0, + ) + larger_legend_text = render_datafield_preview(larger_legend_field, larger_legend_field.colormap) + assert larger_legend_text.shape == with_legend.shape + assert not np.array_equal(larger_legend_text, with_legend) + + annotation_font, = font_node.build("Arial") + with_font_field, = node.render( + field, + colormap="auto", + show_scale_bar=False, + show_color_map=True, + text_size=28.0, + font=annotation_font, + ) + assert with_font_field.overlays[-1]["font"] == {"family": "Arial", "path": ""} + with_font = render_datafield_preview(with_font_field, with_font_field.colormap) + assert with_font.shape == with_legend.shape + + with_both_field, = node.render(field, colormap="auto", show_scale_bar=True, show_color_map=True) + with_both = render_datafield_preview(with_both_field, with_both_field.colormap) + assert with_both.shape == with_legend.shape + assert not np.array_equal(with_both[:, :base.shape[1]], base) + + print(" PASS\n") + + +def test_markup(): + print("=== Test: Markup ===") + from backend.nodes.display import Markup + from backend.data_types import _preview_markup_stroke_width + + node = Markup() + field = make_field(data=np.linspace(0.0, 1.0, 48 * 48, dtype=np.float64).reshape(48, 48)) + base = render_datafield_preview(field, field.colormap) + + assert _preview_markup_stroke_width(5, 128, 128) == 5 + assert _preview_markup_stroke_width(5, 2048, 2048) > 5 + + overlays = [] + Markup._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + Markup._current_node_id = "test" + + plain_field, = node.process( + field=field, + shape="line", + stroke_color="#ffd54f", + stroke_width=3, + markup_shapes="[]", + ) + assert isinstance(plain_field, DataField) + assert plain_field.overlays[-1]["kind"] == "markup" + plain = render_datafield_preview(plain_field, plain_field.colormap) + assert np.array_equal(plain, base) + assert overlays[-1]["kind"] == "markup" + assert overlays[-1]["image"].startswith("data:image/png;base64,") + + shapes = json.dumps([ + {"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 3, "color": "#ff0000"}, + {"kind": "rectangle", "x1": 0.2, "y1": 0.2, "x2": 0.8, "y2": 0.5, "width": 2, "color": "#00ff00"}, + {"kind": "circle", "x1": 0.25, "y1": 0.55, "x2": 0.55, "y2": 0.85, "width": 2, "color": "#4fc3f7"}, + {"kind": "arrow", "x1": 0.15, "y1": 0.85, "x2": 0.85, "y2": 0.2, "width": 4, "color": "#ffffff"}, + ]) + marked_field, = node.process( + field=field, + shape="arrow", + stroke_color="#ffffff", + stroke_width=4, + markup_shapes=shapes, + ) + marked = render_datafield_preview(marked_field, marked_field.colormap) + assert marked.shape == base.shape + assert not np.array_equal(marked, base) + + Markup._broadcast_overlay_fn = None + print(" PASS\n") + + def test_print_table(): print("=== Test: PrintTable ===") from backend.nodes.display import PrintTable @@ -1086,7 +1362,8 @@ def test_load_file_warning(): def test_list_channels(): print("=== Test: list_channels ===") - from backend.nodes.io import list_channels + from backend.nodes.io import list_channels, list_folder_paths, Folder + from PIL import Image # Non-existent file → default ch = list_channels("/nonexistent/file.ibw") @@ -1105,7 +1382,6 @@ def test_list_channels(): # Plain image → single default channel with tempfile.TemporaryDirectory() as tmpdir: - from PIL import Image img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8)) path = os.path.join(tmpdir, "test.png") img.save(path) @@ -1122,6 +1398,32 @@ def test_list_channels(): ch = list_channels(path) assert len(ch) == 1 + with tempfile.TemporaryDirectory() as tmpdir: + img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8)) + png_path = os.path.join(tmpdir, "a.png") + npy_path = os.path.join(tmpdir, "b.npy") + gwy_path = os.path.join(tmpdir, "c.gwy") + sxm_path = os.path.join(tmpdir, "d.sxm") + ibw_path = os.path.join(tmpdir, "e.ibw") + txt_path = os.path.join(tmpdir, "notes.txt") + img.save(png_path) + np.save(npy_path, np.zeros((4, 4))) + Path(gwy_path).write_bytes(b"gwy") + Path(sxm_path).write_bytes(b"sxm") + Path(ibw_path).write_bytes(b"ibw") + with open(txt_path, "w", encoding="utf-8") as fh: + fh.write("ignore me") + + paths = list_folder_paths(tmpdir) + assert [entry["name"] for entry in paths] == ["directory", "a.png", "b.npy", "c.gwy", "d.sxm", "e.ibw"] + assert Path(paths[0]["path"]).resolve() == Path(tmpdir).resolve() + assert paths[0]["type"] == "DIRECTORY" + assert all(entry["type"] == "FILE_PATH" for entry in paths[1:]) + + folder_node = Folder() + folder_result = folder_node.list_files(tmpdir) + assert folder_result == tuple(entry["path"] for entry in paths) + print(" PASS\n") @@ -1157,6 +1459,35 @@ def test_load_demo(): print(" PASS\n") +def test_load_demo_multi_layer_preview_payload(): + print("=== Test: LoadDemo multi-layer preview payload ===") + from backend.execution import ExecutionEngine + import backend.nodes # noqa: F401 + + previews = [] + prompt = { + "1": { + "class_type": "LoadDemo", + "inputs": { + "name": "whiskers.ibw", + "colormap": "viridis", + }, + }, + } + + ExecutionEngine().execute(prompt, on_preview=lambda node_id, payload: previews.append((node_id, payload))) + + assert len(previews) == 1 + node_id, payload = previews[0] + assert node_id == "1" + assert payload["kind"] == "layer_gallery" + assert len(payload["layers"]) == 4 + assert all(isinstance(layer["name"], str) and layer["name"] for layer in payload["layers"]) + assert all(layer["image"].startswith("data:image/png;base64,") for layer in payload["layers"]) + + print(" PASS\n") + + # ========================================================================= # I/O — Coordinate # ========================================================================= @@ -1181,6 +1512,25 @@ def test_coordinate(): print(" PASS\n") +# ========================================================================= +# I/O — Number +# ========================================================================= + +def test_number(): + print("=== Test: Number ===") + from backend.nodes.io import Number + + node = Number() + + result = node.process(value=1.25) + assert result == (1.25,) + + result_neg = node.process(value=-3.5) + assert result_neg == (-3.5,) + + print(" PASS\n") + + def test_range_slider(): print("=== Test: RangeSlider ===") from backend.nodes.io import RangeSlider @@ -1205,6 +1555,62 @@ def test_range_slider(): print(" PASS\n") +def test_execution_engine_numeric_socket_coercion(): + print("=== Test: ExecutionEngine numeric socket coercion ===") + from backend.execution import ExecutionEngine + from backend.node_registry import register_node + + @register_node(display_name="Test Echo Int") + class TestEchoInt: + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("INT",)}} + + RETURN_TYPES = ("INT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + CATEGORY = "tests" + + def process(self, value): + return (value,) + + @register_node(display_name="Test Echo Float") + class TestEchoFloat: + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("FLOAT",)}} + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + CATEGORY = "tests" + + def process(self, value): + return (value,) + + engine = ExecutionEngine() + prompt = { + "1": { + "class_type": "Number", + "inputs": {"value": 3.6}, + }, + "2": { + "class_type": "TestEchoInt", + "inputs": {"value": ["1", 0]}, + }, + "3": { + "class_type": "TestEchoFloat", + "inputs": {"value": ["1", 0]}, + }, + } + + outputs = engine.execute(prompt) + assert outputs["2"] == (4,) + assert outputs["3"] == (3.6,) + + print(" PASS\n") + + # ========================================================================= # Analysis — LineCursors # =========================================================================