diff --git a/backend/execution.py b/backend/execution.py index 223b136..fcc331b 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -218,11 +218,25 @@ class ExecutionEngine: on_warning: Callable | None = None, ) -> None: """Wire up broadcast callbacks on display node classes.""" - from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup - from backend.nodes.analysis import CrossSection, Cursors, Stats, Histogram - from backend.nodes.modify import CropResizeField, RotateField - from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask - from backend.nodes.io import SaveImage, Image, ImageDemo + from backend.nodes.preview_image import PreviewImage + from backend.nodes.print_table import PrintTable + from backend.nodes.view_3d import View3D + from backend.nodes.value_display import ValueDisplay + from backend.nodes.markup import Markup + from backend.nodes.cross_section import CrossSection + from backend.nodes.cursors import Cursors + from backend.nodes.stats import Stats + from backend.nodes.histogram import Histogram + from backend.nodes.crop_resize_field import CropResizeField + from backend.nodes.rotate_field import RotateField + from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_morphology import MaskMorphology + from backend.nodes.mask_invert import MaskInvert + from backend.nodes.mask_combine import MaskCombine + from backend.nodes.draw_mask import DrawMask + from backend.nodes.save_image import SaveImage + from backend.nodes.image import Image + from backend.nodes.image_demo import ImageDemo PreviewImage._broadcast_fn = on_preview ThresholdMask._broadcast_fn = on_preview @@ -246,11 +260,25 @@ class ExecutionEngine: def _set_node_id_on_display(self, cls: type, node_id: str) -> None: """Inform display nodes of their current node_id for WS tagging.""" - from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup - from backend.nodes.analysis import CrossSection, Cursors, Stats, Histogram - from backend.nodes.modify import CropResizeField, RotateField - from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask - from backend.nodes.io import Image, ImageDemo, SaveImage + from backend.nodes.preview_image import PreviewImage + from backend.nodes.print_table import PrintTable + from backend.nodes.view_3d import View3D + from backend.nodes.value_display import ValueDisplay + from backend.nodes.markup import Markup + from backend.nodes.cross_section import CrossSection + from backend.nodes.cursors import Cursors + from backend.nodes.stats import Stats + from backend.nodes.histogram import Histogram + from backend.nodes.crop_resize_field import CropResizeField + from backend.nodes.rotate_field import RotateField + from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_morphology import MaskMorphology + from backend.nodes.mask_invert import MaskInvert + from backend.nodes.mask_combine import MaskCombine + from backend.nodes.draw_mask import DrawMask + from backend.nodes.image import Image + from backend.nodes.image_demo import ImageDemo + from backend.nodes.save_image import SaveImage if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup, ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask, Image, ImageDemo, SaveImage): @@ -274,7 +302,8 @@ class ExecutionEngine: from backend.data_types import ( DataField, LineData, image_to_uint8, encode_preview, render_datafield_preview, ) - from backend.nodes.io import Image, ImageDemo + from backend.nodes.image import Image + from backend.nodes.image_demo import ImageDemo if getattr(cls, "_CUSTOM_PREVIEW", False): return @@ -318,7 +347,7 @@ class ExecutionEngine: inputs: dict[str, Any], ) -> dict | None: from backend.data_types import DataField, encode_preview, render_datafield_preview - from backend.nodes.io import list_channels + from backend.nodes.helpers import list_channels fields = [value for value in result if isinstance(value, DataField)] if not fields: diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index d2fc2ff..1b88dcf 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -1,7 +1,54 @@ # Import all node modules to trigger @register_node decorators. -from . import io, filters, modify, level, analysis, mask, display +from backend.nodes import ( + # IO + image, + image_demo, + folder, + coordinate, + coordinate_pair, + number, + range_slider, + save_image, + # Filters + gaussian_filter, + median_filter, + edge_detect, + fft_filter_1d, + fft_filter_2d, + # Modify + colormap_adjust, + crop_resize_field, + rotate_field, + # Level + plane_level_field, + poly_level_field, + fix_zero, + # Mask + draw_mask, + threshold_mask, + mask_morphology, + mask_invert, + mask_combine, + # Display + color_map, + font_node, + annotations, + markup, + preview_image, + view_3d, + print_table, + value_display, + # Analysis + statistics_node, + histogram, + cursors, + fft_2d, + inverse_fft_2d, + cross_section, + stats, +) try: - from . import particle + from backend.nodes import particle_analysis except ImportError: - from . import particless + pass diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py deleted file mode 100644 index 839592f..0000000 --- a/backend/nodes/analysis.py +++ /dev/null @@ -1,1091 +0,0 @@ -""" -Analysis nodes — statistics, histograms, FFT, cross sections. - -Gwyddion equivalents: - Statistics → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h) - Histogram → DH (height distribution), gwy_data_field_dh - FFT2D → gwy_data_field_2dfft + gwy_data_field_2dpsdf - CrossSection → gwy_data_field_get_profile (libprocess/datafield.c) -""" - -from __future__ import annotations -import numpy as np -from typing import Callable -from backend.node_registry import register_node -from backend.data_types import DataField, LineData, MeasureTable, RecordTable, datafield_to_uint8, encode_preview, render_datafield_preview -from backend.nodes.io import Coordinate, CoordinatePair - - -# --------------------------------------------------------------------------- -# Statistics -# --------------------------------------------------------------------------- - -@register_node(display_name="Statistics") -class Statistics: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - } - } - - RETURN_TYPES = ("MEASURE_TABLE",) - RETURN_NAMES = ("stats",) - FUNCTION = "process" - - DESCRIPTION = ( - "Compute basic surface statistics: min, max, mean, RMS roughness, median, " - "and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms." - ) - - def process(self, field: DataField) -> tuple: - d = field.data - mean = float(d.mean()) - rms = float(np.sqrt(np.mean((d - mean) ** 2))) - skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0 - kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0 - - table = MeasureTable([ - {"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z}, - {"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z}, - {"quantity": "mean", "value": mean, "unit": field.si_unit_z}, - {"quantity": "RMS", "value": rms, "unit": field.si_unit_z}, - {"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z}, - {"quantity": "skewness", "value": skewness, "unit": ""}, - {"quantity": "kurtosis", "value": kurtosis, "unit": ""}, - {"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z}, - ]) - return (table,) - - -# --------------------------------------------------------------------------- -# Histogram -# --------------------------------------------------------------------------- - -@register_node(display_name="Histogram") -class Histogram: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}), - "y_scale": (["linear", "log"],), - "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - } - } - - RETURN_TYPES = ("MEASURE_TABLE", "COORDPAIR",) - RETURN_NAMES = ("measurements", "marker pair",) - FUNCTION = "process" - - DESCRIPTION = ( - "Compute the height distribution histogram (DH). " - "Use log scale to reveal small peaks next to a dominant background. " - "Outputs marker measurements while showing the histogram interactively in-node. " - "Equivalent to gwy_data_field_dh." - ) - - _broadcast_overlay_fn = None - _current_node_id: str = "" - - def process( - self, - field: DataField, - n_bins: int, - y_scale: str = "linear", - x1: float = 0.25, - y1: float = 0.5, - x2: float = 0.75, - y2: float = 0.5, - ) -> tuple: - raw_counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins)) - bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) - counts = raw_counts.astype(np.float64) - if y_scale == "log": - counts = np.log10(1.0 + counts) - - x1 = float(np.clip(x1, 0.0, 1.0)) - x2 = float(np.clip(x2, 0.0, 0.0 + 1.0)) - - xmin = float(np.min(bin_centers)) if len(bin_centers) else 0.0 - xmax = float(np.max(bin_centers)) if len(bin_centers) else 1.0 - - def x_frac_to_idx(frac): - if len(bin_centers) <= 1: - return 0 - if xmax == xmin: - return 0 - target_x = xmin + frac * (xmax - xmin) - return int(np.argmin(np.abs(bin_centers - target_x))) - - idx_a = x_frac_to_idx(x1) - idx_b = x_frac_to_idx(x2) - xa = float(bin_centers[idx_a]) if len(bin_centers) else 0.0 - xb = float(bin_centers[idx_b]) if len(bin_centers) else 0.0 - ya = float(counts[idx_a]) if len(counts) else 0.0 - yb = float(counts[idx_b]) if len(counts) else 0.0 - count_unit = "count" if y_scale == "linear" else "log10(1+count)" - - if Histogram._broadcast_overlay_fn is not None: - Histogram._broadcast_overlay_fn( - Histogram._current_node_id, - { - "kind": "line_plot", - "section_title": "Histogram", - "line": counts.tolist(), - "x_axis": bin_centers.astype(np.float64).tolist(), - "x1": float(np.clip(x1, 0.0, 1.0)), - "x2": float(np.clip(x2, 0.0, 1.0)), - "y1": float(y1), - "y2": float(y2), - "a_locked": False, - "b_locked": False, - }, - ) - - table = MeasureTable([ - {"quantity": "A position", "value": xa, "unit": field.si_unit_z}, - {"quantity": "A count", "value": ya, "unit": count_unit}, - {"quantity": "B position", "value": xb, "unit": field.si_unit_z}, - {"quantity": "B count", "value": yb, "unit": count_unit}, - {"quantity": "delta X", "value": xb - xa, "unit": field.si_unit_z}, - {"quantity": "delta Y", "value": yb - ya, "unit": count_unit}, - ]) - return (table, ((x1, y1), (x2, y2))) - - -# --------------------------------------------------------------------------- -# Cursors — interactive measurement cursors on lines or fields -# --------------------------------------------------------------------------- - -@register_node(display_name="Cursors") -class Cursors: - """Place two draggable cursors on a line plot or field to measure deltas.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "line": ("CURSOR_SOURCE", {"label": "input"}), - "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - }, - "optional": { - "coord_pair": ("COORDPAIR", {"label": "coord pair"}), - }, - } - - RETURN_TYPES = ("MEASURE_TABLE","COORDPAIR",) - RETURN_NAMES = ("measurement","coord pair",) - FUNCTION = "process" - - DESCRIPTION = ( - "Place two cursors on a line plot or 2D field. " - "On lines it reports x/y positions and dx/dy. " - "On fields it reports x/y/z at both markers plus dx/dy/dz." - ) - - _broadcast_overlay_fn = None - _current_node_id: str = "" - - def process( - self, line, x1: float, y1: float, x2: float, y2: float, - coord_pair=None, - ) -> tuple: - if coord_pair is not None: - (x1, y1), (x2, y2) = coord_pair - - locked = coord_pair is not None - - if isinstance(line, DataField): - return self._process_field(line, x1=x1, y1=y1, x2=x2, y2=y2, locked=locked) - - return self._process_line(line, x1=x1, y1=y1, x2=x2, y2=y2, locked=locked) - - def _process_line( - self, - line, - x1: float, - y1: float, - x2: float, - y2: float, - locked: bool = False, - ) -> tuple: - y = np.asarray(line, dtype=np.float64).ravel() - x_unit = line.x_unit if isinstance(line, LineData) else "" - y_unit = line.y_unit if isinstance(line, LineData) else "" - n = len(y) - if isinstance(line, LineData) and line.x_axis is not None: - x = np.asarray(line.x_axis, dtype=np.float64).ravel()[:n] - else: - x = np.arange(n, dtype=np.float64) - x1 = float(np.clip(x1, 0.0, 1.0)) - x2 = float(np.clip(x2, 0.0, 1.0)) - - xmin = float(np.min(x)) if len(x) else 0.0 - xmax = float(np.max(x)) if len(x) else 1.0 - - def x_frac_to_idx(frac): - if n <= 1: - return 0 - if xmax == xmin: - return 0 - target_x = xmin + frac * (xmax - xmin) - return int(np.argmin(np.abs(x - target_x))) - - idx_a = x_frac_to_idx(x1) - idx_b = x_frac_to_idx(x2) - - xa, ya = float(x[idx_a]), float(y[idx_a]) - xb, yb = float(x[idx_b]), float(y[idx_b]) - - # --- Broadcast overlay --- - if Cursors._broadcast_overlay_fn is not None: - Cursors._broadcast_overlay_fn( - Cursors._current_node_id, - { - "kind": "line_plot", - "section_title": "Cursors", - "line": y.tolist(), - "x_axis": x.tolist(), - "x1": x1, - "x2": x2, - "y1": float(y1), - "y2": float(y2), - "a_locked": locked, - "b_locked": locked, - }, - ) - - # --- Output table --- - table = MeasureTable([ - {"quantity": "A x", "value": xa, "unit": x_unit}, - {"quantity": "A y", "value": ya, "unit": y_unit}, - {"quantity": "B x", "value": xb, "unit": x_unit}, - {"quantity": "B y", "value": yb, "unit": y_unit}, - {"quantity": "dx", "value": xb - xa, "unit": x_unit}, - {"quantity": "dy", "value": yb - ya, "unit": y_unit}, - ]) - return (table, ((x1, y1), (x2, y2))) - - def _process_field( - self, - field: DataField, - x1: float, - y1: float, - x2: float, - y2: float, - locked: bool = False, - ) -> tuple: - from scipy.ndimage import map_coordinates - - x1 = float(np.clip(x1, 0.0, 1.0)) - y1 = float(np.clip(y1, 0.0, 1.0)) - x2 = float(np.clip(x2, 0.0, 1.0)) - y2 = float(np.clip(y2, 0.0, 1.0)) - - px1 = x1 * max(field.xres - 1, 0) - py1 = y1 * max(field.yres - 1, 0) - px2 = x2 * max(field.xres - 1, 0) - py2 = y2 * max(field.yres - 1, 0) - - z1 = float(map_coordinates(field.data, [[py1], [px1]], order=1, mode="nearest")[0]) - z2 = float(map_coordinates(field.data, [[py2], [px2]], order=1, mode="nearest")[0]) - - ax = float(field.xoff + x1 * field.xreal) - ay = float(field.yoff + y1 * field.yreal) - bx = float(field.xoff + x2 * field.xreal) - by = float(field.yoff + y2 * field.yreal) - - if Cursors._broadcast_overlay_fn is not None: - Cursors._broadcast_overlay_fn( - Cursors._current_node_id, - { - "kind": "cursor_points", - "section_title": "Cursors", - "image": encode_preview(render_datafield_preview(field, field.colormap)), - "x1": x1, - "y1": y1, - "x2": x2, - "y2": y2, - "a_locked": locked, - "b_locked": locked, - }, - ) - - table = MeasureTable([ - {"quantity": "A x", "value": ax, "unit": field.si_unit_xy}, - {"quantity": "A y", "value": ay, "unit": field.si_unit_xy}, - {"quantity": "A z", "value": z1, "unit": field.si_unit_z}, - {"quantity": "B x", "value": bx, "unit": field.si_unit_xy}, - {"quantity": "B y", "value": by, "unit": field.si_unit_xy}, - {"quantity": "B z", "value": z2, "unit": field.si_unit_z}, - {"quantity": "dx", "value": bx - ax, "unit": field.si_unit_xy}, - {"quantity": "dy", "value": by - ay, "unit": field.si_unit_xy}, - {"quantity": "dz", "value": z2 - z1, "unit": field.si_unit_z}, - ]) - return (table, ((x1, y1), (x2, y2))) - - -# --------------------------------------------------------------------------- -# FFT2D -# --------------------------------------------------------------------------- - -@register_node(display_name="2D FFT") -class FFT2D: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "windowing": (["hann", "hamming", "blackman", "none"],), - "level": (["mean", "plane", "none"],), - } - } - - RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "DATA_FIELD", "DATA_FIELD") - RETURN_NAMES = ("log_magnitude", "magnitude", "phase", "psdf") - FUNCTION = "process" - - DESCRIPTION = ( - "Compute the 2D FFT with optional windowing and mean/plane subtraction. " - "Outputs log magnitude, magnitude, phase, and PSDF as separate channels. " - "Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf." - ) - - def process(self, field: DataField, windowing: str, level: str) -> tuple: - data = field.data.copy() - yres, xres = data.shape - - # Level subtraction (Gwyddion-style, before windowing) - if level == "mean": - data -= data.mean() - elif level == "plane": - # Fit and subtract a plane: z = a + b*x + c*y - yy, xx = np.mgrid[0:yres, 0:xres] - xx_f = xx.ravel().astype(np.float64) - yy_f = yy.ravel().astype(np.float64) - zz_f = data.ravel() - A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f]) - coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None) - plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy) - data -= plane - - # Windowing (Gwyddion uses (i+0.5)/n centred formulation) - if windowing != "none": - t_y = (np.arange(yres) + 0.5) / yres - t_x = (np.arange(xres) + 0.5) / xres - if windowing == "hann": - wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y) - wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x) - elif windowing == "hamming": - wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y) - wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x) - elif windowing == "blackman": - wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y) - wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x) - else: - wy = np.ones(yres) - wx = np.ones(xres) - data *= np.outer(wy, wx) - - # 2D FFT, shifted so DC is at centre - F = np.fft.fftshift(np.fft.fft2(data)) - n = xres * yres - - magnitude = np.abs(F) - log_magnitude = np.log1p(magnitude) - phase = np.angle(F) - - dx = field.xreal / xres - dy = field.yreal / yres - psdf = (magnitude ** 2) * dx * dy / (n * 4.0 * np.pi ** 2) - - spatial_freq_xreal = xres / field.xreal - spatial_freq_yreal = yres / field.yreal - angular_freq_xreal = 2.0 * np.pi * xres / field.xreal - angular_freq_yreal = 2.0 * np.pi * yres / field.yreal - - return ( - DataField( - data=log_magnitude, - xreal=spatial_freq_xreal, - yreal=spatial_freq_yreal, - si_unit_xy="1/m", - si_unit_z=field.si_unit_z, - domain="frequency", - colormap=field.colormap, - ), - DataField( - data=magnitude, - xreal=spatial_freq_xreal, - yreal=spatial_freq_yreal, - si_unit_xy="1/m", - si_unit_z=field.si_unit_z, - domain="frequency", - colormap=field.colormap, - ), - DataField( - data=phase, - xreal=spatial_freq_xreal, - yreal=spatial_freq_yreal, - si_unit_xy="1/m", - si_unit_z=field.si_unit_z, - domain="frequency", - colormap=field.colormap, - ), - DataField( - data=psdf, - xreal=angular_freq_xreal, - yreal=angular_freq_yreal, - si_unit_xy="1/m", - si_unit_z=f"({field.si_unit_z})^2 m^2", - domain="frequency", - colormap=field.colormap, - ), - ) - - if False: # Unreachable legacy block retained below. - # Log scale with floor to avoid log(0) - result = np.log1p(mag) - elif output == "magnitude": - result = np.abs(F) - elif output == "phase": - result = np.angle(F) - elif output == "psdf": - # Gwyddion-equivalent PSDF: |F|^2 * dx * dy / (n * 4π²) - dx = field.xreal / xres - dy = field.yreal / yres - result = (np.abs(F) ** 2) * dx * dy / (n * 4.0 * np.pi ** 2) - else: - result = np.abs(F) - - # Calibrate the output field in spatial-frequency units - if output == "psdf": - # Gwyddion uses angular frequency: 2π/dx, 2π/dy - freq_xreal = 2.0 * np.pi * xres / field.xreal - freq_yreal = 2.0 * np.pi * yres / field.yreal - z_unit = f"({field.si_unit_z})^2 m^2" - else: - freq_xreal = xres / field.xreal - freq_yreal = yres / field.yreal - z_unit = field.si_unit_z - - out_field = DataField( - data=result, - xreal=freq_xreal, - yreal=freq_yreal, - si_unit_xy="1/m", - si_unit_z=z_unit, - domain="frequency", - colormap=field.colormap, - ) - return (out_field,) - - -# --------------------------------------------------------------------------- -# InverseFFT2D -# --------------------------------------------------------------------------- - -@register_node(display_name="Inverse 2D FFT") -class InverseFFT2D: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "spectrum": ("DATA_FIELD",), - "representation": (["magnitude", "log_magnitude", "psdf"],), - }, - "optional": { - "phase": ("DATA_FIELD",), - }, - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("image",) - FUNCTION = "process" - - DESCRIPTION = ( - "Reconstruct a spatial-domain image from a 2D frequency spectrum. " - "For exact reconstruction, connect magnitude/phase (or log magnitude/phase, " - "or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed." - ) - - def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple: - if spectrum.domain != "frequency": - raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.") - - if phase is not None: - if phase.data.shape != spectrum.data.shape: - raise ValueError("Phase input must have the same shape as the spectrum.") - if phase.domain != "frequency": - raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.") - - amplitude = self._resolve_amplitude(spectrum, representation) - phase_data = phase.data if phase is not None else np.zeros_like(amplitude) - F = amplitude * np.exp(1j * phase_data) - - spatial = np.fft.ifft2(np.fft.ifftshift(F)).real - xreal, yreal = self._recover_spatial_extent(spectrum, representation) - z_unit = self._recover_z_unit(spectrum, representation, phase) - - out_field = DataField( - data=spatial, - xreal=xreal, - yreal=yreal, - si_unit_xy="m", - si_unit_z=z_unit, - domain="spatial", - colormap=spectrum.colormap, - ) - return (out_field,) - - def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray: - data = np.asarray(spectrum.data, dtype=np.float64) - - if representation == "magnitude": - return np.clip(data, 0.0, None) - if representation == "log_magnitude": - return np.expm1(data) - if representation == "psdf": - xreal, yreal = self._recover_spatial_extent(spectrum, representation) - n = spectrum.xres * spectrum.yres - dx = xreal / spectrum.xres - dy = yreal / spectrum.yres - scale = n * 4.0 * np.pi ** 2 / (dx * dy) - return np.sqrt(np.clip(data, 0.0, None) * scale) - - raise ValueError(f"Unsupported spectrum representation: {representation}") - - def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]: - if representation == "psdf": - xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal - yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal - else: - xreal = spectrum.xres / spectrum.xreal - yreal = spectrum.yres / spectrum.yreal - return float(xreal), float(yreal) - - def _recover_z_unit( - self, - spectrum: DataField, - representation: str, - phase: DataField | None, - ) -> str: - if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip(): - return phase.si_unit_z - - if representation != "psdf": - return spectrum.si_unit_z - - unit = str(spectrum.si_unit_z or "").strip() - if unit.startswith("(") and ")^2 m^2" in unit: - return unit.split(")^2 m^2", 1)[0][1:] - if unit.endswith("^2 m^2"): - return unit[:-6].removesuffix("^2").strip() - return "" - - -# --------------------------------------------------------------------------- -# CrossSection -# --------------------------------------------------------------------------- - -def _extend_to_edges(x1, y1, x2, y2): - """ - Extend the line through (x1,y1)-(x2,y2) to the boundaries of [0,1]x[0,1]. - Returns the two intersection points (clipped to the unit square). - """ - dx = x2 - x1 - dy = y2 - y1 - - # Collect parametric t values where line hits each boundary - t_candidates = [] - if abs(dx) > 1e-12: - for bx in (0.0, 1.0): - t = (bx - x1) / dx - y_at_t = y1 + t * dy - if -1e-9 <= y_at_t <= 1.0 + 1e-9: - t_candidates.append(t) - if abs(dy) > 1e-12: - for by in (0.0, 1.0): - t = (by - y1) / dy - x_at_t = x1 + t * dx - if -1e-9 <= x_at_t <= 1.0 + 1e-9: - t_candidates.append(t) - - if len(t_candidates) < 2: - return x1, y1, x2, y2 - - t_min = min(t_candidates) - t_max = max(t_candidates) - - return ( - np.clip(x1 + t_min * dx, 0, 1), - np.clip(y1 + t_min * dy, 0, 1), - np.clip(x1 + t_max * dx, 0, 1), - np.clip(y1 + t_max * dy, 0, 1), - ) - - -@register_node(display_name="Cross Section") -class CrossSection: - """Extract a 1-D height profile along an arbitrary line across the image.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "x1": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "x2": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "extend": (["none", "to_edges"],), - "n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), - }, - "optional": { - "marker_pair": ("COORDPAIR", {"label": "marker pair"}), - }, - } - - RETURN_TYPES = ("LINE", "COORDPAIR",) - RETURN_NAMES = ("profile", "marker pair",) - FUNCTION = "process" - - DESCRIPTION = ( - "Extract a cross-section profile along a line between two points. " - "Drag the markers on the image to set the line endpoints. " - "Equivalent to gwy_data_field_get_profile." - ) - - _broadcast_overlay_fn = None - _current_node_id: str = "" - - def process( - self, field: DataField, - x1: float, y1: float, x2: float, y2: float, - extend: str, n_samples: int, - marker_pair=None, - ) -> tuple: - from scipy.ndimage import map_coordinates - - # COORDPAIR input overrides widget values - if marker_pair is not None: - (x1, y1), (x2, y2) = marker_pair - - # Remember marker positions (before extend) - marker_x1, marker_y1 = float(x1), float(y1) - marker_x2, marker_y2 = float(x2), float(y2) - - xres, yres = field.xres, field.yres - - if extend == "to_edges": - x1, y1, x2, y2 = _extend_to_edges( - float(x1), float(y1), float(x2), float(y2), - ) - - # Convert fractional [0,1] to pixel indices [0, res-1] - px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1) - px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1) - - # Number of sample points - line_len_px = np.hypot(px2 - px1, py2 - py1) - if n_samples <= 0: - n_samples = max(2, int(np.ceil(line_len_px))) - - # Sample coordinates along the line - t = np.linspace(0, 1, n_samples) - coords_y = py1 + t * (py2 - py1) - coords_x = px1 + t * (px2 - px1) - - # Interpolate values along the line (cubic spline) - profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest") - - # Broadcast overlay image with marker positions - if CrossSection._broadcast_overlay_fn is not None: - # Use the field's native pixel grid for the overlay preview so enlarging - # the panel keeps the image as sharp as the source data allows. - image_uri = encode_preview(datafield_to_uint8(field, field.colormap)) - - CrossSection._broadcast_overlay_fn( - CrossSection._current_node_id, - { - "image": image_uri, - "x1": marker_x1, "y1": marker_y1, - "x2": marker_x2, "y2": marker_y2, - "a_locked": marker_pair is not None, - "b_locked": marker_pair is not None, - }, - ) - - dx_real = (x2 - x1) * field.xreal - dy_real = (y2 - y1) * field.yreal - distance_axis = np.linspace(0.0, float(np.hypot(dx_real, dy_real)), n_samples, dtype=np.float64) - - return ( - LineData( - data=profile.astype(np.float64), - x_axis=distance_axis, - x_unit=field.si_unit_xy, - y_unit=field.si_unit_z, - ), - ((marker_x1, marker_y1), (marker_x2, marker_y2)), - ) - - -# --------------------------------------------------------------------------- -# Shared line-stat helpers used by Stats -# --------------------------------------------------------------------------- - -def _safe_rq(d): - """RMS of deviations from mean.""" - return float(np.sqrt(np.mean(d * d))) - -# Registry: name → (function(z) → float, unit_label) -# All functions receive the raw 1-D profile as float64. -LINE_OPS: dict[str, tuple] = {} - - -def _line_op(name, unit=""): - """Decorator to register a LINE operation.""" - def decorator(fn): - LINE_OPS[name] = (fn, unit) - return fn - return decorator - - -# ── Basic statistics ────────────────────────────────────────────────────── - -@_line_op("min") -def _op_min(z): - return float(z.min()) - -@_line_op("max") -def _op_max(z): - return float(z.max()) - -@_line_op("mean") -def _op_mean(z): - return float(z.mean()) - -@_line_op("median") -def _op_median(z): - return float(np.median(z)) - -@_line_op("sum") -def _op_sum(z): - return float(z.sum()) - -@_line_op("range") -def _op_range(z): - return float(z.max() - z.min()) - -@_line_op("length", unit="pts") -def _op_length(z): - return float(len(z)) - -@_line_op("rms") -def _op_rms(z): - return float(np.sqrt(np.mean(z * z))) - - -# ── Roughness parameters ────────────────────────── - -@_line_op("Ra") -def _op_ra(z): - return float(np.mean(np.abs(z - z.mean()))) - -@_line_op("Rq") -def _op_rq(z): - d = z - z.mean() - return _safe_rq(d) - -@_line_op("Rsk") -def _op_rsk(z): - d = z - z.mean() - rq = _safe_rq(d) - return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0 - -@_line_op("Rku") -def _op_rku(z): - d = z - z.mean() - rq = _safe_rq(d) - return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0 - -@_line_op("Rp") -def _op_rp(z): - return float((z - z.mean()).max()) - -@_line_op("Rv") -def _op_rv(z): - return float(-(z - z.mean()).min()) - -@_line_op("Rt") -def _op_rt(z): - d = z - z.mean() - return float(d.max() - d.min()) - -@_line_op("Dq") -def _op_dq(z): - """RMS slope (first derivative RMS).""" - dz = np.diff(z) - return float(np.sqrt(np.mean(dz * dz))) - -@_line_op("Da") -def _op_da(z): - """Mean absolute slope.""" - return float(np.mean(np.abs(np.diff(z)))) - - -# --------------------------------------------------------------------------- -# Shared record-table helpers used by Stats -# --------------------------------------------------------------------------- - -TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = { - "min": lambda values: float(np.min(values)), - "max": lambda values: float(np.max(values)), - "avg": lambda values: float(np.mean(values)), - "mean": lambda values: float(np.mean(values)), - "median": lambda values: float(np.median(values)), - "sum": lambda values: float(np.sum(values)), - "range": lambda values: float(np.max(values) - np.min(values)), - "std": lambda values: float(np.std(values)), - "variance": lambda values: float(np.var(values)), - "count": lambda values: float(len(values)), -} - -ARRAY_OPS: dict[str, Callable[[np.ndarray], float]] = { - "min": lambda values: float(np.min(values)), - "max": lambda values: float(np.max(values)), - "avg": lambda values: float(np.mean(values)), - "mean": lambda values: float(np.mean(values)), - "median": lambda values: float(np.median(values)), - "sum": lambda values: float(np.sum(values)), - "range": lambda values: float(np.max(values) - np.min(values)), - "std": lambda values: float(np.std(values)), - "variance": lambda values: float(np.var(values)), - "rms": lambda values: float(np.sqrt(np.mean(values * values))), - "count": lambda values: float(values.size), -} - - -def _square_unit(unit: str) -> str: - unit = str(unit or "").strip() - if not unit: - return "" - if any(token in unit for token in ("^", "(", ")", "/", "*", " ")): - return f"({unit})^2" - return f"{unit}^2" - - -def _apply_scalar_unit(base_unit: str, operation: str) -> str: - unit = str(base_unit or "").strip() - if operation == "count": - return "count" - if not unit: - return "" - if operation == "variance": - return _square_unit(unit) - return unit - - -def _common_table_unit(table: list, column: str) -> str: - candidates = [] - seen = set() - unit_key = f"{column}_unit" - - for row in table: - if not isinstance(row, dict): - continue - unit = None - if unit_key in row and isinstance(row.get(unit_key), str): - unit = row.get(unit_key) - elif column == "value" and isinstance(row.get("unit"), str): - unit = row.get("unit") - if unit is None: - continue - unit = unit.strip() - if not unit or unit in seen: - continue - seen.add(unit) - candidates.append(unit) - - if len(candidates) == 1: - return candidates[0] - return "" - - -def _scalar_payload(value: float, unit: str = "") -> dict: - payload = {"value": float(value)} - if isinstance(unit, str) and unit.strip(): - payload["unit"] = unit.strip() - return payload - - -def extract_numeric_table_values(table: list, column: str) -> list[float]: - values = [] - for row in table: - if not isinstance(row, dict) or column not in row: - continue - value = row[column] - if isinstance(value, bool): - continue - try: - numeric = float(value) - except (TypeError, ValueError): - continue - if np.isfinite(numeric): - values.append(numeric) - return values - - -def resolve_table_column_name(table: list, column: str) -> str: - requested = str(column or "").strip() - if requested: - return requested - - if extract_numeric_table_values(table, "value"): - return "value" - - numeric_columns = [] - seen = set() - for row in table: - if not isinstance(row, dict): - continue - for key in row.keys(): - if key in seen: - continue - seen.add(key) - if extract_numeric_table_values(table, key): - numeric_columns.append(key) - - if len(numeric_columns) == 1: - return numeric_columns[0] - if not numeric_columns: - raise ValueError("Stats could not find any numeric columns in the input table.") - raise ValueError( - "Stats found multiple numeric columns; set the column name explicitly." - ) - - -@register_node(display_name="Stats") -class Stats: - """Polymorphic scalar stats node for LINE, RECORD_TABLE, DATA_FIELD, or IMAGE inputs.""" - - _broadcast_value_fn = None - _current_node_id: str = "" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "input": ("STATS_SOURCE",), - "column": ("STRING", { - "default": "value", - "choices_from_table_input": "input", - "show_when_source_type": { - "input": ["RECORD_TABLE"], - }, - }), - "operation": ("STRING", { - "default": "mean", - "choices_by_source_type": { - "LINE": list(LINE_OPS.keys()), - "RECORD_TABLE": list(TABLE_OPS.keys()), - "DATA_FIELD": list(ARRAY_OPS.keys()), - "IMAGE": list(ARRAY_OPS.keys()), - }, - "source_type_input": "input", - }), - } - } - - RETURN_TYPES = ("FLOAT",) - RETURN_NAMES = ("value",) - FUNCTION = "process" - - DESCRIPTION = ( - "Compute a contextual scalar statistic from a LINE, record table, DATA_FIELD, or IMAGE. " - "The available operations adapt to the connected input type." - ) - - def process(self, input, operation: str, column: str = "value") -> tuple: - source_type, values, resolved_column = self._resolve_input_values(input, column) - - if source_type == "RECORD_TABLE": - ops = TABLE_OPS - elif source_type == "LINE": - ops = LINE_OPS - else: - ops = ARRAY_OPS - - if operation not in ops: - raise ValueError(f"Operation '{operation}' is not valid for {source_type} input.") - - op_entry = ops[operation] - fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry - result = fn(values) - if Stats._broadcast_value_fn is not None: - Stats._broadcast_value_fn( - Stats._current_node_id, - _scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)), - ) - return (result,) - - def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str: - if source_type == "DATA_FIELD" and isinstance(input_value, DataField): - return _apply_scalar_unit(input_value.si_unit_z, operation) - - if source_type == "LINE": - line_entry = LINE_OPS.get(operation) - explicit_unit = line_entry[1] if isinstance(line_entry, tuple) and len(line_entry) > 1 else "" - if explicit_unit: - return _apply_scalar_unit(explicit_unit, operation) - if isinstance(input_value, LineData): - return _apply_scalar_unit(input_value.y_unit, operation) - return "" - - if source_type == "RECORD_TABLE" and isinstance(input_value, list) and column: - return _apply_scalar_unit(_common_table_unit(input_value, column), operation) - - return "" - - def _resolve_input_values(self, input_value, column: str) -> tuple[str, np.ndarray, str | None]: - if isinstance(input_value, DataField): - values = np.asarray(input_value.data, dtype=np.float64) - return ("DATA_FIELD", values.ravel(), None) - - if isinstance(input_value, MeasureTable): - raise ValueError("Stats only accepts record tables, not measurement tables.") - - if isinstance(input_value, list): - if not input_value: - raise ValueError("Stats requires a non-empty record table input.") - column_name = resolve_table_column_name(input_value, column) - values = extract_numeric_table_values(input_value, column_name) - if not values: - raise ValueError(f"Column '{column_name}' has no numeric values.") - return ("RECORD_TABLE", np.asarray(values, dtype=np.float64), column_name) - - if isinstance(input_value, LineData): - values = np.asarray(input_value.data, dtype=np.float64) - if values.size == 0: - raise ValueError("Stats requires a non-empty input.") - return ("LINE", values.ravel(), None) - - if isinstance(input_value, np.ndarray): - values = np.asarray(input_value, dtype=np.float64) - if values.size == 0: - raise ValueError("Stats requires a non-empty input.") - if values.ndim == 1: - return ("LINE", values.ravel(), None) - return ("IMAGE", values.ravel(), None) - - raise ValueError(f"Unsupported Stats input type: {type(input_value).__name__}") diff --git a/backend/nodes/annotations.py b/backend/nodes/annotations.py new file mode 100644 index 0000000..4d615c4 --- /dev/null +++ b/backend/nodes/annotations.py @@ -0,0 +1,69 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import COLORMAPS, DataField, normalize_font_spec, resolve_colormap_input + + +@register_node(display_name="Annotations") +class Annotations: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), + "show_scale_bar": ("BOOLEAN", {"default": True}), + "show_color_map": ("BOOLEAN", {"default": True}), + "text_size": ("FLOAT", { + "default": 14.0, + "min": 6.0, + "max": 96.0, + "step": 1.0, + }), + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + "font": ("FONT",), + }, + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("annotated",) + FUNCTION = "render" + + DESCRIPTION = ( + "Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. " + "The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values." + ) + + def render( + self, + field: DataField, + colormap: str, + show_scale_bar: bool, + show_color_map: bool, + text_size: float = 1.0, + colormap_map=None, + font=None, + ) -> tuple: + resolved_colormap = resolve_colormap_input( + colormap, + colormap_input=colormap_map, + inherited=field.colormap, + default="gray", + ) + text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0 + out = field.replace( + colormap=resolved_colormap, + overlays=[ + *field.overlays, + { + "kind": "annotation", + "show_scale_bar": bool(show_scale_bar), + "show_color_map": bool(show_color_map), + "text_size": text_size, + "font": normalize_font_spec(font), + }, + ], + ) + return (out,) diff --git a/backend/nodes/color_map.py b/backend/nodes/color_map.py new file mode 100644 index 0000000..dad6e53 --- /dev/null +++ b/backend/nodes/color_map.py @@ -0,0 +1,48 @@ +from __future__ import annotations +import json +from backend.node_registry import register_node +from backend.data_types import COLORMAPS, DEFAULT_CUSTOM_COLORMAP_STOPS, normalize_colormap_spec + + +@register_node(display_name="Color Map") +class ColorMap: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mode": (["preset", "custom"], {"default": "preset"}), + "preset": (list(COLORMAPS), { + "default": "viridis", + "show_when_widget_value": {"mode": ["preset"]}, + }), + "stops": ("STRING", { + "default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)), + "colormap_stops": True, + "show_when_widget_value": {"mode": ["custom"]}, + }), + } + } + + RETURN_TYPES = ("COLORMAP",) + RETURN_NAMES = ("colormap",) + FUNCTION = "build" + + DESCRIPTION = ( + "Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours " + "and any number of intermediate stops." + ) + + def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple: + if mode == "preset": + return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},) + + try: + raw_stops = stops if stops is not None else stops_json + stops_data = json.loads(raw_stops or "[]") + except json.JSONDecodeError as exc: + raise ValueError("Custom colormap stops must be valid JSON.") from exc + + spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None) + if not (isinstance(spec, dict) and spec.get("mode") == "custom"): + raise ValueError("Custom colormap must include at least min and max colours.") + return (spec,) diff --git a/backend/nodes/colormap_adjust.py b/backend/nodes/colormap_adjust.py new file mode 100644 index 0000000..77b5762 --- /dev/null +++ b/backend/nodes/colormap_adjust.py @@ -0,0 +1,33 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Colormap Adjust") +class ColormapAdjust: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "offset": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}), + "scale": ("FLOAT", {"default": 1.0, "min": 0.05, "max": 4.0, "step": 0.01}), + "auto": ("BUTTON", {"label": "Auto", "set_widgets": {"offset": 0.0, "scale": 1.0}}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("field",) + FUNCTION = "process" + + DESCRIPTION = ( + "Adjust how a DATA_FIELD maps into its colormap without changing the underlying data. " + "offset and scale operate in normalized display coordinates; Auto resets to the full data range." + ) + + def process(self, field: DataField, offset: float, scale: float) -> tuple: + scale = float(scale) + if not np.isfinite(scale) or scale <= 0.0: + raise ValueError("Scale must be a positive number.") + return (field.replace(display_offset=float(offset), display_scale=scale),) diff --git a/backend/nodes/coordinate.py b/backend/nodes/coordinate.py new file mode 100644 index 0000000..c0e7b02 --- /dev/null +++ b/backend/nodes/coordinate.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from backend.node_registry import register_node + + +@register_node(display_name="Coordinate") +class Coordinate: + """Provide a fractional (x, y) point for use with Cross Section or other nodes.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("COORD",) + RETURN_NAMES = ("point",) + FUNCTION = "process" + + DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]." + + def process(self, x: float, y: float) -> tuple: + return ((float(x), float(y)),) diff --git a/backend/nodes/coordinate_pair.py b/backend/nodes/coordinate_pair.py new file mode 100644 index 0000000..6bfeb6c --- /dev/null +++ b/backend/nodes/coordinate_pair.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from backend.node_registry import register_node + + +@register_node(display_name="Coordinate Pair") +class CoordinatePair: + """Provide a pair of Coordinates, for drawing lines between markers, etc.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("COORD",), + "b": ("COORD",), + } + } + + RETURN_TYPES = ("COORDPAIR",) + RETURN_NAMES = ("coord pair",) + FUNCTION = "process" + + DESCRIPTION = "Output a pair of coordinates." + + def process(self, a: tuple, b: tuple) -> tuple: + return ((a, b),) diff --git a/backend/nodes/modify.py b/backend/nodes/crop_resize_field.py similarity index 52% rename from backend/nodes/modify.py rename to backend/nodes/crop_resize_field.py index bb29a3f..3c69a21 100644 --- a/backend/nodes/modify.py +++ b/backend/nodes/crop_resize_field.py @@ -1,52 +1,9 @@ -""" -Modify nodes — geometric transforms for DATA_FIELDs. -""" - from __future__ import annotations - import numpy as np - from backend.node_registry import register_node from backend.data_types import DataField, datafield_to_uint8, encode_preview -# --------------------------------------------------------------------------- -# ColormapAdjust -# --------------------------------------------------------------------------- - -@register_node(display_name="Colormap Adjust") -class ColormapAdjust: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "offset": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}), - "scale": ("FLOAT", {"default": 1.0, "min": 0.05, "max": 4.0, "step": 0.01}), - "auto": ("BUTTON", {"label": "Auto", "set_widgets": {"offset": 0.0, "scale": 1.0}}), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("field",) - FUNCTION = "process" - - DESCRIPTION = ( - "Adjust how a DATA_FIELD maps into its colormap without changing the underlying data. " - "offset and scale operate in normalized display coordinates; Auto resets to the full data range." - ) - - def process(self, field: DataField, offset: float, scale: float) -> tuple: - scale = float(scale) - if not np.isfinite(scale) or scale <= 0.0: - raise ValueError("Scale must be a positive number.") - return (field.replace(display_offset=float(offset), display_scale=scale),) - - -# --------------------------------------------------------------------------- -# CropResizeField -# --------------------------------------------------------------------------- - @register_node(display_name="Crop / Resize") class CropResizeField: @classmethod @@ -190,105 +147,3 @@ class CropResizeField: target_height = max(1, int(round(height * (target_width / width)))) return (max(1, target_width), max(1, target_height)) - - -# --------------------------------------------------------------------------- -# RotateField -# --------------------------------------------------------------------------- - -@register_node(display_name="Rotate") -class RotateField: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "angle": ("FLOAT", {"default": 90.0, "min": -360.0, "max": 360.0, "step": 1.0}), - "interpolation": (["bilinear", "nearest", "bicubic"],), - "expand_canvas": ("BOOLEAN", {"default": True}), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("field",) - FUNCTION = "process" - - DESCRIPTION = ( - "Rotate a DATA_FIELD counterclockwise by an angle in degrees. " - "Optionally expand the canvas to keep the full rotated field while preserving the field center." - ) - - _broadcast_warning_fn = None - _current_node_id: str = "" - - def process( - self, - field: DataField, - angle: float, - interpolation: str, - expand_canvas: bool, - ) -> tuple: - if field.overlays: - self._send_warning("Rotate clears annotation/markup overlays!") - - angle = float(angle) - order_map = { - "nearest": 0, - "bilinear": 1, - "bicubic": 3, - } - if interpolation not in order_map: - raise ValueError(f"Unknown interpolation mode: {interpolation}") - - normalized_angle = angle % 360.0 - snapped_quarters = int(round(normalized_angle / 90.0)) % 4 - snapped_angle = snapped_quarters * 90.0 - is_right_angle = abs(normalized_angle - snapped_angle) < 1e-9 - - if is_right_angle and expand_canvas: - rotated = np.rot90(field.data, k=snapped_quarters).copy() - elif abs(normalized_angle) < 1e-9: - rotated = field.data.copy() - else: - from scipy.ndimage import rotate as nd_rotate - - rotated = nd_rotate( - field.data, - angle=angle, - reshape=bool(expand_canvas), - order=order_map[interpolation], - mode="nearest", - prefilter=order_map[interpolation] > 1, - ) - - new_xreal, new_yreal = self._rotated_extents(field, angle, expand_canvas) - center_x = field.xoff + field.xreal / 2.0 - center_y = field.yoff + field.yreal / 2.0 - - result = field.replace( - data=np.asarray(rotated, dtype=np.float64), - xreal=new_xreal, - yreal=new_yreal, - xoff=center_x - new_xreal / 2.0, - yoff=center_y - new_yreal / 2.0, - overlays=[], - ) - return (result,) - - def _send_warning(self, message: str): - fn = RotateField._broadcast_warning_fn - nid = RotateField._current_node_id - if fn and nid: - fn(nid, message) - - @staticmethod - def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]: - if not expand_canvas: - return (field.xreal, field.yreal) - - theta = np.deg2rad(angle) - cos_t = abs(float(np.cos(theta))) - sin_t = abs(float(np.sin(theta))) - new_xreal = field.xreal * cos_t + field.yreal * sin_t - new_yreal = field.xreal * sin_t + field.yreal * cos_t - return (new_xreal, new_yreal) diff --git a/backend/nodes/cross_section.py b/backend/nodes/cross_section.py new file mode 100644 index 0000000..e63a74c --- /dev/null +++ b/backend/nodes/cross_section.py @@ -0,0 +1,102 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview +from backend.nodes.helpers import _extend_to_edges + + +@register_node(display_name="Cross Section") +class CrossSection: + """Extract a 1-D height profile along an arbitrary line across the image.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "x1": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "x2": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "extend": (["none", "to_edges"],), + "n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), + }, + "optional": { + "marker_pair": ("COORDPAIR", {"label": "marker pair"}), + }, + } + + RETURN_TYPES = ("LINE", "COORDPAIR",) + RETURN_NAMES = ("profile", "marker pair",) + FUNCTION = "process" + + DESCRIPTION = ( + "Extract a cross-section profile along a line between two points. " + "Drag the markers on the image to set the line endpoints. " + "Equivalent to gwy_data_field_get_profile." + ) + + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process( + self, field: DataField, + x1: float, y1: float, x2: float, y2: float, + extend: str, n_samples: int, + marker_pair=None, + ) -> tuple: + from scipy.ndimage import map_coordinates + + if marker_pair is not None: + (x1, y1), (x2, y2) = marker_pair + + marker_x1, marker_y1 = float(x1), float(y1) + marker_x2, marker_y2 = float(x2), float(y2) + + xres, yres = field.xres, field.yres + + if extend == "to_edges": + x1, y1, x2, y2 = _extend_to_edges( + float(x1), float(y1), float(x2), float(y2), + ) + + px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1) + px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1) + + line_len_px = np.hypot(px2 - px1, py2 - py1) + if n_samples <= 0: + n_samples = max(2, int(np.ceil(line_len_px))) + + t = np.linspace(0, 1, n_samples) + coords_y = py1 + t * (py2 - py1) + coords_x = px1 + t * (px2 - px1) + + profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest") + + if CrossSection._broadcast_overlay_fn is not None: + image_uri = encode_preview(datafield_to_uint8(field, field.colormap)) + + CrossSection._broadcast_overlay_fn( + CrossSection._current_node_id, + { + "image": image_uri, + "x1": marker_x1, "y1": marker_y1, + "x2": marker_x2, "y2": marker_y2, + "a_locked": marker_pair is not None, + "b_locked": marker_pair is not None, + }, + ) + + dx_real = (x2 - x1) * field.xreal + dy_real = (y2 - y1) * field.yreal + distance_axis = np.linspace(0.0, float(np.hypot(dx_real, dy_real)), n_samples, dtype=np.float64) + + return ( + LineData( + data=profile.astype(np.float64), + x_axis=distance_axis, + x_unit=field.si_unit_xy, + y_unit=field.si_unit_z, + ), + ((marker_x1, marker_y1), (marker_x2, marker_y2)), + ) diff --git a/backend/nodes/cursors.py b/backend/nodes/cursors.py new file mode 100644 index 0000000..c4655a1 --- /dev/null +++ b/backend/nodes/cursors.py @@ -0,0 +1,173 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, LineData, MeasureTable, encode_preview, render_datafield_preview + + +@register_node(display_name="Cursors") +class Cursors: + """Place two draggable cursors on a line plot or field to measure deltas.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "line": ("CURSOR_SOURCE", {"label": "input"}), + "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + }, + "optional": { + "coord_pair": ("COORDPAIR", {"label": "coord pair"}), + }, + } + + RETURN_TYPES = ("MEASURE_TABLE", "COORDPAIR",) + RETURN_NAMES = ("measurement", "coord pair",) + FUNCTION = "process" + + DESCRIPTION = ( + "Place two cursors on a line plot or 2D field. " + "On lines it reports x/y positions and dx/dy. " + "On fields it reports x/y/z at both markers plus dx/dy/dz." + ) + + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process( + self, line, x1: float, y1: float, x2: float, y2: float, + coord_pair=None, + ) -> tuple: + if coord_pair is not None: + (x1, y1), (x2, y2) = coord_pair + + locked = coord_pair is not None + + if isinstance(line, DataField): + return self._process_field(line, x1=x1, y1=y1, x2=x2, y2=y2, locked=locked) + + return self._process_line(line, x1=x1, y1=y1, x2=x2, y2=y2, locked=locked) + + def _process_line( + self, + line, + x1: float, + y1: float, + x2: float, + y2: float, + locked: bool = False, + ) -> tuple: + y = np.asarray(line, dtype=np.float64).ravel() + x_unit = line.x_unit if isinstance(line, LineData) else "" + y_unit = line.y_unit if isinstance(line, LineData) else "" + n = len(y) + if isinstance(line, LineData) and line.x_axis is not None: + x = np.asarray(line.x_axis, dtype=np.float64).ravel()[:n] + else: + x = np.arange(n, dtype=np.float64) + x1 = float(np.clip(x1, 0.0, 1.0)) + x2 = float(np.clip(x2, 0.0, 1.0)) + + xmin = float(np.min(x)) if len(x) else 0.0 + xmax = float(np.max(x)) if len(x) else 1.0 + + def x_frac_to_idx(frac): + if n <= 1: + return 0 + if xmax == xmin: + return 0 + target_x = xmin + frac * (xmax - xmin) + return int(np.argmin(np.abs(x - target_x))) + + idx_a = x_frac_to_idx(x1) + idx_b = x_frac_to_idx(x2) + + xa, ya = float(x[idx_a]), float(y[idx_a]) + xb, yb = float(x[idx_b]), float(y[idx_b]) + + if Cursors._broadcast_overlay_fn is not None: + Cursors._broadcast_overlay_fn( + Cursors._current_node_id, + { + "kind": "line_plot", + "section_title": "Cursors", + "line": y.tolist(), + "x_axis": x.tolist(), + "x1": x1, + "x2": x2, + "y1": float(y1), + "y2": float(y2), + "a_locked": locked, + "b_locked": locked, + }, + ) + + table = MeasureTable([ + {"quantity": "A x", "value": xa, "unit": x_unit}, + {"quantity": "A y", "value": ya, "unit": y_unit}, + {"quantity": "B x", "value": xb, "unit": x_unit}, + {"quantity": "B y", "value": yb, "unit": y_unit}, + {"quantity": "dx", "value": xb - xa, "unit": x_unit}, + {"quantity": "dy", "value": yb - ya, "unit": y_unit}, + ]) + return (table, ((x1, y1), (x2, y2))) + + def _process_field( + self, + field: DataField, + x1: float, + y1: float, + x2: float, + y2: float, + locked: bool = False, + ) -> tuple: + from scipy.ndimage import map_coordinates + + x1 = float(np.clip(x1, 0.0, 1.0)) + y1 = float(np.clip(y1, 0.0, 1.0)) + x2 = float(np.clip(x2, 0.0, 1.0)) + y2 = float(np.clip(y2, 0.0, 1.0)) + + px1 = x1 * max(field.xres - 1, 0) + py1 = y1 * max(field.yres - 1, 0) + px2 = x2 * max(field.xres - 1, 0) + py2 = y2 * max(field.yres - 1, 0) + + z1 = float(map_coordinates(field.data, [[py1], [px1]], order=1, mode="nearest")[0]) + z2 = float(map_coordinates(field.data, [[py2], [px2]], order=1, mode="nearest")[0]) + + ax = float(field.xoff + x1 * field.xreal) + ay = float(field.yoff + y1 * field.yreal) + bx = float(field.xoff + x2 * field.xreal) + by = float(field.yoff + y2 * field.yreal) + + if Cursors._broadcast_overlay_fn is not None: + Cursors._broadcast_overlay_fn( + Cursors._current_node_id, + { + "kind": "cursor_points", + "section_title": "Cursors", + "image": encode_preview(render_datafield_preview(field, field.colormap)), + "x1": x1, + "y1": y1, + "x2": x2, + "y2": y2, + "a_locked": locked, + "b_locked": locked, + }, + ) + + table = MeasureTable([ + {"quantity": "A x", "value": ax, "unit": field.si_unit_xy}, + {"quantity": "A y", "value": ay, "unit": field.si_unit_xy}, + {"quantity": "A z", "value": z1, "unit": field.si_unit_z}, + {"quantity": "B x", "value": bx, "unit": field.si_unit_xy}, + {"quantity": "B y", "value": by, "unit": field.si_unit_xy}, + {"quantity": "B z", "value": z2, "unit": field.si_unit_z}, + {"quantity": "dx", "value": bx - ax, "unit": field.si_unit_xy}, + {"quantity": "dy", "value": by - ay, "unit": field.si_unit_xy}, + {"quantity": "dz", "value": z2 - z1, "unit": field.si_unit_z}, + ]) + return (table, ((x1, y1), (x2, y2))) diff --git a/backend/nodes/display.py b/backend/nodes/display.py deleted file mode 100644 index 69fbe68..0000000 --- a/backend/nodes/display.py +++ /dev/null @@ -1,743 +0,0 @@ -""" -Display / output nodes. - -Preview accepts both DATA_FIELD and IMAGE via optional inputs — -connect whichever type you have. The server injects _broadcast_fn -before execution begins. -""" - -from __future__ import annotations -import json -import numpy as np -from backend.node_registry import register_node -from backend.data_types import ( - DataField, - MeasureTable, - COLORMAPS, - CUSTOM_FILE_FONT, - DEFAULT_CUSTOM_COLORMAP_STOPS, - SYSTEM_DEFAULT_FONT, - colormap_to_uint8, - datafield_to_uint8, - encode_preview, - image_to_uint8, - list_overlay_font_choices, - normalize_colormap_spec, - normalize_font_spec, - normalize_for_colormap, - render_datafield_preview, - resolve_colormap_input, -) - - -def _measurement_names(table: list) -> list[str]: - names = [] - for row in table: - if not isinstance(row, dict): - continue - quantity = row.get("quantity") - if isinstance(quantity, str) and quantity and quantity not in names: - names.append(quantity) - return names - - -def _measurement_entry(table: list, selection: str) -> dict: - names = _measurement_names(table) - if not names: - raise ValueError("Measurement table has no selectable rows.") - - target = selection if selection in names else names[0] - for row in table: - if isinstance(row, dict) and row.get("quantity") == target: - return row - - raise ValueError(f"Measurement '{target}' was not found.") - - -def _measurement_value(table: list, selection: str) -> float: - row = _measurement_entry(table, selection) - value = row.get("value") - if isinstance(value, bool): - raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") - try: - numeric = float(value) - except (TypeError, ValueError) as exc: - raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") from exc - if np.isfinite(numeric): - return numeric - raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") - - -def _scalar_payload(value: float, unit: str = "") -> dict: - payload = {"value": float(value)} - if isinstance(unit, str) and unit.strip(): - payload["unit"] = unit.strip() - return payload - - -_SI_PREFIXES = [ - (1e24, "Y"), - (1e21, "Z"), - (1e18, "E"), - (1e15, "P"), - (1e12, "T"), - (1e9, "G"), - (1e6, "M"), - (1e3, "k"), - (1.0, ""), - (1e-3, "m"), - (1e-6, "u"), - (1e-9, "n"), - (1e-12, "p"), - (1e-15, "f"), - (1e-18, "a"), - (1e-21, "z"), - (1e-24, "y"), -] -_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"} - - -def _format_numeric(value: float) -> str: - if not np.isfinite(value): - return str(value) - abs_value = abs(value) - if abs_value == 0: - return "0" - if abs_value >= 1e4 or abs_value < 1e-3: - return f"{value:.3e}" - return f"{value:.4g}" - - -def _format_with_unit(value: float, unit: str) -> str: - unit = (unit or "").strip() - if not unit: - return _format_numeric(value) - if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0: - abs_value = abs(value) - for scale, prefix in _SI_PREFIXES: - scaled = abs_value / scale - if 1 <= scaled < 1000: - signed = value / scale - return f"{_format_numeric(signed)} {prefix}{unit}" - return f"{_format_numeric(value)} {unit}" - - -def _nice_length(target: float) -> float: - if not np.isfinite(target) or target <= 0: - return 0.0 - exponent = np.floor(np.log10(target)) - base = 10.0 ** exponent - for step in (5.0, 2.0, 1.0): - candidate = step * base - if candidate <= target: - return candidate - return base - - -def _display_value_range(field: DataField) -> tuple[float, float, float]: - data = np.asarray(field.data, dtype=np.float64) - dmin = float(data.min()) - dmax = float(data.max()) - if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin: - return dmin, dmin, dmin - - offset = float(field.display_offset) - scale = float(field.display_scale) - if not np.isfinite(offset): - offset = 0.0 - if not np.isfinite(scale) or scale <= 0.0: - scale = 1.0 - - low_norm = float(np.clip(offset, 0.0, 1.0)) - high_norm = float(np.clip(offset + scale, 0.0, 1.0)) - if high_norm < low_norm: - low_norm, high_norm = high_norm, low_norm - mid_norm = 0.5 * (low_norm + high_norm) - - span = dmax - dmin - return ( - dmin + low_norm * span, - dmin + mid_norm * span, - dmin + high_norm * span, - ) - - -def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]): - from PIL import Image, ImageDraw, ImageFont - - size_px = max(8, int(round(size_px))) - try: - font = ImageFont.truetype("DejaVuSans.ttf", size_px) - probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0)) - probe_draw = ImageDraw.Draw(probe) - bbox = probe_draw.textbbox((0, 0), text, font=font) - width = max(1, bbox[2] - bbox[0]) - height = max(1, bbox[3] - bbox[1]) - text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0)) - text_draw = ImageDraw.Draw(text_image) - text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255)) - return text_image - except Exception: - font = ImageFont.load_default() - probe = Image.new("L", (1, 1), 0) - probe_draw = ImageDraw.Draw(probe) - bbox = probe_draw.textbbox((0, 0), text, font=font) - width = max(1, bbox[2] - bbox[0]) - height = max(1, bbox[3] - bbox[1]) - mask = Image.new("L", (width, height), 0) - mask_draw = ImageDraw.Draw(mask) - mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255) - - scale = max(1.0, size_px / max(1, height)) - scaled_width = max(1, int(round(width * scale))) - scaled_height = max(1, int(round(height * scale))) - resampling = getattr(Image, "Resampling", Image) - scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR) - - text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0)) - text_image.putalpha(scaled_mask) - return text_image - - -def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str: - if isinstance(color, str): - text = color.strip() - if len(text) == 4 and text.startswith("#"): - text = "#" + "".join(ch * 2 for ch in text[1:]) - if len(text) == 7 and text.startswith("#"): - try: - int(text[1:], 16) - return text.lower() - except ValueError: - pass - return default - - -def _parse_markup_shapes(raw_shapes: str | list | None) -> list[dict[str, object]]: - if isinstance(raw_shapes, str): - try: - raw_shapes = json.loads(raw_shapes or "[]") - except json.JSONDecodeError: - raw_shapes = [] - - if not isinstance(raw_shapes, list): - return [] - - parsed: list[dict[str, object]] = [] - for shape in raw_shapes: - if not isinstance(shape, dict): - continue - - kind = str(shape.get("kind", "")).strip().lower() - if kind not in {"line", "rectangle", "circle", "arrow"}: - continue - - try: - x1 = float(shape.get("x1")) - y1 = float(shape.get("y1")) - x2 = float(shape.get("x2")) - y2 = float(shape.get("y2")) - width = int(round(float(shape.get("width", 3)))) - except (TypeError, ValueError): - continue - - coords = [x1, y1, x2, y2] - if not all(np.isfinite(value) for value in coords): - continue - - parsed.append({ - "kind": kind, - "x1": float(np.clip(x1, 0.0, 1.0)), - "y1": float(np.clip(y1, 0.0, 1.0)), - "x2": float(np.clip(x2, 0.0, 1.0)), - "y2": float(np.clip(y2, 0.0, 1.0)), - "width": max(1, min(128, width)), - "color": _normalize_markup_color(shape.get("color")), - }) - - return parsed - - -def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int): - dx = end[0] - start[0] - dy = end[1] - start[1] - length = float(np.hypot(dx, dy)) - if length <= 1e-6: - radius = max(1.0, width / 2.0) - draw.ellipse( - (start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius), - fill=color, - ) - return - - ux = dx / length - uy = dy / length - head_length = max(10.0, width * 4.0) - head_width = max(8.0, width * 3.0) - shaft_end = ( - end[0] - ux * head_length, - end[1] - uy * head_length, - ) - - draw.line((start, shaft_end), fill=color, width=width) - - px = -uy - py = ux - left = ( - shaft_end[0] + px * head_width / 2.0, - shaft_end[1] + py * head_width / 2.0, - ) - right = ( - shaft_end[0] - px * head_width / 2.0, - shaft_end[1] - py * head_width / 2.0, - ) - draw.polygon([end, left, right], fill=color) - - -def _render_markup_image(image: np.ndarray, shapes: list[dict[str, object]]) -> np.ndarray: - from PIL import Image, ImageDraw - - base = image_to_uint8(image) - if base.ndim == 2: - base = np.repeat(base[:, :, np.newaxis], 3, axis=2) - - canvas = Image.fromarray(base.copy()) - draw = ImageDraw.Draw(canvas) - height, width = base.shape[:2] - - for shape in shapes: - x1 = float(shape["x1"]) * width - y1 = float(shape["y1"]) * height - x2 = float(shape["x2"]) * width - y2 = float(shape["y2"]) * height - color = str(shape["color"]) - stroke_width = int(shape["width"]) - kind = str(shape["kind"]) - - if kind == "line": - draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width) - elif kind == "rectangle": - draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width) - elif kind == "circle": - draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width) - elif kind == "arrow": - _draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width) - - return np.asarray(canvas, dtype=np.uint8) - - -@register_node(display_name="Color Map") -class ColorMap: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mode": (["preset", "custom"], {"default": "preset"}), - "preset": (list(COLORMAPS), { - "default": "viridis", - "show_when_widget_value": {"mode": ["preset"]}, - }), - "stops": ("STRING", { - "default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)), - "colormap_stops": True, - "show_when_widget_value": {"mode": ["custom"]}, - }), - } - } - - RETURN_TYPES = ("COLORMAP",) - RETURN_NAMES = ("colormap",) - FUNCTION = "build" - - DESCRIPTION = ( - "Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours " - "and any number of intermediate stops." - ) - - def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple: - if mode == "preset": - return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},) - - try: - raw_stops = stops if stops is not None else stops_json - stops_data = json.loads(raw_stops or "[]") - except json.JSONDecodeError as exc: - raise ValueError("Custom colormap stops must be valid JSON.") from exc - - spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None) - if not (isinstance(spec, dict) and spec.get("mode") == "custom"): - raise ValueError("Custom colormap must include at least min and max colours.") - return (spec,) - - -@register_node(display_name="Font") -class Font: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], { - "default": SYSTEM_DEFAULT_FONT, - }), - "font_file": ("FILE_PICKER", { - "default": "", - "show_when_widget_value": {"family": [CUSTOM_FILE_FONT]}, - }), - } - } - - RETURN_TYPES = ("FONT",) - RETURN_NAMES = ("font",) - FUNCTION = "build" - - DESCRIPTION = ( - "Build a reusable font spec for annotation overlays. Choose a discovered system font, " - "use the default fallback stack, or point to a custom font file." - ) - - def build(self, family: str, font_file: str = "") -> tuple: - if family == SYSTEM_DEFAULT_FONT: - return (None,) - if family == CUSTOM_FILE_FONT: - return (normalize_font_spec({"path": font_file}),) - return (normalize_font_spec({"family": family}),) - - -@register_node(display_name="Annotations") -class Annotations: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), - "show_scale_bar": ("BOOLEAN", {"default": True}), - "show_color_map": ("BOOLEAN", {"default": True}), - "text_size": ("FLOAT", { - "default": 14.0, - "min": 6.0, - "max": 96.0, - "step": 1.0, - }), - }, - "optional": { - "colormap_map": ("COLORMAP", {"label": "colormap"}), - "font": ("FONT",), - }, - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("annotated",) - FUNCTION = "render" - - DESCRIPTION = ( - "Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. " - "The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values." - ) - - def render( - self, - field: DataField, - colormap: str, - show_scale_bar: bool, - show_color_map: bool, - text_size: float = 1.0, - colormap_map=None, - font=None, - ) -> tuple: - resolved_colormap = resolve_colormap_input( - colormap, - colormap_input=colormap_map, - inherited=field.colormap, - default="gray", - ) - text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0 - out = field.replace( - colormap=resolved_colormap, - overlays=[ - *field.overlays, - { - "kind": "annotation", - "show_scale_bar": bool(show_scale_bar), - "show_color_map": bool(show_color_map), - "text_size": text_size, - "font": normalize_font_spec(font), - }, - ], - ) - return (out,) - - -@register_node(display_name="Markup") -class Markup: - _CUSTOM_PREVIEW = True - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}), - "stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}), - "stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}), - "clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}), - "markup_shapes": ("STRING", {"default": "[]", "hidden": True}), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("annotated",) - FUNCTION = "process" - - DESCRIPTION = ( - "Draw simple vector markup over a DATA_FIELD without flattening the underlying data. " - "Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows." - ) - - _broadcast_overlay_fn = None - _current_node_id: str = "" - - def process( - self, - field: DataField, - shape: str, - stroke_color: str, - stroke_width: int, - markup_shapes: str, - ) -> tuple: - shapes = _parse_markup_shapes(markup_shapes) - out = field.replace( - overlays=[ - *field.overlays, - { - "kind": "markup", - "shapes": shapes, - }, - ], - ) - - if Markup._broadcast_overlay_fn is not None: - Markup._broadcast_overlay_fn( - Markup._current_node_id, - { - "kind": "markup", - "section_title": "Markup", - "image": encode_preview(datafield_to_uint8(field, field.colormap)), - "shape": str(shape), - "stroke_color": _normalize_markup_color(stroke_color), - "stroke_width": max(1, int(stroke_width)), - }, - ) - - return (out,) - - -@register_node(display_name="Preview") -class PreviewImage: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), - }, - "optional": { - "colormap_map": ("COLORMAP", {"label": "colormap"}), - "image": ("IMAGE",), - "field": ("DATA_FIELD",), - } - } - - RETURN_TYPES = () - FUNCTION = "preview" - - OUTPUT_NODE = True - DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input." - - _broadcast_fn = None - _current_node_id: str = "" - - def preview( - self, - colormap: str, - image: np.ndarray | None = None, - field=None, - colormap_map=None, - ) -> tuple: - resolved_colormap = resolve_colormap_input( - colormap, - colormap_input=colormap_map, - inherited=field.colormap if field is not None else None, - default="gray", - ) - - # Prefer field if both are connected; accept whichever is provided - if field is not None: - arr_u8 = render_datafield_preview(field, resolved_colormap) - elif image is not None: - arr_u8 = image_to_uint8(image) - if arr_u8.ndim == 2: - if image.dtype == np.uint8: - normalized = arr_u8.astype(np.float64) / 255.0 - else: - imin, imax = image.min(), image.max() - if imax > imin: - normalized = (image - imin) / (imax - imin) - else: - normalized = np.zeros_like(image, dtype=np.float64) - arr_u8 = colormap_to_uint8(normalized, resolved_colormap) - else: - raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.") - - data_uri = encode_preview(arr_u8) - - if PreviewImage._broadcast_fn is not None: - PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri) - - return () - - -@register_node(display_name="3D View") -class View3D: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), - "z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}), - "resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}), - }, - "optional": { - "colormap_map": ("COLORMAP", {"label": "colormap"}), - }, - } - - RETURN_TYPES = () - FUNCTION = "render" - - OUTPUT_NODE = True - DESCRIPTION = ( - "Interactive 3D surface view of a DATA_FIELD. " - "Drag to rotate, scroll to zoom. z_scale exaggerates height." - ) - - _broadcast_mesh_fn = None - _current_node_id: str = "" - - def render( - self, field: DataField, - colormap: str, z_scale: float, resolution: int, colormap_map=None, - ) -> tuple: - import base64 - - data = field.data - yres, xres = data.shape - - # Downsample if larger than resolution - step_y = max(1, yres // resolution) - step_x = max(1, xres // resolution) - z = data[::step_y, ::step_x].astype(np.float32) - ny, nx = z.shape - - # Normalize for colormap - zmin, zmax = float(z.min()), float(z.max()) - z_norm = normalize_for_colormap( - z, - offset=field.display_offset, - scale=field.display_scale, - data_min=float(field.data.min()), - data_max=float(field.data.max()), - ) - - resolved_colormap = resolve_colormap_input( - colormap, - colormap_input=colormap_map, - inherited=field.colormap, - default="gray", - ) - colors_u8 = colormap_to_uint8(z_norm, resolved_colormap) - - # Base64-encode arrays for efficient WS transport - z_b64 = base64.b64encode(z.tobytes()).decode() - colors_b64 = base64.b64encode(colors_u8.tobytes()).decode() - - mesh_data = { - "width": nx, - "height": ny, - "z_data": z_b64, - "colors": colors_b64, - "z_min": zmin, - "z_max": zmax, - "z_scale": float(z_scale * 0.1), - "x_range": [float(field.xoff), float(field.xoff + field.xreal)], - "y_range": [float(field.yoff), float(field.yoff + field.yreal)], - } - - if View3D._broadcast_mesh_fn is not None: - View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data) - - return () - - -@register_node(display_name="Print Table") -class PrintTable: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "table": ("ANY_TABLE",), - } - } - - RETURN_TYPES = () - FUNCTION = "print_table" - - OUTPUT_NODE = True - DESCRIPTION = "Send a measurement or record table to the browser as a WebSocket message for display." - - _broadcast_table_fn = None - _current_node_id: str = "" - - def print_table(self, table: list) -> tuple: - if PrintTable._broadcast_table_fn is not None: - PrintTable._broadcast_table_fn(PrintTable._current_node_id, table) - return () - - -@register_node(display_name="Value Display") -class ValueDisplay: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "value": ("VALUE_SOURCE",), - "measurement": ("STRING", { - "default": "", - "choices_from_measure_input": "value", - "show_when_source_type": { - "value": ["MEASURE_TABLE"], - }, - }), - } - } - - RETURN_TYPES = ("FLOAT",) - RETURN_NAMES = ("value",) - FUNCTION = "display_value" - - DESCRIPTION = "Display a FLOAT, or a selected numeric row from a measurement table, and pass the value through unchanged." - - _broadcast_value_fn = None - _current_node_id: str = "" - - def display_value(self, value, measurement: str = "") -> tuple: - unit = "" - if isinstance(value, MeasureTable): - row = _measurement_entry(value, measurement) - numeric = _measurement_value(value, measurement) - unit = row.get("unit", "") if isinstance(row.get("unit"), str) else "" - else: - numeric = float(value) - if ValueDisplay._broadcast_value_fn is not None: - ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit)) - return (numeric,) diff --git a/backend/nodes/draw_mask.py b/backend/nodes/draw_mask.py new file mode 100644 index 0000000..e39a888 --- /dev/null +++ b/backend/nodes/draw_mask.py @@ -0,0 +1,56 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, datafield_to_uint8, encode_preview +from backend.nodes.helpers import _parse_mask_strokes, _rasterize_mask + + +@register_node(display_name="Draw Mask") +class DrawMask: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}), + "invert": ("BOOLEAN", {"default": False}), + "clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}), + "mask_paths": ("STRING", {"default": "[]", "hidden": True}), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + + DESCRIPTION = ( + "Paint a binary mask directly over an image preview. " + "Pen size controls newly drawn strokes, the overlay lets you clear the mask, " + "and invert flips the final binary output." + ) + + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple: + strokes = _parse_mask_strokes(mask_paths) + mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size) + if invert: + mask = np.where(mask > 127, np.uint8(0), np.uint8(255)) + + if DrawMask._broadcast_overlay_fn is not None: + DrawMask._broadcast_overlay_fn( + DrawMask._current_node_id, + { + "kind": "mask_paint", + "section_title": "Mask", + "image": encode_preview(datafield_to_uint8(field, "gray")), + "image_width": field.xres, + "image_height": field.yres, + "invert": bool(invert), + }, + ) + + return (mask,) diff --git a/backend/nodes/edge_detect.py b/backend/nodes/edge_detect.py new file mode 100644 index 0000000..902deb3 --- /dev/null +++ b/backend/nodes/edge_detect.py @@ -0,0 +1,47 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Edge Detect") +class EdgeDetect: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "method": (["sobel", "prewitt", "laplacian", "log"],), + "sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("edges",) + FUNCTION = "process" + + DESCRIPTION = ( + "Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. " + "Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian." + ) + + def process(self, field: DataField, method: str, sigma: float) -> tuple: + from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace + data = field.data + + if method == "sobel": + sx = sobel(data, axis=1) + sy = sobel(data, axis=0) + result = np.hypot(sx, sy) + elif method == "prewitt": + px = prewitt(data, axis=1) + py = prewitt(data, axis=0) + result = np.hypot(px, py) + elif method == "laplacian": + result = laplace(data) + elif method == "log": + result = gaussian_laplace(data, sigma=float(sigma)) + else: + raise ValueError(f"Unknown edge detection method: {method}") + + return (field.replace(data=result),) diff --git a/backend/nodes/fft_2d.py b/backend/nodes/fft_2d.py new file mode 100644 index 0000000..f604b0e --- /dev/null +++ b/backend/nodes/fft_2d.py @@ -0,0 +1,115 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="2D FFT") +class FFT2D: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "windowing": (["hann", "hamming", "blackman", "none"],), + "level": (["mean", "plane", "none"],), + } + } + + RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "DATA_FIELD", "DATA_FIELD") + RETURN_NAMES = ("log_magnitude", "magnitude", "phase", "psdf") + FUNCTION = "process" + + DESCRIPTION = ( + "Compute the 2D FFT with optional windowing and mean/plane subtraction. " + "Outputs log magnitude, magnitude, phase, and PSDF as separate channels. " + "Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf." + ) + + def process(self, field: DataField, windowing: str, level: str) -> tuple: + data = field.data.copy() + yres, xres = data.shape + + if level == "mean": + data -= data.mean() + elif level == "plane": + yy, xx = np.mgrid[0:yres, 0:xres] + xx_f = xx.ravel().astype(np.float64) + yy_f = yy.ravel().astype(np.float64) + zz_f = data.ravel() + A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f]) + coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None) + plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy) + data -= plane + + if windowing != "none": + t_y = (np.arange(yres) + 0.5) / yres + t_x = (np.arange(xres) + 0.5) / xres + if windowing == "hann": + wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y) + wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x) + elif windowing == "hamming": + wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y) + wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x) + elif windowing == "blackman": + wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y) + wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x) + else: + wy = np.ones(yres) + wx = np.ones(xres) + data *= np.outer(wy, wx) + + F = np.fft.fftshift(np.fft.fft2(data)) + n = xres * yres + + magnitude = np.abs(F) + log_magnitude = np.log1p(magnitude) + phase = np.angle(F) + + dx = field.xreal / xres + dy = field.yreal / yres + psdf = (magnitude ** 2) * dx * dy / (n * 4.0 * np.pi ** 2) + + spatial_freq_xreal = xres / field.xreal + spatial_freq_yreal = yres / field.yreal + angular_freq_xreal = 2.0 * np.pi * xres / field.xreal + angular_freq_yreal = 2.0 * np.pi * yres / field.yreal + + return ( + DataField( + data=log_magnitude, + xreal=spatial_freq_xreal, + yreal=spatial_freq_yreal, + si_unit_xy="1/m", + si_unit_z=field.si_unit_z, + domain="frequency", + colormap=field.colormap, + ), + DataField( + data=magnitude, + xreal=spatial_freq_xreal, + yreal=spatial_freq_yreal, + si_unit_xy="1/m", + si_unit_z=field.si_unit_z, + domain="frequency", + colormap=field.colormap, + ), + DataField( + data=phase, + xreal=spatial_freq_xreal, + yreal=spatial_freq_yreal, + si_unit_xy="1/m", + si_unit_z=field.si_unit_z, + domain="frequency", + colormap=field.colormap, + ), + DataField( + data=psdf, + xreal=angular_freq_xreal, + yreal=angular_freq_yreal, + si_unit_xy="1/m", + si_unit_z=f"({field.si_unit_z})^2 m^2", + domain="frequency", + colormap=field.colormap, + ), + ) diff --git a/backend/nodes/fft_filter_1d.py b/backend/nodes/fft_filter_1d.py new file mode 100644 index 0000000..0817c9e --- /dev/null +++ b/backend/nodes/fft_filter_1d.py @@ -0,0 +1,62 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import LineData +from backend.nodes.helpers import _cached_1d_transfer + + +@register_node(display_name="1D FFT Filter") +class FFTFilter1D: + """Bandpass / lowpass / highpass / notch filtering of 1-D line profiles. + + Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth + transfer function with configurable order for a smooth roll-off. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "line": ("LINE",), + "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), + "cutoff": ("FLOAT", { + "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "cutoff_high": ("FLOAT", { + "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), + } + } + + RETURN_TYPES = ("LINE",) + RETURN_NAMES = ("filtered",) + FUNCTION = "process" + + DESCRIPTION = ( + "Frequency-domain filtering of a 1-D line profile. " + "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " + "with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. " + "Equivalent to Gwyddion fft_filter_1d." + ) + + def process(self, line, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> tuple: + z = np.asarray(line, dtype=np.float64).ravel() + n = len(z) + + Z = np.fft.rfft(z) + H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order)) + Z *= H + filtered = np.fft.irfft(Z, n=n) + + if isinstance(line, LineData): + return ( + LineData( + data=filtered, + x_axis=line.x_axis.copy() if line.x_axis is not None else None, + x_unit=line.x_unit, + y_unit=line.y_unit, + ), + ) + return (filtered,) diff --git a/backend/nodes/fft_filter_2d.py b/backend/nodes/fft_filter_2d.py new file mode 100644 index 0000000..da80d3d --- /dev/null +++ b/backend/nodes/fft_filter_2d.py @@ -0,0 +1,62 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField +from backend.nodes.helpers import _cached_2d_transfer + + +@register_node(display_name="2D FFT Filter") +class FFTFilter2D: + """Frequency-domain filtering of 2-D data fields (images). + + Equivalent to Gwyddion's fft_filter_2d module. Applies a radial + Butterworth transfer function in the frequency domain to remove or + isolate periodic features. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), + "cutoff": ("FLOAT", { + "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "cutoff_high": ("FLOAT", { + "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("filtered",) + FUNCTION = "process" + + DESCRIPTION = ( + "Frequency-domain filtering of a 2-D data field. " + "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " + "with a radial Butterworth roll-off. Cutoffs are fractions of the " + "Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or " + "bandpass/notch to isolate or remove periodic noise. " + "Equivalent to Gwyddion fft_filter_2d." + ) + + def process(self, field: DataField, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> tuple: + data = field.data + yres, xres = data.shape + + mean_val = float(data.mean()) + centered = data - mean_val + + spectrum = np.fft.rfft2(centered) + transfer = _cached_2d_transfer( + yres, xres, filter_type, + float(cutoff), float(cutoff_high), int(order), + ) + result = np.fft.irfft2(spectrum * transfer, s=(yres, xres)) + result += mean_val + + return (field.replace(data=result),) diff --git a/backend/nodes/filters.py b/backend/nodes/filters.py deleted file mode 100644 index f4aa850..0000000 --- a/backend/nodes/filters.py +++ /dev/null @@ -1,332 +0,0 @@ -""" -Filter nodes — Gwyddion-equivalent image filters. - -Gwyddion equivalents: - GaussianFilter → gwy_data_field_filter_gaussian - MedianFilter → gwy_data_field_filter_median - EdgeDetect → gwy_data_field_filter_sobel / laplacian / log - FFTFilter1D → fft_filter_1d.c (bandpass/lowpass/highpass on LINE profiles) - FFTFilter2D → fft_filter_2d.c (frequency-domain filtering of DATA_FIELDs) -""" - -from __future__ import annotations -from functools import lru_cache -import numpy as np -from backend.node_registry import register_node -from backend.data_types import DataField, LineData - - -# --------------------------------------------------------------------------- -# GaussianFilter -# --------------------------------------------------------------------------- - -@register_node(display_name="Gaussian Filter") -class GaussianFilter: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("filtered",) - FUNCTION = "process" - - DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian." - - def process(self, field: DataField, sigma: float) -> tuple: - from scipy.ndimage import gaussian_filter - data = gaussian_filter(field.data, sigma=float(sigma)) - return (field.replace(data=data),) - - -# --------------------------------------------------------------------------- -# MedianFilter -# --------------------------------------------------------------------------- - -@register_node(display_name="Median Filter") -class MedianFilter: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("filtered",) - FUNCTION = "process" - - DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median." - - def process(self, field: DataField, size: int) -> tuple: - from scipy.ndimage import median_filter - size = max(1, int(size)) - data = median_filter(field.data, size=size) - return (field.replace(data=data),) - - -# --------------------------------------------------------------------------- -# EdgeDetect -# --------------------------------------------------------------------------- - -@register_node(display_name="Edge Detect") -class EdgeDetect: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "method": (["sobel", "prewitt", "laplacian", "log"],), - "sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("edges",) - FUNCTION = "process" - - DESCRIPTION = ( - "Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. " - "Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian." - ) - - def process(self, field: DataField, method: str, sigma: float) -> tuple: - from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace - data = field.data - - if method == "sobel": - sx = sobel(data, axis=1) - sy = sobel(data, axis=0) - result = np.hypot(sx, sy) - elif method == "prewitt": - px = prewitt(data, axis=1) - py = prewitt(data, axis=0) - result = np.hypot(px, py) - elif method == "laplacian": - result = laplace(data) - elif method == "log": - result = gaussian_laplace(data, sigma=float(sigma)) - else: - raise ValueError(f"Unknown edge detection method: {method}") - - return (field.replace(data=result),) - - -# --------------------------------------------------------------------------- -# Butterworth transfer function helpers -# --------------------------------------------------------------------------- - -def _butterworth_lp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray: - """Butterworth lowpass: H = 1 / (1 + (f/fc)^(2n)).""" - with np.errstate(divide="ignore", over="ignore"): - return 1.0 / (1.0 + (freq / cutoff) ** (2 * order)) - - -def _butterworth_hp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray: - """Butterworth highpass: H = 1 / (1 + (fc/f)^(2n)).""" - with np.errstate(divide="ignore", invalid="ignore"): - h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order)) - h = np.where(np.isfinite(h), h, 0.0) - return h - - -def _build_1d_transfer(n: int, filter_type: str, cutoff: float, - cutoff_high: float, order: int) -> np.ndarray: - """Build a 1-D transfer function for an FFT of length *n*. - - Frequencies are normalised so that 1.0 = Nyquist (fs/2). - The returned array has the same layout as np.fft.rfft output (length n//2+1). - """ - freq = np.linspace(0, 1, n // 2 + 1) - - if filter_type == "lowpass": - H = _butterworth_lp(freq, cutoff, order) - elif filter_type == "highpass": - H = _butterworth_hp(freq, cutoff, order) - elif filter_type == "bandpass": - H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order) - elif filter_type == "notch": - bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order) - H = 1.0 - bp - else: - H = np.ones_like(freq) - return H - - -@lru_cache(maxsize=64) -def _cached_1d_transfer(n: int, filter_type: str, cutoff: float, - cutoff_high: float, order: int) -> np.ndarray: - transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order) - transfer.setflags(write=False) - return transfer - - -@lru_cache(maxsize=32) -def _fft_radius_grid(yres: int, xres: int) -> np.ndarray: - fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0 - fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0 - radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0) - np.clip(radius, 0.0, 1.0, out=radius) - radius.setflags(write=False) - return radius - - -@lru_cache(maxsize=128) -def _cached_2d_transfer(yres: int, xres: int, filter_type: str, - cutoff: float, cutoff_high: float, order: int) -> np.ndarray: - radius = _fft_radius_grid(yres, xres) - - if filter_type == "lowpass": - transfer = _butterworth_lp(radius, cutoff, order) - elif filter_type == "highpass": - transfer = _butterworth_hp(radius, cutoff, order) - elif filter_type == "bandpass": - transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order) - elif filter_type == "notch": - band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order) - transfer = 1.0 - band - else: - transfer = np.ones_like(radius) - - transfer.setflags(write=False) - return transfer - - -# --------------------------------------------------------------------------- -# FFTFilter1D — frequency-domain filtering of LINE profiles -# --------------------------------------------------------------------------- - -@register_node(display_name="1D FFT Filter") -class FFTFilter1D: - """Bandpass / lowpass / highpass / notch filtering of 1-D line profiles. - - Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth - transfer function with configurable order for a smooth roll-off. - """ - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "line": ("LINE",), - "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), - "cutoff": ("FLOAT", { - "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, - }), - "cutoff_high": ("FLOAT", { - "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, - }), - "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), - } - } - - RETURN_TYPES = ("LINE",) - RETURN_NAMES = ("filtered",) - FUNCTION = "process" - - DESCRIPTION = ( - "Frequency-domain filtering of a 1-D line profile. " - "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " - "with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. " - "Equivalent to Gwyddion fft_filter_1d." - ) - - def process(self, line, filter_type: str, cutoff: float, - cutoff_high: float, order: int) -> tuple: - z = np.asarray(line, dtype=np.float64).ravel() - n = len(z) - - # Forward FFT (real-valued) - Z = np.fft.rfft(z) - - # Build and apply transfer function - H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order)) - Z *= H - - # Inverse FFT - filtered = np.fft.irfft(Z, n=n) - if isinstance(line, LineData): - return ( - LineData( - data=filtered, - x_axis=line.x_axis.copy() if line.x_axis is not None else None, - x_unit=line.x_unit, - y_unit=line.y_unit, - ), - ) - return (filtered,) - - -# --------------------------------------------------------------------------- -# FFTFilter2D — frequency-domain filtering of DATA_FIELDs -# --------------------------------------------------------------------------- - -@register_node(display_name="2D FFT Filter") -class FFTFilter2D: - """Frequency-domain filtering of 2-D data fields (images). - - Equivalent to Gwyddion's fft_filter_2d module. Applies a radial - Butterworth transfer function in the frequency domain to remove or - isolate periodic features. - """ - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), - "cutoff": ("FLOAT", { - "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, - }), - "cutoff_high": ("FLOAT", { - "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, - }), - "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("filtered",) - FUNCTION = "process" - - DESCRIPTION = ( - "Frequency-domain filtering of a 2-D data field. " - "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " - "with a radial Butterworth roll-off. Cutoffs are fractions of the " - "Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or " - "bandpass/notch to isolate or remove periodic noise. " - "Equivalent to Gwyddion fft_filter_2d." - ) - - def process(self, field: DataField, filter_type: str, cutoff: float, - cutoff_high: float, order: int) -> tuple: - data = field.data - yres, xres = data.shape - - # Subtract mean to avoid DC leakage artefacts. - mean_val = float(data.mean()) - centered = data - mean_val - - # Real-valued FFT keeps only the unique half-plane and avoids shift copies. - spectrum = np.fft.rfft2(centered) - transfer = _cached_2d_transfer( - yres, - xres, - filter_type, - float(cutoff), - float(cutoff_high), - int(order), - ) - result = np.fft.irfft2(spectrum * transfer, s=(yres, xres)) - - # Restore DC - result += mean_val - - return (field.replace(data=result),) diff --git a/backend/nodes/fix_zero.py b/backend/nodes/fix_zero.py new file mode 100644 index 0000000..f556147 --- /dev/null +++ b/backend/nodes/fix_zero.py @@ -0,0 +1,37 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Fix Zero") +class FixZero: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "method": (["min", "mean", "median"],), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("zeroed",) + FUNCTION = "process" + + DESCRIPTION = ( + "Shift data so that the minimum (or mean/median) is zero. " + "Equivalent to fix_zero in Gwyddion's level.c." + ) + + def process(self, field: DataField, method: str) -> tuple: + data = field.data.copy() + if method == "min": + data -= data.min() + elif method == "mean": + data -= data.mean() + elif method == "median": + data -= np.median(data) + else: + raise ValueError(f"Unknown method: {method}") + return (field.replace(data=data),) diff --git a/backend/nodes/folder.py b/backend/nodes/folder.py new file mode 100644 index 0000000..cd45137 --- /dev/null +++ b/backend/nodes/folder.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from backend.node_registry import register_node +from backend.nodes.helpers import list_folder_paths + + +@register_node(display_name="Folder") +class Folder: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}), + } + } + + RETURN_TYPES = ("DIRECTORY",) + RETURN_NAMES = ("directory",) + FUNCTION = "list_files" + + DESCRIPTION = ( + "Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. " + "Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans." + ) + + def list_files(self, folder: str) -> tuple: + entries = list_folder_paths(folder) + if not entries: + return tuple() + return tuple(item["path"] for item in entries) diff --git a/backend/nodes/font_node.py b/backend/nodes/font_node.py new file mode 100644 index 0000000..ae96f1f --- /dev/null +++ b/backend/nodes/font_node.py @@ -0,0 +1,36 @@ +from __future__ import annotations +from backend.node_registry import register_node +from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT, list_overlay_font_choices, normalize_font_spec + + +@register_node(display_name="Font") +class Font: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], { + "default": SYSTEM_DEFAULT_FONT, + }), + "font_file": ("FILE_PICKER", { + "default": "", + "show_when_widget_value": {"family": [CUSTOM_FILE_FONT]}, + }), + } + } + + RETURN_TYPES = ("FONT",) + RETURN_NAMES = ("font",) + FUNCTION = "build" + + DESCRIPTION = ( + "Build a reusable font spec for annotation overlays. Choose a discovered system font, " + "use the default fallback stack, or point to a custom font file." + ) + + def build(self, family: str, font_file: str = "") -> tuple: + if family == SYSTEM_DEFAULT_FONT: + return (None,) + if family == CUSTOM_FILE_FONT: + return (normalize_font_spec({"path": font_file}),) + return (normalize_font_spec({"family": family}),) diff --git a/backend/nodes/gaussian_filter.py b/backend/nodes/gaussian_filter.py new file mode 100644 index 0000000..b6d0be5 --- /dev/null +++ b/backend/nodes/gaussian_filter.py @@ -0,0 +1,26 @@ +from __future__ import annotations +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Gaussian Filter") +class GaussianFilter: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("filtered",) + FUNCTION = "process" + + DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian." + + def process(self, field: DataField, sigma: float) -> tuple: + from scipy.ndimage import gaussian_filter + data = gaussian_filter(field.data, sigma=float(sigma)) + return (field.replace(data=data),) diff --git a/backend/nodes/helpers.py b/backend/nodes/helpers.py new file mode 100644 index 0000000..756bc46 --- /dev/null +++ b/backend/nodes/helpers.py @@ -0,0 +1,873 @@ +""" +Shared helper functions for argonode nodes. +""" + +from __future__ import annotations +import json +from functools import lru_cache +from pathlib import Path +from typing import Callable + +import numpy as np + +from backend.runtime_paths import demo_dir, input_dir, output_dir + +# --------------------------------------------------------------------------- +# Scalar payload helpers (from display.py) +# --------------------------------------------------------------------------- + +def _scalar_payload(value: float, unit: str = "") -> dict: + payload = {"value": float(value)} + if isinstance(unit, str) and unit.strip(): + payload["unit"] = unit.strip() + return payload + + +# --------------------------------------------------------------------------- +# Measurement helpers (from display.py — used by ValueDisplay) +# --------------------------------------------------------------------------- + +def _measurement_names(table: list) -> list[str]: + names = [] + for row in table: + if not isinstance(row, dict): + continue + quantity = row.get("quantity") + if isinstance(quantity, str) and quantity and quantity not in names: + names.append(quantity) + return names + + +def _measurement_entry(table: list, selection: str) -> dict: + names = _measurement_names(table) + if not names: + raise ValueError("Measurement table has no selectable rows.") + + target = selection if selection in names else names[0] + for row in table: + if isinstance(row, dict) and row.get("quantity") == target: + return row + + raise ValueError(f"Measurement '{target}' was not found.") + + +def _measurement_value(table: list, selection: str) -> float: + row = _measurement_entry(table, selection) + value = row.get("value") + if isinstance(value, bool): + raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") + try: + numeric = float(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") from exc + if np.isfinite(numeric): + return numeric + raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") + + +# --------------------------------------------------------------------------- +# SI formatting helpers (from display.py — used by Annotations) +# --------------------------------------------------------------------------- + +_SI_PREFIXES = [ + (1e24, "Y"), (1e21, "Z"), (1e18, "E"), (1e15, "P"), (1e12, "T"), + (1e9, "G"), (1e6, "M"), (1e3, "k"), (1.0, ""), (1e-3, "m"), + (1e-6, "u"), (1e-9, "n"), (1e-12, "p"), (1e-15, "f"), + (1e-18, "a"), (1e-21, "z"), (1e-24, "y"), +] +_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "\u03a9"} + + +def _format_numeric(value: float) -> str: + if not np.isfinite(value): + return str(value) + abs_value = abs(value) + if abs_value == 0: + return "0" + if abs_value >= 1e4 or abs_value < 1e-3: + return f"{value:.3e}" + return f"{value:.4g}" + + +def _format_with_unit(value: float, unit: str) -> str: + unit = (unit or "").strip() + if not unit: + return _format_numeric(value) + if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0: + abs_value = abs(value) + for scale, prefix in _SI_PREFIXES: + scaled = abs_value / scale + if 1 <= scaled < 1000: + signed = value / scale + return f"{_format_numeric(signed)} {prefix}{unit}" + return f"{_format_numeric(value)} {unit}" + + +def _nice_length(target: float) -> float: + if not np.isfinite(target) or target <= 0: + return 0.0 + exponent = np.floor(np.log10(target)) + base = 10.0 ** exponent + for step in (5.0, 2.0, 1.0): + candidate = step * base + if candidate <= target: + return candidate + return base + + +def _display_value_range(field) -> tuple[float, float, float]: + data = np.asarray(field.data, dtype=np.float64) + dmin = float(data.min()) + dmax = float(data.max()) + if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin: + return dmin, dmin, dmin + + offset = float(field.display_offset) + scale = float(field.display_scale) + if not np.isfinite(offset): + offset = 0.0 + if not np.isfinite(scale) or scale <= 0.0: + scale = 1.0 + + low_norm = float(np.clip(offset, 0.0, 1.0)) + high_norm = float(np.clip(offset + scale, 0.0, 1.0)) + if high_norm < low_norm: + low_norm, high_norm = high_norm, low_norm + mid_norm = 0.5 * (low_norm + high_norm) + + span = dmax - dmin + return ( + dmin + low_norm * span, + dmin + mid_norm * span, + dmin + high_norm * span, + ) + + +def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]): + from PIL import Image, ImageDraw, ImageFont + + size_px = max(8, int(round(size_px))) + try: + font = ImageFont.truetype("DejaVuSans.ttf", size_px) + probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0)) + probe_draw = ImageDraw.Draw(probe) + bbox = probe_draw.textbbox((0, 0), text, font=font) + width = max(1, bbox[2] - bbox[0]) + height = max(1, bbox[3] - bbox[1]) + text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + text_draw = ImageDraw.Draw(text_image) + text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255)) + return text_image + except Exception: + font = ImageFont.load_default() + probe = Image.new("L", (1, 1), 0) + probe_draw = ImageDraw.Draw(probe) + bbox = probe_draw.textbbox((0, 0), text, font=font) + width = max(1, bbox[2] - bbox[0]) + height = max(1, bbox[3] - bbox[1]) + mask = Image.new("L", (width, height), 0) + mask_draw = ImageDraw.Draw(mask) + mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255) + + scale = max(1.0, size_px / max(1, height)) + scaled_width = max(1, int(round(width * scale))) + scaled_height = max(1, int(round(height * scale))) + resampling = getattr(Image, "Resampling", Image) + scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR) + + text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0)) + text_image.putalpha(scaled_mask) + return text_image + + +# --------------------------------------------------------------------------- +# Markup helpers (from display.py — used by Markup) +# --------------------------------------------------------------------------- + +def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str: + if isinstance(color, str): + text = color.strip() + if len(text) == 4 and text.startswith("#"): + text = "#" + "".join(ch * 2 for ch in text[1:]) + if len(text) == 7 and text.startswith("#"): + try: + int(text[1:], 16) + return text.lower() + except ValueError: + pass + return default + + +def _parse_markup_shapes(raw_shapes) -> list[dict]: + if isinstance(raw_shapes, str): + try: + raw_shapes = json.loads(raw_shapes or "[]") + except json.JSONDecodeError: + raw_shapes = [] + + if not isinstance(raw_shapes, list): + return [] + + parsed = [] + for shape in raw_shapes: + if not isinstance(shape, dict): + continue + + kind = str(shape.get("kind", "")).strip().lower() + if kind not in {"line", "rectangle", "circle", "arrow"}: + continue + + try: + x1 = float(shape.get("x1")) + y1 = float(shape.get("y1")) + x2 = float(shape.get("x2")) + y2 = float(shape.get("y2")) + width = int(round(float(shape.get("width", 3)))) + except (TypeError, ValueError): + continue + + coords = [x1, y1, x2, y2] + if not all(np.isfinite(value) for value in coords): + continue + + parsed.append({ + "kind": kind, + "x1": float(np.clip(x1, 0.0, 1.0)), + "y1": float(np.clip(y1, 0.0, 1.0)), + "x2": float(np.clip(x2, 0.0, 1.0)), + "y2": float(np.clip(y2, 0.0, 1.0)), + "width": max(1, min(128, width)), + "color": _normalize_markup_color(shape.get("color")), + }) + + return parsed + + +def _draw_arrow(draw, start, end, color, width): + dx = end[0] - start[0] + dy = end[1] - start[1] + length = float(np.hypot(dx, dy)) + if length <= 1e-6: + radius = max(1.0, width / 2.0) + draw.ellipse( + (start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius), + fill=color, + ) + return + + ux = dx / length + uy = dy / length + head_length = max(10.0, width * 4.0) + head_width = max(8.0, width * 3.0) + shaft_end = ( + end[0] - ux * head_length, + end[1] - uy * head_length, + ) + + draw.line((start, shaft_end), fill=color, width=width) + + px = -uy + py = ux + left = ( + shaft_end[0] + px * head_width / 2.0, + shaft_end[1] + py * head_width / 2.0, + ) + right = ( + shaft_end[0] - px * head_width / 2.0, + shaft_end[1] - py * head_width / 2.0, + ) + draw.polygon([end, left, right], fill=color) + + +def _render_markup_image(image, shapes): + from PIL import Image as PILImage, ImageDraw + from backend.data_types import image_to_uint8 + + base = image_to_uint8(image) + if base.ndim == 2: + base = np.repeat(base[:, :, np.newaxis], 3, axis=2) + + canvas = PILImage.fromarray(base.copy()) + draw = ImageDraw.Draw(canvas) + height, width = base.shape[:2] + + for shape in shapes: + x1 = float(shape["x1"]) * width + y1 = float(shape["y1"]) * height + x2 = float(shape["x2"]) * width + y2 = float(shape["y2"]) * height + color = str(shape["color"]) + stroke_width = int(shape["width"]) + kind = str(shape["kind"]) + + if kind == "line": + draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width) + elif kind == "rectangle": + draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width) + elif kind == "circle": + draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width) + elif kind == "arrow": + _draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width) + + return np.asarray(canvas, dtype=np.uint8) + + +# --------------------------------------------------------------------------- +# Mask helpers (from mask.py — used by multiple mask nodes) +# --------------------------------------------------------------------------- + +def _mask_overlay(field, mask): + from backend.data_types import datafield_to_uint8 + grey = datafield_to_uint8(field, "gray") + mask_bool = mask > 127 + if not np.any(mask_bool): + return grey + + overlay = grey.copy() + red = overlay[..., 0] + green = overlay[..., 1] + blue = overlay[..., 2] + + red_vals = red[mask_bool].astype(np.uint16) + green_vals = green[mask_bool].astype(np.uint16) + blue_vals = blue[mask_bool].astype(np.uint16) + red[mask_bool] = ((red_vals * 55) + (255 * 45) + 50) // 100 + green[mask_bool] = ((green_vals * 55) + 50) // 100 + blue[mask_bool] = ((blue_vals * 55) + 50) // 100 + return overlay + + +@lru_cache(maxsize=128) +def _mask_structure(radius: int, shape: str): + radius = max(1, int(radius)) + if shape == "disk": + y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1] + struct = (x * x + y * y) <= radius * radius + else: + size = 2 * radius + 1 + struct = np.ones((size, size), dtype=bool) + struct.setflags(write=False) + return struct + + +def _clamp_fraction(value) -> float: + try: + numeric = float(value) + except (TypeError, ValueError): + return 0.0 + return max(0.0, min(1.0, numeric)) + + +def _parse_mask_strokes(mask_paths) -> list[dict]: + if isinstance(mask_paths, list): + raw_strokes = mask_paths + elif isinstance(mask_paths, str) and mask_paths.strip(): + try: + parsed = json.loads(mask_paths) + except json.JSONDecodeError: + return [] + raw_strokes = parsed if isinstance(parsed, list) else [] + else: + return [] + + strokes = [] + for stroke in raw_strokes: + if not isinstance(stroke, dict): + continue + raw_points = stroke.get("points") + if not isinstance(raw_points, list): + continue + + points = [] + for point in raw_points: + if not isinstance(point, dict): + continue + if "x" not in point or "y" not in point: + continue + points.append({ + "x": _clamp_fraction(point.get("x")), + "y": _clamp_fraction(point.get("y")), + }) + + if not points: + continue + + try: + size = max(1, int(round(float(stroke.get("size", 1))))) + except (TypeError, ValueError): + size = 1 + + strokes.append({ + "size": size, + "points": points, + }) + + return strokes + + +def _rasterize_mask(width, height, strokes, default_pen_size): + from PIL import Image as PILImage, ImageDraw + + width = max(1, int(width)) + height = max(1, int(height)) + default_pen_size = max(1, int(default_pen_size)) + + mask_image = PILImage.new("L", (width, height), 0) + draw = ImageDraw.Draw(mask_image) + + for stroke in strokes: + points = stroke.get("points") or [] + if not points: + continue + + size = stroke.get("size", default_pen_size) + try: + size = max(1, int(round(float(size)))) + except (TypeError, ValueError): + size = default_pen_size + + pixel_points = [] + for point in points: + px = int(round(_clamp_fraction(point.get("x")) * (width - 1))) + py = int(round(_clamp_fraction(point.get("y")) * (height - 1))) + pixel_points.append((px, py)) + + radius = max(0.5, size / 2.0) + + if len(pixel_points) == 1: + x, y = pixel_points[0] + draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255) + continue + + draw.line(pixel_points, fill=255, width=size) + for x, y in pixel_points: + draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255) + + return np.asarray(mask_image, dtype=np.uint8) + + +# --------------------------------------------------------------------------- +# Path / directory helpers (from io.py) +# --------------------------------------------------------------------------- + +DEMO_DIR = demo_dir() +INPUT_DIR = input_dir() +OUTPUT_DIR = output_dir() + +_MAX_SAVE_FIELDS = 8 + +_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz", + ".gwy", ".sxm", ".ibw"} + +_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"} +_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"} +_ARRAY_EXTENSIONS = {".npy", ".npz"} +_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS + + +def _resolve_path(filepath: str): + path = Path(filepath) + if path.is_absolute(): + return path + candidate = INPUT_DIR / filepath + if candidate.exists(): + return candidate + candidate = DEMO_DIR / filepath + if candidate.exists(): + return candidate + return INPUT_DIR / filepath + + +def list_channels(filepath: str) -> list[dict]: + path = _resolve_path(filepath) + if not path.exists(): + return [{"name": "field", "type": "DATA_FIELD"}] + + ext = path.suffix.lower() + + if ext == ".gwy": + try: + import gwyfile + obj = gwyfile.load(str(path)) + channels = gwyfile.util.get_datafields(obj) + if channels: + return [{"name": k, "type": "DATA_FIELD"} for k in channels] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + if ext == ".sxm": + try: + import nanonispy as nap + sxm = nap.read.Scan(str(path)) + if sxm.signals: + return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + if ext == ".ibw": + try: + from igor.binarywave import load as load_ibw + wave = load_ibw(str(path)) + raw = wave["wave"]["wData"] + labels = wave["wave"].get("labels", None) + if raw.ndim >= 3 and labels: + dim_idx = min(2, len(labels) - 1) + if dim_idx >= 0 and labels[dim_idx]: + decoded = [] + for lbl in labels[dim_idx]: + if lbl: + name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip() + if name: + decoded.append(name) + if decoded: + return [{"name": n, "type": "DATA_FIELD"} for n in decoded] + if raw.ndim >= 3 and raw.shape[2] > 1: + return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])] + except Exception: + pass + return [{"name": "field", "type": "DATA_FIELD"}] + + return [{"name": "field", "type": "DATA_FIELD"}] + + +def list_folder_paths(folderpath: str) -> list[dict]: + path = _resolve_path(folderpath) + if not path.exists() or not path.is_dir(): + return [] + + resolved_dir = str(path.resolve()) + results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}] + for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()): + if not entry.is_file() or entry.name.startswith("."): + continue + if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS: + continue + results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())}) + return results + + +def _list_demo_files() -> list[str]: + if not DEMO_DIR.exists(): + return [] + return sorted( + f.name for f in DEMO_DIR.iterdir() + if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS + ) + + +# --------------------------------------------------------------------------- +# Butterworth / FFT helpers (from filters.py — used by FFTFilter1D, FFTFilter2D) +# --------------------------------------------------------------------------- + +def _butterworth_lp(freq, cutoff, order): + with np.errstate(divide="ignore", over="ignore"): + return 1.0 / (1.0 + (freq / cutoff) ** (2 * order)) + + +def _butterworth_hp(freq, cutoff, order): + with np.errstate(divide="ignore", invalid="ignore"): + h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order)) + h = np.where(np.isfinite(h), h, 0.0) + return h + + +def _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order): + freq = np.linspace(0, 1, n // 2 + 1) + + if filter_type == "lowpass": + H = _butterworth_lp(freq, cutoff, order) + elif filter_type == "highpass": + H = _butterworth_hp(freq, cutoff, order) + elif filter_type == "bandpass": + H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order) + elif filter_type == "notch": + bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order) + H = 1.0 - bp + else: + H = np.ones_like(freq) + return H + + +@lru_cache(maxsize=64) +def _cached_1d_transfer(n, filter_type, cutoff, cutoff_high, order): + transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order) + transfer.setflags(write=False) + return transfer + + +@lru_cache(maxsize=32) +def _fft_radius_grid(yres, xres): + fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0 + fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0 + radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0) + np.clip(radius, 0.0, 1.0, out=radius) + radius.setflags(write=False) + return radius + + +@lru_cache(maxsize=128) +def _cached_2d_transfer(yres, xres, filter_type, cutoff, cutoff_high, order): + radius = _fft_radius_grid(yres, xres) + + if filter_type == "lowpass": + transfer = _butterworth_lp(radius, cutoff, order) + elif filter_type == "highpass": + transfer = _butterworth_hp(radius, cutoff, order) + elif filter_type == "bandpass": + transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order) + elif filter_type == "notch": + band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order) + transfer = 1.0 - band + else: + transfer = np.ones_like(radius) + + transfer.setflags(write=False) + return transfer + + +# --------------------------------------------------------------------------- +# Cross-section and stats helpers (from analysis.py) +# --------------------------------------------------------------------------- + +def _extend_to_edges(x1, y1, x2, y2): + dx = x2 - x1 + dy = y2 - y1 + + t_candidates = [] + if abs(dx) > 1e-12: + for bx in (0.0, 1.0): + t = (bx - x1) / dx + y_at_t = y1 + t * dy + if -1e-9 <= y_at_t <= 1.0 + 1e-9: + t_candidates.append(t) + if abs(dy) > 1e-12: + for by in (0.0, 1.0): + t = (by - y1) / dy + x_at_t = x1 + t * dx + if -1e-9 <= x_at_t <= 1.0 + 1e-9: + t_candidates.append(t) + + if len(t_candidates) < 2: + return x1, y1, x2, y2 + + t_min = min(t_candidates) + t_max = max(t_candidates) + + return ( + np.clip(x1 + t_min * dx, 0, 1), + np.clip(y1 + t_min * dy, 0, 1), + np.clip(x1 + t_max * dx, 0, 1), + np.clip(y1 + t_max * dy, 0, 1), + ) + + +def _safe_rq(d): + return float(np.sqrt(np.mean(d * d))) + + +LINE_OPS: dict[str, tuple] = {} + + +def _line_op(name, unit=""): + def decorator(fn): + LINE_OPS[name] = (fn, unit) + return fn + return decorator + + +@_line_op("min") +def _op_min(z): + return float(z.min()) + +@_line_op("max") +def _op_max(z): + return float(z.max()) + +@_line_op("mean") +def _op_mean(z): + return float(z.mean()) + +@_line_op("median") +def _op_median(z): + return float(np.median(z)) + +@_line_op("sum") +def _op_sum(z): + return float(z.sum()) + +@_line_op("range") +def _op_range(z): + return float(z.max() - z.min()) + +@_line_op("length", unit="pts") +def _op_length(z): + return float(len(z)) + +@_line_op("rms") +def _op_rms(z): + return float(np.sqrt(np.mean(z * z))) + +@_line_op("Ra") +def _op_ra(z): + return float(np.mean(np.abs(z - z.mean()))) + +@_line_op("Rq") +def _op_rq(z): + d = z - z.mean() + return _safe_rq(d) + +@_line_op("Rsk") +def _op_rsk(z): + d = z - z.mean() + rq = _safe_rq(d) + return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0 + +@_line_op("Rku") +def _op_rku(z): + d = z - z.mean() + rq = _safe_rq(d) + return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0 + +@_line_op("Rp") +def _op_rp(z): + return float((z - z.mean()).max()) + +@_line_op("Rv") +def _op_rv(z): + return float(-(z - z.mean()).min()) + +@_line_op("Rt") +def _op_rt(z): + d = z - z.mean() + return float(d.max() - d.min()) + +@_line_op("Dq") +def _op_dq(z): + dz = np.diff(z) + return float(np.sqrt(np.mean(dz * dz))) + +@_line_op("Da") +def _op_da(z): + return float(np.mean(np.abs(np.diff(z)))) + + +TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = { + "min": lambda values: float(np.min(values)), + "max": lambda values: float(np.max(values)), + "avg": lambda values: float(np.mean(values)), + "mean": lambda values: float(np.mean(values)), + "median": lambda values: float(np.median(values)), + "sum": lambda values: float(np.sum(values)), + "range": lambda values: float(np.max(values) - np.min(values)), + "std": lambda values: float(np.std(values)), + "variance": lambda values: float(np.var(values)), + "count": lambda values: float(len(values)), +} + +ARRAY_OPS: dict[str, Callable[[np.ndarray], float]] = { + "min": lambda values: float(np.min(values)), + "max": lambda values: float(np.max(values)), + "avg": lambda values: float(np.mean(values)), + "mean": lambda values: float(np.mean(values)), + "median": lambda values: float(np.median(values)), + "sum": lambda values: float(np.sum(values)), + "range": lambda values: float(np.max(values) - np.min(values)), + "std": lambda values: float(np.std(values)), + "variance": lambda values: float(np.var(values)), + "rms": lambda values: float(np.sqrt(np.mean(values * values))), + "count": lambda values: float(values.size), +} + + +def _square_unit(unit: str) -> str: + unit = str(unit or "").strip() + if not unit: + return "" + if any(token in unit for token in ("^", "(", ")", "/", "*", " ")): + return f"({unit})^2" + return f"{unit}^2" + + +def _apply_scalar_unit(base_unit: str, operation: str) -> str: + unit = str(base_unit or "").strip() + if operation == "count": + return "count" + if not unit: + return "" + if operation == "variance": + return _square_unit(unit) + return unit + + +def _common_table_unit(table: list, column: str) -> str: + candidates = [] + seen = set() + unit_key = f"{column}_unit" + + for row in table: + if not isinstance(row, dict): + continue + unit = None + if unit_key in row and isinstance(row.get(unit_key), str): + unit = row.get(unit_key) + elif column == "value" and isinstance(row.get("unit"), str): + unit = row.get("unit") + if unit is None: + continue + unit = unit.strip() + if not unit or unit in seen: + continue + seen.add(unit) + candidates.append(unit) + + if len(candidates) == 1: + return candidates[0] + return "" + + +def extract_numeric_table_values(table: list, column: str) -> list[float]: + values = [] + for row in table: + if not isinstance(row, dict) or column not in row: + continue + value = row[column] + if isinstance(value, bool): + continue + try: + numeric = float(value) + except (TypeError, ValueError): + continue + if np.isfinite(numeric): + values.append(numeric) + return values + + +def resolve_table_column_name(table: list, column: str) -> str: + requested = str(column or "").strip() + if requested: + return requested + + if extract_numeric_table_values(table, "value"): + return "value" + + numeric_columns = [] + seen = set() + for row in table: + if not isinstance(row, dict): + continue + for key in row.keys(): + if key in seen: + continue + seen.add(key) + if extract_numeric_table_values(table, key): + numeric_columns.append(key) + + if len(numeric_columns) == 1: + return numeric_columns[0] + if not numeric_columns: + raise ValueError("Stats could not find any numeric columns in the input table.") + raise ValueError( + "Stats found multiple numeric columns; set the column name explicitly." + ) diff --git a/backend/nodes/histogram.py b/backend/nodes/histogram.py new file mode 100644 index 0000000..4888aa4 --- /dev/null +++ b/backend/nodes/histogram.py @@ -0,0 +1,100 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, MeasureTable + + +@register_node(display_name="Histogram") +class Histogram: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}), + "y_scale": (["linear", "log"],), + "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + } + } + + RETURN_TYPES = ("MEASURE_TABLE", "COORDPAIR",) + RETURN_NAMES = ("measurements", "marker pair",) + FUNCTION = "process" + + DESCRIPTION = ( + "Compute the height distribution histogram (DH). " + "Use log scale to reveal small peaks next to a dominant background. " + "Outputs marker measurements while showing the histogram interactively in-node. " + "Equivalent to gwy_data_field_dh." + ) + + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process( + self, + field: DataField, + n_bins: int, + y_scale: str = "linear", + x1: float = 0.25, + y1: float = 0.5, + x2: float = 0.75, + y2: float = 0.5, + ) -> tuple: + raw_counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins)) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + counts = raw_counts.astype(np.float64) + if y_scale == "log": + counts = np.log10(1.0 + counts) + + x1 = float(np.clip(x1, 0.0, 1.0)) + x2 = float(np.clip(x2, 0.0, 0.0 + 1.0)) + + xmin = float(np.min(bin_centers)) if len(bin_centers) else 0.0 + xmax = float(np.max(bin_centers)) if len(bin_centers) else 1.0 + + def x_frac_to_idx(frac): + if len(bin_centers) <= 1: + return 0 + if xmax == xmin: + return 0 + target_x = xmin + frac * (xmax - xmin) + return int(np.argmin(np.abs(bin_centers - target_x))) + + idx_a = x_frac_to_idx(x1) + idx_b = x_frac_to_idx(x2) + xa = float(bin_centers[idx_a]) if len(bin_centers) else 0.0 + xb = float(bin_centers[idx_b]) if len(bin_centers) else 0.0 + ya = float(counts[idx_a]) if len(counts) else 0.0 + yb = float(counts[idx_b]) if len(counts) else 0.0 + count_unit = "count" if y_scale == "linear" else "log10(1+count)" + + if Histogram._broadcast_overlay_fn is not None: + Histogram._broadcast_overlay_fn( + Histogram._current_node_id, + { + "kind": "line_plot", + "section_title": "Histogram", + "line": counts.tolist(), + "x_axis": bin_centers.astype(np.float64).tolist(), + "x1": float(np.clip(x1, 0.0, 1.0)), + "x2": float(np.clip(x2, 0.0, 1.0)), + "y1": float(y1), + "y2": float(y2), + "a_locked": False, + "b_locked": False, + }, + ) + + table = MeasureTable([ + {"quantity": "A position", "value": xa, "unit": field.si_unit_z}, + {"quantity": "A count", "value": ya, "unit": count_unit}, + {"quantity": "B position", "value": xb, "unit": field.si_unit_z}, + {"quantity": "B count", "value": yb, "unit": count_unit}, + {"quantity": "delta X", "value": xb - xa, "unit": field.si_unit_z}, + {"quantity": "delta Y", "value": yb - ya, "unit": count_unit}, + ]) + return (table, ((x1, y1), (x2, y2))) diff --git a/backend/nodes/image.py b/backend/nodes/image.py new file mode 100644 index 0000000..a1d7b95 --- /dev/null +++ b/backend/nodes/image.py @@ -0,0 +1,215 @@ +from __future__ import annotations +import numpy as np +from pathlib import Path + +from backend.node_registry import register_node +from backend.data_types import COLORMAPS, DataField, resolve_colormap_input +from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS + + +@register_node(display_name="Image") +class Image: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}), + "colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + "path": ("FILE_PATH", {"label": "path"}), + }, + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("field",) + FUNCTION = "load" + + DESCRIPTION = ( + "Load any supported file. " + "SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; " + "each channel gets its own output. " + "Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields." + ) + + _broadcast_warning_fn = None + _current_node_id = None + + def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None): + selected_path = str(path).strip() if path is not None else str(filename).strip() + if not selected_path: + raise ValueError("No file selected — use Browse to pick a file.") + path_obj = _resolve_path(selected_path) + if not path_obj.exists(): + raise FileNotFoundError(f"File not found: {path_obj}") + if path_obj.is_dir(): + raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}") + + ext = path_obj.suffix.lower() + resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis") + + if ext in _SPM_EXTENSIONS: + fields = self._load_spm_all(path_obj, ext) + for f in fields: + f.colormap = resolved_colormap + return tuple(fields) + + field = self._load_image_or_array(path_obj, ext) + field.colormap = resolved_colormap + self._send_warning("Uncalibrated data — no physical dimensions.") + return (field,) + + def _send_warning(self, message: str): + fn = Image._broadcast_warning_fn + nid = Image._current_node_id + if fn and nid: + fn(nid, message) + + def _load_spm_all(self, path: Path, ext: str) -> list[DataField]: + if ext == ".gwy": + return self._load_gwy_all(path) + elif ext == ".sxm": + return self._load_sxm_all(path) + elif ext == ".ibw": + return self._load_ibw_all(path) + else: + raise ValueError(f"Unsupported SPM format: {ext}") + + def _load_gwy_all(self, path: Path) -> list[DataField]: + try: + import gwyfile + except ImportError: + raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile") + + obj = gwyfile.load(str(path)) + channels = gwyfile.util.get_datafields(obj) + if not channels: + raise ValueError(f"No data channels found in {path.name}") + + fields = [] + for ch in channels.values(): + data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) + fields.append(DataField( + data=data, + xreal=float(ch.xreal), + yreal=float(ch.yreal), + xoff=float(getattr(ch, "xoff", 0.0)), + yoff=float(getattr(ch, "yoff", 0.0)), + si_unit_xy="m", + si_unit_z="m", + )) + return fields + + def _load_sxm_all(self, path: Path) -> list[DataField]: + try: + import nanonispy as nap + except ImportError: + raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy") + + sxm = nap.read.Scan(str(path)) + signals = sxm.signals + if not signals: + raise ValueError(f"No signals found in {path.name}") + + header = sxm.header + scan_range = header.get("scan_range", [1e-6, 1e-6]) + + fields = [] + for sig in signals.values(): + data = sig.get("forward", list(sig.values())[0]) + data = np.asarray(data, dtype=np.float64) + if data.ndim != 2: + data = data.reshape(data.shape[-2], data.shape[-1]) + fields.append(DataField( + data=data, + xreal=float(scan_range[0]), + yreal=float(scan_range[1]), + si_unit_xy="m", + si_unit_z="m", + )) + return fields + + def _load_ibw_all(self, path: Path) -> list[DataField]: + try: + from igor.binarywave import load as load_ibw + except ImportError: + raise ImportError("Install 'igor' package to load .ibw files: pip install igor") + + wave = load_ibw(str(path)) + wdata = wave["wave"] + header = wdata["wave_header"] + raw = wdata["wData"] + + n_channels = raw.shape[2] if raw.ndim >= 3 else 1 + + sfA = header.get("sfA", None) + + def _decode_unit(raw_unit): + if raw_unit is None: + return "m" + if isinstance(raw_unit, bytes): + return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" + if isinstance(raw_unit, np.ndarray): + return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" + return str(raw_unit).strip() or "m" + + dim_units_raw = header.get("dimUnits", None) + data_units_raw = header.get("dataUnits", None) + + if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2: + si_unit_xy = _decode_unit(dim_units_raw[0]) + elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0: + si_unit_xy = _decode_unit(dim_units_raw[0]) + else: + si_unit_xy = _decode_unit(dim_units_raw) + + si_unit_z = _decode_unit(data_units_raw) + + fields = [] + for ch_idx in range(n_channels): + if raw.ndim >= 3: + ch_data = raw[:, :, ch_idx] + elif raw.ndim == 1: + ch_data = raw.reshape(-1, 1) + else: + ch_data = raw + + data = np.flipud(ch_data.T).astype(np.float64) + yres, xres = data.shape + + if sfA is not None and len(sfA) >= 2: + xreal = abs(float(sfA[0]) * xres) or 1e-6 + yreal = abs(float(sfA[1]) * yres) or 1e-6 + else: + hsA = header.get("hsA", 0.0) + xreal = abs(float(hsA) * xres) or 1e-6 + yreal = xreal * (yres / xres) if xres else 1e-6 + + fields.append(DataField( + data=data, xreal=xreal, yreal=yreal, + si_unit_xy=si_unit_xy, si_unit_z=si_unit_z, + )) + + return fields + + def _load_image_or_array(self, path: Path, ext: str) -> DataField: + if ext == ".npy": + arr = np.load(str(path)).astype(np.float64) + elif ext == ".npz": + npz = np.load(str(path)) + key = list(npz.files)[0] + arr = npz[key].astype(np.float64) + else: + from PIL import Image as PILImage + img = PILImage.open(str(path)) + arr = np.array(img) + if arr.dtype != np.uint8: + arr = arr.astype(np.float64) + + if arr.ndim == 3: + gray = np.mean(arr.astype(np.float64), axis=2) + else: + gray = arr.astype(np.float64) + + return DataField(data=gray) diff --git a/backend/nodes/image_demo.py b/backend/nodes/image_demo.py new file mode 100644 index 0000000..be7d1d3 --- /dev/null +++ b/backend/nodes/image_demo.py @@ -0,0 +1,37 @@ +from __future__ import annotations +from backend.node_registry import register_node +from backend.data_types import COLORMAPS +from backend.nodes.helpers import DEMO_DIR, _list_demo_files + + +@register_node(display_name="Image (Demo)") +class ImageDemo: + @classmethod + def INPUT_TYPES(cls): + choices = _list_demo_files() or ["(no demo files found)"] + return { + "required": { + "name": (choices,), + "colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + }, + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("field",) + FUNCTION = "load" + + DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data." + + _broadcast_warning_fn = None + _current_node_id = None + + def load(self, name: str = "", colormap: str = "viridis", colormap_map=None): + from backend.nodes.image import Image + loader = Image() + demo_path = DEMO_DIR / name + if not demo_path.exists(): + raise FileNotFoundError(f"Demo file not found: {name}") + return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map) diff --git a/backend/nodes/inverse_fft_2d.py b/backend/nodes/inverse_fft_2d.py new file mode 100644 index 0000000..2764d95 --- /dev/null +++ b/backend/nodes/inverse_fft_2d.py @@ -0,0 +1,103 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Inverse 2D FFT") +class InverseFFT2D: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "spectrum": ("DATA_FIELD",), + "representation": (["magnitude", "log_magnitude", "psdf"],), + }, + "optional": { + "phase": ("DATA_FIELD",), + }, + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("image",) + FUNCTION = "process" + + DESCRIPTION = ( + "Reconstruct a spatial-domain image from a 2D frequency spectrum. " + "For exact reconstruction, connect magnitude/phase (or log magnitude/phase, " + "or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed." + ) + + def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple: + if spectrum.domain != "frequency": + raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.") + + if phase is not None: + if phase.data.shape != spectrum.data.shape: + raise ValueError("Phase input must have the same shape as the spectrum.") + if phase.domain != "frequency": + raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.") + + amplitude = self._resolve_amplitude(spectrum, representation) + phase_data = phase.data if phase is not None else np.zeros_like(amplitude) + F = amplitude * np.exp(1j * phase_data) + + spatial = np.fft.ifft2(np.fft.ifftshift(F)).real + xreal, yreal = self._recover_spatial_extent(spectrum, representation) + z_unit = self._recover_z_unit(spectrum, representation, phase) + + out_field = DataField( + data=spatial, + xreal=xreal, + yreal=yreal, + si_unit_xy="m", + si_unit_z=z_unit, + domain="spatial", + colormap=spectrum.colormap, + ) + return (out_field,) + + def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray: + data = np.asarray(spectrum.data, dtype=np.float64) + + if representation == "magnitude": + return np.clip(data, 0.0, None) + if representation == "log_magnitude": + return np.expm1(data) + if representation == "psdf": + xreal, yreal = self._recover_spatial_extent(spectrum, representation) + n = spectrum.xres * spectrum.yres + dx = xreal / spectrum.xres + dy = yreal / spectrum.yres + scale = n * 4.0 * np.pi ** 2 / (dx * dy) + return np.sqrt(np.clip(data, 0.0, None) * scale) + + raise ValueError(f"Unsupported spectrum representation: {representation}") + + def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]: + if representation == "psdf": + xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal + yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal + else: + xreal = spectrum.xres / spectrum.xreal + yreal = spectrum.yres / spectrum.yreal + return float(xreal), float(yreal) + + def _recover_z_unit( + self, + spectrum: DataField, + representation: str, + phase: DataField | None, + ) -> str: + if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip(): + return phase.si_unit_z + + if representation != "psdf": + return spectrum.si_unit_z + + unit = str(spectrum.si_unit_z or "").strip() + if unit.startswith("(") and ")^2 m^2" in unit: + return unit.split(")^2 m^2", 1)[0][1:] + if unit.endswith("^2 m^2"): + return unit[:-6].removesuffix("^2").strip() + return "" diff --git a/backend/nodes/io.py b/backend/nodes/io.py deleted file mode 100644 index e7afd46..0000000 --- a/backend/nodes/io.py +++ /dev/null @@ -1,721 +0,0 @@ -""" -I/O nodes: load and save images and SPM data. -""" - -from __future__ import annotations -import os -import re -import numpy as np -from pathlib import Path - -from backend.node_registry import register_node -from backend.data_types import COLORMAPS, DataField, encode_preview, image_to_uint8, resolve_colormap_input -from backend.runtime_paths import demo_dir, input_dir, output_dir - -# Resolved at server startup so nodes know where to look -DEMO_DIR = demo_dir() -INPUT_DIR = input_dir() -OUTPUT_DIR = output_dir() - -_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz", - ".gwy", ".sxm", ".ibw"} - -_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"} -_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"} -_ARRAY_EXTENSIONS = {".npy", ".npz"} -_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS - - -# --------------------------------------------------------------------------- -# Channel listing helper (used by the /channels endpoint) -# --------------------------------------------------------------------------- - -def _resolve_path(filepath: str) -> Path: - path = Path(filepath) - if path.is_absolute(): - return path - # Try input dir first, then demo dir - candidate = INPUT_DIR / filepath - if candidate.exists(): - return candidate - candidate = DEMO_DIR / filepath - if candidate.exists(): - return candidate - # Fall back to input dir (will trigger FileNotFoundError later) - return INPUT_DIR / filepath - - -def list_channels(filepath: str) -> list[dict]: - """Return available channel info for a file. - - Returns a list of {"name": str, "type": "DATA_FIELD"} dicts. - For SPM formats this inspects the file header. - For images / arrays, returns a single unnamed channel. - """ - path = _resolve_path(filepath) - if not path.exists(): - return [{"name": "field", "type": "DATA_FIELD"}] - - ext = path.suffix.lower() - - if ext == ".gwy": - try: - import gwyfile - obj = gwyfile.load(str(path)) - channels = gwyfile.util.get_datafields(obj) - if channels: - return [{"name": k, "type": "DATA_FIELD"} for k in channels] - except Exception: - pass - return [{"name": "field", "type": "DATA_FIELD"}] - - if ext == ".sxm": - try: - import nanonispy as nap - sxm = nap.read.Scan(str(path)) - if sxm.signals: - return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals] - except Exception: - pass - return [{"name": "field", "type": "DATA_FIELD"}] - - if ext == ".ibw": - try: - from igor.binarywave import load as load_ibw - wave = load_ibw(str(path)) - raw = wave["wave"]["wData"] - labels = wave["wave"].get("labels", None) - if raw.ndim >= 3 and labels: - dim_idx = min(2, len(labels) - 1) - if dim_idx >= 0 and labels[dim_idx]: - decoded = [] - for lbl in labels[dim_idx]: - if lbl: - name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip() - if name: - decoded.append(name) - if decoded: - return [{"name": n, "type": "DATA_FIELD"} for n in decoded] - # Multi-channel without labels — use numeric names - if raw.ndim >= 3 and raw.shape[2] > 1: - return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])] - except Exception: - pass - return [{"name": "field", "type": "DATA_FIELD"}] - - # Image or array — single channel - return [{"name": "field", "type": "DATA_FIELD"}] - - -def list_folder_paths(folderpath: str) -> list[dict]: - """Return a folder DIRECTORY plus compatible image/array/SPM FILE_PATH outputs.""" - path = _resolve_path(folderpath) - if not path.exists() or not path.is_dir(): - return [] - - resolved_dir = str(path.resolve()) - results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}] - for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()): - if not entry.is_file() or entry.name.startswith("."): - continue - if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS: - continue - results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())}) - return results - - -# --------------------------------------------------------------------------- -# Image (unified loader — replaces LoadImage + LoadSPM) -# --------------------------------------------------------------------------- - -@register_node(display_name="Image") -class Image: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}), - "colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), - }, - "optional": { - "colormap_map": ("COLORMAP", {"label": "colormap"}), - "path": ("FILE_PATH", {"label": "path"}), - }, - } - - # Default outputs — overridden dynamically by the frontend for multi-channel files - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("field",) - FUNCTION = "load" - - DESCRIPTION = ( - "Load any supported file. " - "SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; " - "each channel gets its own output. " - "Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields." - ) - - # Set by execution engine for warning broadcast - _broadcast_warning_fn = None - _current_node_id = None - - def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None): - selected_path = str(path).strip() if path is not None else str(filename).strip() - if not selected_path: - raise ValueError("No file selected — use Browse to pick a file.") - path_obj = _resolve_path(selected_path) - if not path_obj.exists(): - raise FileNotFoundError(f"File not found: {path_obj}") - if path_obj.is_dir(): - raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}") - - ext = path_obj.suffix.lower() - resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis") - - if ext in _SPM_EXTENSIONS: - fields = self._load_spm_all(path_obj, ext) - for f in fields: - f.colormap = resolved_colormap - return tuple(fields) - - # Image or array — uncalibrated, single output - field = self._load_image_or_array(path_obj, ext) - field.colormap = resolved_colormap - self._send_warning("Uncalibrated data — no physical dimensions.") - return (field,) - - def _send_warning(self, message: str): - fn = Image._broadcast_warning_fn - nid = Image._current_node_id - if fn and nid: - fn(nid, message) - - # -- SPM: load all channels --------------------------------------------- - - def _load_spm_all(self, path: Path, ext: str) -> list[DataField]: - if ext == ".gwy": - return self._load_gwy_all(path) - elif ext == ".sxm": - return self._load_sxm_all(path) - elif ext == ".ibw": - return self._load_ibw_all(path) - else: - raise ValueError(f"Unsupported SPM format: {ext}") - - # -- GWY ---------------------------------------------------------------- - - def _load_gwy_all(self, path: Path) -> list[DataField]: - try: - import gwyfile - except ImportError: - raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile") - - obj = gwyfile.load(str(path)) - channels = gwyfile.util.get_datafields(obj) - if not channels: - raise ValueError(f"No data channels found in {path.name}") - - fields = [] - for ch in channels.values(): - data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) - fields.append(DataField( - data=data, - xreal=float(ch.xreal), - yreal=float(ch.yreal), - xoff=float(getattr(ch, "xoff", 0.0)), - yoff=float(getattr(ch, "yoff", 0.0)), - si_unit_xy="m", - si_unit_z="m", - )) - return fields - - # -- SXM ---------------------------------------------------------------- - - def _load_sxm_all(self, path: Path) -> list[DataField]: - try: - import nanonispy as nap - except ImportError: - raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy") - - sxm = nap.read.Scan(str(path)) - signals = sxm.signals - if not signals: - raise ValueError(f"No signals found in {path.name}") - - header = sxm.header - scan_range = header.get("scan_range", [1e-6, 1e-6]) - - fields = [] - for sig in signals.values(): - data = sig.get("forward", list(sig.values())[0]) - data = np.asarray(data, dtype=np.float64) - if data.ndim != 2: - data = data.reshape(data.shape[-2], data.shape[-1]) - fields.append(DataField( - data=data, - xreal=float(scan_range[0]), - yreal=float(scan_range[1]), - si_unit_xy="m", - si_unit_z="m", - )) - return fields - - # -- IBW ---------------------------------------------------------------- - - def _load_ibw_all(self, path: Path) -> list[DataField]: - try: - from igor.binarywave import load as load_ibw - except ImportError: - raise ImportError("Install 'igor' package to load .ibw files: pip install igor") - - wave = load_ibw(str(path)) - wdata = wave["wave"] - header = wdata["wave_header"] - raw = wdata["wData"] - - n_channels = raw.shape[2] if raw.ndim >= 3 else 1 - - # Physical scaling - sfA = header.get("sfA", None) - - def _decode_unit(raw_unit): - if raw_unit is None: - return "m" - if isinstance(raw_unit, bytes): - return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" - if isinstance(raw_unit, np.ndarray): - return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" - return str(raw_unit).strip() or "m" - - dim_units_raw = header.get("dimUnits", None) - data_units_raw = header.get("dataUnits", None) - - if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2: - si_unit_xy = _decode_unit(dim_units_raw[0]) - elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0: - si_unit_xy = _decode_unit(dim_units_raw[0]) - else: - si_unit_xy = _decode_unit(dim_units_raw) - - si_unit_z = _decode_unit(data_units_raw) - - fields = [] - for ch_idx in range(n_channels): - if raw.ndim >= 3: - ch_data = raw[:, :, ch_idx] - elif raw.ndim == 1: - ch_data = raw.reshape(-1, 1) - else: - ch_data = raw - - # Transpose from (xres, yres) Igor order to (yres, xres) DataField order, - # then flip vertically to match gwyddion - data = np.flipud(ch_data.T).astype(np.float64) - yres, xres = data.shape - - if sfA is not None and len(sfA) >= 2: - xreal = abs(float(sfA[0]) * xres) or 1e-6 - yreal = abs(float(sfA[1]) * yres) or 1e-6 - else: - hsA = header.get("hsA", 0.0) - xreal = abs(float(hsA) * xres) or 1e-6 - yreal = xreal * (yres / xres) if xres else 1e-6 - - fields.append(DataField( - data=data, xreal=xreal, yreal=yreal, - si_unit_xy=si_unit_xy, si_unit_z=si_unit_z, - )) - - return fields - - # -- Image / array (uncalibrated) -------------------------------------- - - def _load_image_or_array(self, path: Path, ext: str) -> DataField: - if ext == ".npy": - arr = np.load(str(path)).astype(np.float64) - elif ext == ".npz": - npz = np.load(str(path)) - key = list(npz.files)[0] - arr = npz[key].astype(np.float64) - else: - from PIL import Image - img = Image.open(str(path)) - arr = np.array(img) - if arr.dtype != np.uint8: - arr = arr.astype(np.float64) - - if arr.ndim == 3: - gray = np.mean(arr.astype(np.float64), axis=2) - else: - gray = arr.astype(np.float64) - - return DataField(data=gray) - - -# --------------------------------------------------------------------------- -# ImageDemo -# --------------------------------------------------------------------------- - -def _list_demo_files() -> list[str]: - """Return sorted list of demo filenames available in the demo/ directory.""" - if not DEMO_DIR.exists(): - return [] - return sorted( - f.name for f in DEMO_DIR.iterdir() - if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS - ) - - -@register_node(display_name="Image (Demo)") -class ImageDemo: - @classmethod - def INPUT_TYPES(cls): - choices = _list_demo_files() or ["(no demo files found)"] - return { - "required": { - "name": (choices,), - "colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), - }, - "optional": { - "colormap_map": ("COLORMAP", {"label": "colormap"}), - }, - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("field",) - FUNCTION = "load" - - DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data." - - def load(self, name: str = "", colormap: str = "viridis", colormap_map=None): - loader = Image() - demo_path = DEMO_DIR / name - if not demo_path.exists(): - raise FileNotFoundError(f"Demo file not found: {name}") - return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map) - - -@register_node(display_name="Folder") -class Folder: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}), - } - } - - RETURN_TYPES = ("DIRECTORY",) - RETURN_NAMES = ("directory",) - FUNCTION = "list_files" - - DESCRIPTION = ( - "Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. " - "Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans." - ) - - def list_files(self, folder: str) -> tuple: - entries = list_folder_paths(folder) - if not entries: - return tuple() - return tuple(item["path"] for item in entries) - - -# --------------------------------------------------------------------------- -# Coordinate -# --------------------------------------------------------------------------- - -@register_node(display_name="Coordinate") -class Coordinate: - """Provide a fractional (x, y) point for use with Cross Section or other nodes.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - "y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - - RETURN_TYPES = ("COORD",) - RETURN_NAMES = ("point",) - FUNCTION = "process" - - DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]." - - def process(self, x: float, y: float) -> tuple: - return ((float(x), float(y)),) - - -@register_node(display_name="Coordinate Pair") -class CoordinatePair: - """Provide a pair of Coordinates, for drawing lines between markers, etc.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "a": ("COORD",), - "b": ("COORD",), - } - } - - RETURN_TYPES = ("COORDPAIR",) - RETURN_NAMES = ("coord pair",) - FUNCTION = "process" - - DESCRIPTION = "Output a pair of coordinates." - - def process(self, a: tuple, b: tuple) -> tuple: - return ((a, b),) - - -# --------------------------------------------------------------------------- -# Number -# --------------------------------------------------------------------------- - -@register_node(display_name="Number") -class Number: - """Provide a fixed scalar value that can feed FLOAT or INT widget sockets.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "value": ("FLOAT", {"default": 0.0, "step": 0.01}), - } - } - - RETURN_TYPES = ("FLOAT",) - RETURN_NAMES = ("value",) - FUNCTION = "process" - - DESCRIPTION = ( - "Output a fixed numeric value. " - "When connected to FLOAT inputs the exact value is used; " - "INT inputs round to the nearest integer at execution time." - ) - - def process(self, value: float) -> tuple: - return (float(value),) - - -# --------------------------------------------------------------------------- -# RangeSlider -# --------------------------------------------------------------------------- - -@register_node(display_name="Float Slider") -class RangeSlider: - """Interactive float control node with min/max bounds and a slider value.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "min_value": ("FLOAT", {"default": 0.0, "step": 0.01}), - "max_value": ("FLOAT", {"default": 1.0, "step": 0.01}), - "value": ("FLOAT", { - "default": 0.5, - "step": 0.01, - "slider": True, - "min_widget": "min_value", - "max_widget": "max_value", - }), - } - } - - RETURN_TYPES = ("FLOAT",) - RETURN_NAMES = ("value",) - FUNCTION = "process" - - DESCRIPTION = ( - "Interactive float slider. Set min and max bounds, then drag the slider to output a FLOAT value." - ) - - def process(self, min_value: float, max_value: float, value: float) -> tuple: - lo = min(float(min_value), float(max_value)) - hi = max(float(min_value), float(max_value)) - if hi == lo: - return (lo,) - return (float(np.clip(float(value), lo, hi)),) - - -# --------------------------------------------------------------------------- -# SaveImage -# --------------------------------------------------------------------------- - -_MAX_SAVE_FIELDS = 8 - -@register_node(display_name="Save Layers") -class SaveImage: - @classmethod - def INPUT_TYPES(cls): - optional = { - "directory": ("DIRECTORY", {"label": "directory"}), - } - for i in range(_MAX_SAVE_FIELDS): - optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"}) - optional[f"layer_name_{i}"] = ("STRING", { - "default": "", - "placeholder": "name", - "show_when_input_visible": f"field_{i}", - "inline_with_input": f"field_{i}", - "hide_label": True, - }) - return { - "required": { - "filename": ("STRING", { - "default": "", - "placeholder": "filename", - "placement": "top", - }), - "directory_path": ("FOLDER_PICKER", { - "default": "", - "label": "directory", - "placement": "top", - "hide_when_input_connected": "directory", - "top_socket_input": "directory", - }), - "format": (["TIFF", "NPZ"],), - }, - "optional": optional, - } - - RETURN_TYPES = () - FUNCTION = "save" - - OUTPUT_NODE = True - MANUAL_TRIGGER = True - DESCRIPTION = ( - "Save one or more layers to a single file. " - "Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. " - "Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. " - "A new slot appears as each one is filled, with a matching per-layer name field. " - "TIFF writes multi-page data and stores layer names as page descriptions; " - "NPZ writes named arrays using those layer names as keys. " - "Click Save to write (does not auto-run)." - ) - - _broadcast_warning_fn = None - _current_node_id = None - - def save( - self, - filename: str, - directory_path: str = "", - format: str = "TIFF", - directory: str | None = None, - **kwargs, - ): - layers = [] - layer_names = [] - for i in range(_MAX_SAVE_FIELDS): - layer = kwargs.get(f"field_{i}") - if layer is not None: - layers.append(layer) - layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i)) - - if not layers: - raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.") - - path = self._resolve_save_path(filename, format, directory, directory_path) - - if format == "TIFF": - self._save_tiff(path, layers, layer_names) - else: - self._save_npz(path, layers, layer_names) - - self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}") - return () - - def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]): - import tifffile - - with tifffile.TiffWriter(str(path)) as tif: - for layer, layer_name in zip(layers, layer_names): - tif.write(self._layer_array_for_tiff(layer), description=layer_name) - - def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]): - arrays = {} - used_keys = set() - for i, (layer, layer_name) in enumerate(zip(layers, layer_names)): - arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer) - np.savez(str(path), **arrays) - - def _resolve_layer_name(self, raw_name: object, index: int) -> str: - text = str(raw_name).strip() if raw_name is not None else "" - return text or f"layer_{index}" - - def _resolve_save_path( - self, - filename: str, - format: str, - directory: str | None, - directory_path: str = "", - ) -> Path: - ext = ".tiff" if format == "TIFF" else ".npz" - raw_filename = str(filename).strip() if filename is not None else "" - raw_directory = str(directory).strip() if directory is not None else "" - if not raw_directory: - raw_directory = str(directory_path).strip() if directory_path is not None else "" - - if raw_directory: - dir_path = Path(raw_directory).expanduser() - if dir_path.exists() and not dir_path.is_dir(): - raise ValueError("Directory input expects a folder path, not a file path.") - if not dir_path.exists(): - if dir_path.suffix: - raise ValueError("Directory input expects a folder path, not a file path.") - dir_path.mkdir(parents=True, exist_ok=True) - - filename_part = Path(raw_filename).name if raw_filename else "" - if not filename_part: - raise ValueError("No output filename selected — enter a file name when using a directory input.") - path = dir_path / filename_part - else: - if not raw_filename: - raise ValueError("No output path selected — use Browse to pick a location.") - path = Path(raw_filename).expanduser() - path.parent.mkdir(parents=True, exist_ok=True) - - if path.suffix.lower() != ext: - path = path.with_suffix(ext) - return path - - def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str: - key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_") - if not key: - key = f"layer_{index}" - if key[0].isdigit(): - key = f"layer_{key}" - - candidate = key - suffix = 2 - while candidate in used_keys: - candidate = f"{key}_{suffix}" - suffix += 1 - used_keys.add(candidate) - return candidate - - def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray: - if isinstance(layer, DataField): - return np.asarray(layer.data, dtype=np.float32) - if isinstance(layer, np.ndarray): - return image_to_uint8(layer) - raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") - - def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray: - if isinstance(layer, DataField): - return np.asarray(layer.data) - if isinstance(layer, np.ndarray): - return np.asarray(layer) - raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") - - def _send_warning(self, message: str): - fn = SaveImage._broadcast_warning_fn - nid = SaveImage._current_node_id - if fn and nid: - fn(nid, message) - - return () diff --git a/backend/nodes/level.py b/backend/nodes/level.py deleted file mode 100644 index f254289..0000000 --- a/backend/nodes/level.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Leveling nodes — background removal and zero correction. - -Gwyddion equivalents: - PlaneLevelField → gwy_data_field_fit_plane + gwy_data_field_plane_level - PolyLevelField → gwy_data_field_fit_polynom (via level.c polylevel module) - FixZero → fix_zero in level.c - -Plane-fit algorithm follows Gwyddion's level.h definition: - z_fit = pa + pbx * x + pby * y (least-squares over all pixels) -""" - -from __future__ import annotations -import numpy as np -from backend.node_registry import register_node -from backend.data_types import DataField - - -# --------------------------------------------------------------------------- -# PlaneLevelField -# --------------------------------------------------------------------------- - -@register_node(display_name="Plane Level") -class PlaneLevelField: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("leveled",) - FUNCTION = "process" - - DESCRIPTION = ( - "Fit and subtract a least-squares plane from the data. " - "Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level." - ) - - def process(self, field: DataField) -> tuple: - data = field.data.copy() - yres, xres = data.shape - - # Normalised coordinate grids in [0, 1] - x = np.linspace(0.0, 1.0, xres) - y = np.linspace(0.0, 1.0, yres) - xx, yy = np.meshgrid(x, y) - - # Design matrix: [1, x, y] shape (N, 3) - A = np.column_stack([ - np.ones(xres * yres), - xx.ravel(), - yy.ravel(), - ]) - z = data.ravel() - - # Least-squares: solve A @ [pa, pbx, pby] = z - coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None) - pa, pbx, pby = coeffs - - plane = (pa + pbx * xx + pby * yy) - return (field.replace(data=data - plane),) - - -# --------------------------------------------------------------------------- -# PolyLevelField -# --------------------------------------------------------------------------- - -@register_node(display_name="Polynomial Level") -class PolyLevelField: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}), - "degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}), - } - } - - RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD") - RETURN_NAMES = ("leveled", "background") - FUNCTION = "process" - - DESCRIPTION = ( - "Fit and subtract a polynomial background of given degree in x and y. " - "Equivalent to gwy_data_field_fit_polynom." - ) - - def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple: - data = field.data.copy() - yres, xres = data.shape - - x = np.linspace(0.0, 1.0, xres) - y = np.linspace(0.0, 1.0, yres) - xx, yy = np.meshgrid(x, y) - - # Build Vandermonde-style design matrix with all monomials x^i * y^j - cols = [] - for i in range(degree_x + 1): - for j in range(degree_y + 1): - cols.append((xx ** i * yy ** j).ravel()) - A = np.column_stack(cols) - z = data.ravel() - - coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None) - - background = (A @ coeffs).reshape(yres, xres) - leveled = data - background - - return (field.replace(data=leveled), field.replace(data=background)) - - -# --------------------------------------------------------------------------- -# FixZero -# --------------------------------------------------------------------------- - -@register_node(display_name="Fix Zero") -class FixZero: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "method": (["min", "mean", "median"],), - } - } - - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("zeroed",) - FUNCTION = "process" - - DESCRIPTION = ( - "Shift data so that the minimum (or mean/median) is zero. " - "Equivalent to fix_zero in Gwyddion's level.c." - ) - - def process(self, field: DataField, method: str) -> tuple: - data = field.data.copy() - if method == "min": - data -= data.min() - elif method == "mean": - data -= data.mean() - elif method == "median": - data -= np.median(data) - else: - raise ValueError(f"Unknown method: {method}") - return (field.replace(data=data),) diff --git a/backend/nodes/markup.py b/backend/nodes/markup.py new file mode 100644 index 0000000..c9e6785 --- /dev/null +++ b/backend/nodes/markup.py @@ -0,0 +1,68 @@ +from __future__ import annotations +from backend.node_registry import register_node +from backend.data_types import DataField, datafield_to_uint8, encode_preview +from backend.nodes.helpers import _parse_markup_shapes, _normalize_markup_color + + +@register_node(display_name="Markup") +class Markup: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}), + "stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}), + "stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}), + "clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}), + "markup_shapes": ("STRING", {"default": "[]", "hidden": True}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("annotated",) + FUNCTION = "process" + + DESCRIPTION = ( + "Draw simple vector markup over a DATA_FIELD without flattening the underlying data. " + "Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows." + ) + + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process( + self, + field: DataField, + shape: str, + stroke_color: str, + stroke_width: int, + markup_shapes: str, + ) -> tuple: + shapes = _parse_markup_shapes(markup_shapes) + out = field.replace( + overlays=[ + *field.overlays, + { + "kind": "markup", + "shapes": shapes, + }, + ], + ) + + if Markup._broadcast_overlay_fn is not None: + Markup._broadcast_overlay_fn( + Markup._current_node_id, + { + "kind": "markup", + "section_title": "Markup", + "image": encode_preview(datafield_to_uint8(field, field.colormap)), + "shape": str(shape), + "stroke_color": _normalize_markup_color(stroke_color), + "stroke_width": max(1, int(stroke_width)), + }, + ) + + return (out,) diff --git a/backend/nodes/mask.py b/backend/nodes/mask.py deleted file mode 100644 index e92ef63..0000000 --- a/backend/nodes/mask.py +++ /dev/null @@ -1,437 +0,0 @@ -""" -Mask operation nodes — creation, morphology, and boolean combination. - -Gwyddion equivalents: - ThresholdMask → threshold.c / otsu_threshold.c - MaskMorphology → mask_morph.c (erode, dilate, open, close) - MaskInvert → (bitwise NOT on mask) - MaskCombine → (boolean ops between two masks) -""" - -from __future__ import annotations -from functools import lru_cache -import json -import numpy as np -from backend.node_registry import register_node -from backend.data_types import DataField, datafield_to_uint8, encode_preview - - -def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray: - """Render greyscale base image with red shadow on masked (255) pixels. - - Returns (H, W, 3) uint8 array. - """ - grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8 - mask_bool = mask > 127 - if not np.any(mask_bool): - return grey - - overlay = grey.copy() - red = overlay[..., 0] - green = overlay[..., 1] - blue = overlay[..., 2] - - # Integer alpha blend equivalent to a 45% red overlay, without float64 work. - red_vals = red[mask_bool].astype(np.uint16) - green_vals = green[mask_bool].astype(np.uint16) - blue_vals = blue[mask_bool].astype(np.uint16) - red[mask_bool] = ((red_vals * 55) + (255 * 45) + 50) // 100 - green[mask_bool] = ((green_vals * 55) + 50) // 100 - blue[mask_bool] = ((blue_vals * 55) + 50) // 100 - return overlay - - -@lru_cache(maxsize=128) -def _mask_structure(radius: int, shape: str) -> np.ndarray: - radius = max(1, int(radius)) - if shape == "disk": - y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1] - struct = (x * x + y * y) <= radius * radius - else: - size = 2 * radius + 1 - struct = np.ones((size, size), dtype=bool) - struct.setflags(write=False) - return struct - - -def _clamp_fraction(value) -> float: - try: - numeric = float(value) - except (TypeError, ValueError): - return 0.0 - return max(0.0, min(1.0, numeric)) - - -def _parse_mask_strokes(mask_paths) -> list[dict]: - if isinstance(mask_paths, list): - raw_strokes = mask_paths - elif isinstance(mask_paths, str) and mask_paths.strip(): - try: - parsed = json.loads(mask_paths) - except json.JSONDecodeError: - return [] - raw_strokes = parsed if isinstance(parsed, list) else [] - else: - return [] - - strokes = [] - for stroke in raw_strokes: - if not isinstance(stroke, dict): - continue - raw_points = stroke.get("points") - if not isinstance(raw_points, list): - continue - - points = [] - for point in raw_points: - if not isinstance(point, dict): - continue - if "x" not in point or "y" not in point: - continue - points.append({ - "x": _clamp_fraction(point.get("x")), - "y": _clamp_fraction(point.get("y")), - }) - - if not points: - continue - - try: - size = max(1, int(round(float(stroke.get("size", 1))))) - except (TypeError, ValueError): - size = 1 - - strokes.append({ - "size": size, - "points": points, - }) - - return strokes - - -def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray: - from PIL import Image, ImageDraw - - width = max(1, int(width)) - height = max(1, int(height)) - default_pen_size = max(1, int(default_pen_size)) - - mask_image = Image.new("L", (width, height), 0) - draw = ImageDraw.Draw(mask_image) - - for stroke in strokes: - points = stroke.get("points") or [] - if not points: - continue - - size = stroke.get("size", default_pen_size) - try: - size = max(1, int(round(float(size)))) - except (TypeError, ValueError): - size = default_pen_size - - pixel_points = [] - for point in points: - px = int(round(_clamp_fraction(point.get("x")) * (width - 1))) - py = int(round(_clamp_fraction(point.get("y")) * (height - 1))) - pixel_points.append((px, py)) - - radius = max(0.5, size / 2.0) - - if len(pixel_points) == 1: - x, y = pixel_points[0] - draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255) - continue - - draw.line(pixel_points, fill=255, width=size) - for x, y in pixel_points: - draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255) - - return np.asarray(mask_image, dtype=np.uint8) - - -# --------------------------------------------------------------------------- -# DrawMask -# --------------------------------------------------------------------------- - -@register_node(display_name="Draw Mask") -class DrawMask: - _CUSTOM_PREVIEW = True - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}), - "invert": ("BOOLEAN", {"default": False}), - "clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}), - "mask_paths": ("STRING", {"default": "[]", "hidden": True}), - } - } - - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("mask",) - FUNCTION = "process" - - DESCRIPTION = ( - "Paint a binary mask directly over an image preview. " - "Pen size controls newly drawn strokes, the overlay lets you clear the mask, " - "and invert flips the final binary output." - ) - - _broadcast_overlay_fn = None - _current_node_id: str = "" - - def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple: - strokes = _parse_mask_strokes(mask_paths) - mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size) - if invert: - mask = np.where(mask > 127, np.uint8(0), np.uint8(255)) - - if DrawMask._broadcast_overlay_fn is not None: - DrawMask._broadcast_overlay_fn( - DrawMask._current_node_id, - { - "kind": "mask_paint", - "section_title": "Mask", - "image": encode_preview(datafield_to_uint8(field, "gray")), - "image_width": field.xres, - "image_height": field.yres, - "invert": bool(invert), - }, - ) - - return (mask,) - - -# --------------------------------------------------------------------------- -# ThresholdMask -# --------------------------------------------------------------------------- - -@register_node(display_name="Threshold Mask") -class ThresholdMask: - _CUSTOM_PREVIEW = True - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "method": (["otsu", "absolute", "relative"],), - "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}), - "direction": (["above", "below"],), - } - } - - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("mask",) - FUNCTION = "process" - - DESCRIPTION = ( - "Create a binary mask by thresholding data. " - "Otsu automatically finds the optimal threshold. " - "Equivalent to Gwyddion's threshold and otsu_threshold modules." - ) - - _broadcast_fn = None - _current_node_id: str = "" - - def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple: - data = field.data - - if method == "otsu": - from skimage.filters import threshold_otsu - t = threshold_otsu(data) - elif method == "absolute": - t = float(threshold) - elif method == "relative": - # threshold is a fraction [0, 1] of the data range - dmin, dmax = data.min(), data.max() - t = dmin + float(threshold) * (dmax - dmin) - else: - raise ValueError(f"Unknown threshold method: {method}") - - if direction == "above": - mask = (data >= t).astype(np.uint8) * 255 - else: - mask = (data < t).astype(np.uint8) * 255 - - if ThresholdMask._broadcast_fn is not None: - overlay = _mask_overlay(field, mask) - ThresholdMask._broadcast_fn( - ThresholdMask._current_node_id, encode_preview(overlay), - ) - - return (mask,) - - -# --------------------------------------------------------------------------- -# MaskMorphology -# --------------------------------------------------------------------------- - -@register_node(display_name="Mask Morphology") -class MaskMorphology: - """Morphological operations on binary masks. - - Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close). - """ - _CUSTOM_PREVIEW = True - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("IMAGE",), - "operation": (["dilate", "erode", "open", "close"],), - "radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}), - "shape": (["disk", "square"],), - }, - "optional": { - "field": ("DATA_FIELD",), - } - } - - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("mask",) - FUNCTION = "process" - - DESCRIPTION = ( - "Apply morphological operations to a binary mask. " - "Dilate expands regions, erode shrinks them, " - "open (erode then dilate) removes small spots, " - "close (dilate then erode) fills small holes. " - "Equivalent to Gwyddion mask_morph." - ) - - _broadcast_fn = None - _current_node_id: str = "" - - def process(self, mask: np.ndarray, operation: str, radius: int, shape: str, - field: DataField | None = None) -> tuple: - from scipy.ndimage import binary_closing, binary_dilation, binary_erosion, binary_opening - - binary = mask > 127 - struct = _mask_structure(radius, shape) - - if operation == "dilate": - result = binary_dilation(binary, structure=struct) - elif operation == "erode": - result = binary_erosion(binary, structure=struct) - elif operation == "open": - result = binary_opening(binary, structure=struct) - elif operation == "close": - result = binary_closing(binary, structure=struct) - else: - raise ValueError(f"Unknown morphological operation: {operation}") - - out = result.astype(np.uint8) * 255 - - if field is not None and MaskMorphology._broadcast_fn is not None: - overlay = _mask_overlay(field, out) - MaskMorphology._broadcast_fn( - MaskMorphology._current_node_id, encode_preview(overlay), - ) - - return (out,) - - -# --------------------------------------------------------------------------- -# MaskInvert -# --------------------------------------------------------------------------- - -@register_node(display_name="Mask Invert") -class MaskInvert: - _CUSTOM_PREVIEW = True - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("IMAGE",), - }, - "optional": { - "field": ("DATA_FIELD",), - } - } - - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("mask",) - FUNCTION = "process" - - DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions." - - _broadcast_fn = None - _current_node_id: str = "" - - def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple: - out = np.where(mask > 127, np.uint8(0), np.uint8(255)) - - if field is not None and MaskInvert._broadcast_fn is not None: - overlay = _mask_overlay(field, out) - MaskInvert._broadcast_fn( - MaskInvert._current_node_id, encode_preview(overlay), - ) - - return (out,) - - -# --------------------------------------------------------------------------- -# MaskCombine -# --------------------------------------------------------------------------- - -@register_node(display_name="Mask Combine") -class MaskCombine: - _CUSTOM_PREVIEW = True - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask_a": ("IMAGE",), - "mask_b": ("IMAGE",), - "operation": (["and", "or", "xor", "subtract"],), - }, - "optional": { - "field": ("DATA_FIELD",), - } - } - - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("mask",) - FUNCTION = "process" - - DESCRIPTION = ( - "Combine two binary masks with a boolean operation. " - "AND keeps overlap, OR merges, XOR keeps non-overlapping regions, " - "subtract removes mask_b from mask_a." - ) - - _broadcast_fn = None - _current_node_id: str = "" - - def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str, - field: DataField | None = None) -> tuple: - a = mask_a > 127 - b = mask_b > 127 - - if operation == "and": - result = a & b - elif operation == "or": - result = a | b - elif operation == "xor": - result = a ^ b - elif operation == "subtract": - result = a & ~b - else: - raise ValueError(f"Unknown mask operation: {operation}") - - out = result.astype(np.uint8) * 255 - - if field is not None and MaskCombine._broadcast_fn is not None: - overlay = _mask_overlay(field, out) - MaskCombine._broadcast_fn( - MaskCombine._current_node_id, encode_preview(overlay), - ) - - return (out,) diff --git a/backend/nodes/mask_combine.py b/backend/nodes/mask_combine.py new file mode 100644 index 0000000..28d7fae --- /dev/null +++ b/backend/nodes/mask_combine.py @@ -0,0 +1,62 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, encode_preview +from backend.nodes.helpers import _mask_overlay + + +@register_node(display_name="Mask Combine") +class MaskCombine: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask_a": ("IMAGE",), + "mask_b": ("IMAGE",), + "operation": (["and", "or", "xor", "subtract"],), + }, + "optional": { + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + + DESCRIPTION = ( + "Combine two binary masks with a boolean operation. " + "AND keeps overlap, OR merges, XOR keeps non-overlapping regions, " + "subtract removes mask_b from mask_a." + ) + + _broadcast_fn = None + _current_node_id: str = "" + + def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str, + field: DataField | None = None) -> tuple: + a = mask_a > 127 + b = mask_b > 127 + + if operation == "and": + result = a & b + elif operation == "or": + result = a | b + elif operation == "xor": + result = a ^ b + elif operation == "subtract": + result = a & ~b + else: + raise ValueError(f"Unknown mask operation: {operation}") + + out = result.astype(np.uint8) * 255 + + if field is not None and MaskCombine._broadcast_fn is not None: + overlay = _mask_overlay(field, out) + MaskCombine._broadcast_fn( + MaskCombine._current_node_id, encode_preview(overlay), + ) + + return (out,) diff --git a/backend/nodes/mask_invert.py b/backend/nodes/mask_invert.py new file mode 100644 index 0000000..1cf3529 --- /dev/null +++ b/backend/nodes/mask_invert.py @@ -0,0 +1,41 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, encode_preview +from backend.nodes.helpers import _mask_overlay + + +@register_node(display_name="Mask Invert") +class MaskInvert: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("IMAGE",), + }, + "optional": { + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + + DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions." + + _broadcast_fn = None + _current_node_id: str = "" + + def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple: + out = np.where(mask > 127, np.uint8(0), np.uint8(255)) + + if field is not None and MaskInvert._broadcast_fn is not None: + overlay = _mask_overlay(field, out) + MaskInvert._broadcast_fn( + MaskInvert._current_node_id, encode_preview(overlay), + ) + + return (out,) diff --git a/backend/nodes/mask_morphology.py b/backend/nodes/mask_morphology.py new file mode 100644 index 0000000..e9603f0 --- /dev/null +++ b/backend/nodes/mask_morphology.py @@ -0,0 +1,71 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, encode_preview +from backend.nodes.helpers import _mask_overlay, _mask_structure + + +@register_node(display_name="Mask Morphology") +class MaskMorphology: + """Morphological operations on binary masks. + + Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close). + """ + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("IMAGE",), + "operation": (["dilate", "erode", "open", "close"],), + "radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}), + "shape": (["disk", "square"],), + }, + "optional": { + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + + DESCRIPTION = ( + "Apply morphological operations to a binary mask. " + "Dilate expands regions, erode shrinks them, " + "open (erode then dilate) removes small spots, " + "close (dilate then erode) fills small holes. " + "Equivalent to Gwyddion mask_morph." + ) + + _broadcast_fn = None + _current_node_id: str = "" + + def process(self, mask: np.ndarray, operation: str, radius: int, shape: str, + field: DataField | None = None) -> tuple: + from scipy.ndimage import binary_closing, binary_dilation, binary_erosion, binary_opening + + binary = mask > 127 + struct = _mask_structure(radius, shape) + + if operation == "dilate": + result = binary_dilation(binary, structure=struct) + elif operation == "erode": + result = binary_erosion(binary, structure=struct) + elif operation == "open": + result = binary_opening(binary, structure=struct) + elif operation == "close": + result = binary_closing(binary, structure=struct) + else: + raise ValueError(f"Unknown morphological operation: {operation}") + + out = result.astype(np.uint8) * 255 + + if field is not None and MaskMorphology._broadcast_fn is not None: + overlay = _mask_overlay(field, out) + MaskMorphology._broadcast_fn( + MaskMorphology._current_node_id, encode_preview(overlay), + ) + + return (out,) diff --git a/backend/nodes/median_filter.py b/backend/nodes/median_filter.py new file mode 100644 index 0000000..8ae7b3c --- /dev/null +++ b/backend/nodes/median_filter.py @@ -0,0 +1,27 @@ +from __future__ import annotations +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Median Filter") +class MedianFilter: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("filtered",) + FUNCTION = "process" + + DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median." + + def process(self, field: DataField, size: int) -> tuple: + from scipy.ndimage import median_filter + size = max(1, int(size)) + data = median_filter(field.data, size=size) + return (field.replace(data=data),) diff --git a/backend/nodes/number.py b/backend/nodes/number.py new file mode 100644 index 0000000..9f028f4 --- /dev/null +++ b/backend/nodes/number.py @@ -0,0 +1,28 @@ +from __future__ import annotations +from backend.node_registry import register_node + + +@register_node(display_name="Number") +class Number: + """Provide a fixed scalar value that can feed FLOAT or INT widget sockets.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + + DESCRIPTION = ( + "Output a fixed numeric value. " + "When connected to FLOAT inputs the exact value is used; " + "INT inputs round to the nearest integer at execution time." + ) + + def process(self, value: float) -> tuple: + return (float(value),) diff --git a/backend/nodes/particle.py b/backend/nodes/particle_analysis.py similarity index 83% rename from backend/nodes/particle.py rename to backend/nodes/particle_analysis.py index f6c3726..049fb9f 100644 --- a/backend/nodes/particle.py +++ b/backend/nodes/particle_analysis.py @@ -1,20 +1,9 @@ -""" -Particle detection nodes. - -Gwyddion equivalents: - ParticleAnalysis → gwy_data_field_particles_get_values (particles-values.c) -""" - from __future__ import annotations import numpy as np from backend.node_registry import register_node from backend.data_types import DataField, RecordTable -# --------------------------------------------------------------------------- -# ParticleAnalysis -# --------------------------------------------------------------------------- - @register_node(display_name="Particle Analysis") class ParticleAnalysis: @classmethod @@ -43,7 +32,7 @@ class ParticleAnalysis: binary = (mask > 127).astype(np.int32) labeled, n_particles = label(binary) - pixel_area = field.dx * field.dy # m^2 per pixel + pixel_area = field.dx * field.dy rows = RecordTable() for pid in range(1, n_particles + 1): @@ -59,7 +48,6 @@ class ParticleAnalysis: mean_h = float(heights.mean()) max_h = float(heights.max()) - # Bounding box ys, xs = np.where(particle_pixels) bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})" diff --git a/backend/nodes/plane_level_field.py b/backend/nodes/plane_level_field.py new file mode 100644 index 0000000..6bf8f10 --- /dev/null +++ b/backend/nodes/plane_level_field.py @@ -0,0 +1,45 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Plane Level") +class PlaneLevelField: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("leveled",) + FUNCTION = "process" + + DESCRIPTION = ( + "Fit and subtract a least-squares plane from the data. " + "Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level." + ) + + def process(self, field: DataField) -> tuple: + data = field.data.copy() + yres, xres = data.shape + + x = np.linspace(0.0, 1.0, xres) + y = np.linspace(0.0, 1.0, yres) + xx, yy = np.meshgrid(x, y) + + A = np.column_stack([ + np.ones(xres * yres), + xx.ravel(), + yy.ravel(), + ]) + z = data.ravel() + + coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None) + pa, pbx, pby = coeffs + + plane = (pa + pbx * xx + pby * yy) + return (field.replace(data=data - plane),) diff --git a/backend/nodes/poly_level_field.py b/backend/nodes/poly_level_field.py new file mode 100644 index 0000000..09ad5b8 --- /dev/null +++ b/backend/nodes/poly_level_field.py @@ -0,0 +1,48 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Polynomial Level") +class PolyLevelField: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}), + "degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}), + } + } + + RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD") + RETURN_NAMES = ("leveled", "background") + FUNCTION = "process" + + DESCRIPTION = ( + "Fit and subtract a polynomial background of given degree in x and y. " + "Equivalent to gwy_data_field_fit_polynom." + ) + + def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple: + data = field.data.copy() + yres, xres = data.shape + + x = np.linspace(0.0, 1.0, xres) + y = np.linspace(0.0, 1.0, yres) + xx, yy = np.meshgrid(x, y) + + cols = [] + for i in range(degree_x + 1): + for j in range(degree_y + 1): + cols.append((xx ** i * yy ** j).ravel()) + A = np.column_stack(cols) + z = data.ravel() + + coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None) + + background = (A @ coeffs).reshape(yres, xres) + leveled = data - background + + return (field.replace(data=leveled), field.replace(data=background)) diff --git a/backend/nodes/preview_image.py b/backend/nodes/preview_image.py new file mode 100644 index 0000000..a955458 --- /dev/null +++ b/backend/nodes/preview_image.py @@ -0,0 +1,74 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import ( + COLORMAPS, + colormap_to_uint8, + encode_preview, + image_to_uint8, + render_datafield_preview, + resolve_colormap_input, +) + + +@register_node(display_name="Preview") +class PreviewImage: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + "image": ("IMAGE",), + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = () + FUNCTION = "preview" + + OUTPUT_NODE = True + DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input." + + _broadcast_fn = None + _current_node_id: str = "" + + def preview( + self, + colormap: str, + image: np.ndarray | None = None, + field=None, + colormap_map=None, + ) -> tuple: + resolved_colormap = resolve_colormap_input( + colormap, + colormap_input=colormap_map, + inherited=field.colormap if field is not None else None, + default="gray", + ) + + if field is not None: + arr_u8 = render_datafield_preview(field, resolved_colormap) + elif image is not None: + arr_u8 = image_to_uint8(image) + if arr_u8.ndim == 2: + if image.dtype == np.uint8: + normalized = arr_u8.astype(np.float64) / 255.0 + else: + imin, imax = image.min(), image.max() + if imax > imin: + normalized = (image - imin) / (imax - imin) + else: + normalized = np.zeros_like(image, dtype=np.float64) + arr_u8 = colormap_to_uint8(normalized, resolved_colormap) + else: + raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.") + + data_uri = encode_preview(arr_u8) + + if PreviewImage._broadcast_fn is not None: + PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri) + + return () diff --git a/backend/nodes/print_table.py b/backend/nodes/print_table.py new file mode 100644 index 0000000..f53f93b --- /dev/null +++ b/backend/nodes/print_table.py @@ -0,0 +1,27 @@ +from __future__ import annotations +from backend.node_registry import register_node + + +@register_node(display_name="Print Table") +class PrintTable: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "table": ("ANY_TABLE",), + } + } + + RETURN_TYPES = () + FUNCTION = "print_table" + + OUTPUT_NODE = True + DESCRIPTION = "Send a measurement or record table to the browser as a WebSocket message for display." + + _broadcast_table_fn = None + _current_node_id: str = "" + + def print_table(self, table: list) -> tuple: + if PrintTable._broadcast_table_fn is not None: + PrintTable._broadcast_table_fn(PrintTable._current_node_id, table) + return () diff --git a/backend/nodes/range_slider.py b/backend/nodes/range_slider.py new file mode 100644 index 0000000..11dcd01 --- /dev/null +++ b/backend/nodes/range_slider.py @@ -0,0 +1,39 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node + + +@register_node(display_name="Float Slider") +class RangeSlider: + """Interactive float control node with min/max bounds and a slider value.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "min_value": ("FLOAT", {"default": 0.0, "step": 0.01}), + "max_value": ("FLOAT", {"default": 1.0, "step": 0.01}), + "value": ("FLOAT", { + "default": 0.5, + "step": 0.01, + "slider": True, + "min_widget": "min_value", + "max_widget": "max_value", + }), + } + } + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + + DESCRIPTION = ( + "Interactive float slider. Set min and max bounds, then drag the slider to output a FLOAT value." + ) + + def process(self, min_value: float, max_value: float, value: float) -> tuple: + lo = min(float(min_value), float(max_value)) + hi = max(float(min_value), float(max_value)) + if hi == lo: + return (lo,) + return (float(np.clip(float(value), lo, hi)),) diff --git a/backend/nodes/rotate_field.py b/backend/nodes/rotate_field.py new file mode 100644 index 0000000..2c2a601 --- /dev/null +++ b/backend/nodes/rotate_field.py @@ -0,0 +1,102 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField + + +@register_node(display_name="Rotate") +class RotateField: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "angle": ("FLOAT", {"default": 90.0, "min": -360.0, "max": 360.0, "step": 1.0}), + "interpolation": (["bilinear", "nearest", "bicubic"],), + "expand_canvas": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("field",) + FUNCTION = "process" + + DESCRIPTION = ( + "Rotate a DATA_FIELD counterclockwise by an angle in degrees. " + "Optionally expand the canvas to keep the full rotated field while preserving the field center." + ) + + _broadcast_warning_fn = None + _current_node_id: str = "" + + def process( + self, + field: DataField, + angle: float, + interpolation: str, + expand_canvas: bool, + ) -> tuple: + if field.overlays: + self._send_warning("Rotate clears annotation/markup overlays!") + + angle = float(angle) + order_map = { + "nearest": 0, + "bilinear": 1, + "bicubic": 3, + } + if interpolation not in order_map: + raise ValueError(f"Unknown interpolation mode: {interpolation}") + + normalized_angle = angle % 360.0 + snapped_quarters = int(round(normalized_angle / 90.0)) % 4 + snapped_angle = snapped_quarters * 90.0 + is_right_angle = abs(normalized_angle - snapped_angle) < 1e-9 + + if is_right_angle and expand_canvas: + rotated = np.rot90(field.data, k=snapped_quarters).copy() + elif abs(normalized_angle) < 1e-9: + rotated = field.data.copy() + else: + from scipy.ndimage import rotate as nd_rotate + + rotated = nd_rotate( + field.data, + angle=angle, + reshape=bool(expand_canvas), + order=order_map[interpolation], + mode="nearest", + prefilter=order_map[interpolation] > 1, + ) + + new_xreal, new_yreal = self._rotated_extents(field, angle, expand_canvas) + center_x = field.xoff + field.xreal / 2.0 + center_y = field.yoff + field.yreal / 2.0 + + result = field.replace( + data=np.asarray(rotated, dtype=np.float64), + xreal=new_xreal, + yreal=new_yreal, + xoff=center_x - new_xreal / 2.0, + yoff=center_y - new_yreal / 2.0, + overlays=[], + ) + return (result,) + + def _send_warning(self, message: str): + fn = RotateField._broadcast_warning_fn + nid = RotateField._current_node_id + if fn and nid: + fn(nid, message) + + @staticmethod + def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]: + if not expand_canvas: + return (field.xreal, field.yreal) + + theta = np.deg2rad(angle) + cos_t = abs(float(np.cos(theta))) + sin_t = abs(float(np.sin(theta))) + new_xreal = field.xreal * cos_t + field.yreal * sin_t + new_yreal = field.xreal * sin_t + field.yreal * cos_t + return (new_xreal, new_yreal) diff --git a/backend/nodes/save_image.py b/backend/nodes/save_image.py new file mode 100644 index 0000000..887d602 --- /dev/null +++ b/backend/nodes/save_image.py @@ -0,0 +1,182 @@ +from __future__ import annotations +import re +import numpy as np +from pathlib import Path + +from backend.node_registry import register_node +from backend.data_types import DataField, image_to_uint8 +from backend.nodes.helpers import _MAX_SAVE_FIELDS + + +@register_node(display_name="Save Layers") +class SaveImage: + @classmethod + def INPUT_TYPES(cls): + optional = { + "directory": ("DIRECTORY", {"label": "directory"}), + } + for i in range(_MAX_SAVE_FIELDS): + optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"}) + optional[f"layer_name_{i}"] = ("STRING", { + "default": "", + "placeholder": "name", + "show_when_input_visible": f"field_{i}", + "inline_with_input": f"field_{i}", + "hide_label": True, + }) + return { + "required": { + "filename": ("STRING", { + "default": "", + "placeholder": "filename", + "placement": "top", + }), + "directory_path": ("FOLDER_PICKER", { + "default": "", + "label": "directory", + "placement": "top", + "hide_when_input_connected": "directory", + "top_socket_input": "directory", + }), + "format": (["TIFF", "NPZ"],), + }, + "optional": optional, + } + + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + MANUAL_TRIGGER = True + DESCRIPTION = ( + "Save one or more layers to a single file. " + "Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. " + "Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. " + "A new slot appears as each one is filled, with a matching per-layer name field. " + "TIFF writes multi-page data and stores layer names as page descriptions; " + "NPZ writes named arrays using those layer names as keys. " + "Click Save to write (does not auto-run)." + ) + + _broadcast_warning_fn = None + _current_node_id = None + + def save( + self, + filename: str, + directory_path: str = "", + format: str = "TIFF", + directory: str | None = None, + **kwargs, + ): + layers = [] + layer_names = [] + for i in range(_MAX_SAVE_FIELDS): + layer = kwargs.get(f"field_{i}") + if layer is not None: + layers.append(layer) + layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i)) + + if not layers: + raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.") + + path = self._resolve_save_path(filename, format, directory, directory_path) + + if format == "TIFF": + self._save_tiff(path, layers, layer_names) + else: + self._save_npz(path, layers, layer_names) + + self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}") + return () + + def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]): + import tifffile + + with tifffile.TiffWriter(str(path)) as tif: + for layer, layer_name in zip(layers, layer_names): + tif.write(self._layer_array_for_tiff(layer), description=layer_name) + + def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]): + arrays = {} + used_keys = set() + for i, (layer, layer_name) in enumerate(zip(layers, layer_names)): + arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer) + np.savez(str(path), **arrays) + + def _resolve_layer_name(self, raw_name: object, index: int) -> str: + text = str(raw_name).strip() if raw_name is not None else "" + return text or f"layer_{index}" + + def _resolve_save_path( + self, + filename: str, + format: str, + directory: str | None, + directory_path: str = "", + ) -> Path: + ext = ".tiff" if format == "TIFF" else ".npz" + raw_filename = str(filename).strip() if filename is not None else "" + raw_directory = str(directory).strip() if directory is not None else "" + if not raw_directory: + raw_directory = str(directory_path).strip() if directory_path is not None else "" + + if raw_directory: + dir_path = Path(raw_directory).expanduser() + if dir_path.exists() and not dir_path.is_dir(): + raise ValueError("Directory input expects a folder path, not a file path.") + if not dir_path.exists(): + if dir_path.suffix: + raise ValueError("Directory input expects a folder path, not a file path.") + dir_path.mkdir(parents=True, exist_ok=True) + + filename_part = Path(raw_filename).name if raw_filename else "" + if not filename_part: + raise ValueError("No output filename selected — enter a file name when using a directory input.") + path = dir_path / filename_part + else: + if not raw_filename: + raise ValueError("No output path selected — use Browse to pick a location.") + path = Path(raw_filename).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + + if path.suffix.lower() != ext: + path = path.with_suffix(ext) + return path + + def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str: + key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_") + if not key: + key = f"layer_{index}" + if key[0].isdigit(): + key = f"layer_{key}" + + candidate = key + suffix = 2 + while candidate in used_keys: + candidate = f"{key}_{suffix}" + suffix += 1 + used_keys.add(candidate) + return candidate + + def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray: + if isinstance(layer, DataField): + return np.asarray(layer.data, dtype=np.float32) + if isinstance(layer, np.ndarray): + return image_to_uint8(layer) + raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") + + def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray: + if isinstance(layer, DataField): + return np.asarray(layer.data) + if isinstance(layer, np.ndarray): + return np.asarray(layer) + raise ValueError(f"Unsupported save layer type: {type(layer).__name__}") + + def _send_warning(self, message: str): + fn = SaveImage._broadcast_warning_fn + nid = SaveImage._current_node_id + if fn and nid: + fn(nid, message) + + return () diff --git a/backend/nodes/statistics_node.py b/backend/nodes/statistics_node.py new file mode 100644 index 0000000..f508b64 --- /dev/null +++ b/backend/nodes/statistics_node.py @@ -0,0 +1,43 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, MeasureTable + + +@register_node(display_name="Statistics") +class Statistics: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = ("MEASURE_TABLE",) + RETURN_NAMES = ("stats",) + FUNCTION = "process" + + DESCRIPTION = ( + "Compute basic surface statistics: min, max, mean, RMS roughness, median, " + "and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms." + ) + + def process(self, field: DataField) -> tuple: + d = field.data + mean = float(d.mean()) + rms = float(np.sqrt(np.mean((d - mean) ** 2))) + skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0 + kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0 + + table = MeasureTable([ + {"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z}, + {"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z}, + {"quantity": "mean", "value": mean, "unit": field.si_unit_z}, + {"quantity": "RMS", "value": rms, "unit": field.si_unit_z}, + {"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z}, + {"quantity": "skewness", "value": skewness, "unit": ""}, + {"quantity": "kurtosis", "value": kurtosis, "unit": ""}, + {"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z}, + ]) + return (table,) diff --git a/backend/nodes/stats.py b/backend/nodes/stats.py new file mode 100644 index 0000000..59ea289 --- /dev/null +++ b/backend/nodes/stats.py @@ -0,0 +1,130 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, LineData, MeasureTable +from backend.nodes.helpers import ( + LINE_OPS, + TABLE_OPS, + ARRAY_OPS, + _scalar_payload, + _apply_scalar_unit, + _common_table_unit, + extract_numeric_table_values, + resolve_table_column_name, +) + + +@register_node(display_name="Stats") +class Stats: + """Polymorphic scalar stats node for LINE, RECORD_TABLE, DATA_FIELD, or IMAGE inputs.""" + + _broadcast_value_fn = None + _current_node_id: str = "" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input": ("STATS_SOURCE",), + "column": ("STRING", { + "default": "value", + "choices_from_table_input": "input", + "show_when_source_type": { + "input": ["RECORD_TABLE"], + }, + }), + "operation": ("STRING", { + "default": "mean", + "choices_by_source_type": { + "LINE": list(LINE_OPS.keys()), + "RECORD_TABLE": list(TABLE_OPS.keys()), + "DATA_FIELD": list(ARRAY_OPS.keys()), + "IMAGE": list(ARRAY_OPS.keys()), + }, + "source_type_input": "input", + }), + } + } + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + + DESCRIPTION = ( + "Compute a contextual scalar statistic from a LINE, record table, DATA_FIELD, or IMAGE. " + "The available operations adapt to the connected input type." + ) + + def process(self, input, operation: str, column: str = "value") -> tuple: + source_type, values, resolved_column = self._resolve_input_values(input, column) + + if source_type == "RECORD_TABLE": + ops = TABLE_OPS + elif source_type == "LINE": + ops = LINE_OPS + else: + ops = ARRAY_OPS + + if operation not in ops: + raise ValueError(f"Operation '{operation}' is not valid for {source_type} input.") + + op_entry = ops[operation] + fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry + result = fn(values) + if Stats._broadcast_value_fn is not None: + Stats._broadcast_value_fn( + Stats._current_node_id, + _scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)), + ) + return (result,) + + def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str: + if source_type == "DATA_FIELD" and isinstance(input_value, DataField): + return _apply_scalar_unit(input_value.si_unit_z, operation) + + if source_type == "LINE": + line_entry = LINE_OPS.get(operation) + explicit_unit = line_entry[1] if isinstance(line_entry, tuple) and len(line_entry) > 1 else "" + if explicit_unit: + return _apply_scalar_unit(explicit_unit, operation) + if isinstance(input_value, LineData): + return _apply_scalar_unit(input_value.y_unit, operation) + return "" + + if source_type == "RECORD_TABLE" and isinstance(input_value, list) and column: + return _apply_scalar_unit(_common_table_unit(input_value, column), operation) + + return "" + + def _resolve_input_values(self, input_value, column: str) -> tuple[str, np.ndarray, str | None]: + if isinstance(input_value, DataField): + values = np.asarray(input_value.data, dtype=np.float64) + return ("DATA_FIELD", values.ravel(), None) + + if isinstance(input_value, MeasureTable): + raise ValueError("Stats only accepts record tables, not measurement tables.") + + if isinstance(input_value, list): + if not input_value: + raise ValueError("Stats requires a non-empty record table input.") + column_name = resolve_table_column_name(input_value, column) + values = extract_numeric_table_values(input_value, column_name) + if not values: + raise ValueError(f"Column '{column_name}' has no numeric values.") + return ("RECORD_TABLE", np.asarray(values, dtype=np.float64), column_name) + + if isinstance(input_value, LineData): + values = np.asarray(input_value.data, dtype=np.float64) + if values.size == 0: + raise ValueError("Stats requires a non-empty input.") + return ("LINE", values.ravel(), None) + + if isinstance(input_value, np.ndarray): + values = np.asarray(input_value, dtype=np.float64) + if values.size == 0: + raise ValueError("Stats requires a non-empty input.") + if values.ndim == 1: + return ("LINE", values.ravel(), None) + return ("IMAGE", values.ravel(), None) + + raise ValueError(f"Unsupported Stats input type: {type(input_value).__name__}") diff --git a/backend/nodes/threshold_mask.py b/backend/nodes/threshold_mask.py new file mode 100644 index 0000000..56bbac2 --- /dev/null +++ b/backend/nodes/threshold_mask.py @@ -0,0 +1,61 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, encode_preview +from backend.nodes.helpers import _mask_overlay + + +@register_node(display_name="Threshold Mask") +class ThresholdMask: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "method": (["otsu", "absolute", "relative"],), + "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}), + "direction": (["above", "below"],), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + + DESCRIPTION = ( + "Create a binary mask by thresholding data. " + "Otsu automatically finds the optimal threshold. " + "Equivalent to Gwyddion's threshold and otsu_threshold modules." + ) + + _broadcast_fn = None + _current_node_id: str = "" + + def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple: + data = field.data + + if method == "otsu": + from skimage.filters import threshold_otsu + t = threshold_otsu(data) + elif method == "absolute": + t = float(threshold) + elif method == "relative": + dmin, dmax = data.min(), data.max() + t = dmin + float(threshold) * (dmax - dmin) + else: + raise ValueError(f"Unknown threshold method: {method}") + + if direction == "above": + mask = (data >= t).astype(np.uint8) * 255 + else: + mask = (data < t).astype(np.uint8) * 255 + + if ThresholdMask._broadcast_fn is not None: + overlay = _mask_overlay(field, mask) + ThresholdMask._broadcast_fn( + ThresholdMask._current_node_id, encode_preview(overlay), + ) + + return (mask,) diff --git a/backend/nodes/value_display.py b/backend/nodes/value_display.py new file mode 100644 index 0000000..e4a58e8 --- /dev/null +++ b/backend/nodes/value_display.py @@ -0,0 +1,43 @@ +from __future__ import annotations +from backend.node_registry import register_node +from backend.data_types import MeasureTable +from backend.nodes.helpers import _measurement_entry, _measurement_value, _scalar_payload + + +@register_node(display_name="Value Display") +class ValueDisplay: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("VALUE_SOURCE",), + "measurement": ("STRING", { + "default": "", + "choices_from_measure_input": "value", + "show_when_source_type": { + "value": ["MEASURE_TABLE"], + }, + }), + } + } + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "display_value" + + DESCRIPTION = "Display a FLOAT, or a selected numeric row from a measurement table, and pass the value through unchanged." + + _broadcast_value_fn = None + _current_node_id: str = "" + + def display_value(self, value, measurement: str = "") -> tuple: + unit = "" + if isinstance(value, MeasureTable): + row = _measurement_entry(value, measurement) + numeric = _measurement_value(value, measurement) + unit = row.get("unit", "") if isinstance(row.get("unit"), str) else "" + else: + numeric = float(value) + if ValueDisplay._broadcast_value_fn is not None: + ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit)) + return (numeric,) diff --git a/backend/nodes/view_3d.py b/backend/nodes/view_3d.py new file mode 100644 index 0000000..9ffe623 --- /dev/null +++ b/backend/nodes/view_3d.py @@ -0,0 +1,90 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import ( + COLORMAPS, + DataField, + colormap_to_uint8, + normalize_for_colormap, + resolve_colormap_input, +) + + +@register_node(display_name="3D View") +class View3D: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), + "z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}), + "resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}), + }, + "optional": { + "colormap_map": ("COLORMAP", {"label": "colormap"}), + }, + } + + RETURN_TYPES = () + FUNCTION = "render" + + OUTPUT_NODE = True + DESCRIPTION = ( + "Interactive 3D surface view of a DATA_FIELD. " + "Drag to rotate, scroll to zoom. z_scale exaggerates height." + ) + + _broadcast_mesh_fn = None + _current_node_id: str = "" + + def render( + self, field: DataField, + colormap: str, z_scale: float, resolution: int, colormap_map=None, + ) -> tuple: + import base64 + + data = field.data + yres, xres = data.shape + + step_y = max(1, yres // resolution) + step_x = max(1, xres // resolution) + z = data[::step_y, ::step_x].astype(np.float32) + ny, nx = z.shape + + zmin, zmax = float(z.min()), float(z.max()) + z_norm = normalize_for_colormap( + z, + offset=field.display_offset, + scale=field.display_scale, + data_min=float(field.data.min()), + data_max=float(field.data.max()), + ) + + resolved_colormap = resolve_colormap_input( + colormap, + colormap_input=colormap_map, + inherited=field.colormap, + default="gray", + ) + colors_u8 = colormap_to_uint8(z_norm, resolved_colormap) + + z_b64 = base64.b64encode(z.tobytes()).decode() + colors_b64 = base64.b64encode(colors_u8.tobytes()).decode() + + mesh_data = { + "width": nx, + "height": ny, + "z_data": z_b64, + "colors": colors_b64, + "z_min": zmin, + "z_max": zmax, + "z_scale": float(z_scale * 0.1), + "x_range": [float(field.xoff), float(field.xoff + field.xreal)], + "y_range": [float(field.yoff), float(field.yoff + field.yreal)], + } + + if View3D._broadcast_mesh_fn is not None: + View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data) + + return () diff --git a/backend/server.py b/backend/server.py index 6af0284..4426c3d 100644 --- a/backend/server.py +++ b/backend/server.py @@ -217,7 +217,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: async def get_folder_files(request: web.Request) -> web.Response: folder_path = request.query.get("folder", "") - from backend.nodes.io import list_folder_paths + from backend.nodes.helpers import list_folder_paths loop = asyncio.get_running_loop() entries = await loop.run_in_executor(None, list_folder_paths, folder_path) return web.Response(text=_dumps(entries), content_type="application/json") @@ -267,7 +267,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: async def get_channels(request: web.Request) -> web.Response: """Return available channels for a given file path.""" - from backend.nodes.io import list_channels + from backend.nodes.helpers import list_channels filepath = request.query.get("file", "") if not filepath: return web.Response( diff --git a/tests/test_fft.py b/tests/test_fft.py index cb2e794..c69d593 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -9,7 +9,8 @@ import numpy as np sys.path.insert(0, ".") from backend.data_types import DataField -from backend.nodes.analysis import FFT2D, InverseFFT2D +from backend.nodes.fft_2d import FFT2D +from backend.nodes.inverse_fft_2d import InverseFFT2D def make_field(data, xreal=1e-6, yreal=1e-6): diff --git a/tests/test_fft_visual.py b/tests/test_fft_visual.py index fde9bfd..38ff616 100644 --- a/tests/test_fft_visual.py +++ b/tests/test_fft_visual.py @@ -10,7 +10,7 @@ import numpy as np sys.path.insert(0, ".") from backend.data_types import DataField, datafield_to_uint8, encode_preview -from backend.nodes.analysis import FFT2D +from backend.nodes.fft_2d import FFT2D OUT_DIR = os.path.join(os.path.dirname(__file__), "output") os.makedirs(OUT_DIR, exist_ok=True) diff --git a/tests/test_grains.py b/tests/test_grains.py index 30899b6..ade075d 100644 --- a/tests/test_grains.py +++ b/tests/test_grains.py @@ -28,7 +28,7 @@ def make_field(data, xreal=1e-6, yreal=1e-6): def test_threshold_otsu_bimodal(): """Otsu on a clean bimodal image should separate the two populations.""" print("=== Test: Otsu on bimodal image ===") - from backend.nodes.particle import ThresholdMask + from backend.nodes.threshold_mask import ThresholdMask node = ThresholdMask() data = np.zeros((128, 128)) @@ -50,7 +50,7 @@ def test_threshold_otsu_bimodal(): def test_threshold_relative_range(): """Relative threshold at 0.5 should be the midpoint of [min, max].""" print("=== Test: Relative threshold at midpoint ===") - from backend.nodes.particle import ThresholdMask + from backend.nodes.threshold_mask import ThresholdMask node = ThresholdMask() data = np.full((64, 64), 2.0) @@ -68,7 +68,7 @@ def test_threshold_relative_range(): def test_threshold_empty_mask(): """Very high absolute threshold on low data should produce an empty mask.""" print("=== Test: Empty mask from high threshold ===") - from backend.nodes.particle import ThresholdMask + from backend.nodes.threshold_mask import ThresholdMask node = ThresholdMask() data = np.ones((64, 64)) @@ -82,7 +82,7 @@ def test_threshold_empty_mask(): def test_threshold_full_mask(): """Very low absolute threshold should produce an all-white mask.""" print("=== Test: Full mask from low threshold ===") - from backend.nodes.particle import ThresholdMask + from backend.nodes.threshold_mask import ThresholdMask node = ThresholdMask() data = np.ones((64, 64)) * 5.0 @@ -100,7 +100,7 @@ def test_threshold_full_mask(): def test_single_circle_area(): """A single filled circle — verify pixel count and physical area.""" print("=== Test: Single circle area ===") - from backend.nodes.particle import GrainAnalysis + from backend.nodes.particle_analysis import GrainAnalysis node = GrainAnalysis() N = 200 @@ -146,7 +146,7 @@ def test_single_circle_area(): def test_multiple_particles_separation(): """Three well-separated particles of different sizes — check each is reported.""" print("=== Test: Multiple particles separation ===") - from backend.nodes.particle import GrainAnalysis + from backend.nodes.particle_analysis import GrainAnalysis node = GrainAnalysis() N = 128 @@ -184,7 +184,7 @@ def test_multiple_particles_separation(): def test_min_size_filtering(): """min_size should exclude particles smaller than the threshold.""" print("=== Test: min_size filtering ===") - from backend.nodes.particle import GrainAnalysis + from backend.nodes.particle_analysis import GrainAnalysis node = GrainAnalysis() N = 64 @@ -227,7 +227,7 @@ def test_min_size_filtering(): def test_particles_bounding_box(): """Bounding box should match the particles extents.""" print("=== Test: Grain bounding box ===") - from backend.nodes.particle import GrainAnalysis + from backend.nodes.particle_analysis import GrainAnalysis node = GrainAnalysis() N = 64 @@ -250,7 +250,7 @@ def test_particles_bounding_box(): def test_empty_mask_produces_no_particles(): """An all-zero mask should yield zero particles.""" print("=== Test: Empty mask → no particles ===") - from backend.nodes.particle import GrainAnalysis + from backend.nodes.particle_analysis import GrainAnalysis node = GrainAnalysis() field = make_field(np.ones((64, 64))) @@ -264,7 +264,7 @@ def test_empty_mask_produces_no_particles(): def test_particles_at_image_edge(): """A particles touching the image border should still be detected.""" print("=== Test: Grain at image edge ===") - from backend.nodes.particle import GrainAnalysis + from backend.nodes.particle_analysis import GrainAnalysis node = GrainAnalysis() N = 64 @@ -286,7 +286,7 @@ def test_adjacent_particles_connectivity(): """Two diagonally-touching blocks should be separate particles (scipy.ndimage.label uses 4-connectivity by default).""" print("=== Test: Diagonal adjacency → separate particles ===") - from backend.nodes.particle import GrainAnalysis + from backend.nodes.particle_analysis import GrainAnalysis node = GrainAnalysis() N = 32 @@ -316,7 +316,8 @@ def test_adjacent_particles_connectivity(): def test_pipeline_synthetic(): """Full pipeline on a synthetic image with known geometry.""" print("=== Test: Full pipeline on synthetic particles ===") - from backend.nodes.particle import ThresholdMask, GrainAnalysis + from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.particle_analysis import GrainAnalysis N = 200 XREAL = 10e-6 # 10 µm @@ -371,7 +372,8 @@ def test_pipeline_demo_image(): """Run the full pipeline on the bundled demo nanoparticles image.""" print("=== Test: Full pipeline on demo nanoparticles.npy ===") from pathlib import Path - from backend.nodes.particle import ThresholdMask, GrainAnalysis + from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.particle_analysis import GrainAnalysis from backend.runtime_paths import demo_dir npy_path = demo_dir() / "nanoparticles.npy" diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 8e057c9..0c8b6bd 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -28,7 +28,7 @@ def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6): def test_gaussian_filter(): print("=== Test: GaussianFilter ===") - from backend.nodes.filters import GaussianFilter + from backend.nodes.gaussian_filter import GaussianFilter node = GaussianFilter() field = make_field() @@ -46,7 +46,7 @@ def test_gaussian_filter(): def test_median_filter(): print("=== Test: MedianFilter ===") - from backend.nodes.filters import MedianFilter + from backend.nodes.median_filter import MedianFilter node = MedianFilter() # Median filter should remove salt-and-pepper noise @@ -68,7 +68,7 @@ def test_median_filter(): def test_crop_resize_field(): print("=== Test: CropResizeField ===") - from backend.nodes.modify import CropResizeField + from backend.nodes.crop_resize_field import CropResizeField node = CropResizeField() data = np.arange(32, dtype=np.float64).reshape(4, 8) @@ -167,7 +167,7 @@ def test_crop_resize_field(): def test_rotate_field(): print("=== Test: RotateField ===") - from backend.nodes.modify import RotateField + from backend.nodes.rotate_field import RotateField node = RotateField() data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) @@ -230,7 +230,7 @@ def test_rotate_field(): def test_rotate_field_overlay_warning(): print("=== Test: RotateField overlay warning ===") - from backend.nodes.modify import RotateField + from backend.nodes.rotate_field import RotateField node = RotateField() warnings = [] @@ -258,7 +258,7 @@ def test_rotate_field_overlay_warning(): def test_colormap_adjust(): print("=== Test: ColormapAdjust ===") - from backend.nodes.modify import ColormapAdjust + from backend.nodes.colormap_adjust import ColormapAdjust node = ColormapAdjust() field = DataField( @@ -299,7 +299,7 @@ def test_colormap_adjust(): def test_edge_detect(): print("=== Test: EdgeDetect ===") - from backend.nodes.filters import EdgeDetect + from backend.nodes.edge_detect import EdgeDetect node = EdgeDetect() # Create an image with a sharp vertical edge @@ -320,7 +320,7 @@ def test_edge_detect(): def test_fft_filter_1d(): print("=== Test: FFTFilter1D ===") - from backend.nodes.filters import FFTFilter1D + from backend.nodes.fft_filter_1d import FFTFilter1D node = FFTFilter1D() # Signal: low-frequency sine + high-frequency sine @@ -364,7 +364,7 @@ def test_fft_filter_1d(): def test_fft_filter_2d(): print("=== Test: FFTFilter2D ===") - from backend.nodes.filters import FFTFilter2D + from backend.nodes.fft_filter_2d import FFTFilter2D node = FFTFilter2D() N = 128 @@ -406,7 +406,7 @@ def test_fft_filter_2d(): def test_plane_level(): print("=== Test: PlaneLevelField ===") - from backend.nodes.level import PlaneLevelField + from backend.nodes.plane_level_field import PlaneLevelField node = PlaneLevelField() # Create a tilted plane + small signal @@ -428,7 +428,7 @@ def test_plane_level(): def test_poly_level(): print("=== Test: PolyLevelField ===") - from backend.nodes.level import PolyLevelField + from backend.nodes.poly_level_field import PolyLevelField node = PolyLevelField() N = 64 @@ -455,7 +455,7 @@ def test_poly_level(): def test_fix_zero(): print("=== Test: FixZero ===") - from backend.nodes.level import FixZero + from backend.nodes.fix_zero import FixZero node = FixZero() field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64)) @@ -477,7 +477,7 @@ def test_fix_zero(): def test_statistics(): print("=== Test: Statistics ===") - from backend.nodes.analysis import Statistics + from backend.nodes.statistics_node import Statistics node = Statistics() data = np.array([[1, 2], [3, 4]], dtype=np.float64) @@ -507,7 +507,7 @@ def test_statistics(): def test_height_histogram(): print("=== Test: Histogram ===") - from backend.nodes.analysis import Histogram + from backend.nodes.histogram import Histogram node = Histogram() # Uniform data should give a roughly flat histogram @@ -556,7 +556,7 @@ def test_height_histogram(): def test_cross_section(): print("=== Test: CrossSection ===") - from backend.nodes.analysis import CrossSection + from backend.nodes.cross_section import CrossSection node = CrossSection() # Create a field with a known horizontal gradient @@ -604,7 +604,8 @@ def test_cross_section(): ) assert len(profile_diag) == 50 - from backend.nodes.analysis import Cursors, Stats + from backend.nodes.cursors import Cursors + from backend.nodes.stats import Stats cursors = Cursors() table, _ = cursors.process(profile, x1=0.25, y1=0.5, x2=0.75, y2=0.5) @@ -630,7 +631,7 @@ def test_cross_section(): def test_threshold_mask(): print("=== Test: ThresholdMask ===") - from backend.nodes.mask import ThresholdMask + from backend.nodes.threshold_mask import ThresholdMask node = ThresholdMask() # Clear bimodal data: left half = 0, right half = 1 @@ -673,7 +674,7 @@ def test_threshold_mask(): def test_mask_morphology(): print("=== Test: MaskMorphology ===") - from backend.nodes.mask import MaskMorphology + from backend.nodes.mask_morphology import MaskMorphology node = MaskMorphology() # Small square blob in the centre @@ -710,7 +711,7 @@ def test_mask_morphology(): def test_mask_invert(): print("=== Test: MaskInvert ===") - from backend.nodes.mask import MaskInvert + from backend.nodes.mask_invert import MaskInvert node = MaskInvert() mask = np.zeros((64, 64), dtype=np.uint8) @@ -729,7 +730,7 @@ def test_mask_invert(): def test_mask_combine(): print("=== Test: MaskCombine ===") - from backend.nodes.mask import MaskCombine + from backend.nodes.mask_combine import MaskCombine node = MaskCombine() # Two overlapping squares @@ -768,7 +769,7 @@ def test_mask_combine(): def test_draw_mask(): print("=== Test: DrawMask ===") - from backend.nodes.mask import DrawMask + from backend.nodes.draw_mask import DrawMask node = DrawMask() field = make_field(data=np.zeros((32, 32), dtype=np.float64)) @@ -815,7 +816,7 @@ def test_draw_mask(): def test_particle_analysis(): print("=== Test: ParticleAnalysis ===") - from backend.nodes.particless import ParticleAnalysis + from backend.nodes.particle_analysis import ParticleAnalysis node = ParticleAnalysis() # Create a field with two distinct particles @@ -855,7 +856,7 @@ def test_particle_analysis(): def test_load_file(): print("=== Test: Image ===") - from backend.nodes.io import Image as ImageNode + from backend.nodes.image import Image as ImageNode from PIL import Image as PILImage node = ImageNode() @@ -912,7 +913,7 @@ def test_load_file(): def test_save_image(): print("=== Test: SaveImage (Save Layers) ===") - from backend.nodes.io import SaveImage + from backend.nodes.save_image import SaveImage import tifffile node = SaveImage() @@ -1012,7 +1013,7 @@ def test_save_image(): def test_color_map_node(): print("=== Test: ColorMap ===") - from backend.nodes.display import ColorMap + from backend.nodes.color_map import ColorMap node = ColorMap() @@ -1038,7 +1039,7 @@ def test_color_map_node(): def test_font_node(): print("=== Test: Font ===") - from backend.nodes.display import Font + from backend.nodes.font_node import Font from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT node = Font() @@ -1056,7 +1057,7 @@ def test_font_node(): def test_preview_image(): print("=== Test: PreviewImage ===") - from backend.nodes.display import PreviewImage + from backend.nodes.preview_image import PreviewImage node = PreviewImage() # Set up a capture for the broadcast @@ -1104,7 +1105,8 @@ def test_preview_image(): def test_annotations(): print("=== Test: Annotations ===") - from backend.nodes.display import Annotations, Font + from backend.nodes.annotations import Annotations + from backend.nodes.font_node import Font node = Annotations() font_node = Font() @@ -1175,7 +1177,7 @@ def test_annotations(): def test_markup(): print("=== Test: Markup ===") - from backend.nodes.display import Markup + from backend.nodes.markup import Markup from backend.data_types import _preview_markup_stroke_width node = Markup() @@ -1226,7 +1228,7 @@ def test_markup(): def test_print_table(): print("=== Test: PrintTable ===") - from backend.nodes.display import PrintTable + from backend.nodes.print_table import PrintTable node = PrintTable() captured = [] @@ -1244,7 +1246,7 @@ def test_print_table(): def test_value_display(): print("=== Test: ValueDisplay ===") - from backend.nodes.display import ValueDisplay + from backend.nodes.value_display import ValueDisplay node = ValueDisplay() captured = [] @@ -1273,7 +1275,7 @@ def test_value_display(): def test_load_file_ibw(): print("=== Test: Image IBW multi-channel ===") - from backend.nodes.io import Image + from backend.nodes.image import Image node = Image() ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw") @@ -1309,7 +1311,7 @@ def test_load_file_ibw(): def test_load_file_npz(): print("=== Test: Image .npz ===") - from backend.nodes.io import Image + from backend.nodes.image import Image node = Image() with tempfile.TemporaryDirectory() as tmpdir: @@ -1326,7 +1328,7 @@ def test_load_file_npz(): def test_load_file_not_found(): print("=== Test: Image not found ===") - from backend.nodes.io import Image + from backend.nodes.image import Image node = Image() try: @@ -1340,7 +1342,7 @@ def test_load_file_not_found(): def test_load_file_unsupported(): print("=== Test: Image unsupported format ===") - from backend.nodes.io import Image + from backend.nodes.image import Image node = Image() with tempfile.TemporaryDirectory() as tmpdir: @@ -1358,7 +1360,7 @@ def test_load_file_unsupported(): def test_load_file_warning(): print("=== Test: Image warning for uncalibrated data ===") - from backend.nodes.io import Image as ImageNode + from backend.nodes.image import Image as ImageNode from PIL import Image as PILImage node = ImageNode() @@ -1387,7 +1389,8 @@ def test_load_file_warning(): def test_list_channels(): print("=== Test: list_channels ===") - from backend.nodes.io import list_channels, list_folder_paths, Folder + from backend.nodes.helpers import list_channels, list_folder_paths + from backend.nodes.folder import Folder from PIL import Image # Non-existent file → default @@ -1458,7 +1461,7 @@ def test_list_channels(): def test_load_demo(): print("=== Test: ImageDemo ===") - from backend.nodes.io import ImageDemo + from backend.nodes.image_demo import ImageDemo node = ImageDemo() @@ -1519,7 +1522,7 @@ def test_load_demo_multi_layer_preview_payload(): def test_coordinate(): print("=== Test: Coordinate ===") - from backend.nodes.io import Coordinate + from backend.nodes.coordinate import Coordinate node = Coordinate() @@ -1543,7 +1546,7 @@ def test_coordinate(): def test_number(): print("=== Test: Number ===") - from backend.nodes.io import Number + from backend.nodes.number import Number node = Number() @@ -1558,7 +1561,7 @@ def test_number(): def test_range_slider(): print("=== Test: RangeSlider ===") - from backend.nodes.io import RangeSlider + from backend.nodes.range_slider import RangeSlider node = RangeSlider() @@ -1642,7 +1645,7 @@ def test_execution_engine_numeric_socket_coercion(): def test_line_cursors(): print("=== Test: Cursors ===") - from backend.nodes.analysis import Cursors + from backend.nodes.cursors import Cursors node = Cursors() @@ -1718,7 +1721,7 @@ def test_line_cursors(): def test_fft2d(): print("=== Test: FFT2D ===") - from backend.nodes.analysis import FFT2D + from backend.nodes.fft_2d import FFT2D node = FFT2D() @@ -1777,7 +1780,7 @@ def test_fft2d(): def test_stats(): print("=== Test: Stats ===") - from backend.nodes.analysis import Stats + from backend.nodes.stats import Stats node = Stats() captured = [] @@ -1845,7 +1848,7 @@ def test_stats(): def test_view3d(): print("=== Test: View3D ===") - from backend.nodes.display import View3D + from backend.nodes.view_3d import View3D node = View3D() field = make_field() @@ -1863,7 +1866,7 @@ def test_view3d(): assert "height" in mesh assert "z_data" in mesh assert "colors" in mesh - assert mesh["z_scale"] == 2.0 + assert mesh["z_scale"] == 0.2 assert mesh["width"] <= 64 assert mesh["height"] <= 64 # z_min < z_max for non-constant data