diff --git a/backend/__init__.py b/backend/__init__.py index f22815b..9d48db4 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -1,18 +1 @@ from __future__ import annotations - -import numpy as np - - -def _apply_numpy_compat_aliases() -> None: - """Restore removed NumPy scalar aliases still used by some dependencies.""" - aliases = { - "complex": complex, - "float": float, - "int": int, - } - for name, value in aliases.items(): - if not hasattr(np, name): - setattr(np, name, value) - - -_apply_numpy_compat_aliases() diff --git a/backend/data_types.py b/backend/data_types.py index 2754a4f..38a3ddb 100644 --- a/backend/data_types.py +++ b/backend/data_types.py @@ -67,6 +67,43 @@ class LineData: return self.data[item] +@dataclass +class MeshModel: + vertices: np.ndarray + faces: np.ndarray + colors: np.ndarray | None = None + + def __post_init__(self) -> None: + self.vertices = np.asarray(self.vertices, dtype=np.float32).reshape(-1, 3) + self.faces = np.asarray(self.faces, dtype=np.int32).reshape(-1, 3) + if self.colors is not None: + self.colors = np.asarray(self.colors, dtype=np.uint8).reshape(-1, 3) + if len(self.colors) != len(self.vertices): + raise ValueError("MeshModel.colors must have one RGB triplet per vertex.") + + +class ImageData(np.ndarray): + def __new__(cls, data: Any, metadata: dict[str, Any] | None = None): + obj = np.asarray(data).view(cls) + obj.metadata = deepcopy(metadata) if isinstance(metadata, dict) else {} + return obj + + def __array_finalize__(self, obj): + self.metadata = deepcopy(getattr(obj, "metadata", {})) if obj is not None else {} + + def copy_with_metadata(self, *, data: Any | None = None, metadata: dict[str, Any] | None = None) -> "ImageData": + base = np.asarray(self if data is None else data) + merged = deepcopy(self.metadata) + if isinstance(metadata, dict): + merged.update(deepcopy(metadata)) + return ImageData(base, metadata=merged) + + +def image_metadata(image: Any) -> dict[str, Any]: + metadata = getattr(image, "metadata", None) + return deepcopy(metadata) if isinstance(metadata, dict) else {} + + def _normalize_hex_color(color: Any, default: str = "#000000") -> str: if isinstance(color, str): text = color.strip() @@ -638,10 +675,28 @@ def _sanitize_markup_shapes(shapes: Any) -> list[dict[str, Any]]: return parsed -def _apply_annotation_overlay( +def _annotation_context_from_field(field: DataField, colormap: Any) -> dict[str, Any]: + legend_min, legend_mid, legend_max = _display_value_range(field) + return { + "xreal": float(field.xreal), + "si_unit_xy": str(field.si_unit_xy), + "legend_min": float(legend_min), + "legend_mid": float(legend_mid), + "legend_max": float(legend_max), + "legend_unit": str(field.si_unit_z), + "colormap": normalize_colormap_spec(colormap, fallback=field.colormap), + } + + +def _annotation_context_from_image(image: Any) -> dict[str, Any] | None: + metadata = image_metadata(image) + context = metadata.get("annotation_context") + return deepcopy(context) if isinstance(context, dict) else None + + +def _apply_annotation_overlay_from_context( image: np.ndarray, - field: DataField, - colormap: Any, + context: dict[str, Any], spec: dict[str, Any], ) -> np.ndarray: from PIL import Image, ImageDraw @@ -657,8 +712,7 @@ def _apply_annotation_overlay( current = np.repeat(current[:, :, np.newaxis], 3, axis=2) height, current_width = current.shape[:2] - field_width = max(1, int(field.xres)) - legend_width = max(72, int(round(field_width * 0.18))) if show_color_map else 0 + legend_width = max(72, int(round(current_width * 0.18))) if show_color_map else 0 canvas_width = current_width + legend_width canvas = np.full((height, canvas_width, 3), 255, dtype=np.uint8) canvas[:, :current_width] = current @@ -667,20 +721,40 @@ def _apply_annotation_overlay( draw = ImageDraw.Draw(pil_image) base_font_px = max(6, int(round(text_size))) - if show_scale_bar and field_width > 0 and np.isfinite(field.xreal) and field.xreal > 0: - target_real = field.xreal / 5.0 + xreal_raw = context.get("xreal") + xreal = float(xreal_raw) if xreal_raw is not None else 0.0 + si_unit_xy = str(context.get("si_unit_xy", "") or "") + legend_unit = str(context.get("legend_unit", "") or "") + legend_min_raw = context.get("legend_min") + legend_mid_raw = context.get("legend_mid") + legend_max_raw = context.get("legend_max") + legend_min = float(legend_min_raw) if legend_min_raw is not None else 0.0 + legend_mid = float(legend_mid_raw) if legend_mid_raw is not None else 0.0 + legend_max = float(legend_max_raw) if legend_max_raw is not None else 0.0 + colormap = normalize_colormap_spec(context.get("colormap", "gray"), fallback="gray") + has_scale_bar = np.isfinite(xreal) and xreal > 0 and bool(si_unit_xy) + has_color_legend = ( + np.isfinite(legend_min) + and np.isfinite(legend_mid) + and np.isfinite(legend_max) + and bool(legend_unit) + ) + + if show_scale_bar and has_scale_bar and current_width > 0: + target_real = xreal / 5.0 bar_real = _nice_length(target_real) - if bar_real > 0 and np.isfinite(field.dx) and field.dx > 0: - bar_px = max(1, int(round(bar_real / field.dx))) - margin_x = max(8, field_width // 24) + if bar_real > 0: + px_per_real = current_width / xreal + bar_px = max(1, int(round(bar_real * px_per_real))) + margin_x = max(8, current_width // 24) margin_y = max(8, height // 24) bar_height = max(3, int(round(height * 0.012))) - bar_px = min(bar_px, max(1, field_width - 2 * margin_x)) + bar_px = min(bar_px, max(1, current_width - 2 * margin_x)) x0 = margin_x x1 = x0 + bar_px y1 = height - margin_y y0 = y1 - bar_height - text = _format_with_unit(bar_real, field.si_unit_xy) + text = _format_with_unit(bar_real, si_unit_xy) text_image = _render_overlay_text(text, base_font_px, (255, 255, 255), font_spec=font_spec) text_w, text_h = text_image.size label_pad = 2 @@ -692,7 +766,7 @@ def _apply_annotation_overlay( draw.rectangle((x0, y0, x1, y1), fill=(255, 255, 255)) pil_image.paste(text_image, (x0, bg_top + label_pad), text_image) - if show_color_map and legend_width > 0: + if show_color_map and has_color_legend and legend_width > 0: panel_x0 = current_width draw.rectangle((panel_x0, 0, canvas_width, height), fill=(245, 245, 245)) grad_x0 = panel_x0 + max(8, legend_width // 7) @@ -706,7 +780,6 @@ def _apply_annotation_overlay( pil_image.paste(Image.fromarray(gradient_rgb), (grad_x0, grad_y0)) draw.rectangle((grad_x0, grad_y0, grad_x0 + grad_w, grad_y1), outline=(40, 40, 40), width=1) - legend_min, legend_mid, legend_max = _display_value_range(field) labels = [ (legend_max, grad_y0), (legend_mid, grad_y0 + grad_h // 2), @@ -715,7 +788,7 @@ def _apply_annotation_overlay( text_x = grad_x0 + grad_w + 8 for value, y_center in labels: text_image = _render_overlay_text( - _format_with_unit(value, field.si_unit_z), + _format_with_unit(value, legend_unit), base_font_px, (20, 20, 20), font_spec=font_spec, @@ -727,7 +800,20 @@ def _apply_annotation_overlay( return np.asarray(pil_image, dtype=np.uint8) -def _apply_markup_overlay(image: np.ndarray, field: DataField, spec: dict[str, Any]) -> np.ndarray: +def _apply_annotation_overlay( + image: np.ndarray, + field: DataField, + colormap: Any, + spec: dict[str, Any], +) -> np.ndarray: + return _apply_annotation_overlay_from_context( + image, + _annotation_context_from_field(field, colormap), + spec, + ) + + +def _apply_markup_overlay(image: np.ndarray, field: DataField | None, spec: dict[str, Any]) -> np.ndarray: from PIL import Image, ImageDraw current = np.asarray(image, dtype=np.uint8) @@ -736,8 +822,8 @@ def _apply_markup_overlay(image: np.ndarray, field: DataField, spec: dict[str, A pil_image = Image.fromarray(current.copy()) draw = ImageDraw.Draw(pil_image) - field_width = max(1, int(field.xres)) - field_height = max(1, int(field.yres)) + field_width = max(1, int(field.xres)) if isinstance(field, DataField) else max(1, current.shape[1]) + field_height = max(1, int(field.yres)) if isinstance(field, DataField) else max(1, current.shape[0]) for shape in _sanitize_markup_shapes(spec.get("shapes")): x1 = float(shape["x1"]) * field_width diff --git a/backend/execution.py b/backend/execution.py index fcc331b..ccb4596 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -221,6 +221,7 @@ class ExecutionEngine: from backend.nodes.preview_image import PreviewImage from backend.nodes.print_table import PrintTable from backend.nodes.view_3d import View3D + from backend.nodes.annotations import Annotations from backend.nodes.value_display import ValueDisplay from backend.nodes.markup import Markup from backend.nodes.cross_section import CrossSection @@ -234,6 +235,7 @@ class ExecutionEngine: from backend.nodes.mask_invert import MaskInvert from backend.nodes.mask_combine import MaskCombine from backend.nodes.draw_mask import DrawMask + from backend.nodes.save import Save from backend.nodes.save_image import SaveImage from backend.nodes.image import Image from backend.nodes.image_demo import ImageDemo @@ -245,6 +247,7 @@ class ExecutionEngine: MaskCombine._broadcast_fn = on_preview DrawMask._broadcast_overlay_fn = on_overlay View3D._broadcast_mesh_fn = on_mesh + Annotations._broadcast_warning_fn = on_warning PrintTable._broadcast_table_fn = on_table ValueDisplay._broadcast_value_fn = on_value Stats._broadcast_value_fn = on_value @@ -256,6 +259,7 @@ class ExecutionEngine: Markup._broadcast_overlay_fn = on_overlay Image._broadcast_warning_fn = on_warning ImageDemo._broadcast_warning_fn = on_warning + Save._broadcast_warning_fn = on_warning SaveImage._broadcast_warning_fn = on_warning def _set_node_id_on_display(self, cls: type, node_id: str) -> None: @@ -263,6 +267,7 @@ class ExecutionEngine: from backend.nodes.preview_image import PreviewImage from backend.nodes.print_table import PrintTable from backend.nodes.view_3d import View3D + from backend.nodes.annotations import Annotations from backend.nodes.value_display import ValueDisplay from backend.nodes.markup import Markup from backend.nodes.cross_section import CrossSection @@ -278,10 +283,11 @@ class ExecutionEngine: from backend.nodes.draw_mask import DrawMask from backend.nodes.image import Image from backend.nodes.image_demo import ImageDemo + from backend.nodes.save import Save from backend.nodes.save_image import SaveImage - if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup, + if cls in (PreviewImage, PrintTable, View3D, Annotations, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup, ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask, - Image, ImageDemo, SaveImage): + Image, ImageDemo, Save, SaveImage): cls._current_node_id = node_id def _auto_preview( @@ -331,6 +337,16 @@ class ExecutionEngine: on_preview(node_id, encode_preview(arr)) return + if type_name == "ANNOTATION_SOURCE" and on_preview: + if isinstance(value, DataField): + arr = render_datafield_preview(value, value.colormap) + on_preview(node_id, encode_preview(arr)) + return + if isinstance(value, np.ndarray): + arr = image_to_uint8(value) + on_preview(node_id, encode_preview(arr)) + return + if type_name == "LINE" and isinstance(value, (np.ndarray, LineData)) and on_preview: preview = self._render_line_preview(cls, slot, result) if preview: diff --git a/backend/node_menu.py b/backend/node_menu.py index 2bc2de8..f8175b6 100644 --- a/backend/node_menu.py +++ b/backend/node_menu.py @@ -25,6 +25,7 @@ MENU_LAYOUT: dict[str, list[str]] = { ], "Output": [ "PreviewImage", + "Save", "SaveImage", "View3D", "PrintTable", diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index 1b88dcf..234cf27 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -8,6 +8,7 @@ from backend.nodes import ( coordinate_pair, number, range_slider, + save, save_image, # Filters gaussian_filter, diff --git a/backend/nodes/annotations.py b/backend/nodes/annotations.py index 4d615c4..8e74c3e 100644 --- a/backend/nodes/annotations.py +++ b/backend/nodes/annotations.py @@ -1,16 +1,28 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node -from backend.data_types import COLORMAPS, DataField, normalize_font_spec, resolve_colormap_input +from backend.data_types import ( + COLORMAPS, + DataField, + ImageData, + _apply_annotation_overlay_from_context, + _annotation_context_from_image, + image_to_uint8, + normalize_font_spec, + resolve_colormap_input, +) @register_node(display_name="Annotations") class Annotations: + _broadcast_warning_fn = None + _current_node_id: str = "" + @classmethod def INPUT_TYPES(cls): return { "required": { - "field": ("DATA_FIELD",), + "input": ("ANNOTATION_SOURCE", {"label": "Input"}), "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), "show_scale_bar": ("BOOLEAN", {"default": True}), "show_color_map": ("BOOLEAN", {"default": True}), @@ -27,18 +39,18 @@ class Annotations: }, } - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("annotated",) + RETURN_TYPES = ("ANNOTATION_SOURCE",) + RETURN_NAMES = ("Output",) FUNCTION = "render" DESCRIPTION = ( - "Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. " - "The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values." + "Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data, " + "or annotate an IMAGE that carries viewport metadata from View3D." ) def render( self, - field: DataField, + input, colormap: str, show_scale_bar: bool, show_color_map: bool, @@ -46,24 +58,69 @@ class Annotations: colormap_map=None, font=None, ) -> tuple: + annotation_spec = { + "kind": "annotation", + "show_scale_bar": bool(show_scale_bar), + "show_color_map": bool(show_color_map), + "text_size": float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0, + "font": normalize_font_spec(font), + } + + if isinstance(input, DataField): + resolved_colormap = resolve_colormap_input( + colormap, + colormap_input=colormap_map, + inherited=input.colormap, + default="gray", + ) + out = input.replace( + colormap=resolved_colormap, + overlays=[ + *input.overlays, + annotation_spec, + ], + ) + return (out,) + + context = _annotation_context_from_image(input) + if context is None: + self._send_warning( + "Annotations image input has no scale metadata, so scale bar and color-map legend cannot be added." + ) + return (ImageData(image_to_uint8(input)),) + resolved_colormap = resolve_colormap_input( colormap, colormap_input=colormap_map, - inherited=field.colormap, + inherited=context.get("colormap"), default="gray", ) - text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0 - out = field.replace( - colormap=resolved_colormap, - overlays=[ - *field.overlays, - { - "kind": "annotation", - "show_scale_bar": bool(show_scale_bar), - "show_color_map": bool(show_color_map), - "text_size": text_size, - "font": normalize_font_spec(font), - }, - ], + context["colormap"] = resolved_colormap + missing_features = [] + xreal = context.get("xreal") + if bool(show_scale_bar) and not (isinstance(xreal, (int, float)) and np.isfinite(float(xreal)) and float(xreal) > 0 and str(context.get("si_unit_xy", "")).strip()): + missing_features.append("scale bar") + if bool(show_color_map): + legend_values = (context.get("legend_min"), context.get("legend_mid"), context.get("legend_max")) + has_legend_values = all( + isinstance(value, (int, float)) and np.isfinite(float(value)) + for value in legend_values + ) + if not (has_legend_values and str(context.get("legend_unit", "")).strip()): + missing_features.append("color-map legend") + if missing_features: + self._send_warning( + f"Annotations image input is missing metadata for: {', '.join(missing_features)}." + ) + annotated = _apply_annotation_overlay_from_context( + image_to_uint8(input), + context, + annotation_spec, ) - return (out,) + return (ImageData(annotated, metadata={"annotation_context": context}),) + + def _send_warning(self, message: str): + fn = Annotations._broadcast_warning_fn + nid = Annotations._current_node_id + if fn and nid: + fn(nid, message) diff --git a/backend/nodes/image.py b/backend/nodes/image.py index a1d7b95..4cf8e18 100644 --- a/backend/nodes/image.py +++ b/backend/nodes/image.py @@ -1,4 +1,5 @@ from __future__ import annotations +from functools import lru_cache import numpy as np from pathlib import Path @@ -48,17 +49,21 @@ class Image: ext = path_obj.suffix.lower() resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis") + stat = path_obj.stat() + cached_fields = Image._load_fields_cached( + str(path_obj.resolve()), + int(stat.st_mtime_ns), + int(stat.st_size), + ) + fields = tuple(field.copy() for field in cached_fields) - if ext in _SPM_EXTENSIONS: - fields = self._load_spm_all(path_obj, ext) - for f in fields: - f.colormap = resolved_colormap - return tuple(fields) + for field in fields: + field.colormap = resolved_colormap - field = self._load_image_or_array(path_obj, ext) - field.colormap = resolved_colormap - self._send_warning("Uncalibrated data — no physical dimensions.") - return (field,) + if ext not in _SPM_EXTENSIONS: + self._send_warning("Uncalibrated data — no physical dimensions.") + + return fields def _send_warning(self, message: str): fn = Image._broadcast_warning_fn @@ -66,17 +71,28 @@ class Image: if fn and nid: fn(nid, message) - def _load_spm_all(self, path: Path, ext: str) -> list[DataField]: + @staticmethod + @lru_cache(maxsize=32) + def _load_fields_cached(path_str: str, mtime_ns: int, size_bytes: int) -> tuple[DataField, ...]: + path = Path(path_str) + ext = path.suffix.lower() + if ext in _SPM_EXTENSIONS: + return tuple(Image._load_spm_all(path, ext)) + return (Image._load_image_or_array(path, ext),) + + @staticmethod + def _load_spm_all(path: Path, ext: str) -> list[DataField]: if ext == ".gwy": - return self._load_gwy_all(path) + return Image._load_gwy_all(path) elif ext == ".sxm": - return self._load_sxm_all(path) + return Image._load_sxm_all(path) elif ext == ".ibw": - return self._load_ibw_all(path) + return Image._load_ibw_all(path) else: raise ValueError(f"Unsupported SPM format: {ext}") - def _load_gwy_all(self, path: Path) -> list[DataField]: + @staticmethod + def _load_gwy_all(path: Path) -> list[DataField]: try: import gwyfile except ImportError: @@ -101,7 +117,8 @@ class Image: )) return fields - def _load_sxm_all(self, path: Path) -> list[DataField]: + @staticmethod + def _load_sxm_all(path: Path) -> list[DataField]: try: import nanonispy as nap except ImportError: @@ -130,7 +147,8 @@ class Image: )) return fields - def _load_ibw_all(self, path: Path) -> list[DataField]: + @staticmethod + def _load_ibw_all(path: Path) -> list[DataField]: try: from igor.binarywave import load as load_ibw except ImportError: @@ -193,7 +211,8 @@ class Image: return fields - def _load_image_or_array(self, path: Path, ext: str) -> DataField: + @staticmethod + def _load_image_or_array(path: Path, ext: str) -> DataField: if ext == ".npy": arr = np.load(str(path)).astype(np.float64) elif ext == ".npz": diff --git a/backend/nodes/markup.py b/backend/nodes/markup.py index c9e6785..3190946 100644 --- a/backend/nodes/markup.py +++ b/backend/nodes/markup.py @@ -1,6 +1,14 @@ from __future__ import annotations from backend.node_registry import register_node -from backend.data_types import DataField, datafield_to_uint8, encode_preview +from backend.data_types import ( + DataField, + ImageData, + _apply_markup_overlay, + encode_preview, + image_metadata, + image_to_uint8, + render_datafield_preview, +) from backend.nodes.helpers import _parse_markup_shapes, _normalize_markup_color @@ -12,7 +20,7 @@ class Markup: def INPUT_TYPES(cls): return { "required": { - "field": ("DATA_FIELD",), + "input": ("ANNOTATION_SOURCE", {"label": "Input"}), "shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}), "stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}), "stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}), @@ -21,13 +29,13 @@ class Markup: } } - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("annotated",) + RETURN_TYPES = ("ANNOTATION_SOURCE",) + RETURN_NAMES = ("Output",) FUNCTION = "process" DESCRIPTION = ( - "Draw simple vector markup over a DATA_FIELD without flattening the underlying data. " - "Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows." + "Draw simple vector markup over a DATA_FIELD without flattening the underlying data, " + "or rasterize markup directly onto an IMAGE." ) _broadcast_overlay_fn = None @@ -35,22 +43,32 @@ class Markup: def process( self, - field: DataField, + input, shape: str, stroke_color: str, stroke_width: int, markup_shapes: str, ) -> tuple: shapes = _parse_markup_shapes(markup_shapes) - out = field.replace( - overlays=[ - *field.overlays, - { - "kind": "markup", - "shapes": shapes, - }, - ], - ) + markup_spec = { + "kind": "markup", + "shapes": shapes, + } + + if isinstance(input, DataField): + out = input.replace( + overlays=[ + *input.overlays, + markup_spec, + ], + ) + preview_base = render_datafield_preview(input, input.colormap) + else: + preview_base = image_to_uint8(input) + out = ImageData( + _apply_markup_overlay(preview_base, None, markup_spec), + metadata=image_metadata(input), + ) if Markup._broadcast_overlay_fn is not None: Markup._broadcast_overlay_fn( @@ -58,7 +76,7 @@ class Markup: { "kind": "markup", "section_title": "Markup", - "image": encode_preview(datafield_to_uint8(field, field.colormap)), + "image": encode_preview(preview_base), "shape": str(shape), "stroke_color": _normalize_markup_color(stroke_color), "stroke_width": max(1, int(stroke_width)), diff --git a/backend/nodes/save.py b/backend/nodes/save.py new file mode 100644 index 0000000..95ef18b --- /dev/null +++ b/backend/nodes/save.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import csv +import json +from pathlib import Path + +import numpy as np + +from backend.node_registry import register_node +from backend.data_types import DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8 + + +@register_node(display_name="Save") +class Save: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "filename": ("STRING", { + "default": "", + "placeholder": "filename", + "placement": "top", + }), + "directory_path": ("FOLDER_PICKER", { + "default": "", + "label": "directory", + "placement": "top", + "hide_when_input_connected": "directory", + "top_socket_input": "directory", + }), + "value": ("SAVE_VALUE", {"label": "value"}), + "format": ("STRING", { + "default": "TIFF", + "choices_by_source_type": { + "DATA_FIELD": ["TIFF", "PNG", "NPZ"], + "IMAGE": ["PNG", "TIFF", "NPZ"], + "LINE": ["CSV", "NPZ", "JSON"], + "MEASURE_TABLE": ["CSV", "JSON"], + "RECORD_TABLE": ["CSV", "JSON"], + "FLOAT": ["TXT", "JSON"], + "MESH_MODEL": ["OBJ", "STL"], + }, + "source_type_input": "value", + }), + }, + "optional": { + "directory": ("DIRECTORY", {"label": "directory"}), + }, + } + + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + MANUAL_TRIGGER = True + DESCRIPTION = ( + "Save a single graph value to disk. Supports fields, images, lines, tables, scalars, and 3D meshes." + ) + + _broadcast_warning_fn = None + _current_node_id = None + + def save( + self, + filename: str, + directory_path: str, + format: str, + value, + directory: str | None = None, + ): + path = self._resolve_save_path(filename, format, directory, directory_path) + + if isinstance(value, MeshModel): + self._save_mesh(path, value, format) + elif isinstance(value, DataField): + self._save_datafield(path, value, format) + elif isinstance(value, np.ndarray): + if value.ndim == 1: + self._save_line(path, LineData(data=value), format) + else: + self._save_image_or_array(path, value, format) + elif isinstance(value, LineData): + self._save_line(path, value, format) + elif isinstance(value, list): + self._save_table(path, value, format) + elif isinstance(value, (int, float, np.floating, np.integer)): + self._save_scalar(path, float(value), format) + else: + raise ValueError(f"Save does not support input type: {type(value).__name__}") + + self._send_warning(f"Saved to {path.name}") + return () + + def _resolve_save_path( + self, + filename: str, + format_name: str, + directory: str | None, + directory_path: str = "", + ) -> Path: + ext_map = { + "PNG": ".png", + "TIFF": ".tiff", + "NPZ": ".npz", + "CSV": ".csv", + "JSON": ".json", + "OBJ": ".obj", + "STL": ".stl", + "TXT": ".txt", + } + ext = ext_map[format_name] + + raw_filename = str(filename).strip() if filename is not None else "" + raw_directory = str(directory).strip() if directory is not None else "" + if not raw_directory: + raw_directory = str(directory_path).strip() if directory_path is not None else "" + + if not raw_filename: + raise ValueError("No output filename selected — enter a file name.") + + if raw_directory: + dir_path = Path(raw_directory).expanduser() + if dir_path.exists() and not dir_path.is_dir(): + raise ValueError("Directory input expects a folder path, not a file path.") + if not dir_path.exists(): + if dir_path.suffix: + raise ValueError("Directory input expects a folder path, not a file path.") + dir_path.mkdir(parents=True, exist_ok=True) + path = dir_path / Path(raw_filename).name + else: + path = Path(raw_filename).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + + if path.suffix.lower() != ext: + path = path.with_suffix(ext) + return path + + def _save_datafield(self, path: Path, field: DataField, format_name: str): + if format_name == "TIFF": + import tifffile + tifffile.imwrite(str(path), np.asarray(field.data, dtype=np.float32)) + return + if format_name == "NPZ": + np.savez(str(path), field=np.asarray(field.data)) + return + if format_name == "PNG": + from PIL import Image + Image.fromarray(datafield_to_uint8(field, field.colormap)).save(str(path)) + return + raise ValueError(f"Format {format_name} is not supported for DATA_FIELD.") + + def _save_image_or_array(self, path: Path, image: np.ndarray, format_name: str): + arr = np.asarray(image) + if format_name == "PNG": + from PIL import Image + Image.fromarray(image_to_uint8(arr)).save(str(path)) + return + if format_name == "TIFF": + import tifffile + tifffile.imwrite(str(path), image_to_uint8(arr)) + return + if format_name == "NPZ": + np.savez(str(path), image=arr) + return + raise ValueError(f"Format {format_name} is not supported for IMAGE.") + + def _save_line(self, path: Path, line: LineData, format_name: str): + y = np.asarray(line.data, dtype=np.float64).ravel() + x = np.asarray(line.x_axis, dtype=np.float64).ravel()[: len(y)] if line.x_axis is not None else np.arange(len(y), dtype=np.float64) + if format_name == "CSV": + with path.open("w", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + writer.writerow(["x", "y", "x_unit", "y_unit"]) + for xv, yv in zip(x, y): + writer.writerow([xv, yv, line.x_unit, line.y_unit]) + return + if format_name == "NPZ": + np.savez(str(path), x=x, y=y) + return + if format_name == "JSON": + path.write_text(json.dumps({ + "x": x.tolist(), + "y": y.tolist(), + "x_unit": line.x_unit, + "y_unit": line.y_unit, + }, indent=2), encoding="utf-8") + return + raise ValueError(f"Format {format_name} is not supported for LINE.") + + def _save_table(self, path: Path, rows: list, format_name: str): + if format_name == "JSON": + path.write_text(json.dumps(rows, indent=2), encoding="utf-8") + return + if format_name == "CSV": + columns: list[str] = [] + for row in rows: + if isinstance(row, dict): + for key in row.keys(): + if key not in columns: + columns.append(str(key)) + with path.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=columns) + writer.writeheader() + for row in rows: + writer.writerow(row if isinstance(row, dict) else {"value": row}) + return + raise ValueError(f"Format {format_name} is not supported for table inputs.") + + def _save_scalar(self, path: Path, value: float, format_name: str): + if format_name == "TXT": + path.write_text(f"{value}\n", encoding="utf-8") + return + if format_name == "JSON": + path.write_text(json.dumps({"value": value}, indent=2), encoding="utf-8") + return + raise ValueError(f"Format {format_name} is not supported for scalar values.") + + def _save_mesh(self, path: Path, mesh: MeshModel, format_name: str): + if format_name == "OBJ": + self._save_obj(path, mesh) + return + if format_name == "STL": + self._save_stl(path, mesh) + return + raise ValueError(f"Format {format_name} is not supported for MESH_MODEL.") + + def _save_obj(self, path: Path, mesh: MeshModel): + lines = [] + for vertex in mesh.vertices: + lines.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}") + for face in mesh.faces: + lines.append(f"f {int(face[0]) + 1} {int(face[1]) + 1} {int(face[2]) + 1}") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + def _save_stl(self, path: Path, mesh: MeshModel): + def normal(a, b, c): + n = np.cross(b - a, c - a) + length = float(np.linalg.norm(n)) + return n / length if length > 0 else np.array([0.0, 1.0, 0.0], dtype=np.float32) + + lines = ["solid argonode"] + vertices = np.asarray(mesh.vertices, dtype=np.float32) + for face in np.asarray(mesh.faces, dtype=np.int32): + a, b, c = vertices[int(face[0])], vertices[int(face[1])], vertices[int(face[2])] + n = normal(a, b, c) + lines.append(f" facet normal {n[0]} {n[1]} {n[2]}") + lines.append(" outer loop") + lines.append(f" vertex {a[0]} {a[1]} {a[2]}") + lines.append(f" vertex {b[0]} {b[1]} {b[2]}") + lines.append(f" vertex {c[0]} {c[1]} {c[2]}") + lines.append(" endloop") + lines.append(" endfacet") + lines.append("endsolid argonode") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + def _send_warning(self, message: str): + fn = Save._broadcast_warning_fn + nid = Save._current_node_id + if fn and nid: + fn(nid, message) diff --git a/backend/nodes/save_image.py b/backend/nodes/save_image.py index 887d602..f122c5e 100644 --- a/backend/nodes/save_image.py +++ b/backend/nodes/save_image.py @@ -49,11 +49,11 @@ class SaveImage: OUTPUT_NODE = True MANUAL_TRIGGER = True DESCRIPTION = ( - "Save one or more layers to a single file. " + "Save one or more image/field layers to a single file. " "Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. " "Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. " "A new slot appears as each one is filled, with a matching per-layer name field. " - "TIFF writes multi-page data and stores layer names as page descriptions; " + "Use this for composing multi-channel stacks. TIFF writes multi-page data and stores layer names as page descriptions; " "NPZ writes named arrays using those layer names as keys. " "Click Save to write (does not auto-run)." ) diff --git a/backend/nodes/view_3d.py b/backend/nodes/view_3d.py index 9ffe623..8bfcef1 100644 --- a/backend/nodes/view_3d.py +++ b/backend/nodes/view_3d.py @@ -1,37 +1,119 @@ from __future__ import annotations +import base64 +import io import numpy as np from backend.node_registry import register_node from backend.data_types import ( COLORMAPS, DataField, + ImageData, + MeshModel, + _annotation_context_from_field, colormap_to_uint8, normalize_for_colormap, resolve_colormap_input, ) +def _darken_colors(colors: np.ndarray, factor: float) -> np.ndarray: + return np.clip(np.rint(colors.astype(np.float32) * factor), 0, 255).astype(np.uint8) + + +def _grid_triangle_indices(nx: int, ny: int, *, reverse: bool = False) -> list[list[int]]: + faces: list[list[int]] = [] + for iy in range(ny - 1): + for ix in range(nx - 1): + a = iy * nx + ix + b = a + 1 + c = a + nx + d = c + 1 + if reverse: + faces.append([a, b, c]) + faces.append([b, d, c]) + else: + faces.append([a, c, b]) + faces.append([b, c, d]) + return faces + + +def _build_mesh_model(z: np.ndarray, colors_u8: np.ndarray, z_scale: float, make_solid: bool) -> MeshModel: + ny, nx = z.shape + zmin = float(z.min()) + zmax = float(z.max()) + z_range = zmax - zmin if zmax != zmin else 1.0 + + top_vertices = np.empty((nx * ny, 3), dtype=np.float32) + top_colors = colors_u8.reshape(-1, 3).astype(np.uint8) + for iy in range(ny): + py = iy / max(ny - 1, 1) - 0.5 + for ix in range(nx): + idx = iy * nx + ix + px = ix / max(nx - 1, 1) - 0.5 + pz = ((float(z[iy, ix]) - zmin) / z_range - 0.5) * z_scale + top_vertices[idx] = (px, pz, py) + + faces = _grid_triangle_indices(nx, ny) + if not make_solid: + return MeshModel(vertices=top_vertices, faces=np.asarray(faces, dtype=np.int32), colors=top_colors) + + base_y = float(top_vertices[:, 1].min()) + bottom_vertices = top_vertices.copy() + bottom_vertices[:, 1] = base_y + bottom_colors = _darken_colors(top_colors, 0.35) + + vertices = np.vstack([top_vertices, bottom_vertices]).astype(np.float32) + colors = np.vstack([top_colors, bottom_colors]).astype(np.uint8) + + bottom_offset = len(top_vertices) + faces.extend([[a + bottom_offset, b + bottom_offset, c + bottom_offset] for a, b, c in _grid_triangle_indices(nx, ny, reverse=True)]) + + def _add_wall(a: int, b: int): + faces.append([a, a + bottom_offset, b]) + faces.append([b, a + bottom_offset, b + bottom_offset]) + + for ix in range(nx - 1): + _add_wall(ix, ix + 1) + top_row = (ny - 1) * nx + _add_wall(top_row + ix + 1, top_row + ix) + for iy in range(ny - 1): + _add_wall((iy + 1) * nx, iy * nx) + _add_wall(iy * nx + (nx - 1), (iy + 1) * nx + (nx - 1)) + + return MeshModel(vertices=vertices, faces=np.asarray(faces, dtype=np.int32), colors=colors) + + @register_node(display_name="3D View") class View3D: + _CUSTOM_PREVIEW = True + @classmethod def INPUT_TYPES(cls): return { "required": { - "field": ("DATA_FIELD",), + "field": ("DATA_FIELD", {"label": "mesh"}), "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), "z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}), "resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}), + "make_solid": ("BOOLEAN", {"default": False}), + "camera_azimuth": ("FLOAT", {"default": 0.0, "hidden": True}), + "camera_polar": ("FLOAT", {"default": 1.1, "hidden": True}), + "camera_distance": ("FLOAT", {"default": 1.8, "hidden": True}), + "viewport_snapshot": ("STRING", {"default": "", "hidden": True}), }, "optional": { + "map_field": ("DATA_FIELD", {"label": "map"}), "colormap_map": ("COLORMAP", {"label": "colormap"}), }, } - RETURN_TYPES = () + RETURN_TYPES = ("MESH_MODEL", "IMAGE") + RETURN_NAMES = ("mesh", "viewport") FUNCTION = "render" OUTPUT_NODE = True DESCRIPTION = ( "Interactive 3D surface view of a DATA_FIELD. " + "Use the mesh input for geometry and optionally a second map input for coloring. " "Drag to rotate, scroll to zoom. z_scale exaggerates height." ) @@ -40,9 +122,12 @@ class View3D: def render( self, field: DataField, - colormap: str, z_scale: float, resolution: int, colormap_map=None, + colormap: str, z_scale: float, resolution: int, make_solid: bool = False, + camera_azimuth: float = 0.0, camera_polar: float = 1.1, camera_distance: float = 1.8, + viewport_snapshot: str = "", + map_field: DataField | None = None, colormap_map=None, ) -> tuple: - import base64 + from scipy.ndimage import map_coordinates data = field.data yres, xres = data.shape @@ -53,33 +138,75 @@ class View3D: ny, nx = z.shape zmin, zmax = float(z.min()), float(z.max()) + color_field = map_field if map_field is not None else field + color_data = color_field.data + + if color_field is field and color_data.shape == z.shape: + color_samples = z + elif color_field is field: + color_samples = color_data[::step_y, ::step_x].astype(np.float32) + else: + x_phys = np.linspace(field.xoff, field.xoff + field.xreal, nx, dtype=np.float64) + y_phys = np.linspace(field.yoff, field.yoff + field.yreal, ny, dtype=np.float64) + grid_y, grid_x = np.meshgrid(y_phys, x_phys, indexing="ij") + + map_x = np.clip( + (grid_x - color_field.xoff) / max(color_field.xreal, 1e-12) * max(color_field.xres - 1, 0), + 0.0, + max(color_field.xres - 1, 0), + ) + map_y = np.clip( + (grid_y - color_field.yoff) / max(color_field.yreal, 1e-12) * max(color_field.yres - 1, 0), + 0.0, + max(color_field.yres - 1, 0), + ) + color_samples = map_coordinates( + color_data.astype(np.float64), + [map_y, map_x], + order=1, + mode="nearest", + ).astype(np.float32) + z_norm = normalize_for_colormap( - z, - offset=field.display_offset, - scale=field.display_scale, - data_min=float(field.data.min()), - data_max=float(field.data.max()), + color_samples, + offset=color_field.display_offset, + scale=color_field.display_scale, + data_min=float(color_field.data.min()), + data_max=float(color_field.data.max()), ) resolved_colormap = resolve_colormap_input( colormap, colormap_input=colormap_map, - inherited=field.colormap, + inherited=color_field.colormap, default="gray", ) colors_u8 = colormap_to_uint8(z_norm, resolved_colormap) + mesh_model = _build_mesh_model(z, colors_u8, float(z_scale * 0.1), bool(make_solid)) z_b64 = base64.b64encode(z.tobytes()).decode() colors_b64 = base64.b64encode(colors_u8.tobytes()).decode() + positions_b64 = base64.b64encode(np.asarray(mesh_model.vertices, dtype=np.float32).tobytes()).decode() + indices_b64 = base64.b64encode(np.asarray(mesh_model.faces, dtype=np.uint32).tobytes()).decode() + mesh_colors_b64 = None + if mesh_model.colors is not None: + mesh_colors_b64 = base64.b64encode(np.asarray(mesh_model.colors, dtype=np.uint8).tobytes()).decode() mesh_data = { "width": nx, "height": ny, "z_data": z_b64, "colors": colors_b64, + "positions": positions_b64, + "indices": indices_b64, + "vertex_colors": mesh_colors_b64, "z_min": zmin, "z_max": zmax, "z_scale": float(z_scale * 0.1), + "make_solid": bool(make_solid), + "camera_azimuth": float(camera_azimuth), + "camera_polar": float(camera_polar), + "camera_distance": float(camera_distance), "x_range": [float(field.xoff), float(field.xoff + field.xreal)], "y_range": [float(field.yoff), float(field.yoff + field.yreal)], } @@ -87,4 +214,32 @@ class View3D: if View3D._broadcast_mesh_fn is not None: View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data) - return () + annotation_context = _annotation_context_from_field(color_field, resolved_colormap) + annotation_context["xreal"] = float(field.xreal) + annotation_context["si_unit_xy"] = str(field.si_unit_xy) + viewport_image = ImageData( + self._decode_viewport_snapshot(viewport_snapshot), + metadata={ + "annotation_context": annotation_context, + "viewport_camera": { + "azimuth": float(camera_azimuth), + "polar": float(camera_polar), + "distance": float(camera_distance), + }, + }, + ) + return (mesh_model, viewport_image) + + def _decode_viewport_snapshot(self, snapshot: str) -> np.ndarray: + text = str(snapshot or "").strip() + if not text.startswith("data:image/"): + return np.zeros((1, 1, 3), dtype=np.uint8) + + try: + header, payload = text.split(",", 1) + raw = base64.b64decode(payload) + from PIL import Image + image = Image.open(io.BytesIO(raw)).convert("RGB") + return np.asarray(image, dtype=np.uint8) + except Exception: + return np.zeros((1, 1, 3), dtype=np.uint8) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 2d6cd2c..a89b236 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -16,6 +16,12 @@ import { embedWorkflow, extractWorkflow } from './pngMetadata'; import { captureViewportBlob as captureWorkflowViewportBlob } from './workflowCapture'; import { hydrateWorkflowState } from './workflowHydration'; import { serializeWorkflowState } from './workflowSerialization'; +import { + buildNodeClipboardPayload, + instantiateNodeClipboardPayload, + NODE_CLIPBOARD_MIME, + parseNodeClipboardPayload, +} from './nodeClipboard'; import { loadDefaultWorkflowAsset } from './defaultWorkflow'; import { serializeExecutionGraph, @@ -49,6 +55,12 @@ function sameStringArray(a = [], b = []) { return a.every((item, index) => item === b[index]); } +function isEditableTarget(target) { + if (!target || !(target instanceof Element)) return false; + if (target.closest('input, textarea, select')) return true; + return target.closest('[contenteditable="true"]') !== null; +} + function compareMenuNodes(a, b) { const orderA = Number.isFinite(a?.def?.menu_order) ? a.def.menu_order : Number.MAX_SAFE_INTEGER; const orderB = Number.isFinite(b?.def?.menu_order) ? b.def.menu_order : Number.MAX_SAFE_INTEGER; @@ -254,7 +266,11 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti }); if (!hasMatch) continue; } else { - if (!def.output.some((type) => socketTypesCompatible(type, filterType))) continue; + const hasMatch = def.output.some((type) => + socketTypesCompatible(type, filterType) + || (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE')) + ); + if (!hasMatch) continue; } } const cat = def.category || 'uncategorized'; @@ -454,6 +470,8 @@ function Flow() { const autoRunTimer = useRef(null); const autoRunRef = useRef(null); const defaultWorkflowLoadAttemptedRef = useRef(false); + const lastPastedClipboardTextRef = useRef(''); + const pasteRepeatCountRef = useRef(0); const reactFlow = useReactFlow(); // ── WebSocket ─────────────────────────────────────────────────────── @@ -554,6 +572,24 @@ function Flow() { } }, [reactFlow, refreshLoadNodeOutputs, setNodeOutputs]); + const refreshAnnotationNodeOutputs = useCallback((nodeId) => { + const node = reactFlow.getNode(nodeId); + if (!node) return; + + const inputEdge = reactFlow.getEdges().find( + (edge) => edge.target === nodeId && getInputName(edge.targetHandle) === 'input' + ); + const outputType = inputEdge ? getHandleType(inputEdge.sourceHandle) : 'ANNOTATION_SOURCE'; + setNodeOutputs(nodeId, [outputType], ['Output']); + + if (!inputEdge || outputType === 'ANNOTATION_SOURCE') return; + + setEdges((prev) => prev.filter((edge) => { + if (edge.source !== nodeId) return true; + return socketTypesCompatible(outputType, getHandleType(edge.targetHandle)); + })); + }, [reactFlow, setEdges, setNodeOutputs]); + useEffect(() => { api.setMessageHandler((msg) => { console.log('[argonode] WS:', msg.type, msg.data?.node_id || msg.data?.node || ''); @@ -639,14 +675,21 @@ function Flow() { refreshLoadNodeOutputs(params.target); }, 0); } + const targetNode = reactFlow.getNode(params.target); + if (targetNode && (targetNode.data.className === 'Annotations' || targetNode.data.className === 'Markup')) { + setTimeout(() => { + refreshAnnotationNodeOutputs(params.target); + }, 0); + } scheduleAutoRun(); - }, [refreshLoadNodeOutputs, setEdges]); // scheduleAutoRun is stable (no deps) + }, [reactFlow, refreshAnnotationNodeOutputs, refreshLoadNodeOutputs, setEdges]); // scheduleAutoRun is stable (no deps) const handleEdgesChange = useCallback((changes) => { const currentEdges = reactFlow.getEdges(); onEdgesChange(changes); const affectedPathTargets = new Set(); + const affectedAnnotationTargets = new Set(); for (const change of changes) { if (change.type !== 'remove') continue; const removedEdge = currentEdges.find((edge) => edge.id === change.id); @@ -654,6 +697,12 @@ function Flow() { if (getInputName(removedEdge.targetHandle) === 'path') { affectedPathTargets.add(removedEdge.target); } + if (getInputName(removedEdge.targetHandle) === 'input') { + const targetNode = reactFlow.getNode(removedEdge.target); + if (targetNode && (targetNode.data.className === 'Annotations' || targetNode.data.className === 'Markup')) { + affectedAnnotationTargets.add(removedEdge.target); + } + } } if (affectedPathTargets.size > 0) { @@ -663,7 +712,14 @@ function Flow() { }); }, 0); } - }, [onEdgesChange, reactFlow, refreshLoadNodeOutputs]); + if (affectedAnnotationTargets.size > 0) { + setTimeout(() => { + affectedAnnotationTargets.forEach((nodeId) => { + refreshAnnotationNodeOutputs(nodeId); + }); + }, 0); + } + }, [onEdgesChange, reactFlow, refreshAnnotationNodeOutputs, refreshLoadNodeOutputs]); // ── Drop-on-blank: open filtered context menu ────────────────────── @@ -749,12 +805,6 @@ function Flow() { }); }, [reactFlow]); - const contextValue = useMemo(() => ({ - onWidgetChange, - openFileBrowser, - onManualTrigger, - }), [onWidgetChange, openFileBrowser, onManualTrigger]); - // ── Add node from context menu ────────────────────────────────────── const addNode = useCallback((className, def) => { @@ -789,6 +839,7 @@ function Flow() { className, definition: def, widgetValues, + runtimeValues: {}, previewImage: null, tableRows: null, meshData: null, @@ -842,9 +893,12 @@ function Flow() { } } else { // Dragged from an input → connect from the first matching output on the new node - const outputIdx = def.output.findIndex((type) => socketTypesCompatible(type, filterType)); + const outputIdx = def.output.findIndex((type) => + socketTypesCompatible(type, filterType) + || (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE')) + ); if (outputIdx !== -1) { - const outputType = def.output[outputIdx]; + const outputType = def.output[outputIdx] === 'ANNOTATION_SOURCE' ? filterType : def.output[outputIdx]; const sourceHandle = `output::${outputIdx}::${outputType}`; const color = TYPE_COLORS[outputType] || 'var(--fallback-type)'; setEdges((eds) => addEdge({ @@ -907,6 +961,101 @@ function Flow() { autoRunTimer.current = setTimeout(() => autoRunRef.current?.(), 300); }, []); + const onRuntimeValuesChange = useCallback((nodeId, patch, { scheduleRun = false } = {}) => { + if (!patch || typeof patch !== 'object') return; + + setNodes((ns) => ns.map((n) => { + if (n.id !== nodeId) return n; + return { + ...n, + data: { + ...n.data, + runtimeValues: { ...(n.data.runtimeValues || {}), ...patch }, + }, + }; + })); + + if (scheduleRun) { + scheduleAutoRun(); + } + }, [setNodes, scheduleAutoRun]); + + const pasteClipboardSelection = useCallback((clipboardText) => { + const payload = parseNodeClipboardPayload(clipboardText); + if (!payload) return false; + + if (clipboardText === lastPastedClipboardTextRef.current) { + pasteRepeatCountRef.current += 1; + } else { + lastPastedClipboardTextRef.current = clipboardText; + pasteRepeatCountRef.current = 1; + } + + const offsetAmount = 36 * pasteRepeatCountRef.current; + const pasted = instantiateNodeClipboardPayload( + payload, + nodeDefsRef.current, + nextIdRef.current, + { x: offsetAmount, y: offsetAmount }, + ); + + if (pasted.nodes.length === 0) return false; + + nextIdRef.current = pasted.nextNodeId; + + setNodes((existing) => [ + ...existing.map((node) => ({ ...node, selected: false })), + ...pasted.nodes, + ]); + setEdges((existing) => [ + ...existing.map((edge) => ({ ...edge, selected: false })), + ...pasted.edges, + ]); + + setTimeout(() => { + pasted.nodes.forEach((node) => { + if (node.data.className === 'Folder' && node.data.widgetValues?.folder) { + refreshFolderNodeOutputs(node.id, node.data.widgetValues.folder); + } + }); + pasted.nodes.forEach((node) => { + if (node.data.className === 'Image' || node.data.className === 'ImageDemo') { + refreshLoadNodeOutputs(node.id); + } + }); + pasted.nodes.forEach((node) => { + if (node.data.className === 'Annotations' || node.data.className === 'Markup') { + refreshAnnotationNodeOutputs(node.id); + } + }); + pasted.nodes.forEach((node) => { + reactFlow.updateNodeInternals(node.id); + }); + }, 0); + + setStatus({ + text: `Pasted ${pasted.nodes.length} node${pasted.nodes.length === 1 ? '' : 's'}.`, + level: 'info', + }); + scheduleAutoRun(); + return true; + }, [ + reactFlow, + refreshAnnotationNodeOutputs, + refreshFolderNodeOutputs, + refreshLoadNodeOutputs, + scheduleAutoRun, + setEdges, + setNodes, + ]); + + const contextValue = useMemo(() => ({ + onWidgetChange, + onRuntimeValuesChange, + openFileBrowser, + onManualTrigger, + }), [onRuntimeValuesChange, onWidgetChange, openFileBrowser, onManualTrigger]); + const clearGraph = useCallback(() => { setNodes([]); setEdges([]); @@ -930,8 +1079,13 @@ function Flow() { refreshLoadNodeOutputs(node.id); } }); + hydrated.nodes.forEach((node) => { + if (node.data.className === 'Annotations' || node.data.className === 'Markup') { + refreshAnnotationNodeOutputs(node.id); + } + }); }, 0); - }, [refreshFolderNodeOutputs, refreshLoadNodeOutputs, setNodes, setEdges]); + }, [refreshAnnotationNodeOutputs, refreshFolderNodeOutputs, refreshLoadNodeOutputs, setNodes, setEdges]); const loadDefaultWorkflow = useCallback(async () => { if (defaultWorkflowLoadAttemptedRef.current) return; @@ -1168,6 +1322,45 @@ function Flow() { return () => window.removeEventListener('keydown', handler); }, [runWorkflow]); + useEffect(() => { + const handleCopy = (event) => { + if (isEditableTarget(event.target)) return; + + const payload = buildNodeClipboardPayload(reactFlow.getNodes(), reactFlow.getEdges()); + if (!payload) return; + + const serialized = JSON.stringify(payload); + event.preventDefault(); + event.clipboardData?.setData(NODE_CLIPBOARD_MIME, serialized); + event.clipboardData?.setData('text/plain', serialized); + setStatus({ + text: `Copied ${payload.nodes.length} node${payload.nodes.length === 1 ? '' : 's'}.`, + level: 'info', + }); + }; + + const handlePaste = (event) => { + if (isEditableTarget(event.target)) return; + + const clipboardText = event.clipboardData?.getData(NODE_CLIPBOARD_MIME) + || event.clipboardData?.getData('text/plain') + || ''; + if (!clipboardText) return; + + const pasted = pasteClipboardSelection(clipboardText); + if (pasted) { + event.preventDefault(); + } + }; + + window.addEventListener('copy', handleCopy); + window.addEventListener('paste', handlePaste); + return () => { + window.removeEventListener('copy', handleCopy); + window.removeEventListener('paste', handlePaste); + }; + }, [pasteClipboardSelection, reactFlow]); + // ── Context menu ──────────────────────────────────────────────────── const onPaneContextMenu = useCallback((event) => { diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 37d1b00..b5b4666 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -1075,7 +1075,13 @@ function CustomNode({ id, data }) { {data.meshData && ( Loading 3D...}> - + )} diff --git a/frontend/src/SurfaceView.jsx b/frontend/src/SurfaceView.jsx index 8247ed7..df8e53e 100644 --- a/frontend/src/SurfaceView.jsx +++ b/frontend/src/SurfaceView.jsx @@ -8,9 +8,13 @@ import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js'; * meshData: { width, height, z_data (b64 float32), colors (b64 uint8 RGB), * z_min, z_max, z_scale, x_range, y_range } */ -export default function SurfaceView({ meshData }) { +export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeValues, onRuntimeValuesChange }) { const containerRef = useRef(null); const threeRef = useRef(null); // { renderer, scene, camera, controls, mesh } + const syncTimerRef = useRef(null); + const lastSnapshotRef = useRef(''); + const lastAnglesRef = useRef({ azimuth: null, polar: null, distance: null }); + const hasSyncedInitialSnapshotRef = useRef(false); // Decode base64 to typed arrays const decode = useCallback((b64, ArrayType) => { @@ -20,6 +24,52 @@ export default function SurfaceView({ meshData }) { return new ArrayType(bytes.buffer); }, []); + const syncViewportState = useCallback((scheduleRun = false) => { + const state = threeRef.current; + if (!state || !nodeId || !onRuntimeValuesChange) return; + const { renderer, controls } = state; + const azimuth = Number(controls.getAzimuthalAngle().toFixed(4)); + const polar = Number(controls.getPolarAngle().toFixed(4)); + const distance = Number(controls.getDistance().toFixed(4)); + const snapshot = renderer.domElement.toDataURL('image/png'); + const previous = lastAnglesRef.current; + const patch = {}; + if (previous.azimuth !== azimuth) patch.camera_azimuth = azimuth; + if (previous.polar !== polar) patch.camera_polar = polar; + if (previous.distance !== distance) patch.camera_distance = distance; + if (snapshot !== lastSnapshotRef.current) patch.viewport_snapshot = snapshot; + if (Object.keys(patch).length > 0) { + onRuntimeValuesChange(nodeId, patch, { scheduleRun }); + lastAnglesRef.current = { azimuth, polar, distance }; + lastSnapshotRef.current = snapshot; + } + }, [nodeId, onRuntimeValuesChange]); + + const scheduleViewportSync = useCallback((delay = 120, scheduleRun = false) => { + if (syncTimerRef.current) { + clearTimeout(syncTimerRef.current); + } + syncTimerRef.current = setTimeout(() => { + syncTimerRef.current = null; + syncViewportState(scheduleRun); + }, delay); + }, [syncViewportState]); + + const applyCameraState = useCallback((azimuth, polar, distance) => { + const state = threeRef.current; + if (!state) return; + const { camera, controls } = state; + const target = controls.target.clone(); + const spherical = new THREE.Spherical( + Math.max(0.3, Number.isFinite(distance) ? distance : 1.8), + THREE.MathUtils.clamp(Number.isFinite(polar) ? polar : 1.1, 0.01, Math.PI - 0.01), + Number.isFinite(azimuth) ? azimuth : 0.0, + ); + const offset = new THREE.Vector3().setFromSpherical(spherical); + camera.position.copy(target).add(offset); + controls.update(); + }, []); + // Initialize Three.js scene once useEffect(() => { const container = containerRef.current; @@ -48,6 +98,8 @@ export default function SurfaceView({ meshData }) { controls.dampingFactor = 0.1; controls.minDistance = 0.3; controls.maxDistance = 10; + const handleControlsEnd = () => scheduleViewportSync(0, true); + controls.addEventListener('end', handleControlsEnd); // Lighting const ambient = new THREE.AmbientLight(0xffffff, 0.4); @@ -69,6 +121,11 @@ export default function SurfaceView({ meshData }) { animate(); threeRef.current = { renderer, scene, camera, controls, mesh: null, animId }; + applyCameraState( + Number(runtimeValues?.camera_azimuth ?? widgetValues?.camera_azimuth), + Number(runtimeValues?.camera_polar ?? widgetValues?.camera_polar), + Number(runtimeValues?.camera_distance ?? widgetValues?.camera_distance), + ); // Resize observer to maintain 1:1 aspect when node width changes const ro = new ResizeObserver((entries) => { @@ -86,6 +143,8 @@ export default function SurfaceView({ meshData }) { return () => { ro.disconnect(); cancelAnimationFrame(animId); + if (syncTimerRef.current) clearTimeout(syncTimerRef.current); + controls.removeEventListener('end', handleControlsEnd); controls.dispose(); renderer.dispose(); if (container.contains(renderer.domElement)) { @@ -93,18 +152,24 @@ export default function SurfaceView({ meshData }) { } threeRef.current = null; }; - }, []); + }, [applyCameraState, scheduleViewportSync]); // Update mesh when data changes useEffect(() => { if (!threeRef.current || !meshData) return; const { scene, camera, controls } = threeRef.current; - const { width: nx, height: ny, z_data, colors, z_min, z_max, z_scale, x_range, y_range } = meshData; + const { + width: nx, height: ny, z_data, colors, z_min, z_max, z_scale, + positions, indices, vertex_colors, camera_azimuth, camera_polar, camera_distance, + } = meshData; // Decode arrays - const zArr = decode(z_data, Float32Array); - const colArr = decode(colors, Uint8Array); + const zArr = z_data ? decode(z_data, Float32Array) : null; + const colArr = colors ? decode(colors, Uint8Array) : null; + const posArr = positions ? decode(positions, Float32Array) : null; + const indexArr = indices ? decode(indices, Uint32Array) : null; + const vertexColorArr = vertex_colors ? decode(vertex_colors, Uint8Array) : null; // Remove old mesh if (threeRef.current.mesh) { @@ -115,45 +180,51 @@ export default function SurfaceView({ meshData }) { // Build geometry const geom = new THREE.BufferGeometry(); - const positions = new Float32Array(nx * ny * 3); - const colorAttr = new Float32Array(nx * ny * 3); + const positionsArray = posArr ?? new Float32Array(nx * ny * 3); + const colorAttr = new Float32Array((vertexColorArr ? vertexColorArr.length : (nx * ny * 3))); - // Normalize coordinates to roughly [-0.5, 0.5] for good camera framing - const zRange = z_max - z_min || 1; + if (!posArr) { + const zRange = z_max - z_min || 1; + for (let iy = 0; iy < ny; iy++) { + for (let ix = 0; ix < nx; ix++) { + const idx = iy * nx + ix; + const px = ix / (nx - 1) - 0.5; + const py = iy / (ny - 1) - 0.5; + const pz = ((zArr[idx] - z_min) / zRange - 0.5) * z_scale; - for (let iy = 0; iy < ny; iy++) { - for (let ix = 0; ix < nx; ix++) { - const idx = iy * nx + ix; - const px = ix / (nx - 1) - 0.5; // [-0.5, 0.5] - const py = iy / (ny - 1) - 0.5; - const pz = ((zArr[idx] - z_min) / zRange - 0.5) * z_scale; - - positions[idx * 3] = px; - positions[idx * 3 + 1] = pz; // height on Y axis - positions[idx * 3 + 2] = py; - - colorAttr[idx * 3] = colArr[idx * 3] / 255; - colorAttr[idx * 3 + 1] = colArr[idx * 3 + 1] / 255; - colorAttr[idx * 3 + 2] = colArr[idx * 3 + 2] / 255; + positionsArray[idx * 3] = px; + positionsArray[idx * 3 + 1] = pz; + positionsArray[idx * 3 + 2] = py; + } } } - geom.setAttribute('position', new THREE.BufferAttribute(positions, 3)); + const sourceColors = vertexColorArr ?? colArr; + if (sourceColors) { + for (let i = 0; i < sourceColors.length; i += 1) { + colorAttr[i] = sourceColors[i] / 255; + } + } + + geom.setAttribute('position', new THREE.BufferAttribute(positionsArray, 3)); geom.setAttribute('color', new THREE.BufferAttribute(colorAttr, 3)); - // Build index (triangles from grid) - const indices = []; - for (let iy = 0; iy < ny - 1; iy++) { - for (let ix = 0; ix < nx - 1; ix++) { - const a = iy * nx + ix; - const b = a + 1; - const c = a + nx; - const d = c + 1; - indices.push(a, c, b); - indices.push(b, c, d); + if (indexArr) { + geom.setIndex(Array.from(indexArr)); + } else { + const gridIndices = []; + for (let iy = 0; iy < ny - 1; iy++) { + for (let ix = 0; ix < nx - 1; ix++) { + const a = iy * nx + ix; + const b = a + 1; + const c = a + nx; + const d = c + 1; + gridIndices.push(a, c, b); + gridIndices.push(b, c, d); + } } + geom.setIndex(gridIndices); } - geom.setIndex(indices); geom.computeVertexNormals(); const mat = new THREE.MeshPhongMaterial({ @@ -169,8 +240,16 @@ export default function SurfaceView({ meshData }) { // Reset camera target to center of mesh controls.target.set(0, 0, 0); - controls.update(); - }, [meshData, decode]); + if (!hasSyncedInitialSnapshotRef.current) { + applyCameraState( + Number.isFinite(camera_azimuth) ? camera_azimuth : Number(runtimeValues?.camera_azimuth ?? widgetValues?.camera_azimuth), + Number.isFinite(camera_polar) ? camera_polar : Number(runtimeValues?.camera_polar ?? widgetValues?.camera_polar), + Number.isFinite(camera_distance) ? camera_distance : Number(runtimeValues?.camera_distance ?? widgetValues?.camera_distance), + ); + hasSyncedInitialSnapshotRef.current = true; + } + scheduleViewportSync(0, false); + }, [meshData, decode, applyCameraState, runtimeValues, scheduleViewportSync, widgetValues]); // Prevent scroll events from propagating to React Flow const onWheel = useCallback((e) => { diff --git a/frontend/src/constants.js b/frontend/src/constants.js index cf398aa..691da4d 100644 --- a/frontend/src/constants.js +++ b/frontend/src/constants.js @@ -2,8 +2,8 @@ export const DATA_TYPES = new Set([ 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE', - 'COORD', 'STATS_SOURCE', 'CURSOR_SOURCE', 'VALUE_SOURCE', 'COLORMAP', - 'SAVE_LAYER', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR', + 'COORD', 'STATS_SOURCE', 'CURSOR_SOURCE', 'VALUE_SOURCE', 'ANNOTATION_SOURCE', 'COLORMAP', + 'SAVE_LAYER', 'SAVE_VALUE', 'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR', ]); export const SOCKET_WIDGET_TYPES = new Set(['FLOAT', 'INT']); @@ -22,8 +22,11 @@ export const TYPE_COLORS = { STATS_SOURCE: '#c084fc', CURSOR_SOURCE: '#a78bfa', VALUE_SOURCE: '#60a5fa', + ANNOTATION_SOURCE: '#06b6d4', COLORMAP: '#f472b6', SAVE_LAYER: '#22c55e', + SAVE_VALUE: '#4ade80', + MESH_MODEL: '#14b8a6', FONT: '#fb7185', FILE_PATH: '#f59e0b', DIRECTORY: '#f97316', @@ -44,7 +47,9 @@ export const SOCKET_COMPATIBILITY = { CURSOR_SOURCE: new Set(['DATA_FIELD', 'LINE']), ANY_TABLE: new Set(['MEASURE_TABLE', 'RECORD_TABLE']), VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']), + ANNOTATION_SOURCE: new Set(['DATA_FIELD', 'IMAGE']), SAVE_LAYER: new Set(['DATA_FIELD', 'IMAGE']), + SAVE_VALUE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'MESH_MODEL', 'FLOAT']), FLOAT: new Set(['INT']), INT: new Set(['FLOAT']), LINE: new Set(['COORDPAIR']), diff --git a/frontend/src/executionGraph.js b/frontend/src/executionGraph.js index 25e9b2f..1bc0cf7 100644 --- a/frontend/src/executionGraph.js +++ b/frontend/src/executionGraph.js @@ -52,11 +52,12 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f for (const node of nodes) { if (!runnableNodeIds.has(node.id)) continue; - const { className, definition, widgetValues } = node.data; + const { className, definition, widgetValues, runtimeValues } = node.data; if (!definition) continue; if (excludeManualTrigger && definition.manual_trigger) continue; const inputs = {}; + const valueBag = { ...(widgetValues || {}), ...(runtimeValues || {}) }; const allWidgets = { ...(definition.input.required || {}), @@ -66,8 +67,8 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f const [type] = Array.isArray(spec) ? spec : [spec]; if (DATA_TYPES.has(type)) continue; if (type === 'BUTTON') continue; - if (widgetValues[name] !== undefined) { - inputs[name] = widgetValues[name]; + if (valueBag[name] !== undefined) { + inputs[name] = valueBag[name]; } } diff --git a/frontend/src/nodeClipboard.js b/frontend/src/nodeClipboard.js new file mode 100644 index 0000000..b87e533 --- /dev/null +++ b/frontend/src/nodeClipboard.js @@ -0,0 +1,128 @@ +export const NODE_CLIPBOARD_KIND = 'argonode/node-selection'; +export const NODE_CLIPBOARD_MIME = 'application/x-argonode-node-selection'; + +function cloneValue(value) { + if (value == null) return value; + if (typeof structuredClone === 'function') { + try { + return structuredClone(value); + } catch { + // Fall through to JSON clone for simple plain data. + } + } + return JSON.parse(JSON.stringify(value)); +} + +function clonePlainObject(value) { + if (!value || typeof value !== 'object' || Array.isArray(value)) return {}; + return cloneValue(value) || {}; +} + +export function buildNodeClipboardPayload(nodes, edges) { + const selectedNodes = Array.isArray(nodes) ? nodes.filter((node) => node?.selected) : []; + if (selectedNodes.length === 0) return null; + + const selectedIds = new Set(selectedNodes.map((node) => String(node.id))); + const internalEdges = Array.isArray(edges) + ? edges.filter((edge) => selectedIds.has(String(edge.source)) && selectedIds.has(String(edge.target))) + : []; + + return { + kind: NODE_CLIPBOARD_KIND, + version: 1, + nodes: selectedNodes.map((node) => ({ + id: String(node.id), + type: node.type || 'custom', + position: { + x: Number(node.position?.x) || 0, + y: Number(node.position?.y) || 0, + }, + dragHandle: node.dragHandle || '.drag-handle', + data: { + label: node.data?.label || node.data?.className || 'Node', + className: node.data?.className || '', + widgetValues: clonePlainObject(node.data?.widgetValues), + runtimeValues: clonePlainObject(node.data?.runtimeValues), + }, + })), + edges: internalEdges.map((edge) => ({ + source: String(edge.source), + sourceHandle: edge.sourceHandle, + target: String(edge.target), + targetHandle: edge.targetHandle, + ...(edge.style ? { style: { ...edge.style } } : {}), + })), + }; +} + +export function parseNodeClipboardPayload(text) { + if (typeof text !== 'string' || !text.trim()) return null; + + try { + const parsed = JSON.parse(text); + if (parsed?.kind !== NODE_CLIPBOARD_KIND) return null; + if (!Array.isArray(parsed.nodes) || !Array.isArray(parsed.edges)) return null; + return parsed; + } catch { + return null; + } +} + +export function instantiateNodeClipboardPayload(payload, defs = {}, nextNodeId = 1, offset = { x: 40, y: 40 }) { + if (!payload || !Array.isArray(payload.nodes) || payload.nodes.length === 0) { + return { nodes: [], edges: [], nextNodeId }; + } + + const idMap = new Map(); + let currentId = Number(nextNodeId) || 1; + + const nodes = payload.nodes.map((node) => { + const newId = String(currentId++); + idMap.set(String(node.id), newId); + const className = node.data?.className || ''; + const definition = className ? defs[className] || null : null; + + return { + id: newId, + type: node.type || 'custom', + position: { + x: (Number(node.position?.x) || 0) + (Number(offset?.x) || 0), + y: (Number(node.position?.y) || 0) + (Number(offset?.y) || 0), + }, + dragHandle: node.dragHandle || '.drag-handle', + selected: true, + data: { + label: node.data?.label || className || 'Node', + className, + widgetValues: clonePlainObject(node.data?.widgetValues), + runtimeValues: clonePlainObject(node.data?.runtimeValues), + definition, + previewImage: null, + tableRows: null, + meshData: null, + overlay: null, + scalarValue: null, + processingTimeMs: null, + warning: null, + }, + }; + }); + + const edges = payload.edges + .filter((edge) => idMap.has(String(edge.source)) && idMap.has(String(edge.target))) + .map((edge, index) => ({ + id: `e${idMap.get(String(edge.source))}-${idMap.get(String(edge.target))}-${index}`, + source: idMap.get(String(edge.source)), + sourceHandle: edge.sourceHandle, + target: idMap.get(String(edge.target)), + targetHandle: edge.targetHandle, + selected: false, + ...(edge.style ? { style: { ...edge.style } } : {}), + })); + + return { + nodes, + edges, + nextNodeId: currentId, + }; +} diff --git a/frontend/src/pngMetadata.js b/frontend/src/pngMetadata.js index 716d3e7..2e76afa 100644 --- a/frontend/src/pngMetadata.js +++ b/frontend/src/pngMetadata.js @@ -3,7 +3,7 @@ * * PNG files are composed of chunks: [4-byte length][4-byte type][data][4-byte CRC]. * We add an iTXt chunk with key "workflow" containing the JSON-serialised graph, - * inserted just before the IEND chunk. We still read legacy tEXt chunks. + * inserted just before the IEND chunk. */ // ── CRC32 (PNG uses CRC-32/ISO 3309) ──────────────────────────────── @@ -71,10 +71,6 @@ function parseTextChunk(type, chunkData) { const keyword = decoder.decode(chunkData.subarray(0, keywordEnd)); if (keyword !== 'workflow') return null; - if (type === 'tEXt') { - return JSON.parse(decoder.decode(chunkData.subarray(keywordEnd + 1))); - } - if (type !== 'iTXt') return null; const compressionFlagIdx = keywordEnd + 1; @@ -139,7 +135,7 @@ export async function embedWorkflow(pngBlob, workflow) { } /** - * Extract the workflow object from a PNG blob's iTXt/tEXt chunks. + * Extract the workflow object from a PNG blob's iTXt chunks. * Returns the parsed object, or null if no "workflow" key is found. */ export async function extractWorkflow(pngBlob) { @@ -154,7 +150,7 @@ export async function extractWorkflow(pngBlob) { if (pos + 12 + len > data.length) break; const type = chunkType(data, pos); - if (type === 'tEXt' || type === 'iTXt') { + if (type === 'iTXt') { const chunkData = data.subarray(pos + 8, pos + 8 + len); const parsed = parseTextChunk(type, chunkData); if (parsed) found = parsed; diff --git a/frontend/src/workflowHydration.js b/frontend/src/workflowHydration.js index 39691ee..3e97032 100644 --- a/frontend/src/workflowHydration.js +++ b/frontend/src/workflowHydration.js @@ -1,25 +1,7 @@ function mergeDefinition(nodeData, defs) { const savedData = nodeData || {}; - const savedDefinition = savedData.definition && typeof savedData.definition === 'object' - ? savedData.definition - : null; const registryDefinition = savedData.className ? defs[savedData.className] : null; - const definition = registryDefinition || savedDefinition; - - if (!definition) return null; - - const output = Array.isArray(savedData.output) - ? savedData.output - : (Array.isArray(savedDefinition?.output) ? savedDefinition.output : null); - const outputName = Array.isArray(savedData.output_name) - ? savedData.output_name - : (Array.isArray(savedDefinition?.output_name) ? savedDefinition.output_name : null); - - return { - ...definition, - ...(output ? { output } : {}), - ...(outputName ? { output_name: outputName } : {}), - }; + return registryDefinition || null; } function getSocketType(inputDef) { @@ -28,12 +10,6 @@ function getSocketType(inputDef) { return Array.isArray(type) ? type[0] : type; } -function getInputType(definition, inputName) { - const required = definition?.input?.required || {}; - const optional = definition?.input?.optional || {}; - return getSocketType(required[inputName] ?? optional[inputName]); -} - function getInputEntries(definition) { return [ ...Object.entries(definition?.input?.required || {}), @@ -54,31 +30,6 @@ function sanitizeWidgetValues(widgetValues, definition) { return nextValues; } -function remapLegacyHandle(handleId, kind, nodeData) { - if (typeof handleId !== 'string') return handleId; - - const parts = handleId.split('::'); - if (parts.length !== 3 || parts[2] !== 'TABLE') return handleId; - - if (kind === 'source' && parts[0] === 'output') { - const outputSlot = Number.parseInt(parts[1], 10); - const outputType = nodeData?.definition?.output?.[outputSlot]; - if (typeof outputType === 'string' && outputType !== 'TABLE') { - return `output::${outputSlot}::${outputType}`; - } - return handleId; - } - - if (kind === 'target' && parts[0] === 'input') { - const inputType = getInputType(nodeData?.definition, parts[1]); - if (typeof inputType === 'string' && inputType !== 'TABLE') { - return `input::${parts[1]}::${inputType}`; - } - } - - return handleId; -} - export function hydrateWorkflowState(data, defs = {}) { const loadedNodes = Array.isArray(data?.nodes) ? data.nodes : []; const loadedEdges = Array.isArray(data?.edges) ? data.edges : []; @@ -94,6 +45,7 @@ export function hydrateWorkflowState(data, defs = {}) { ...node.data, label: node.data?.label || node.data?.className || 'Node', widgetValues: sanitizeWidgetValues(node.data?.widgetValues, definition), + runtimeValues: {}, definition, previewImage: null, tableRows: null, @@ -104,13 +56,7 @@ export function hydrateWorkflowState(data, defs = {}) { }; }); - const nodeById = new Map(nodes.map((node) => [String(node.id), node.data])); - - const edges = loadedEdges.map((edge) => ({ - ...edge, - sourceHandle: remapLegacyHandle(edge.sourceHandle, 'source', nodeById.get(String(edge.source))), - targetHandle: remapLegacyHandle(edge.targetHandle, 'target', nodeById.get(String(edge.target))), - })); + const edges = loadedEdges.map((edge) => ({ ...edge })); const nextNodeId = Math.max(0, ...loadedNodes.map((node) => parseInt(node.id, 10) || 0)) + 1; diff --git a/frontend/tests/nodeClipboard.test.mjs b/frontend/tests/nodeClipboard.test.mjs new file mode 100644 index 0000000..7ea25b0 --- /dev/null +++ b/frontend/tests/nodeClipboard.test.mjs @@ -0,0 +1,179 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { + buildNodeClipboardPayload, + instantiateNodeClipboardPayload, + NODE_CLIPBOARD_KIND, + parseNodeClipboardPayload, +} from '../src/nodeClipboard.js'; + +test('buildNodeClipboardPayload keeps only selected nodes and internal edges', () => { + const nodes = [ + { + id: '1', + selected: true, + type: 'custom', + position: { x: 10, y: 20 }, + data: { + label: 'Image', + className: 'Image', + widgetValues: { filename: 'scan.ibw' }, + runtimeValues: { layerIndex: 2 }, + }, + }, + { + id: '2', + selected: true, + position: { x: 100, y: 200 }, + data: { + label: 'Preview', + className: 'Preview', + widgetValues: { mode: 'auto' }, + }, + }, + { + id: '3', + selected: false, + position: { x: 500, y: 600 }, + data: { + label: 'Save', + className: 'Save', + }, + }, + ]; + + const edges = [ + { + id: 'e1-2', + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '2', + targetHandle: 'input::field::DATA_FIELD', + style: { stroke: '#fff', strokeWidth: 2 }, + }, + { + id: 'e2-3', + source: '2', + sourceHandle: 'output::0::IMAGE', + target: '3', + targetHandle: 'input::value::SAVE_VALUE', + }, + ]; + + const payload = buildNodeClipboardPayload(nodes, edges); + + assert.equal(payload.kind, NODE_CLIPBOARD_KIND); + assert.equal(payload.nodes.length, 2); + assert.deepEqual(payload.nodes.map((node) => node.id), ['1', '2']); + assert.equal(payload.edges.length, 1); + assert.deepEqual(payload.edges[0], { + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '2', + targetHandle: 'input::field::DATA_FIELD', + style: { stroke: '#fff', strokeWidth: 2 }, + }); + + const reparsed = parseNodeClipboardPayload(JSON.stringify(payload)); + assert.deepEqual(reparsed, payload); +}); + +test('instantiateNodeClipboardPayload remaps ids, offsets positions, and hydrates node shells', () => { + const payload = { + kind: NODE_CLIPBOARD_KIND, + version: 1, + nodes: [ + { + id: '1', + position: { x: 10, y: 20 }, + data: { + label: 'Image', + className: 'Image', + widgetValues: { filename: 'scan.ibw', colormap: 'viridis' }, + runtimeValues: { layerIndex: 1 }, + }, + }, + { + id: '2', + position: { x: 100, y: 200 }, + data: { + label: 'Preview', + className: 'Preview', + widgetValues: { colormap: 'gray' }, + }, + }, + ], + edges: [ + { + source: '1', + sourceHandle: 'output::0::DATA_FIELD', + target: '2', + targetHandle: 'input::field::DATA_FIELD', + style: { stroke: '#abc', strokeWidth: 2 }, + }, + ], + }; + + const defs = { + Image: { output: ['DATA_FIELD'], output_name: ['field'] }, + Preview: { output: ['IMAGE'], output_name: ['preview'] }, + }; + + const instantiated = instantiateNodeClipboardPayload(payload, defs, 12, { x: 32, y: 48 }); + + assert.equal(instantiated.nextNodeId, 14); + assert.deepEqual(instantiated.nodes.map((node) => node.id), ['12', '13']); + assert.deepEqual(instantiated.nodes.map((node) => node.position), [ + { x: 42, y: 68 }, + { x: 132, y: 248 }, + ]); + assert.equal(instantiated.nodes[0].selected, true); + assert.deepEqual(instantiated.nodes[0].data.widgetValues, { filename: 'scan.ibw', colormap: 'viridis' }); + assert.deepEqual(instantiated.nodes[0].data.runtimeValues, { layerIndex: 1 }); + assert.equal(instantiated.nodes[0].data.previewImage, null); + assert.deepEqual(instantiated.nodes[0].data.definition, defs.Image); + + assert.deepEqual(instantiated.edges, [ + { + id: 'e12-13-0', + source: '12', + sourceHandle: 'output::0::DATA_FIELD', + target: '13', + targetHandle: 'input::field::DATA_FIELD', + selected: false, + style: { stroke: '#abc', strokeWidth: 2 }, + }, + ]); +}); + +test('clipboard payload deep-copies local widget and runtime fields', () => { + const nodes = [ + { + id: '9', + selected: true, + position: { x: 0, y: 0 }, + data: { + label: 'Markup', + className: 'Markup', + widgetValues: { + stroke_width: 3, + markup_shapes: [ + { kind: 'line', points: [0.1, 0.2, 0.3, 0.4] }, + ], + }, + runtimeValues: { + camera: { azimuth: 15, polar: 60 }, + }, + }, + }, + ]; + + const payload = buildNodeClipboardPayload(nodes, []); + + nodes[0].data.widgetValues.markup_shapes[0].points[0] = 0.9; + nodes[0].data.runtimeValues.camera.azimuth = 90; + + assert.equal(payload.nodes[0].data.widgetValues.markup_shapes[0].points[0], 0.1); + assert.equal(payload.nodes[0].data.runtimeValues.camera.azimuth, 15); +}); diff --git a/frontend/tests/pngMetadata.test.mjs b/frontend/tests/pngMetadata.test.mjs index 16831fe..cad6065 100644 --- a/frontend/tests/pngMetadata.test.mjs +++ b/frontend/tests/pngMetadata.test.mjs @@ -9,63 +9,6 @@ function makePngBlob() { return new Blob([Buffer.from(PNG_BASE64, 'base64')], { type: 'image/png' }); } -function crc32(bytes) { - const table = new Uint32Array(256); - for (let i = 0; i < 256; i++) { - let c = i; - for (let j = 0; j < 8; j++) { - c = (c & 1) ? (0xEDB88320 ^ (c >>> 1)) : (c >>> 1); - } - table[i] = c; - } - - let crc = 0xFFFFFFFF; - for (let i = 0; i < bytes.length; i++) { - crc = table[(crc ^ bytes[i]) & 0xFF] ^ (crc >>> 8); - } - return (crc ^ 0xFFFFFFFF) >>> 0; -} - -function buildChunk(type, payload) { - const typeBytes = new TextEncoder().encode(type); - const crcInput = new Uint8Array(4 + payload.length); - crcInput.set(typeBytes, 0); - crcInput.set(payload, 4); - - const chunk = new Uint8Array(12 + payload.length); - const view = new DataView(chunk.buffer); - view.setUint32(0, payload.length); - chunk.set(typeBytes, 4); - chunk.set(payload, 8); - view.setUint32(8 + payload.length, crc32(crcInput)); - return chunk; -} - -async function insertTextChunk(blob, workflow) { - const png = new Uint8Array(await blob.arrayBuffer()); - const encoder = new TextEncoder(); - const key = encoder.encode('workflow'); - const text = encoder.encode(JSON.stringify(workflow)); - const payload = new Uint8Array(key.length + 1 + text.length); - payload.set(key, 0); - payload.set(text, key.length + 1); - const chunk = buildChunk('tEXt', payload); - - let pos = 8; - while (pos < png.length) { - const len = new DataView(png.buffer, pos, 4).getUint32(0); - const type = String.fromCharCode(png[pos + 4], png[pos + 5], png[pos + 6], png[pos + 7]); - if (type === 'IEND') break; - pos += 12 + len; - } - - const out = new Uint8Array(png.length + chunk.length); - out.set(png.subarray(0, pos), 0); - out.set(chunk, pos); - out.set(png.subarray(pos), pos + chunk.length); - return new Blob([out], { type: 'image/png' }); -} - test('embedWorkflow roundtrips workflow data through an iTXt chunk', async () => { const workflow = { version: 1, @@ -88,15 +31,6 @@ test('embedWorkflow roundtrips workflow data through an iTXt chunk', async () => assert.match(Buffer.from(bytes).toString('latin1'), /iTXt/); }); -test('extractWorkflow still supports legacy tEXt metadata chunks', async () => { - const workflow = { version: 1, legacy: true, nodes: [], edges: [] }; - const legacyBlob = await insertTextChunk(makePngBlob(), workflow); - - const extracted = await extractWorkflow(legacyBlob); - - assert.deepEqual(extracted, workflow); -}); - test('extractWorkflow returns the last workflow chunk when an image is re-saved', async () => { const first = { version: 1, name: 'old', nodes: [], edges: [] }; const second = { version: 1, name: 'new', nodes: [], edges: [] }; diff --git a/frontend/tests/workflowSerialization.test.mjs b/frontend/tests/workflowSerialization.test.mjs index 0e0e154..0bf80e4 100644 --- a/frontend/tests/workflowSerialization.test.mjs +++ b/frontend/tests/workflowSerialization.test.mjs @@ -95,7 +95,7 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload assert.equal('selected' in serialized.edges[0], false); }); -test('hydrateWorkflowState clears shared path widgets while restoring saved dynamic outputs', () => { +test('hydrateWorkflowState clears shared path widgets and uses registry definitions', () => { const saved = { version: 1, nodes: [ @@ -142,12 +142,12 @@ test('hydrateWorkflowState clears shared path widgets while restoring saved dyna assert.equal(hydrated.nodes[0].data.previewImage, null); assert.equal(hydrated.nodes[0].data.widgetValues.filename, ''); assert.equal(hydrated.nodes[0].data.widgetValues.colormap, 'viridis'); - assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD']); - assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Height', 'Phase']); + assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD']); + assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['field']); assert.deepEqual(hydrated.nodes[0].data.definition.input, defs.Image.input); }); -test('serializeWorkflowState and hydrateWorkflowState clear path-like widgets but preserve other metadata', () => { +test('serializeWorkflowState and hydrateWorkflowState clear path-like widgets without restoring saved outputs', () => { const nodes = [ { id: '7', @@ -188,8 +188,8 @@ test('serializeWorkflowState and hydrateWorkflowState clear path-like widgets bu const hydrated = hydrateWorkflowState(serialized, defs); assert.deepEqual(hydrated.nodes[0].data.widgetValues, { filename: '', colormap: 'gray' }); - assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD']); - assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Topography', 'Error', 'Mask']); + assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD']); + assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['field']); assert.deepEqual(hydrated.edges, edges); }); @@ -223,6 +223,6 @@ test('hydrateWorkflowState clears saved folder selections on shared workflows', const hydrated = hydrateWorkflowState(saved, defs); assert.equal(hydrated.nodes[0].data.widgetValues.folder, ''); - assert.deepEqual(hydrated.nodes[0].data.definition.output, ['PATH', 'PATH']); - assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['scan1.png', 'scan2.png']); + assert.deepEqual(hydrated.nodes[0].data.definition.output, ['PATH']); + assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['path']); }); diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 0c8b6bd..114a17e 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1107,9 +1107,13 @@ def test_annotations(): print("=== Test: Annotations ===") from backend.nodes.annotations import Annotations from backend.nodes.font_node import Font + from backend.data_types import ImageData node = Annotations() font_node = Font() + warnings = [] + Annotations._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) + Annotations._current_node_id = "test" field = DataField( data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64), xreal=1e-6, @@ -1123,7 +1127,7 @@ def test_annotations(): plain_preview = render_datafield_preview(field, "viridis") assert np.array_equal(plain_preview, base) - plain_field, = node.render(field, colormap="auto", show_scale_bar=False, show_color_map=False) + plain_field, = node.render(input=field, colormap="auto", show_scale_bar=False, show_color_map=False) assert isinstance(plain_field, DataField) assert np.array_equal(plain_field.data, field.data) assert plain_field.colormap == "viridis" @@ -1132,19 +1136,19 @@ def test_annotations(): assert plain.shape == base.shape assert np.array_equal(plain, base) - with_scale_field, = node.render(field, colormap="auto", show_scale_bar=True, show_color_map=False) + with_scale_field, = node.render(input=field, colormap="auto", show_scale_bar=True, show_color_map=False) with_scale = render_datafield_preview(with_scale_field, with_scale_field.colormap) assert with_scale.shape == base.shape assert not np.array_equal(with_scale, base) - with_legend_field, = node.render(field, colormap="auto", show_scale_bar=False, show_color_map=True) + with_legend_field, = node.render(input=field, colormap="auto", show_scale_bar=False, show_color_map=True) with_legend = render_datafield_preview(with_legend_field, with_legend_field.colormap) assert with_legend.shape[0] == base.shape[0] assert with_legend.shape[1] > base.shape[1] assert with_legend.shape[2] == 3 larger_legend_field, = node.render( - field, + input=field, colormap="auto", show_scale_bar=False, show_color_map=True, @@ -1156,7 +1160,7 @@ def test_annotations(): annotation_font, = font_node.build("Arial") with_font_field, = node.render( - field, + input=field, colormap="auto", show_scale_bar=False, show_color_map=True, @@ -1167,18 +1171,62 @@ def test_annotations(): with_font = render_datafield_preview(with_font_field, with_font_field.colormap) assert with_font.shape == with_legend.shape - with_both_field, = node.render(field, colormap="auto", show_scale_bar=True, show_color_map=True) + with_both_field, = node.render(input=field, colormap="auto", show_scale_bar=True, show_color_map=True) with_both = render_datafield_preview(with_both_field, with_both_field.colormap) assert with_both.shape == with_legend.shape assert not np.array_equal(with_both[:, :base.shape[1]], base) + viewport_image = ImageData( + np.zeros((48, 64, 3), dtype=np.uint8), + metadata={ + "annotation_context": { + "xreal": 2e-6, + "si_unit_xy": "m", + "legend_min": -1.5, + "legend_mid": 0.0, + "legend_max": 1.5, + "legend_unit": "V", + "colormap": "viridis", + }, + }, + ) + annotated_image, = node.render( + input=viewport_image, + colormap="auto", + show_scale_bar=True, + show_color_map=True, + text_size=18.0, + ) + assert isinstance(annotated_image, ImageData) + assert annotated_image.shape[0] == viewport_image.shape[0] + assert annotated_image.shape[1] > viewport_image.shape[1] + assert annotated_image.metadata["annotation_context"]["legend_unit"] == "V" + assert not np.array_equal(np.asarray(annotated_image)[:, :viewport_image.shape[1]], np.asarray(viewport_image)) + assert warnings == [] + + plain_image = ImageData(np.zeros((32, 40, 3), dtype=np.uint8)) + passthrough_image, = node.render( + input=plain_image, + colormap="auto", + show_scale_bar=True, + show_color_map=True, + text_size=18.0, + ) + assert isinstance(passthrough_image, ImageData) + assert passthrough_image.shape == plain_image.shape + assert np.array_equal(np.asarray(passthrough_image), np.asarray(plain_image)) + assert len(warnings) == 1 + assert "no scale metadata" in warnings[0] + + Annotations._broadcast_warning_fn = None + print(" PASS\n") def test_markup(): print("=== Test: Markup ===") from backend.nodes.markup import Markup - from backend.data_types import _preview_markup_stroke_width + from backend.data_types import ImageData, _preview_markup_stroke_width node = Markup() field = make_field(data=np.linspace(0.0, 1.0, 48 * 48, dtype=np.float64).reshape(48, 48)) @@ -1192,7 +1240,7 @@ def test_markup(): Markup._current_node_id = "test" plain_field, = node.process( - field=field, + input=field, shape="line", stroke_color="#ffd54f", stroke_width=3, @@ -1212,7 +1260,7 @@ def test_markup(): {"kind": "arrow", "x1": 0.15, "y1": 0.85, "x2": 0.85, "y2": 0.2, "width": 4, "color": "#ffffff"}, ]) marked_field, = node.process( - field=field, + input=field, shape="arrow", stroke_color="#ffffff", stroke_width=4, @@ -1222,6 +1270,23 @@ def test_markup(): assert marked.shape == base.shape assert not np.array_equal(marked, base) + viewport_image = ImageData( + np.zeros((48, 48, 3), dtype=np.uint8), + metadata={"annotation_context": {"xreal": 1e-6, "si_unit_xy": "m"}}, + ) + image_markup, = node.process( + input=viewport_image, + shape="line", + stroke_color="#ff0000", + stroke_width=4, + markup_shapes=json.dumps([ + {"kind": "line", "x1": 0.1, "y1": 0.2, "x2": 0.9, "y2": 0.8, "width": 4, "color": "#ff0000"}, + ]), + ) + assert isinstance(image_markup, ImageData) + assert image_markup.metadata["annotation_context"]["si_unit_xy"] == "m" + assert not np.array_equal(np.asarray(image_markup), np.asarray(viewport_image)) + Markup._broadcast_overlay_fn = None print(" PASS\n") @@ -1326,6 +1391,36 @@ def test_load_file_npz(): print(" PASS\n") +def test_load_file_cache(): + print("=== Test: Image cache ===") + from unittest.mock import patch + from backend.nodes.image import Image + + node = Image() + Image._load_fields_cached.cache_clear() + + with tempfile.TemporaryDirectory() as tmpdir: + data = np.arange(16, dtype=np.float64).reshape(4, 4) + path = os.path.join(tmpdir, "cached.npy") + np.save(path, data) + + with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: + first, = node.load(filename=path) + second, = node.load(filename=path) + assert loader.call_count == 1 + + assert np.allclose(first.data, data) + assert np.allclose(second.data, data) + assert first is not second + first.data[0, 0] = -999.0 + + third, = node.load(filename=path) + assert third.data[0, 0] == data[0, 0] + + Image._load_fields_cached.cache_clear() + print(" PASS\n") + + def test_load_file_not_found(): print("=== Test: Image not found ===") from backend.nodes.image import Image @@ -1487,6 +1582,31 @@ def test_load_demo(): print(" PASS\n") +def test_load_demo_cache(): + print("=== Test: ImageDemo cache ===") + from unittest.mock import patch + from backend.nodes.image import Image + from backend.nodes.image_demo import ImageDemo + + node = ImageDemo() + Image._load_fields_cached.cache_clear() + + with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: + first, = node.load(name="nanoparticles.npy") + second, = node.load(name="nanoparticles.npy") + assert loader.call_count == 1 + + assert np.allclose(first.data, second.data) + assert first is not second + first.data[0, 0] = -999.0 + + third, = node.load(name="nanoparticles.npy") + assert third.data[0, 0] != -999.0 + + Image._load_fields_cached.cache_clear() + print(" PASS\n") + + def test_load_demo_multi_layer_preview_payload(): print("=== Test: ImageDemo multi-layer preview payload ===") from backend.execution import ExecutionEngine @@ -1849,6 +1969,10 @@ def test_stats(): def test_view3d(): print("=== Test: View3D ===") from backend.nodes.view_3d import View3D + from backend.data_types import ImageData, MeshModel + import base64 + import io + from PIL import Image node = View3D() field = make_field() @@ -1857,8 +1981,25 @@ def test_view3d(): View3D._broadcast_mesh_fn = lambda nid, mesh: captured.append(mesh) View3D._current_node_id = "test" - result = node.render(field, colormap="viridis", z_scale=2.0, resolution=64) - assert result == () + preview_image = Image.new("RGB", (12, 10), (255, 0, 0)) + preview_buffer = io.BytesIO() + preview_image.save(preview_buffer, format="PNG") + viewport_snapshot = "data:image/png;base64," + base64.b64encode(preview_buffer.getvalue()).decode() + + result = node.render( + field, + colormap="viridis", + z_scale=2.0, + resolution=64, + make_solid=False, + viewport_snapshot=viewport_snapshot, + ) + assert len(result) == 2 + assert isinstance(result[0], MeshModel) + assert isinstance(result[1], ImageData) + assert result[1].shape == (10, 12, 3) + assert np.all(result[1][0, 0] == np.array([255, 0, 0], dtype=np.uint8)) + assert result[1].metadata["annotation_context"]["si_unit_xy"] == field.si_unit_xy assert len(captured) == 1 mesh = captured[0] @@ -1873,7 +2014,6 @@ def test_view3d(): assert mesh["z_min"] < mesh["z_max"] # Verify base64 data can be decoded - import base64 z_bytes = base64.b64decode(mesh["z_data"]) assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 # float32 @@ -1883,14 +2023,178 @@ def test_view3d(): # High-res input should be downsampled big_field = make_field(shape=(256, 256)) captured.clear() - node.render(big_field, colormap="hot", z_scale=1.0, resolution=64) + node.render(big_field, colormap="hot", z_scale=1.0, resolution=64, make_solid=False) assert captured[0]["width"] <= 64 assert captured[0]["height"] <= 64 + # Separate map input should affect colors without changing mesh geometry + mesh_field = make_field(data=np.zeros((64, 64), dtype=np.float64), xreal=2.0, yreal=3.0) + map_field = make_field(data=np.tile(np.linspace(0.0, 1.0, 64, dtype=np.float64), (64, 1)), xreal=2.0, yreal=3.0) + captured.clear() + mapped_result = node.render(mesh_field, map_field=map_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) + mapped_mesh = captured[0] + assert mapped_mesh["x_range"] == [float(mesh_field.xoff), float(mesh_field.xoff + mesh_field.xreal)] + assert mapped_mesh["y_range"] == [float(mesh_field.yoff), float(mesh_field.yoff + mesh_field.yreal)] + mapped_z = np.frombuffer(base64.b64decode(mapped_mesh["z_data"]), dtype=np.float32) + assert np.allclose(mapped_z, 0.0) + mapped_colors = np.frombuffer(base64.b64decode(mapped_mesh["colors"]), dtype=np.uint8) + + captured.clear() + node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) + mesh_only = captured[0] + mesh_only_colors = np.frombuffer(base64.b64decode(mesh_only["colors"]), dtype=np.uint8) + assert not np.array_equal(mapped_colors, mesh_only_colors) + + # make_solid should add extra geometry beyond the top surface grid + solid_mesh = mapped_result[0] + assert isinstance(solid_mesh, MeshModel) + captured.clear() + solid_result = node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=16, make_solid=True) + assert len(solid_result[0].vertices) > 16 * 16 + assert len(solid_result[0].faces) > (15 * 15 * 2) + solid_payload = captured[0] + assert solid_payload["make_solid"] is True + assert "positions" in solid_payload + assert "indices" in solid_payload + assert "vertex_colors" in solid_payload + View3D._broadcast_mesh_fn = None print(" PASS\n") +def test_save_generic(): + print("=== Test: Save ===") + from backend.nodes.save import Save + from backend.data_types import DataField, LineData, MeasureTable, MeshModel, RecordTable + import tifffile + from PIL import Image as PILImage + + node = Save() + + with tempfile.TemporaryDirectory() as tmpdir: + # Save scalar as TXT and JSON + node.save(filename="scalar", directory_path=tmpdir, format="TXT", value=3.5) + assert Path(tmpdir, "scalar.txt").read_text(encoding="utf-8").strip() == "3.5" + node.save(filename="scalar_json", directory_path=tmpdir, format="JSON", value=3.5) + assert json.loads(Path(tmpdir, "scalar_json.json").read_text(encoding="utf-8")) == {"value": 3.5} + + # Save line as CSV, NPZ, and JSON + line = LineData(data=np.array([1.0, 2.0, 3.0]), x_axis=np.array([0.0, 0.5, 1.0]), x_unit="um", y_unit="nm") + node.save(filename="profile", directory_path=tmpdir, format="CSV", value=line) + csv_text = Path(tmpdir, "profile.csv").read_text(encoding="utf-8") + assert "x,y,x_unit,y_unit" in csv_text + assert "um" in csv_text and "nm" in csv_text + node.save(filename="profile_npz", directory_path=tmpdir, format="NPZ", value=line) + line_npz = np.load(Path(tmpdir, "profile_npz.npz")) + assert np.allclose(line_npz["x"], line.x_axis) + assert np.allclose(line_npz["y"], line.data) + node.save(filename="profile_json", directory_path=tmpdir, format="JSON", value=line) + line_json = json.loads(Path(tmpdir, "profile_json.json").read_text(encoding="utf-8")) + assert line_json["x_unit"] == "um" + assert line_json["y_unit"] == "nm" + assert line_json["x"] == [0.0, 0.5, 1.0] + assert line_json["y"] == [1.0, 2.0, 3.0] + + # Save DATA_FIELD as TIFF, PNG, and NPZ + field = DataField( + data=np.array([[1.0, 2.0], [3.0, 4.5]], dtype=np.float64), + xreal=2e-6, + yreal=1e-6, + si_unit_xy="m", + si_unit_z="m", + colormap="viridis", + ) + node.save(filename="field_tiff", directory_path=tmpdir, format="TIFF", value=field) + field_tiff = tifffile.imread(Path(tmpdir, "field_tiff.tiff")) + assert field_tiff.shape == field.data.shape + assert field_tiff.dtype == np.float32 + assert np.allclose(field_tiff, field.data.astype(np.float32)) + + node.save(filename="field_png", directory_path=tmpdir, format="PNG", value=field) + field_png = np.asarray(PILImage.open(Path(tmpdir, "field_png.png"))) + assert field_png.shape == (2, 2, 3) + assert field_png.dtype == np.uint8 + + node.save(filename="field_npz", directory_path=tmpdir, format="NPZ", value=field) + field_npz = np.load(Path(tmpdir, "field_npz.npz")) + assert np.allclose(field_npz["field"], field.data) + + # Save IMAGE as PNG, TIFF, and NPZ + image = np.array( + [ + [[255, 0, 0], [0, 255, 0]], + [[0, 0, 255], [255, 255, 0]], + ], + dtype=np.uint8, + ) + node.save(filename="image_png", directory_path=tmpdir, format="PNG", value=image) + image_png = np.asarray(PILImage.open(Path(tmpdir, "image_png.png"))) + assert image_png.shape == image.shape + assert np.array_equal(image_png, image) + + node.save(filename="image_tiff", directory_path=tmpdir, format="TIFF", value=image) + image_tiff = tifffile.imread(Path(tmpdir, "image_tiff.tiff")) + assert image_tiff.shape == image.shape + assert image_tiff.dtype == np.uint8 + assert np.array_equal(image_tiff, image) + + node.save(filename="image_npz", directory_path=tmpdir, format="NPZ", value=image) + image_npz = np.load(Path(tmpdir, "image_npz.npz")) + assert np.array_equal(image_npz["image"], image) + + # Save tables as CSV and JSON + measure_table = MeasureTable([ + {"quantity": "Rq", "value": 1.23, "unit": "nm"}, + {"quantity": "Ra", "value": 0.98, "unit": "nm"}, + ]) + node.save(filename="measurements_csv", directory_path=tmpdir, format="CSV", value=measure_table) + measure_csv = Path(tmpdir, "measurements_csv.csv").read_text(encoding="utf-8") + assert "quantity,value,unit" in measure_csv + assert "Rq,1.23,nm" in measure_csv + node.save(filename="measurements_json", directory_path=tmpdir, format="JSON", value=measure_table) + assert json.loads(Path(tmpdir, "measurements_json.json").read_text(encoding="utf-8")) == list(measure_table) + + record_table = RecordTable([ + {"label": "particle-1", "height": 12.0, "area": 44.0}, + {"label": "particle-2", "height": 8.0, "area": 21.0}, + ]) + node.save(filename="records_csv", directory_path=tmpdir, format="CSV", value=record_table) + record_csv = Path(tmpdir, "records_csv.csv").read_text(encoding="utf-8") + assert "label,height,area" in record_csv + assert "particle-1,12.0,44.0" in record_csv + node.save(filename="records_json", directory_path=tmpdir, format="JSON", value=record_table) + assert json.loads(Path(tmpdir, "records_json.json").read_text(encoding="utf-8")) == list(record_table) + + # Save mesh as OBJ and STL + mesh = MeshModel( + vertices=np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=np.float32), + faces=np.array([[0, 1, 2]], dtype=np.int32), + ) + node.save(filename="triangle", directory_path=tmpdir, format="OBJ", value=mesh) + obj_text = Path(tmpdir, "triangle.obj").read_text(encoding="utf-8") + assert "v 0.0 0.0 0.0" in obj_text + assert "f 1 2 3" in obj_text + + node.save(filename="triangle", directory_path=tmpdir, format="STL", value=mesh) + stl_text = Path(tmpdir, "triangle.stl").read_text(encoding="utf-8") + assert stl_text.startswith("solid argonode") + assert "facet normal" in stl_text + + try: + node.save(filename="triangle", directory_path=tmpdir, format="PNG", value=mesh) + assert False, "Mesh should only be saveable as OBJ or STL" + except ValueError: + pass + + try: + node.save(filename="field_bad", directory_path=tmpdir, format="CSV", value=field) + assert False, "DATA_FIELD should reject unsupported save formats" + except ValueError: + pass + + print(" PASS\n") + + # ========================================================================= # Run all tests # ========================================================================= @@ -1940,6 +2244,7 @@ if __name__ == "__main__": test_load_demo() test_coordinate() test_range_slider() + test_save_generic() test_save_image() # Display diff --git a/tests/test_numpy_compat.py b/tests/test_numpy_compat.py deleted file mode 100644 index a13faa1..0000000 --- a/tests/test_numpy_compat.py +++ /dev/null @@ -1,8 +0,0 @@ -import backend # noqa: F401 -import numpy as np - - -def test_numpy_compat_aliases_are_available_after_backend_import(): - assert np.complex is complex - assert np.float is float - assert np.int is int