refactor nodes into standalone file

This commit is contained in:
2026-03-26 19:50:03 -07:00
parent 711d7995b3
commit de0b49acc5
54 changed files with 3615 additions and 3710 deletions

View File

@@ -218,11 +218,25 @@ class ExecutionEngine:
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
from backend.nodes.analysis import CrossSection, Cursors, Stats, Histogram
from backend.nodes.modify import CropResizeField, RotateField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import SaveImage, Image, ImageDemo
from backend.nodes.preview_image import PreviewImage
from backend.nodes.print_table import PrintTable
from backend.nodes.view_3d import View3D
from backend.nodes.value_display import ValueDisplay
from backend.nodes.markup import Markup
from backend.nodes.cross_section import CrossSection
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
from backend.nodes.histogram import Histogram
from backend.nodes.crop_resize_field import CropResizeField
from backend.nodes.rotate_field import RotateField
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.mask_morphology import MaskMorphology
from backend.nodes.mask_invert import MaskInvert
from backend.nodes.mask_combine import MaskCombine
from backend.nodes.draw_mask import DrawMask
from backend.nodes.save_image import SaveImage
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
@@ -246,11 +260,25 @@ class ExecutionEngine:
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup
from backend.nodes.analysis import CrossSection, Cursors, Stats, Histogram
from backend.nodes.modify import CropResizeField, RotateField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import Image, ImageDemo, SaveImage
from backend.nodes.preview_image import PreviewImage
from backend.nodes.print_table import PrintTable
from backend.nodes.view_3d import View3D
from backend.nodes.value_display import ValueDisplay
from backend.nodes.markup import Markup
from backend.nodes.cross_section import CrossSection
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
from backend.nodes.histogram import Histogram
from backend.nodes.crop_resize_field import CropResizeField
from backend.nodes.rotate_field import RotateField
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.mask_morphology import MaskMorphology
from backend.nodes.mask_invert import MaskInvert
from backend.nodes.mask_combine import MaskCombine
from backend.nodes.draw_mask import DrawMask
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
from backend.nodes.save_image import SaveImage
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, Stats, Histogram, CrossSection, Cursors, CropResizeField, RotateField, Markup,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
Image, ImageDemo, SaveImage):
@@ -274,7 +302,8 @@ class ExecutionEngine:
from backend.data_types import (
DataField, LineData, image_to_uint8, encode_preview, render_datafield_preview,
)
from backend.nodes.io import Image, ImageDemo
from backend.nodes.image import Image
from backend.nodes.image_demo import ImageDemo
if getattr(cls, "_CUSTOM_PREVIEW", False):
return
@@ -318,7 +347,7 @@ class ExecutionEngine:
inputs: dict[str, Any],
) -> dict | None:
from backend.data_types import DataField, encode_preview, render_datafield_preview
from backend.nodes.io import list_channels
from backend.nodes.helpers import list_channels
fields = [value for value in result if isinstance(value, DataField)]
if not fields:

View File

@@ -1,7 +1,54 @@
# Import all node modules to trigger @register_node decorators.
from . import io, filters, modify, level, analysis, mask, display
from backend.nodes import (
# IO
image,
image_demo,
folder,
coordinate,
coordinate_pair,
number,
range_slider,
save_image,
# Filters
gaussian_filter,
median_filter,
edge_detect,
fft_filter_1d,
fft_filter_2d,
# Modify
colormap_adjust,
crop_resize_field,
rotate_field,
# Level
plane_level_field,
poly_level_field,
fix_zero,
# Mask
draw_mask,
threshold_mask,
mask_morphology,
mask_invert,
mask_combine,
# Display
color_map,
font_node,
annotations,
markup,
preview_image,
view_3d,
print_table,
value_display,
# Analysis
statistics_node,
histogram,
cursors,
fft_2d,
inverse_fft_2d,
cross_section,
stats,
)
try:
from . import particle
from backend.nodes import particle_analysis
except ImportError:
from . import particless
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import COLORMAPS, DataField, normalize_font_spec, resolve_colormap_input
@register_node(display_name="Annotations")
class Annotations:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"show_scale_bar": ("BOOLEAN", {"default": True}),
"show_color_map": ("BOOLEAN", {"default": True}),
"text_size": ("FLOAT", {
"default": 14.0,
"min": 6.0,
"max": 96.0,
"step": 1.0,
}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"font": ("FONT",),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "render"
DESCRIPTION = (
"Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. "
"The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values."
)
def render(
self,
field: DataField,
colormap: str,
show_scale_bar: bool,
show_color_map: bool,
text_size: float = 1.0,
colormap_map=None,
font=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0
out = field.replace(
colormap=resolved_colormap,
overlays=[
*field.overlays,
{
"kind": "annotation",
"show_scale_bar": bool(show_scale_bar),
"show_color_map": bool(show_color_map),
"text_size": text_size,
"font": normalize_font_spec(font),
},
],
)
return (out,)

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import json
from backend.node_registry import register_node
from backend.data_types import COLORMAPS, DEFAULT_CUSTOM_COLORMAP_STOPS, normalize_colormap_spec
@register_node(display_name="Color Map")
class ColorMap:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["preset", "custom"], {"default": "preset"}),
"preset": (list(COLORMAPS), {
"default": "viridis",
"show_when_widget_value": {"mode": ["preset"]},
}),
"stops": ("STRING", {
"default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)),
"colormap_stops": True,
"show_when_widget_value": {"mode": ["custom"]},
}),
}
}
RETURN_TYPES = ("COLORMAP",)
RETURN_NAMES = ("colormap",)
FUNCTION = "build"
DESCRIPTION = (
"Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours "
"and any number of intermediate stops."
)
def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple:
if mode == "preset":
return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},)
try:
raw_stops = stops if stops is not None else stops_json
stops_data = json.loads(raw_stops or "[]")
except json.JSONDecodeError as exc:
raise ValueError("Custom colormap stops must be valid JSON.") from exc
spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None)
if not (isinstance(spec, dict) and spec.get("mode") == "custom"):
raise ValueError("Custom colormap must include at least min and max colours.")
return (spec,)

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Colormap Adjust")
class ColormapAdjust:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"offset": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}),
"scale": ("FLOAT", {"default": 1.0, "min": 0.05, "max": 4.0, "step": 0.01}),
"auto": ("BUTTON", {"label": "Auto", "set_widgets": {"offset": 0.0, "scale": 1.0}}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
DESCRIPTION = (
"Adjust how a DATA_FIELD maps into its colormap without changing the underlying data. "
"offset and scale operate in normalized display coordinates; Auto resets to the full data range."
)
def process(self, field: DataField, offset: float, scale: float) -> tuple:
scale = float(scale)
if not np.isfinite(scale) or scale <= 0.0:
raise ValueError("Scale must be a positive number.")
return (field.replace(display_offset=float(offset), display_scale=scale),)

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from backend.node_registry import register_node
@register_node(display_name="Coordinate")
class Coordinate:
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("COORD",)
RETURN_NAMES = ("point",)
FUNCTION = "process"
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
def process(self, x: float, y: float) -> tuple:
return ((float(x), float(y)),)

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from backend.node_registry import register_node
@register_node(display_name="Coordinate Pair")
class CoordinatePair:
"""Provide a pair of Coordinates, for drawing lines between markers, etc."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("COORD",),
"b": ("COORD",),
}
}
RETURN_TYPES = ("COORDPAIR",)
RETURN_NAMES = ("coord pair",)
FUNCTION = "process"
DESCRIPTION = "Output a pair of coordinates."
def process(self, a: tuple, b: tuple) -> tuple:
return ((a, b),)

View File

@@ -1,52 +1,9 @@
"""
Modify nodes geometric transforms for DATA_FIELDs.
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
# ---------------------------------------------------------------------------
# ColormapAdjust
# ---------------------------------------------------------------------------
@register_node(display_name="Colormap Adjust")
class ColormapAdjust:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"offset": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}),
"scale": ("FLOAT", {"default": 1.0, "min": 0.05, "max": 4.0, "step": 0.01}),
"auto": ("BUTTON", {"label": "Auto", "set_widgets": {"offset": 0.0, "scale": 1.0}}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
DESCRIPTION = (
"Adjust how a DATA_FIELD maps into its colormap without changing the underlying data. "
"offset and scale operate in normalized display coordinates; Auto resets to the full data range."
)
def process(self, field: DataField, offset: float, scale: float) -> tuple:
scale = float(scale)
if not np.isfinite(scale) or scale <= 0.0:
raise ValueError("Scale must be a positive number.")
return (field.replace(display_offset=float(offset), display_scale=scale),)
# ---------------------------------------------------------------------------
# CropResizeField
# ---------------------------------------------------------------------------
@register_node(display_name="Crop / Resize")
class CropResizeField:
@classmethod
@@ -190,105 +147,3 @@ class CropResizeField:
target_height = max(1, int(round(height * (target_width / width))))
return (max(1, target_width), max(1, target_height))
# ---------------------------------------------------------------------------
# RotateField
# ---------------------------------------------------------------------------
@register_node(display_name="Rotate")
class RotateField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"angle": ("FLOAT", {"default": 90.0, "min": -360.0, "max": 360.0, "step": 1.0}),
"interpolation": (["bilinear", "nearest", "bicubic"],),
"expand_canvas": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
DESCRIPTION = (
"Rotate a DATA_FIELD counterclockwise by an angle in degrees. "
"Optionally expand the canvas to keep the full rotated field while preserving the field center."
)
_broadcast_warning_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
angle: float,
interpolation: str,
expand_canvas: bool,
) -> tuple:
if field.overlays:
self._send_warning("Rotate clears annotation/markup overlays!")
angle = float(angle)
order_map = {
"nearest": 0,
"bilinear": 1,
"bicubic": 3,
}
if interpolation not in order_map:
raise ValueError(f"Unknown interpolation mode: {interpolation}")
normalized_angle = angle % 360.0
snapped_quarters = int(round(normalized_angle / 90.0)) % 4
snapped_angle = snapped_quarters * 90.0
is_right_angle = abs(normalized_angle - snapped_angle) < 1e-9
if is_right_angle and expand_canvas:
rotated = np.rot90(field.data, k=snapped_quarters).copy()
elif abs(normalized_angle) < 1e-9:
rotated = field.data.copy()
else:
from scipy.ndimage import rotate as nd_rotate
rotated = nd_rotate(
field.data,
angle=angle,
reshape=bool(expand_canvas),
order=order_map[interpolation],
mode="nearest",
prefilter=order_map[interpolation] > 1,
)
new_xreal, new_yreal = self._rotated_extents(field, angle, expand_canvas)
center_x = field.xoff + field.xreal / 2.0
center_y = field.yoff + field.yreal / 2.0
result = field.replace(
data=np.asarray(rotated, dtype=np.float64),
xreal=new_xreal,
yreal=new_yreal,
xoff=center_x - new_xreal / 2.0,
yoff=center_y - new_yreal / 2.0,
overlays=[],
)
return (result,)
def _send_warning(self, message: str):
fn = RotateField._broadcast_warning_fn
nid = RotateField._current_node_id
if fn and nid:
fn(nid, message)
@staticmethod
def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]:
if not expand_canvas:
return (field.xreal, field.yreal)
theta = np.deg2rad(angle)
cos_t = abs(float(np.cos(theta)))
sin_t = abs(float(np.sin(theta)))
new_xreal = field.xreal * cos_t + field.yreal * sin_t
new_yreal = field.xreal * sin_t + field.yreal * cos_t
return (new_xreal, new_yreal)

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _extend_to_edges
@register_node(display_name="Cross Section")
class CrossSection:
"""Extract a 1-D height profile along an arbitrary line across the image."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"x1": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"extend": (["none", "to_edges"],),
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
},
"optional": {
"marker_pair": ("COORDPAIR", {"label": "marker pair"}),
},
}
RETURN_TYPES = ("LINE", "COORDPAIR",)
RETURN_NAMES = ("profile", "marker pair",)
FUNCTION = "process"
DESCRIPTION = (
"Extract a cross-section profile along a line between two points. "
"Drag the markers on the image to set the line endpoints. "
"Equivalent to gwy_data_field_get_profile."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self, field: DataField,
x1: float, y1: float, x2: float, y2: float,
extend: str, n_samples: int,
marker_pair=None,
) -> tuple:
from scipy.ndimage import map_coordinates
if marker_pair is not None:
(x1, y1), (x2, y2) = marker_pair
marker_x1, marker_y1 = float(x1), float(y1)
marker_x2, marker_y2 = float(x2), float(y2)
xres, yres = field.xres, field.yres
if extend == "to_edges":
x1, y1, x2, y2 = _extend_to_edges(
float(x1), float(y1), float(x2), float(y2),
)
px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1)
px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1)
line_len_px = np.hypot(px2 - px1, py2 - py1)
if n_samples <= 0:
n_samples = max(2, int(np.ceil(line_len_px)))
t = np.linspace(0, 1, n_samples)
coords_y = py1 + t * (py2 - py1)
coords_x = px1 + t * (px2 - px1)
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
if CrossSection._broadcast_overlay_fn is not None:
image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
CrossSection._broadcast_overlay_fn(
CrossSection._current_node_id,
{
"image": image_uri,
"x1": marker_x1, "y1": marker_y1,
"x2": marker_x2, "y2": marker_y2,
"a_locked": marker_pair is not None,
"b_locked": marker_pair is not None,
},
)
dx_real = (x2 - x1) * field.xreal
dy_real = (y2 - y1) * field.yreal
distance_axis = np.linspace(0.0, float(np.hypot(dx_real, dy_real)), n_samples, dtype=np.float64)
return (
LineData(
data=profile.astype(np.float64),
x_axis=distance_axis,
x_unit=field.si_unit_xy,
y_unit=field.si_unit_z,
),
((marker_x1, marker_y1), (marker_x2, marker_y2)),
)

173
backend/nodes/cursors.py Normal file
View File

@@ -0,0 +1,173 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, LineData, MeasureTable, encode_preview, render_datafield_preview
@register_node(display_name="Cursors")
class Cursors:
"""Place two draggable cursors on a line plot or field to measure deltas."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("CURSOR_SOURCE", {"label": "input"}),
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
},
"optional": {
"coord_pair": ("COORDPAIR", {"label": "coord pair"}),
},
}
RETURN_TYPES = ("MEASURE_TABLE", "COORDPAIR",)
RETURN_NAMES = ("measurement", "coord pair",)
FUNCTION = "process"
DESCRIPTION = (
"Place two cursors on a line plot or 2D field. "
"On lines it reports x/y positions and dx/dy. "
"On fields it reports x/y/z at both markers plus dx/dy/dz."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self, line, x1: float, y1: float, x2: float, y2: float,
coord_pair=None,
) -> tuple:
if coord_pair is not None:
(x1, y1), (x2, y2) = coord_pair
locked = coord_pair is not None
if isinstance(line, DataField):
return self._process_field(line, x1=x1, y1=y1, x2=x2, y2=y2, locked=locked)
return self._process_line(line, x1=x1, y1=y1, x2=x2, y2=y2, locked=locked)
def _process_line(
self,
line,
x1: float,
y1: float,
x2: float,
y2: float,
locked: bool = False,
) -> tuple:
y = np.asarray(line, dtype=np.float64).ravel()
x_unit = line.x_unit if isinstance(line, LineData) else ""
y_unit = line.y_unit if isinstance(line, LineData) else ""
n = len(y)
if isinstance(line, LineData) and line.x_axis is not None:
x = np.asarray(line.x_axis, dtype=np.float64).ravel()[:n]
else:
x = np.arange(n, dtype=np.float64)
x1 = float(np.clip(x1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 1.0))
xmin = float(np.min(x)) if len(x) else 0.0
xmax = float(np.max(x)) if len(x) else 1.0
def x_frac_to_idx(frac):
if n <= 1:
return 0
if xmax == xmin:
return 0
target_x = xmin + frac * (xmax - xmin)
return int(np.argmin(np.abs(x - target_x)))
idx_a = x_frac_to_idx(x1)
idx_b = x_frac_to_idx(x2)
xa, ya = float(x[idx_a]), float(y[idx_a])
xb, yb = float(x[idx_b]), float(y[idx_b])
if Cursors._broadcast_overlay_fn is not None:
Cursors._broadcast_overlay_fn(
Cursors._current_node_id,
{
"kind": "line_plot",
"section_title": "Cursors",
"line": y.tolist(),
"x_axis": x.tolist(),
"x1": x1,
"x2": x2,
"y1": float(y1),
"y2": float(y2),
"a_locked": locked,
"b_locked": locked,
},
)
table = MeasureTable([
{"quantity": "A x", "value": xa, "unit": x_unit},
{"quantity": "A y", "value": ya, "unit": y_unit},
{"quantity": "B x", "value": xb, "unit": x_unit},
{"quantity": "B y", "value": yb, "unit": y_unit},
{"quantity": "dx", "value": xb - xa, "unit": x_unit},
{"quantity": "dy", "value": yb - ya, "unit": y_unit},
])
return (table, ((x1, y1), (x2, y2)))
def _process_field(
self,
field: DataField,
x1: float,
y1: float,
x2: float,
y2: float,
locked: bool = False,
) -> tuple:
from scipy.ndimage import map_coordinates
x1 = float(np.clip(x1, 0.0, 1.0))
y1 = float(np.clip(y1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 1.0))
y2 = float(np.clip(y2, 0.0, 1.0))
px1 = x1 * max(field.xres - 1, 0)
py1 = y1 * max(field.yres - 1, 0)
px2 = x2 * max(field.xres - 1, 0)
py2 = y2 * max(field.yres - 1, 0)
z1 = float(map_coordinates(field.data, [[py1], [px1]], order=1, mode="nearest")[0])
z2 = float(map_coordinates(field.data, [[py2], [px2]], order=1, mode="nearest")[0])
ax = float(field.xoff + x1 * field.xreal)
ay = float(field.yoff + y1 * field.yreal)
bx = float(field.xoff + x2 * field.xreal)
by = float(field.yoff + y2 * field.yreal)
if Cursors._broadcast_overlay_fn is not None:
Cursors._broadcast_overlay_fn(
Cursors._current_node_id,
{
"kind": "cursor_points",
"section_title": "Cursors",
"image": encode_preview(render_datafield_preview(field, field.colormap)),
"x1": x1,
"y1": y1,
"x2": x2,
"y2": y2,
"a_locked": locked,
"b_locked": locked,
},
)
table = MeasureTable([
{"quantity": "A x", "value": ax, "unit": field.si_unit_xy},
{"quantity": "A y", "value": ay, "unit": field.si_unit_xy},
{"quantity": "A z", "value": z1, "unit": field.si_unit_z},
{"quantity": "B x", "value": bx, "unit": field.si_unit_xy},
{"quantity": "B y", "value": by, "unit": field.si_unit_xy},
{"quantity": "B z", "value": z2, "unit": field.si_unit_z},
{"quantity": "dx", "value": bx - ax, "unit": field.si_unit_xy},
{"quantity": "dy", "value": by - ay, "unit": field.si_unit_xy},
{"quantity": "dz", "value": z2 - z1, "unit": field.si_unit_z},
])
return (table, ((x1, y1), (x2, y2)))

View File

@@ -1,743 +0,0 @@
"""
Display / output nodes.
Preview accepts both DATA_FIELD and IMAGE via optional inputs —
connect whichever type you have. The server injects _broadcast_fn
before execution begins.
"""
from __future__ import annotations
import json
import numpy as np
from backend.node_registry import register_node
from backend.data_types import (
DataField,
MeasureTable,
COLORMAPS,
CUSTOM_FILE_FONT,
DEFAULT_CUSTOM_COLORMAP_STOPS,
SYSTEM_DEFAULT_FONT,
colormap_to_uint8,
datafield_to_uint8,
encode_preview,
image_to_uint8,
list_overlay_font_choices,
normalize_colormap_spec,
normalize_font_spec,
normalize_for_colormap,
render_datafield_preview,
resolve_colormap_input,
)
def _measurement_names(table: list) -> list[str]:
names = []
for row in table:
if not isinstance(row, dict):
continue
quantity = row.get("quantity")
if isinstance(quantity, str) and quantity and quantity not in names:
names.append(quantity)
return names
def _measurement_entry(table: list, selection: str) -> dict:
names = _measurement_names(table)
if not names:
raise ValueError("Measurement table has no selectable rows.")
target = selection if selection in names else names[0]
for row in table:
if isinstance(row, dict) and row.get("quantity") == target:
return row
raise ValueError(f"Measurement '{target}' was not found.")
def _measurement_value(table: list, selection: str) -> float:
row = _measurement_entry(table, selection)
value = row.get("value")
if isinstance(value, bool):
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
try:
numeric = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") from exc
if np.isfinite(numeric):
return numeric
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
def _scalar_payload(value: float, unit: str = "") -> dict:
payload = {"value": float(value)}
if isinstance(unit, str) and unit.strip():
payload["unit"] = unit.strip()
return payload
_SI_PREFIXES = [
(1e24, "Y"),
(1e21, "Z"),
(1e18, "E"),
(1e15, "P"),
(1e12, "T"),
(1e9, "G"),
(1e6, "M"),
(1e3, "k"),
(1.0, ""),
(1e-3, "m"),
(1e-6, "u"),
(1e-9, "n"),
(1e-12, "p"),
(1e-15, "f"),
(1e-18, "a"),
(1e-21, "z"),
(1e-24, "y"),
]
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"}
def _format_numeric(value: float) -> str:
if not np.isfinite(value):
return str(value)
abs_value = abs(value)
if abs_value == 0:
return "0"
if abs_value >= 1e4 or abs_value < 1e-3:
return f"{value:.3e}"
return f"{value:.4g}"
def _format_with_unit(value: float, unit: str) -> str:
unit = (unit or "").strip()
if not unit:
return _format_numeric(value)
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
abs_value = abs(value)
for scale, prefix in _SI_PREFIXES:
scaled = abs_value / scale
if 1 <= scaled < 1000:
signed = value / scale
return f"{_format_numeric(signed)} {prefix}{unit}"
return f"{_format_numeric(value)} {unit}"
def _nice_length(target: float) -> float:
if not np.isfinite(target) or target <= 0:
return 0.0
exponent = np.floor(np.log10(target))
base = 10.0 ** exponent
for step in (5.0, 2.0, 1.0):
candidate = step * base
if candidate <= target:
return candidate
return base
def _display_value_range(field: DataField) -> tuple[float, float, float]:
data = np.asarray(field.data, dtype=np.float64)
dmin = float(data.min())
dmax = float(data.max())
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
return dmin, dmin, dmin
offset = float(field.display_offset)
scale = float(field.display_scale)
if not np.isfinite(offset):
offset = 0.0
if not np.isfinite(scale) or scale <= 0.0:
scale = 1.0
low_norm = float(np.clip(offset, 0.0, 1.0))
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
if high_norm < low_norm:
low_norm, high_norm = high_norm, low_norm
mid_norm = 0.5 * (low_norm + high_norm)
span = dmax - dmin
return (
dmin + low_norm * span,
dmin + mid_norm * span,
dmin + high_norm * span,
)
def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]):
from PIL import Image, ImageDraw, ImageFont
size_px = max(8, int(round(size_px)))
try:
font = ImageFont.truetype("DejaVuSans.ttf", size_px)
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
text_draw = ImageDraw.Draw(text_image)
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
return text_image
except Exception:
font = ImageFont.load_default()
probe = Image.new("L", (1, 1), 0)
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
mask = Image.new("L", (width, height), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
scale = max(1.0, size_px / max(1, height))
scaled_width = max(1, int(round(width * scale)))
scaled_height = max(1, int(round(height * scale)))
resampling = getattr(Image, "Resampling", Image)
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR)
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
text_image.putalpha(scaled_mask)
return text_image
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
if isinstance(color, str):
text = color.strip()
if len(text) == 4 and text.startswith("#"):
text = "#" + "".join(ch * 2 for ch in text[1:])
if len(text) == 7 and text.startswith("#"):
try:
int(text[1:], 16)
return text.lower()
except ValueError:
pass
return default
def _parse_markup_shapes(raw_shapes: str | list | None) -> list[dict[str, object]]:
if isinstance(raw_shapes, str):
try:
raw_shapes = json.loads(raw_shapes or "[]")
except json.JSONDecodeError:
raw_shapes = []
if not isinstance(raw_shapes, list):
return []
parsed: list[dict[str, object]] = []
for shape in raw_shapes:
if not isinstance(shape, dict):
continue
kind = str(shape.get("kind", "")).strip().lower()
if kind not in {"line", "rectangle", "circle", "arrow"}:
continue
try:
x1 = float(shape.get("x1"))
y1 = float(shape.get("y1"))
x2 = float(shape.get("x2"))
y2 = float(shape.get("y2"))
width = int(round(float(shape.get("width", 3))))
except (TypeError, ValueError):
continue
coords = [x1, y1, x2, y2]
if not all(np.isfinite(value) for value in coords):
continue
parsed.append({
"kind": kind,
"x1": float(np.clip(x1, 0.0, 1.0)),
"y1": float(np.clip(y1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y2": float(np.clip(y2, 0.0, 1.0)),
"width": max(1, min(128, width)),
"color": _normalize_markup_color(shape.get("color")),
})
return parsed
def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int):
dx = end[0] - start[0]
dy = end[1] - start[1]
length = float(np.hypot(dx, dy))
if length <= 1e-6:
radius = max(1.0, width / 2.0)
draw.ellipse(
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
fill=color,
)
return
ux = dx / length
uy = dy / length
head_length = max(10.0, width * 4.0)
head_width = max(8.0, width * 3.0)
shaft_end = (
end[0] - ux * head_length,
end[1] - uy * head_length,
)
draw.line((start, shaft_end), fill=color, width=width)
px = -uy
py = ux
left = (
shaft_end[0] + px * head_width / 2.0,
shaft_end[1] + py * head_width / 2.0,
)
right = (
shaft_end[0] - px * head_width / 2.0,
shaft_end[1] - py * head_width / 2.0,
)
draw.polygon([end, left, right], fill=color)
def _render_markup_image(image: np.ndarray, shapes: list[dict[str, object]]) -> np.ndarray:
from PIL import Image, ImageDraw
base = image_to_uint8(image)
if base.ndim == 2:
base = np.repeat(base[:, :, np.newaxis], 3, axis=2)
canvas = Image.fromarray(base.copy())
draw = ImageDraw.Draw(canvas)
height, width = base.shape[:2]
for shape in shapes:
x1 = float(shape["x1"]) * width
y1 = float(shape["y1"]) * height
x2 = float(shape["x2"]) * width
y2 = float(shape["y2"]) * height
color = str(shape["color"])
stroke_width = int(shape["width"])
kind = str(shape["kind"])
if kind == "line":
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
elif kind == "rectangle":
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "circle":
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "arrow":
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
return np.asarray(canvas, dtype=np.uint8)
@register_node(display_name="Color Map")
class ColorMap:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["preset", "custom"], {"default": "preset"}),
"preset": (list(COLORMAPS), {
"default": "viridis",
"show_when_widget_value": {"mode": ["preset"]},
}),
"stops": ("STRING", {
"default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)),
"colormap_stops": True,
"show_when_widget_value": {"mode": ["custom"]},
}),
}
}
RETURN_TYPES = ("COLORMAP",)
RETURN_NAMES = ("colormap",)
FUNCTION = "build"
DESCRIPTION = (
"Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours "
"and any number of intermediate stops."
)
def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple:
if mode == "preset":
return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},)
try:
raw_stops = stops if stops is not None else stops_json
stops_data = json.loads(raw_stops or "[]")
except json.JSONDecodeError as exc:
raise ValueError("Custom colormap stops must be valid JSON.") from exc
spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None)
if not (isinstance(spec, dict) and spec.get("mode") == "custom"):
raise ValueError("Custom colormap must include at least min and max colours.")
return (spec,)
@register_node(display_name="Font")
class Font:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], {
"default": SYSTEM_DEFAULT_FONT,
}),
"font_file": ("FILE_PICKER", {
"default": "",
"show_when_widget_value": {"family": [CUSTOM_FILE_FONT]},
}),
}
}
RETURN_TYPES = ("FONT",)
RETURN_NAMES = ("font",)
FUNCTION = "build"
DESCRIPTION = (
"Build a reusable font spec for annotation overlays. Choose a discovered system font, "
"use the default fallback stack, or point to a custom font file."
)
def build(self, family: str, font_file: str = "") -> tuple:
if family == SYSTEM_DEFAULT_FONT:
return (None,)
if family == CUSTOM_FILE_FONT:
return (normalize_font_spec({"path": font_file}),)
return (normalize_font_spec({"family": family}),)
@register_node(display_name="Annotations")
class Annotations:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"show_scale_bar": ("BOOLEAN", {"default": True}),
"show_color_map": ("BOOLEAN", {"default": True}),
"text_size": ("FLOAT", {
"default": 14.0,
"min": 6.0,
"max": 96.0,
"step": 1.0,
}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"font": ("FONT",),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "render"
DESCRIPTION = (
"Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. "
"The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values."
)
def render(
self,
field: DataField,
colormap: str,
show_scale_bar: bool,
show_color_map: bool,
text_size: float = 1.0,
colormap_map=None,
font=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0
out = field.replace(
colormap=resolved_colormap,
overlays=[
*field.overlays,
{
"kind": "annotation",
"show_scale_bar": bool(show_scale_bar),
"show_color_map": bool(show_color_map),
"text_size": text_size,
"font": normalize_font_spec(font),
},
],
)
return (out,)
@register_node(display_name="Markup")
class Markup:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}),
"stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}),
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),
"clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}),
"markup_shapes": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "process"
DESCRIPTION = (
"Draw simple vector markup over a DATA_FIELD without flattening the underlying data. "
"Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
shape: str,
stroke_color: str,
stroke_width: int,
markup_shapes: str,
) -> tuple:
shapes = _parse_markup_shapes(markup_shapes)
out = field.replace(
overlays=[
*field.overlays,
{
"kind": "markup",
"shapes": shapes,
},
],
)
if Markup._broadcast_overlay_fn is not None:
Markup._broadcast_overlay_fn(
Markup._current_node_id,
{
"kind": "markup",
"section_title": "Markup",
"image": encode_preview(datafield_to_uint8(field, field.colormap)),
"shape": str(shape),
"stroke_color": _normalize_markup_color(stroke_color),
"stroke_width": max(1, int(stroke_width)),
},
)
return (out,)
@register_node(display_name="Preview")
class PreviewImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"image": ("IMAGE",),
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ()
FUNCTION = "preview"
OUTPUT_NODE = True
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
_broadcast_fn = None
_current_node_id: str = ""
def preview(
self,
colormap: str,
image: np.ndarray | None = None,
field=None,
colormap_map=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap if field is not None else None,
default="gray",
)
# Prefer field if both are connected; accept whichever is provided
if field is not None:
arr_u8 = render_datafield_preview(field, resolved_colormap)
elif image is not None:
arr_u8 = image_to_uint8(image)
if arr_u8.ndim == 2:
if image.dtype == np.uint8:
normalized = arr_u8.astype(np.float64) / 255.0
else:
imin, imax = image.min(), image.max()
if imax > imin:
normalized = (image - imin) / (imax - imin)
else:
normalized = np.zeros_like(image, dtype=np.float64)
arr_u8 = colormap_to_uint8(normalized, resolved_colormap)
else:
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
data_uri = encode_preview(arr_u8)
if PreviewImage._broadcast_fn is not None:
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
return ()
@register_node(display_name="3D View")
class View3D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ()
FUNCTION = "render"
OUTPUT_NODE = True
DESCRIPTION = (
"Interactive 3D surface view of a DATA_FIELD. "
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
)
_broadcast_mesh_fn = None
_current_node_id: str = ""
def render(
self, field: DataField,
colormap: str, z_scale: float, resolution: int, colormap_map=None,
) -> tuple:
import base64
data = field.data
yres, xres = data.shape
# Downsample if larger than resolution
step_y = max(1, yres // resolution)
step_x = max(1, xres // resolution)
z = data[::step_y, ::step_x].astype(np.float32)
ny, nx = z.shape
# Normalize for colormap
zmin, zmax = float(z.min()), float(z.max())
z_norm = normalize_for_colormap(
z,
offset=field.display_offset,
scale=field.display_scale,
data_min=float(field.data.min()),
data_max=float(field.data.max()),
)
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
colors_u8 = colormap_to_uint8(z_norm, resolved_colormap)
# Base64-encode arrays for efficient WS transport
z_b64 = base64.b64encode(z.tobytes()).decode()
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
mesh_data = {
"width": nx,
"height": ny,
"z_data": z_b64,
"colors": colors_b64,
"z_min": zmin,
"z_max": zmax,
"z_scale": float(z_scale * 0.1),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}
if View3D._broadcast_mesh_fn is not None:
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
return ()
@register_node(display_name="Print Table")
class PrintTable:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"table": ("ANY_TABLE",),
}
}
RETURN_TYPES = ()
FUNCTION = "print_table"
OUTPUT_NODE = True
DESCRIPTION = "Send a measurement or record table to the browser as a WebSocket message for display."
_broadcast_table_fn = None
_current_node_id: str = ""
def print_table(self, table: list) -> tuple:
if PrintTable._broadcast_table_fn is not None:
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
return ()
@register_node(display_name="Value Display")
class ValueDisplay:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("VALUE_SOURCE",),
"measurement": ("STRING", {
"default": "",
"choices_from_measure_input": "value",
"show_when_source_type": {
"value": ["MEASURE_TABLE"],
},
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "display_value"
DESCRIPTION = "Display a FLOAT, or a selected numeric row from a measurement table, and pass the value through unchanged."
_broadcast_value_fn = None
_current_node_id: str = ""
def display_value(self, value, measurement: str = "") -> tuple:
unit = ""
if isinstance(value, MeasureTable):
row = _measurement_entry(value, measurement)
numeric = _measurement_value(value, measurement)
unit = row.get("unit", "") if isinstance(row.get("unit"), str) else ""
else:
numeric = float(value)
if ValueDisplay._broadcast_value_fn is not None:
ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit))
return (numeric,)

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _parse_mask_strokes, _rasterize_mask
@register_node(display_name="Draw Mask")
class DrawMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}),
"invert": ("BOOLEAN", {"default": False}),
"clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}),
"mask_paths": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Paint a binary mask directly over an image preview. "
"Pen size controls newly drawn strokes, the overlay lets you clear the mask, "
"and invert flips the final binary output."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple:
strokes = _parse_mask_strokes(mask_paths)
mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size)
if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None:
DrawMask._broadcast_overlay_fn(
DrawMask._current_node_id,
{
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
},
)
return (mask,)

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Edge Detect")
class EdgeDetect:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["sobel", "prewitt", "laplacian", "log"],),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("edges",)
FUNCTION = "process"
DESCRIPTION = (
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
)
def process(self, field: DataField, method: str, sigma: float) -> tuple:
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
data = field.data
if method == "sobel":
sx = sobel(data, axis=1)
sy = sobel(data, axis=0)
result = np.hypot(sx, sy)
elif method == "prewitt":
px = prewitt(data, axis=1)
py = prewitt(data, axis=0)
result = np.hypot(px, py)
elif method == "laplacian":
result = laplace(data)
elif method == "log":
result = gaussian_laplace(data, sigma=float(sigma))
else:
raise ValueError(f"Unknown edge detection method: {method}")
return (field.replace(data=result),)

115
backend/nodes/fft_2d.py Normal file
View File

@@ -0,0 +1,115 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="2D FFT")
class FFT2D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"windowing": (["hann", "hamming", "blackman", "none"],),
"level": (["mean", "plane", "none"],),
}
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "DATA_FIELD", "DATA_FIELD")
RETURN_NAMES = ("log_magnitude", "magnitude", "phase", "psdf")
FUNCTION = "process"
DESCRIPTION = (
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
"Outputs log magnitude, magnitude, phase, and PSDF as separate channels. "
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
)
def process(self, field: DataField, windowing: str, level: str) -> tuple:
data = field.data.copy()
yres, xres = data.shape
if level == "mean":
data -= data.mean()
elif level == "plane":
yy, xx = np.mgrid[0:yres, 0:xres]
xx_f = xx.ravel().astype(np.float64)
yy_f = yy.ravel().astype(np.float64)
zz_f = data.ravel()
A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f])
coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None)
plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy)
data -= plane
if windowing != "none":
t_y = (np.arange(yres) + 0.5) / yres
t_x = (np.arange(xres) + 0.5) / xres
if windowing == "hann":
wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y)
wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x)
elif windowing == "hamming":
wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y)
wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x)
elif windowing == "blackman":
wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y)
wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x)
else:
wy = np.ones(yres)
wx = np.ones(xres)
data *= np.outer(wy, wx)
F = np.fft.fftshift(np.fft.fft2(data))
n = xres * yres
magnitude = np.abs(F)
log_magnitude = np.log1p(magnitude)
phase = np.angle(F)
dx = field.xreal / xres
dy = field.yreal / yres
psdf = (magnitude ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
spatial_freq_xreal = xres / field.xreal
spatial_freq_yreal = yres / field.yreal
angular_freq_xreal = 2.0 * np.pi * xres / field.xreal
angular_freq_yreal = 2.0 * np.pi * yres / field.yreal
return (
DataField(
data=log_magnitude,
xreal=spatial_freq_xreal,
yreal=spatial_freq_yreal,
si_unit_xy="1/m",
si_unit_z=field.si_unit_z,
domain="frequency",
colormap=field.colormap,
),
DataField(
data=magnitude,
xreal=spatial_freq_xreal,
yreal=spatial_freq_yreal,
si_unit_xy="1/m",
si_unit_z=field.si_unit_z,
domain="frequency",
colormap=field.colormap,
),
DataField(
data=phase,
xreal=spatial_freq_xreal,
yreal=spatial_freq_yreal,
si_unit_xy="1/m",
si_unit_z=field.si_unit_z,
domain="frequency",
colormap=field.colormap,
),
DataField(
data=psdf,
xreal=angular_freq_xreal,
yreal=angular_freq_yreal,
si_unit_xy="1/m",
si_unit_z=f"({field.si_unit_z})^2 m^2",
domain="frequency",
colormap=field.colormap,
),
)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import LineData
from backend.nodes.helpers import _cached_1d_transfer
@register_node(display_name="1D FFT Filter")
class FFTFilter1D:
"""Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
transfer function with configurable order for a smooth roll-off.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("LINE",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("LINE",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = (
"Frequency-domain filtering of a 1-D line profile. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
"Equivalent to Gwyddion fft_filter_1d."
)
def process(self, line, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
z = np.asarray(line, dtype=np.float64).ravel()
n = len(z)
Z = np.fft.rfft(z)
H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order))
Z *= H
filtered = np.fft.irfft(Z, n=n)
if isinstance(line, LineData):
return (
LineData(
data=filtered,
x_axis=line.x_axis.copy() if line.x_axis is not None else None,
x_unit=line.x_unit,
y_unit=line.y_unit,
),
)
return (filtered,)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
from backend.nodes.helpers import _cached_2d_transfer
@register_node(display_name="2D FFT Filter")
class FFTFilter2D:
"""Frequency-domain filtering of 2-D data fields (images).
Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
Butterworth transfer function in the frequency domain to remove or
isolate periodic features.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = (
"Frequency-domain filtering of a 2-D data field. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a radial Butterworth roll-off. Cutoffs are fractions of the "
"Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
"bandpass/notch to isolate or remove periodic noise. "
"Equivalent to Gwyddion fft_filter_2d."
)
def process(self, field: DataField, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
data = field.data
yres, xres = data.shape
mean_val = float(data.mean())
centered = data - mean_val
spectrum = np.fft.rfft2(centered)
transfer = _cached_2d_transfer(
yres, xres, filter_type,
float(cutoff), float(cutoff_high), int(order),
)
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
result += mean_val
return (field.replace(data=result),)

View File

@@ -1,332 +0,0 @@
"""
Filter nodes — Gwyddion-equivalent image filters.
Gwyddion equivalents:
GaussianFilter → gwy_data_field_filter_gaussian
MedianFilter → gwy_data_field_filter_median
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
FFTFilter1D → fft_filter_1d.c (bandpass/lowpass/highpass on LINE profiles)
FFTFilter2D → fft_filter_2d.c (frequency-domain filtering of DATA_FIELDs)
"""
from __future__ import annotations
from functools import lru_cache
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, LineData
# ---------------------------------------------------------------------------
# GaussianFilter
# ---------------------------------------------------------------------------
@register_node(display_name="Gaussian Filter")
class GaussianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
def process(self, field: DataField, sigma: float) -> tuple:
from scipy.ndimage import gaussian_filter
data = gaussian_filter(field.data, sigma=float(sigma))
return (field.replace(data=data),)
# ---------------------------------------------------------------------------
# MedianFilter
# ---------------------------------------------------------------------------
@register_node(display_name="Median Filter")
class MedianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
def process(self, field: DataField, size: int) -> tuple:
from scipy.ndimage import median_filter
size = max(1, int(size))
data = median_filter(field.data, size=size)
return (field.replace(data=data),)
# ---------------------------------------------------------------------------
# EdgeDetect
# ---------------------------------------------------------------------------
@register_node(display_name="Edge Detect")
class EdgeDetect:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["sobel", "prewitt", "laplacian", "log"],),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("edges",)
FUNCTION = "process"
DESCRIPTION = (
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
)
def process(self, field: DataField, method: str, sigma: float) -> tuple:
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
data = field.data
if method == "sobel":
sx = sobel(data, axis=1)
sy = sobel(data, axis=0)
result = np.hypot(sx, sy)
elif method == "prewitt":
px = prewitt(data, axis=1)
py = prewitt(data, axis=0)
result = np.hypot(px, py)
elif method == "laplacian":
result = laplace(data)
elif method == "log":
result = gaussian_laplace(data, sigma=float(sigma))
else:
raise ValueError(f"Unknown edge detection method: {method}")
return (field.replace(data=result),)
# ---------------------------------------------------------------------------
# Butterworth transfer function helpers
# ---------------------------------------------------------------------------
def _butterworth_lp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
"""Butterworth lowpass: H = 1 / (1 + (f/fc)^(2n))."""
with np.errstate(divide="ignore", over="ignore"):
return 1.0 / (1.0 + (freq / cutoff) ** (2 * order))
def _butterworth_hp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
"""Butterworth highpass: H = 1 / (1 + (fc/f)^(2n))."""
with np.errstate(divide="ignore", invalid="ignore"):
h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order))
h = np.where(np.isfinite(h), h, 0.0)
return h
def _build_1d_transfer(n: int, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> np.ndarray:
"""Build a 1-D transfer function for an FFT of length *n*.
Frequencies are normalised so that 1.0 = Nyquist (fs/2).
The returned array has the same layout as np.fft.rfft output (length n//2+1).
"""
freq = np.linspace(0, 1, n // 2 + 1)
if filter_type == "lowpass":
H = _butterworth_lp(freq, cutoff, order)
elif filter_type == "highpass":
H = _butterworth_hp(freq, cutoff, order)
elif filter_type == "bandpass":
H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
elif filter_type == "notch":
bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
H = 1.0 - bp
else:
H = np.ones_like(freq)
return H
@lru_cache(maxsize=64)
def _cached_1d_transfer(n: int, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> np.ndarray:
transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
transfer.setflags(write=False)
return transfer
@lru_cache(maxsize=32)
def _fft_radius_grid(yres: int, xres: int) -> np.ndarray:
fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0
fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0
radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0)
np.clip(radius, 0.0, 1.0, out=radius)
radius.setflags(write=False)
return radius
@lru_cache(maxsize=128)
def _cached_2d_transfer(yres: int, xres: int, filter_type: str,
cutoff: float, cutoff_high: float, order: int) -> np.ndarray:
radius = _fft_radius_grid(yres, xres)
if filter_type == "lowpass":
transfer = _butterworth_lp(radius, cutoff, order)
elif filter_type == "highpass":
transfer = _butterworth_hp(radius, cutoff, order)
elif filter_type == "bandpass":
transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
elif filter_type == "notch":
band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
transfer = 1.0 - band
else:
transfer = np.ones_like(radius)
transfer.setflags(write=False)
return transfer
# ---------------------------------------------------------------------------
# FFTFilter1D — frequency-domain filtering of LINE profiles
# ---------------------------------------------------------------------------
@register_node(display_name="1D FFT Filter")
class FFTFilter1D:
"""Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
transfer function with configurable order for a smooth roll-off.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("LINE",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("LINE",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = (
"Frequency-domain filtering of a 1-D line profile. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
"Equivalent to Gwyddion fft_filter_1d."
)
def process(self, line, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
z = np.asarray(line, dtype=np.float64).ravel()
n = len(z)
# Forward FFT (real-valued)
Z = np.fft.rfft(z)
# Build and apply transfer function
H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order))
Z *= H
# Inverse FFT
filtered = np.fft.irfft(Z, n=n)
if isinstance(line, LineData):
return (
LineData(
data=filtered,
x_axis=line.x_axis.copy() if line.x_axis is not None else None,
x_unit=line.x_unit,
y_unit=line.y_unit,
),
)
return (filtered,)
# ---------------------------------------------------------------------------
# FFTFilter2D — frequency-domain filtering of DATA_FIELDs
# ---------------------------------------------------------------------------
@register_node(display_name="2D FFT Filter")
class FFTFilter2D:
"""Frequency-domain filtering of 2-D data fields (images).
Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
Butterworth transfer function in the frequency domain to remove or
isolate periodic features.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = (
"Frequency-domain filtering of a 2-D data field. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a radial Butterworth roll-off. Cutoffs are fractions of the "
"Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
"bandpass/notch to isolate or remove periodic noise. "
"Equivalent to Gwyddion fft_filter_2d."
)
def process(self, field: DataField, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
data = field.data
yres, xres = data.shape
# Subtract mean to avoid DC leakage artefacts.
mean_val = float(data.mean())
centered = data - mean_val
# Real-valued FFT keeps only the unique half-plane and avoids shift copies.
spectrum = np.fft.rfft2(centered)
transfer = _cached_2d_transfer(
yres,
xres,
filter_type,
float(cutoff),
float(cutoff_high),
int(order),
)
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
# Restore DC
result += mean_val
return (field.replace(data=result),)

37
backend/nodes/fix_zero.py Normal file
View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Fix Zero")
class FixZero:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["min", "mean", "median"],),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("zeroed",)
FUNCTION = "process"
DESCRIPTION = (
"Shift data so that the minimum (or mean/median) is zero. "
"Equivalent to fix_zero in Gwyddion's level.c."
)
def process(self, field: DataField, method: str) -> tuple:
data = field.data.copy()
if method == "min":
data -= data.min()
elif method == "mean":
data -= data.mean()
elif method == "median":
data -= np.median(data)
else:
raise ValueError(f"Unknown method: {method}")
return (field.replace(data=data),)

29
backend/nodes/folder.py Normal file
View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.nodes.helpers import list_folder_paths
@register_node(display_name="Folder")
class Folder:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}),
}
}
RETURN_TYPES = ("DIRECTORY",)
RETURN_NAMES = ("directory",)
FUNCTION = "list_files"
DESCRIPTION = (
"Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. "
"Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans."
)
def list_files(self, folder: str) -> tuple:
entries = list_folder_paths(folder)
if not entries:
return tuple()
return tuple(item["path"] for item in entries)

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT, list_overlay_font_choices, normalize_font_spec
@register_node(display_name="Font")
class Font:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], {
"default": SYSTEM_DEFAULT_FONT,
}),
"font_file": ("FILE_PICKER", {
"default": "",
"show_when_widget_value": {"family": [CUSTOM_FILE_FONT]},
}),
}
}
RETURN_TYPES = ("FONT",)
RETURN_NAMES = ("font",)
FUNCTION = "build"
DESCRIPTION = (
"Build a reusable font spec for annotation overlays. Choose a discovered system font, "
"use the default fallback stack, or point to a custom font file."
)
def build(self, family: str, font_file: str = "") -> tuple:
if family == SYSTEM_DEFAULT_FONT:
return (None,)
if family == CUSTOM_FILE_FONT:
return (normalize_font_spec({"path": font_file}),)
return (normalize_font_spec({"family": family}),)

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Gaussian Filter")
class GaussianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
def process(self, field: DataField, sigma: float) -> tuple:
from scipy.ndimage import gaussian_filter
data = gaussian_filter(field.data, sigma=float(sigma))
return (field.replace(data=data),)

873
backend/nodes/helpers.py Normal file
View File

@@ -0,0 +1,873 @@
"""
Shared helper functions for argonode nodes.
"""
from __future__ import annotations
import json
from functools import lru_cache
from pathlib import Path
from typing import Callable
import numpy as np
from backend.runtime_paths import demo_dir, input_dir, output_dir
# ---------------------------------------------------------------------------
# Scalar payload helpers (from display.py)
# ---------------------------------------------------------------------------
def _scalar_payload(value: float, unit: str = "") -> dict:
payload = {"value": float(value)}
if isinstance(unit, str) and unit.strip():
payload["unit"] = unit.strip()
return payload
# ---------------------------------------------------------------------------
# Measurement helpers (from display.py — used by ValueDisplay)
# ---------------------------------------------------------------------------
def _measurement_names(table: list) -> list[str]:
names = []
for row in table:
if not isinstance(row, dict):
continue
quantity = row.get("quantity")
if isinstance(quantity, str) and quantity and quantity not in names:
names.append(quantity)
return names
def _measurement_entry(table: list, selection: str) -> dict:
names = _measurement_names(table)
if not names:
raise ValueError("Measurement table has no selectable rows.")
target = selection if selection in names else names[0]
for row in table:
if isinstance(row, dict) and row.get("quantity") == target:
return row
raise ValueError(f"Measurement '{target}' was not found.")
def _measurement_value(table: list, selection: str) -> float:
row = _measurement_entry(table, selection)
value = row.get("value")
if isinstance(value, bool):
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
try:
numeric = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") from exc
if np.isfinite(numeric):
return numeric
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
# ---------------------------------------------------------------------------
# SI formatting helpers (from display.py — used by Annotations)
# ---------------------------------------------------------------------------
_SI_PREFIXES = [
(1e24, "Y"), (1e21, "Z"), (1e18, "E"), (1e15, "P"), (1e12, "T"),
(1e9, "G"), (1e6, "M"), (1e3, "k"), (1.0, ""), (1e-3, "m"),
(1e-6, "u"), (1e-9, "n"), (1e-12, "p"), (1e-15, "f"),
(1e-18, "a"), (1e-21, "z"), (1e-24, "y"),
]
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "\u03a9"}
def _format_numeric(value: float) -> str:
if not np.isfinite(value):
return str(value)
abs_value = abs(value)
if abs_value == 0:
return "0"
if abs_value >= 1e4 or abs_value < 1e-3:
return f"{value:.3e}"
return f"{value:.4g}"
def _format_with_unit(value: float, unit: str) -> str:
unit = (unit or "").strip()
if not unit:
return _format_numeric(value)
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
abs_value = abs(value)
for scale, prefix in _SI_PREFIXES:
scaled = abs_value / scale
if 1 <= scaled < 1000:
signed = value / scale
return f"{_format_numeric(signed)} {prefix}{unit}"
return f"{_format_numeric(value)} {unit}"
def _nice_length(target: float) -> float:
if not np.isfinite(target) or target <= 0:
return 0.0
exponent = np.floor(np.log10(target))
base = 10.0 ** exponent
for step in (5.0, 2.0, 1.0):
candidate = step * base
if candidate <= target:
return candidate
return base
def _display_value_range(field) -> tuple[float, float, float]:
data = np.asarray(field.data, dtype=np.float64)
dmin = float(data.min())
dmax = float(data.max())
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
return dmin, dmin, dmin
offset = float(field.display_offset)
scale = float(field.display_scale)
if not np.isfinite(offset):
offset = 0.0
if not np.isfinite(scale) or scale <= 0.0:
scale = 1.0
low_norm = float(np.clip(offset, 0.0, 1.0))
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
if high_norm < low_norm:
low_norm, high_norm = high_norm, low_norm
mid_norm = 0.5 * (low_norm + high_norm)
span = dmax - dmin
return (
dmin + low_norm * span,
dmin + mid_norm * span,
dmin + high_norm * span,
)
def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]):
from PIL import Image, ImageDraw, ImageFont
size_px = max(8, int(round(size_px)))
try:
font = ImageFont.truetype("DejaVuSans.ttf", size_px)
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
text_draw = ImageDraw.Draw(text_image)
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
return text_image
except Exception:
font = ImageFont.load_default()
probe = Image.new("L", (1, 1), 0)
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
mask = Image.new("L", (width, height), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
scale = max(1.0, size_px / max(1, height))
scaled_width = max(1, int(round(width * scale)))
scaled_height = max(1, int(round(height * scale)))
resampling = getattr(Image, "Resampling", Image)
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR)
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
text_image.putalpha(scaled_mask)
return text_image
# ---------------------------------------------------------------------------
# Markup helpers (from display.py — used by Markup)
# ---------------------------------------------------------------------------
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
if isinstance(color, str):
text = color.strip()
if len(text) == 4 and text.startswith("#"):
text = "#" + "".join(ch * 2 for ch in text[1:])
if len(text) == 7 and text.startswith("#"):
try:
int(text[1:], 16)
return text.lower()
except ValueError:
pass
return default
def _parse_markup_shapes(raw_shapes) -> list[dict]:
if isinstance(raw_shapes, str):
try:
raw_shapes = json.loads(raw_shapes or "[]")
except json.JSONDecodeError:
raw_shapes = []
if not isinstance(raw_shapes, list):
return []
parsed = []
for shape in raw_shapes:
if not isinstance(shape, dict):
continue
kind = str(shape.get("kind", "")).strip().lower()
if kind not in {"line", "rectangle", "circle", "arrow"}:
continue
try:
x1 = float(shape.get("x1"))
y1 = float(shape.get("y1"))
x2 = float(shape.get("x2"))
y2 = float(shape.get("y2"))
width = int(round(float(shape.get("width", 3))))
except (TypeError, ValueError):
continue
coords = [x1, y1, x2, y2]
if not all(np.isfinite(value) for value in coords):
continue
parsed.append({
"kind": kind,
"x1": float(np.clip(x1, 0.0, 1.0)),
"y1": float(np.clip(y1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y2": float(np.clip(y2, 0.0, 1.0)),
"width": max(1, min(128, width)),
"color": _normalize_markup_color(shape.get("color")),
})
return parsed
def _draw_arrow(draw, start, end, color, width):
dx = end[0] - start[0]
dy = end[1] - start[1]
length = float(np.hypot(dx, dy))
if length <= 1e-6:
radius = max(1.0, width / 2.0)
draw.ellipse(
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
fill=color,
)
return
ux = dx / length
uy = dy / length
head_length = max(10.0, width * 4.0)
head_width = max(8.0, width * 3.0)
shaft_end = (
end[0] - ux * head_length,
end[1] - uy * head_length,
)
draw.line((start, shaft_end), fill=color, width=width)
px = -uy
py = ux
left = (
shaft_end[0] + px * head_width / 2.0,
shaft_end[1] + py * head_width / 2.0,
)
right = (
shaft_end[0] - px * head_width / 2.0,
shaft_end[1] - py * head_width / 2.0,
)
draw.polygon([end, left, right], fill=color)
def _render_markup_image(image, shapes):
from PIL import Image as PILImage, ImageDraw
from backend.data_types import image_to_uint8
base = image_to_uint8(image)
if base.ndim == 2:
base = np.repeat(base[:, :, np.newaxis], 3, axis=2)
canvas = PILImage.fromarray(base.copy())
draw = ImageDraw.Draw(canvas)
height, width = base.shape[:2]
for shape in shapes:
x1 = float(shape["x1"]) * width
y1 = float(shape["y1"]) * height
x2 = float(shape["x2"]) * width
y2 = float(shape["y2"]) * height
color = str(shape["color"])
stroke_width = int(shape["width"])
kind = str(shape["kind"])
if kind == "line":
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
elif kind == "rectangle":
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "circle":
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "arrow":
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
return np.asarray(canvas, dtype=np.uint8)
# ---------------------------------------------------------------------------
# Mask helpers (from mask.py — used by multiple mask nodes)
# ---------------------------------------------------------------------------
def _mask_overlay(field, mask):
from backend.data_types import datafield_to_uint8
grey = datafield_to_uint8(field, "gray")
mask_bool = mask > 127
if not np.any(mask_bool):
return grey
overlay = grey.copy()
red = overlay[..., 0]
green = overlay[..., 1]
blue = overlay[..., 2]
red_vals = red[mask_bool].astype(np.uint16)
green_vals = green[mask_bool].astype(np.uint16)
blue_vals = blue[mask_bool].astype(np.uint16)
red[mask_bool] = ((red_vals * 55) + (255 * 45) + 50) // 100
green[mask_bool] = ((green_vals * 55) + 50) // 100
blue[mask_bool] = ((blue_vals * 55) + 50) // 100
return overlay
@lru_cache(maxsize=128)
def _mask_structure(radius: int, shape: str):
radius = max(1, int(radius))
if shape == "disk":
y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1]
struct = (x * x + y * y) <= radius * radius
else:
size = 2 * radius + 1
struct = np.ones((size, size), dtype=bool)
struct.setflags(write=False)
return struct
def _clamp_fraction(value) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, numeric))
def _parse_mask_strokes(mask_paths) -> list[dict]:
if isinstance(mask_paths, list):
raw_strokes = mask_paths
elif isinstance(mask_paths, str) and mask_paths.strip():
try:
parsed = json.loads(mask_paths)
except json.JSONDecodeError:
return []
raw_strokes = parsed if isinstance(parsed, list) else []
else:
return []
strokes = []
for stroke in raw_strokes:
if not isinstance(stroke, dict):
continue
raw_points = stroke.get("points")
if not isinstance(raw_points, list):
continue
points = []
for point in raw_points:
if not isinstance(point, dict):
continue
if "x" not in point or "y" not in point:
continue
points.append({
"x": _clamp_fraction(point.get("x")),
"y": _clamp_fraction(point.get("y")),
})
if not points:
continue
try:
size = max(1, int(round(float(stroke.get("size", 1)))))
except (TypeError, ValueError):
size = 1
strokes.append({
"size": size,
"points": points,
})
return strokes
def _rasterize_mask(width, height, strokes, default_pen_size):
from PIL import Image as PILImage, ImageDraw
width = max(1, int(width))
height = max(1, int(height))
default_pen_size = max(1, int(default_pen_size))
mask_image = PILImage.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
for stroke in strokes:
points = stroke.get("points") or []
if not points:
continue
size = stroke.get("size", default_pen_size)
try:
size = max(1, int(round(float(size))))
except (TypeError, ValueError):
size = default_pen_size
pixel_points = []
for point in points:
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
pixel_points.append((px, py))
radius = max(0.5, size / 2.0)
if len(pixel_points) == 1:
x, y = pixel_points[0]
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
continue
draw.line(pixel_points, fill=255, width=size)
for x, y in pixel_points:
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
return np.asarray(mask_image, dtype=np.uint8)
# ---------------------------------------------------------------------------
# Path / directory helpers (from io.py)
# ---------------------------------------------------------------------------
DEMO_DIR = demo_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
_MAX_SAVE_FIELDS = 8
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS
def _resolve_path(filepath: str):
path = Path(filepath)
if path.is_absolute():
return path
candidate = INPUT_DIR / filepath
if candidate.exists():
return candidate
candidate = DEMO_DIR / filepath
if candidate.exists():
return candidate
return INPUT_DIR / filepath
def list_channels(filepath: str) -> list[dict]:
path = _resolve_path(filepath)
if not path.exists():
return [{"name": "field", "type": "DATA_FIELD"}]
ext = path.suffix.lower()
if ext == ".gwy":
try:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".sxm":
try:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
if sxm.signals:
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".ibw":
try:
from igor.binarywave import load as load_ibw
wave = load_ibw(str(path))
raw = wave["wave"]["wData"]
labels = wave["wave"].get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
if raw.ndim >= 3 and raw.shape[2] > 1:
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
return [{"name": "field", "type": "DATA_FIELD"}]
def list_folder_paths(folderpath: str) -> list[dict]:
path = _resolve_path(folderpath)
if not path.exists() or not path.is_dir():
return []
resolved_dir = str(path.resolve())
results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}]
for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()):
if not entry.is_file() or entry.name.startswith("."):
continue
if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS:
continue
results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())})
return results
def _list_demo_files() -> list[str]:
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
)
# ---------------------------------------------------------------------------
# Butterworth / FFT helpers (from filters.py — used by FFTFilter1D, FFTFilter2D)
# ---------------------------------------------------------------------------
def _butterworth_lp(freq, cutoff, order):
with np.errstate(divide="ignore", over="ignore"):
return 1.0 / (1.0 + (freq / cutoff) ** (2 * order))
def _butterworth_hp(freq, cutoff, order):
with np.errstate(divide="ignore", invalid="ignore"):
h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order))
h = np.where(np.isfinite(h), h, 0.0)
return h
def _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order):
freq = np.linspace(0, 1, n // 2 + 1)
if filter_type == "lowpass":
H = _butterworth_lp(freq, cutoff, order)
elif filter_type == "highpass":
H = _butterworth_hp(freq, cutoff, order)
elif filter_type == "bandpass":
H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
elif filter_type == "notch":
bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
H = 1.0 - bp
else:
H = np.ones_like(freq)
return H
@lru_cache(maxsize=64)
def _cached_1d_transfer(n, filter_type, cutoff, cutoff_high, order):
transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
transfer.setflags(write=False)
return transfer
@lru_cache(maxsize=32)
def _fft_radius_grid(yres, xres):
fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0
fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0
radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0)
np.clip(radius, 0.0, 1.0, out=radius)
radius.setflags(write=False)
return radius
@lru_cache(maxsize=128)
def _cached_2d_transfer(yres, xres, filter_type, cutoff, cutoff_high, order):
radius = _fft_radius_grid(yres, xres)
if filter_type == "lowpass":
transfer = _butterworth_lp(radius, cutoff, order)
elif filter_type == "highpass":
transfer = _butterworth_hp(radius, cutoff, order)
elif filter_type == "bandpass":
transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
elif filter_type == "notch":
band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
transfer = 1.0 - band
else:
transfer = np.ones_like(radius)
transfer.setflags(write=False)
return transfer
# ---------------------------------------------------------------------------
# Cross-section and stats helpers (from analysis.py)
# ---------------------------------------------------------------------------
def _extend_to_edges(x1, y1, x2, y2):
dx = x2 - x1
dy = y2 - y1
t_candidates = []
if abs(dx) > 1e-12:
for bx in (0.0, 1.0):
t = (bx - x1) / dx
y_at_t = y1 + t * dy
if -1e-9 <= y_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if abs(dy) > 1e-12:
for by in (0.0, 1.0):
t = (by - y1) / dy
x_at_t = x1 + t * dx
if -1e-9 <= x_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if len(t_candidates) < 2:
return x1, y1, x2, y2
t_min = min(t_candidates)
t_max = max(t_candidates)
return (
np.clip(x1 + t_min * dx, 0, 1),
np.clip(y1 + t_min * dy, 0, 1),
np.clip(x1 + t_max * dx, 0, 1),
np.clip(y1 + t_max * dy, 0, 1),
)
def _safe_rq(d):
return float(np.sqrt(np.mean(d * d)))
LINE_OPS: dict[str, tuple] = {}
def _line_op(name, unit=""):
def decorator(fn):
LINE_OPS[name] = (fn, unit)
return fn
return decorator
@_line_op("min")
def _op_min(z):
return float(z.min())
@_line_op("max")
def _op_max(z):
return float(z.max())
@_line_op("mean")
def _op_mean(z):
return float(z.mean())
@_line_op("median")
def _op_median(z):
return float(np.median(z))
@_line_op("sum")
def _op_sum(z):
return float(z.sum())
@_line_op("range")
def _op_range(z):
return float(z.max() - z.min())
@_line_op("length", unit="pts")
def _op_length(z):
return float(len(z))
@_line_op("rms")
def _op_rms(z):
return float(np.sqrt(np.mean(z * z)))
@_line_op("Ra")
def _op_ra(z):
return float(np.mean(np.abs(z - z.mean())))
@_line_op("Rq")
def _op_rq(z):
d = z - z.mean()
return _safe_rq(d)
@_line_op("Rsk")
def _op_rsk(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0
@_line_op("Rku")
def _op_rku(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0
@_line_op("Rp")
def _op_rp(z):
return float((z - z.mean()).max())
@_line_op("Rv")
def _op_rv(z):
return float(-(z - z.mean()).min())
@_line_op("Rt")
def _op_rt(z):
d = z - z.mean()
return float(d.max() - d.min())
@_line_op("Dq")
def _op_dq(z):
dz = np.diff(z)
return float(np.sqrt(np.mean(dz * dz)))
@_line_op("Da")
def _op_da(z):
return float(np.mean(np.abs(np.diff(z))))
TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = {
"min": lambda values: float(np.min(values)),
"max": lambda values: float(np.max(values)),
"avg": lambda values: float(np.mean(values)),
"mean": lambda values: float(np.mean(values)),
"median": lambda values: float(np.median(values)),
"sum": lambda values: float(np.sum(values)),
"range": lambda values: float(np.max(values) - np.min(values)),
"std": lambda values: float(np.std(values)),
"variance": lambda values: float(np.var(values)),
"count": lambda values: float(len(values)),
}
ARRAY_OPS: dict[str, Callable[[np.ndarray], float]] = {
"min": lambda values: float(np.min(values)),
"max": lambda values: float(np.max(values)),
"avg": lambda values: float(np.mean(values)),
"mean": lambda values: float(np.mean(values)),
"median": lambda values: float(np.median(values)),
"sum": lambda values: float(np.sum(values)),
"range": lambda values: float(np.max(values) - np.min(values)),
"std": lambda values: float(np.std(values)),
"variance": lambda values: float(np.var(values)),
"rms": lambda values: float(np.sqrt(np.mean(values * values))),
"count": lambda values: float(values.size),
}
def _square_unit(unit: str) -> str:
unit = str(unit or "").strip()
if not unit:
return ""
if any(token in unit for token in ("^", "(", ")", "/", "*", " ")):
return f"({unit})^2"
return f"{unit}^2"
def _apply_scalar_unit(base_unit: str, operation: str) -> str:
unit = str(base_unit or "").strip()
if operation == "count":
return "count"
if not unit:
return ""
if operation == "variance":
return _square_unit(unit)
return unit
def _common_table_unit(table: list, column: str) -> str:
candidates = []
seen = set()
unit_key = f"{column}_unit"
for row in table:
if not isinstance(row, dict):
continue
unit = None
if unit_key in row and isinstance(row.get(unit_key), str):
unit = row.get(unit_key)
elif column == "value" and isinstance(row.get("unit"), str):
unit = row.get("unit")
if unit is None:
continue
unit = unit.strip()
if not unit or unit in seen:
continue
seen.add(unit)
candidates.append(unit)
if len(candidates) == 1:
return candidates[0]
return ""
def extract_numeric_table_values(table: list, column: str) -> list[float]:
values = []
for row in table:
if not isinstance(row, dict) or column not in row:
continue
value = row[column]
if isinstance(value, bool):
continue
try:
numeric = float(value)
except (TypeError, ValueError):
continue
if np.isfinite(numeric):
values.append(numeric)
return values
def resolve_table_column_name(table: list, column: str) -> str:
requested = str(column or "").strip()
if requested:
return requested
if extract_numeric_table_values(table, "value"):
return "value"
numeric_columns = []
seen = set()
for row in table:
if not isinstance(row, dict):
continue
for key in row.keys():
if key in seen:
continue
seen.add(key)
if extract_numeric_table_values(table, key):
numeric_columns.append(key)
if len(numeric_columns) == 1:
return numeric_columns[0]
if not numeric_columns:
raise ValueError("Stats could not find any numeric columns in the input table.")
raise ValueError(
"Stats found multiple numeric columns; set the column name explicitly."
)

100
backend/nodes/histogram.py Normal file
View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, MeasureTable
@register_node(display_name="Histogram")
class Histogram:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
"y_scale": (["linear", "log"],),
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
}
}
RETURN_TYPES = ("MEASURE_TABLE", "COORDPAIR",)
RETURN_NAMES = ("measurements", "marker pair",)
FUNCTION = "process"
DESCRIPTION = (
"Compute the height distribution histogram (DH). "
"Use log scale to reveal small peaks next to a dominant background. "
"Outputs marker measurements while showing the histogram interactively in-node. "
"Equivalent to gwy_data_field_dh."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
n_bins: int,
y_scale: str = "linear",
x1: float = 0.25,
y1: float = 0.5,
x2: float = 0.75,
y2: float = 0.5,
) -> tuple:
raw_counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
counts = raw_counts.astype(np.float64)
if y_scale == "log":
counts = np.log10(1.0 + counts)
x1 = float(np.clip(x1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 0.0 + 1.0))
xmin = float(np.min(bin_centers)) if len(bin_centers) else 0.0
xmax = float(np.max(bin_centers)) if len(bin_centers) else 1.0
def x_frac_to_idx(frac):
if len(bin_centers) <= 1:
return 0
if xmax == xmin:
return 0
target_x = xmin + frac * (xmax - xmin)
return int(np.argmin(np.abs(bin_centers - target_x)))
idx_a = x_frac_to_idx(x1)
idx_b = x_frac_to_idx(x2)
xa = float(bin_centers[idx_a]) if len(bin_centers) else 0.0
xb = float(bin_centers[idx_b]) if len(bin_centers) else 0.0
ya = float(counts[idx_a]) if len(counts) else 0.0
yb = float(counts[idx_b]) if len(counts) else 0.0
count_unit = "count" if y_scale == "linear" else "log10(1+count)"
if Histogram._broadcast_overlay_fn is not None:
Histogram._broadcast_overlay_fn(
Histogram._current_node_id,
{
"kind": "line_plot",
"section_title": "Histogram",
"line": counts.tolist(),
"x_axis": bin_centers.astype(np.float64).tolist(),
"x1": float(np.clip(x1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y1": float(y1),
"y2": float(y2),
"a_locked": False,
"b_locked": False,
},
)
table = MeasureTable([
{"quantity": "A position", "value": xa, "unit": field.si_unit_z},
{"quantity": "A count", "value": ya, "unit": count_unit},
{"quantity": "B position", "value": xb, "unit": field.si_unit_z},
{"quantity": "B count", "value": yb, "unit": count_unit},
{"quantity": "delta X", "value": xb - xa, "unit": field.si_unit_z},
{"quantity": "delta Y", "value": yb - ya, "unit": count_unit},
])
return (table, ((x1, y1), (x2, y2)))

215
backend/nodes/image.py Normal file
View File

@@ -0,0 +1,215 @@
from __future__ import annotations
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import COLORMAPS, DataField, resolve_colormap_input
from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS
@register_node(display_name="Image")
class Image:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"path": ("FILE_PATH", {"label": "path"}),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
DESCRIPTION = (
"Load any supported file. "
"SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; "
"each channel gets its own output. "
"Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields."
)
_broadcast_warning_fn = None
_current_node_id = None
def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None):
selected_path = str(path).strip() if path is not None else str(filename).strip()
if not selected_path:
raise ValueError("No file selected — use Browse to pick a file.")
path_obj = _resolve_path(selected_path)
if not path_obj.exists():
raise FileNotFoundError(f"File not found: {path_obj}")
if path_obj.is_dir():
raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}")
ext = path_obj.suffix.lower()
resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis")
if ext in _SPM_EXTENSIONS:
fields = self._load_spm_all(path_obj, ext)
for f in fields:
f.colormap = resolved_colormap
return tuple(fields)
field = self._load_image_or_array(path_obj, ext)
field.colormap = resolved_colormap
self._send_warning("Uncalibrated data — no physical dimensions.")
return (field,)
def _send_warning(self, message: str):
fn = Image._broadcast_warning_fn
nid = Image._current_node_id
if fn and nid:
fn(nid, message)
def _load_spm_all(self, path: Path, ext: str) -> list[DataField]:
if ext == ".gwy":
return self._load_gwy_all(path)
elif ext == ".sxm":
return self._load_sxm_all(path)
elif ext == ".ibw":
return self._load_ibw_all(path)
else:
raise ValueError(f"Unsupported SPM format: {ext}")
def _load_gwy_all(self, path: Path) -> list[DataField]:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
fields = []
for ch in channels.values():
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
fields.append(DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
))
return fields
def _load_sxm_all(self, path: Path) -> list[DataField]:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
if not signals:
raise ValueError(f"No signals found in {path.name}")
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
fields = []
for sig in signals.values():
data = sig.get("forward", list(sig.values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
fields.append(DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
))
return fields
def _load_ibw_all(self, path: Path) -> list[DataField]:
try:
from igor.binarywave import load as load_ibw
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
wave = load_ibw(str(path))
wdata = wave["wave"]
header = wdata["wave_header"]
raw = wdata["wData"]
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
sfA = header.get("sfA", None)
def _decode_unit(raw_unit):
if raw_unit is None:
return "m"
if isinstance(raw_unit, bytes):
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
if isinstance(raw_unit, np.ndarray):
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
return str(raw_unit).strip() or "m"
dim_units_raw = header.get("dimUnits", None)
data_units_raw = header.get("dataUnits", None)
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
si_unit_xy = _decode_unit(dim_units_raw[0])
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
si_unit_xy = _decode_unit(dim_units_raw[0])
else:
si_unit_xy = _decode_unit(dim_units_raw)
si_unit_z = _decode_unit(data_units_raw)
fields = []
for ch_idx in range(n_channels):
if raw.ndim >= 3:
ch_data = raw[:, :, ch_idx]
elif raw.ndim == 1:
ch_data = raw.reshape(-1, 1)
else:
ch_data = raw
data = np.flipud(ch_data.T).astype(np.float64)
yres, xres = data.shape
if sfA is not None and len(sfA) >= 2:
xreal = abs(float(sfA[0]) * xres) or 1e-6
yreal = abs(float(sfA[1]) * yres) or 1e-6
else:
hsA = header.get("hsA", 0.0)
xreal = abs(float(hsA) * xres) or 1e-6
yreal = xreal * (yres / xres) if xres else 1e-6
fields.append(DataField(
data=data, xreal=xreal, yreal=yreal,
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
))
return fields
def _load_image_or_array(self, path: Path, ext: str) -> DataField:
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image as PILImage
img = PILImage.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
return DataField(data=gray)

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import COLORMAPS
from backend.nodes.helpers import DEMO_DIR, _list_demo_files
@register_node(display_name="Image (Demo)")
class ImageDemo:
@classmethod
def INPUT_TYPES(cls):
choices = _list_demo_files() or ["(no demo files found)"]
return {
"required": {
"name": (choices,),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
_broadcast_warning_fn = None
_current_node_id = None
def load(self, name: str = "", colormap: str = "viridis", colormap_map=None):
from backend.nodes.image import Image
loader = Image()
demo_path = DEMO_DIR / name
if not demo_path.exists():
raise FileNotFoundError(f"Demo file not found: {name}")
return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Inverse 2D FFT")
class InverseFFT2D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"spectrum": ("DATA_FIELD",),
"representation": (["magnitude", "log_magnitude", "psdf"],),
},
"optional": {
"phase": ("DATA_FIELD",),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
DESCRIPTION = (
"Reconstruct a spatial-domain image from a 2D frequency spectrum. "
"For exact reconstruction, connect magnitude/phase (or log magnitude/phase, "
"or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed."
)
def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple:
if spectrum.domain != "frequency":
raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.")
if phase is not None:
if phase.data.shape != spectrum.data.shape:
raise ValueError("Phase input must have the same shape as the spectrum.")
if phase.domain != "frequency":
raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.")
amplitude = self._resolve_amplitude(spectrum, representation)
phase_data = phase.data if phase is not None else np.zeros_like(amplitude)
F = amplitude * np.exp(1j * phase_data)
spatial = np.fft.ifft2(np.fft.ifftshift(F)).real
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
z_unit = self._recover_z_unit(spectrum, representation, phase)
out_field = DataField(
data=spatial,
xreal=xreal,
yreal=yreal,
si_unit_xy="m",
si_unit_z=z_unit,
domain="spatial",
colormap=spectrum.colormap,
)
return (out_field,)
def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray:
data = np.asarray(spectrum.data, dtype=np.float64)
if representation == "magnitude":
return np.clip(data, 0.0, None)
if representation == "log_magnitude":
return np.expm1(data)
if representation == "psdf":
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
n = spectrum.xres * spectrum.yres
dx = xreal / spectrum.xres
dy = yreal / spectrum.yres
scale = n * 4.0 * np.pi ** 2 / (dx * dy)
return np.sqrt(np.clip(data, 0.0, None) * scale)
raise ValueError(f"Unsupported spectrum representation: {representation}")
def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]:
if representation == "psdf":
xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal
yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal
else:
xreal = spectrum.xres / spectrum.xreal
yreal = spectrum.yres / spectrum.yreal
return float(xreal), float(yreal)
def _recover_z_unit(
self,
spectrum: DataField,
representation: str,
phase: DataField | None,
) -> str:
if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip():
return phase.si_unit_z
if representation != "psdf":
return spectrum.si_unit_z
unit = str(spectrum.si_unit_z or "").strip()
if unit.startswith("(") and ")^2 m^2" in unit:
return unit.split(")^2 m^2", 1)[0][1:]
if unit.endswith("^2 m^2"):
return unit[:-6].removesuffix("^2").strip()
return ""

View File

@@ -1,721 +0,0 @@
"""
I/O nodes: load and save images and SPM data.
"""
from __future__ import annotations
import os
import re
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import COLORMAPS, DataField, encode_preview, image_to_uint8, resolve_colormap_input
from backend.runtime_paths import demo_dir, input_dir, output_dir
# Resolved at server startup so nodes know where to look
DEMO_DIR = demo_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS
# ---------------------------------------------------------------------------
# Channel listing helper (used by the /channels endpoint)
# ---------------------------------------------------------------------------
def _resolve_path(filepath: str) -> Path:
path = Path(filepath)
if path.is_absolute():
return path
# Try input dir first, then demo dir
candidate = INPUT_DIR / filepath
if candidate.exists():
return candidate
candidate = DEMO_DIR / filepath
if candidate.exists():
return candidate
# Fall back to input dir (will trigger FileNotFoundError later)
return INPUT_DIR / filepath
def list_channels(filepath: str) -> list[dict]:
"""Return available channel info for a file.
Returns a list of {"name": str, "type": "DATA_FIELD"} dicts.
For SPM formats this inspects the file header.
For images / arrays, returns a single unnamed channel.
"""
path = _resolve_path(filepath)
if not path.exists():
return [{"name": "field", "type": "DATA_FIELD"}]
ext = path.suffix.lower()
if ext == ".gwy":
try:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".sxm":
try:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
if sxm.signals:
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".ibw":
try:
from igor.binarywave import load as load_ibw
wave = load_ibw(str(path))
raw = wave["wave"]["wData"]
labels = wave["wave"].get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
# Multi-channel without labels — use numeric names
if raw.ndim >= 3 and raw.shape[2] > 1:
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
# Image or array — single channel
return [{"name": "field", "type": "DATA_FIELD"}]
def list_folder_paths(folderpath: str) -> list[dict]:
"""Return a folder DIRECTORY plus compatible image/array/SPM FILE_PATH outputs."""
path = _resolve_path(folderpath)
if not path.exists() or not path.is_dir():
return []
resolved_dir = str(path.resolve())
results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}]
for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()):
if not entry.is_file() or entry.name.startswith("."):
continue
if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS:
continue
results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())})
return results
# ---------------------------------------------------------------------------
# Image (unified loader — replaces LoadImage + LoadSPM)
# ---------------------------------------------------------------------------
@register_node(display_name="Image")
class Image:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"path": ("FILE_PATH", {"label": "path"}),
},
}
# Default outputs — overridden dynamically by the frontend for multi-channel files
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
DESCRIPTION = (
"Load any supported file. "
"SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; "
"each channel gets its own output. "
"Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields."
)
# Set by execution engine for warning broadcast
_broadcast_warning_fn = None
_current_node_id = None
def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None):
selected_path = str(path).strip() if path is not None else str(filename).strip()
if not selected_path:
raise ValueError("No file selected — use Browse to pick a file.")
path_obj = _resolve_path(selected_path)
if not path_obj.exists():
raise FileNotFoundError(f"File not found: {path_obj}")
if path_obj.is_dir():
raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}")
ext = path_obj.suffix.lower()
resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis")
if ext in _SPM_EXTENSIONS:
fields = self._load_spm_all(path_obj, ext)
for f in fields:
f.colormap = resolved_colormap
return tuple(fields)
# Image or array — uncalibrated, single output
field = self._load_image_or_array(path_obj, ext)
field.colormap = resolved_colormap
self._send_warning("Uncalibrated data — no physical dimensions.")
return (field,)
def _send_warning(self, message: str):
fn = Image._broadcast_warning_fn
nid = Image._current_node_id
if fn and nid:
fn(nid, message)
# -- SPM: load all channels ---------------------------------------------
def _load_spm_all(self, path: Path, ext: str) -> list[DataField]:
if ext == ".gwy":
return self._load_gwy_all(path)
elif ext == ".sxm":
return self._load_sxm_all(path)
elif ext == ".ibw":
return self._load_ibw_all(path)
else:
raise ValueError(f"Unsupported SPM format: {ext}")
# -- GWY ----------------------------------------------------------------
def _load_gwy_all(self, path: Path) -> list[DataField]:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
fields = []
for ch in channels.values():
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
fields.append(DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
))
return fields
# -- SXM ----------------------------------------------------------------
def _load_sxm_all(self, path: Path) -> list[DataField]:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
if not signals:
raise ValueError(f"No signals found in {path.name}")
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
fields = []
for sig in signals.values():
data = sig.get("forward", list(sig.values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
fields.append(DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
))
return fields
# -- IBW ----------------------------------------------------------------
def _load_ibw_all(self, path: Path) -> list[DataField]:
try:
from igor.binarywave import load as load_ibw
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
wave = load_ibw(str(path))
wdata = wave["wave"]
header = wdata["wave_header"]
raw = wdata["wData"]
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
# Physical scaling
sfA = header.get("sfA", None)
def _decode_unit(raw_unit):
if raw_unit is None:
return "m"
if isinstance(raw_unit, bytes):
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
if isinstance(raw_unit, np.ndarray):
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
return str(raw_unit).strip() or "m"
dim_units_raw = header.get("dimUnits", None)
data_units_raw = header.get("dataUnits", None)
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
si_unit_xy = _decode_unit(dim_units_raw[0])
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
si_unit_xy = _decode_unit(dim_units_raw[0])
else:
si_unit_xy = _decode_unit(dim_units_raw)
si_unit_z = _decode_unit(data_units_raw)
fields = []
for ch_idx in range(n_channels):
if raw.ndim >= 3:
ch_data = raw[:, :, ch_idx]
elif raw.ndim == 1:
ch_data = raw.reshape(-1, 1)
else:
ch_data = raw
# Transpose from (xres, yres) Igor order to (yres, xres) DataField order,
# then flip vertically to match gwyddion
data = np.flipud(ch_data.T).astype(np.float64)
yres, xres = data.shape
if sfA is not None and len(sfA) >= 2:
xreal = abs(float(sfA[0]) * xres) or 1e-6
yreal = abs(float(sfA[1]) * yres) or 1e-6
else:
hsA = header.get("hsA", 0.0)
xreal = abs(float(hsA) * xres) or 1e-6
yreal = xreal * (yres / xres) if xres else 1e-6
fields.append(DataField(
data=data, xreal=xreal, yreal=yreal,
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
))
return fields
# -- Image / array (uncalibrated) --------------------------------------
def _load_image_or_array(self, path: Path, ext: str) -> DataField:
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image
img = Image.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
return DataField(data=gray)
# ---------------------------------------------------------------------------
# ImageDemo
# ---------------------------------------------------------------------------
def _list_demo_files() -> list[str]:
"""Return sorted list of demo filenames available in the demo/ directory."""
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
)
@register_node(display_name="Image (Demo)")
class ImageDemo:
@classmethod
def INPUT_TYPES(cls):
choices = _list_demo_files() or ["(no demo files found)"]
return {
"required": {
"name": (choices,),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
def load(self, name: str = "", colormap: str = "viridis", colormap_map=None):
loader = Image()
demo_path = DEMO_DIR / name
if not demo_path.exists():
raise FileNotFoundError(f"Demo file not found: {name}")
return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map)
@register_node(display_name="Folder")
class Folder:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}),
}
}
RETURN_TYPES = ("DIRECTORY",)
RETURN_NAMES = ("directory",)
FUNCTION = "list_files"
DESCRIPTION = (
"Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. "
"Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans."
)
def list_files(self, folder: str) -> tuple:
entries = list_folder_paths(folder)
if not entries:
return tuple()
return tuple(item["path"] for item in entries)
# ---------------------------------------------------------------------------
# Coordinate
# ---------------------------------------------------------------------------
@register_node(display_name="Coordinate")
class Coordinate:
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("COORD",)
RETURN_NAMES = ("point",)
FUNCTION = "process"
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
def process(self, x: float, y: float) -> tuple:
return ((float(x), float(y)),)
@register_node(display_name="Coordinate Pair")
class CoordinatePair:
"""Provide a pair of Coordinates, for drawing lines between markers, etc."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("COORD",),
"b": ("COORD",),
}
}
RETURN_TYPES = ("COORDPAIR",)
RETURN_NAMES = ("coord pair",)
FUNCTION = "process"
DESCRIPTION = "Output a pair of coordinates."
def process(self, a: tuple, b: tuple) -> tuple:
return ((a, b),)
# ---------------------------------------------------------------------------
# Number
# ---------------------------------------------------------------------------
@register_node(display_name="Number")
class Number:
"""Provide a fixed scalar value that can feed FLOAT or INT widget sockets."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "step": 0.01}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Output a fixed numeric value. "
"When connected to FLOAT inputs the exact value is used; "
"INT inputs round to the nearest integer at execution time."
)
def process(self, value: float) -> tuple:
return (float(value),)
# ---------------------------------------------------------------------------
# RangeSlider
# ---------------------------------------------------------------------------
@register_node(display_name="Float Slider")
class RangeSlider:
"""Interactive float control node with min/max bounds and a slider value."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"min_value": ("FLOAT", {"default": 0.0, "step": 0.01}),
"max_value": ("FLOAT", {"default": 1.0, "step": 0.01}),
"value": ("FLOAT", {
"default": 0.5,
"step": 0.01,
"slider": True,
"min_widget": "min_value",
"max_widget": "max_value",
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Interactive float slider. Set min and max bounds, then drag the slider to output a FLOAT value."
)
def process(self, min_value: float, max_value: float, value: float) -> tuple:
lo = min(float(min_value), float(max_value))
hi = max(float(min_value), float(max_value))
if hi == lo:
return (lo,)
return (float(np.clip(float(value), lo, hi)),)
# ---------------------------------------------------------------------------
# SaveImage
# ---------------------------------------------------------------------------
_MAX_SAVE_FIELDS = 8
@register_node(display_name="Save Layers")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
optional = {
"directory": ("DIRECTORY", {"label": "directory"}),
}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"})
optional[f"layer_name_{i}"] = ("STRING", {
"default": "",
"placeholder": "name",
"show_when_input_visible": f"field_{i}",
"inline_with_input": f"field_{i}",
"hide_label": True,
})
return {
"required": {
"filename": ("STRING", {
"default": "",
"placeholder": "filename",
"placement": "top",
}),
"directory_path": ("FOLDER_PICKER", {
"default": "",
"label": "directory",
"placement": "top",
"hide_when_input_connected": "directory",
"top_socket_input": "directory",
}),
"format": (["TIFF", "NPZ"],),
},
"optional": optional,
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more layers to a single file. "
"Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. "
"Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. "
"A new slot appears as each one is filled, with a matching per-layer name field. "
"TIFF writes multi-page data and stores layer names as page descriptions; "
"NPZ writes named arrays using those layer names as keys. "
"Click Save to write (does not auto-run)."
)
_broadcast_warning_fn = None
_current_node_id = None
def save(
self,
filename: str,
directory_path: str = "",
format: str = "TIFF",
directory: str | None = None,
**kwargs,
):
layers = []
layer_names = []
for i in range(_MAX_SAVE_FIELDS):
layer = kwargs.get(f"field_{i}")
if layer is not None:
layers.append(layer)
layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i))
if not layers:
raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.")
path = self._resolve_save_path(filename, format, directory, directory_path)
if format == "TIFF":
self._save_tiff(path, layers, layer_names)
else:
self._save_npz(path, layers, layer_names)
self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}")
return ()
def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
import tifffile
with tifffile.TiffWriter(str(path)) as tif:
for layer, layer_name in zip(layers, layer_names):
tif.write(self._layer_array_for_tiff(layer), description=layer_name)
def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
arrays = {}
used_keys = set()
for i, (layer, layer_name) in enumerate(zip(layers, layer_names)):
arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer)
np.savez(str(path), **arrays)
def _resolve_layer_name(self, raw_name: object, index: int) -> str:
text = str(raw_name).strip() if raw_name is not None else ""
return text or f"layer_{index}"
def _resolve_save_path(
self,
filename: str,
format: str,
directory: str | None,
directory_path: str = "",
) -> Path:
ext = ".tiff" if format == "TIFF" else ".npz"
raw_filename = str(filename).strip() if filename is not None else ""
raw_directory = str(directory).strip() if directory is not None else ""
if not raw_directory:
raw_directory = str(directory_path).strip() if directory_path is not None else ""
if raw_directory:
dir_path = Path(raw_directory).expanduser()
if dir_path.exists() and not dir_path.is_dir():
raise ValueError("Directory input expects a folder path, not a file path.")
if not dir_path.exists():
if dir_path.suffix:
raise ValueError("Directory input expects a folder path, not a file path.")
dir_path.mkdir(parents=True, exist_ok=True)
filename_part = Path(raw_filename).name if raw_filename else ""
if not filename_part:
raise ValueError("No output filename selected — enter a file name when using a directory input.")
path = dir_path / filename_part
else:
if not raw_filename:
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(raw_filename).expanduser()
path.parent.mkdir(parents=True, exist_ok=True)
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
return path
def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str:
key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_")
if not key:
key = f"layer_{index}"
if key[0].isdigit():
key = f"layer_{key}"
candidate = key
suffix = 2
while candidate in used_keys:
candidate = f"{key}_{suffix}"
suffix += 1
used_keys.add(candidate)
return candidate
def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data, dtype=np.float32)
if isinstance(layer, np.ndarray):
return image_to_uint8(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data)
if isinstance(layer, np.ndarray):
return np.asarray(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return ()

View File

@@ -1,150 +0,0 @@
"""
Leveling nodes — background removal and zero correction.
Gwyddion equivalents:
PlaneLevelField → gwy_data_field_fit_plane + gwy_data_field_plane_level
PolyLevelField → gwy_data_field_fit_polynom (via level.c polylevel module)
FixZero → fix_zero in level.c
Plane-fit algorithm follows Gwyddion's level.h definition:
z_fit = pa + pbx * x + pby * y (least-squares over all pixels)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# PlaneLevelField
# ---------------------------------------------------------------------------
@register_node(display_name="Plane Level")
class PlaneLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("leveled",)
FUNCTION = "process"
DESCRIPTION = (
"Fit and subtract a least-squares plane from the data. "
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
)
def process(self, field: DataField) -> tuple:
data = field.data.copy()
yres, xres = data.shape
# Normalised coordinate grids in [0, 1]
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
# Design matrix: [1, x, y] shape (N, 3)
A = np.column_stack([
np.ones(xres * yres),
xx.ravel(),
yy.ravel(),
])
z = data.ravel()
# Least-squares: solve A @ [pa, pbx, pby] = z
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
pa, pbx, pby = coeffs
plane = (pa + pbx * xx + pby * yy)
return (field.replace(data=data - plane),)
# ---------------------------------------------------------------------------
# PolyLevelField
# ---------------------------------------------------------------------------
@register_node(display_name="Polynomial Level")
class PolyLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
RETURN_NAMES = ("leveled", "background")
FUNCTION = "process"
DESCRIPTION = (
"Fit and subtract a polynomial background of given degree in x and y. "
"Equivalent to gwy_data_field_fit_polynom."
)
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
data = field.data.copy()
yres, xres = data.shape
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
# Build Vandermonde-style design matrix with all monomials x^i * y^j
cols = []
for i in range(degree_x + 1):
for j in range(degree_y + 1):
cols.append((xx ** i * yy ** j).ravel())
A = np.column_stack(cols)
z = data.ravel()
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
background = (A @ coeffs).reshape(yres, xres)
leveled = data - background
return (field.replace(data=leveled), field.replace(data=background))
# ---------------------------------------------------------------------------
# FixZero
# ---------------------------------------------------------------------------
@register_node(display_name="Fix Zero")
class FixZero:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["min", "mean", "median"],),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("zeroed",)
FUNCTION = "process"
DESCRIPTION = (
"Shift data so that the minimum (or mean/median) is zero. "
"Equivalent to fix_zero in Gwyddion's level.c."
)
def process(self, field: DataField, method: str) -> tuple:
data = field.data.copy()
if method == "min":
data -= data.min()
elif method == "mean":
data -= data.mean()
elif method == "median":
data -= np.median(data)
else:
raise ValueError(f"Unknown method: {method}")
return (field.replace(data=data),)

68
backend/nodes/markup.py Normal file
View File

@@ -0,0 +1,68 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _parse_markup_shapes, _normalize_markup_color
@register_node(display_name="Markup")
class Markup:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}),
"stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}),
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),
"clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}),
"markup_shapes": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "process"
DESCRIPTION = (
"Draw simple vector markup over a DATA_FIELD without flattening the underlying data. "
"Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
shape: str,
stroke_color: str,
stroke_width: int,
markup_shapes: str,
) -> tuple:
shapes = _parse_markup_shapes(markup_shapes)
out = field.replace(
overlays=[
*field.overlays,
{
"kind": "markup",
"shapes": shapes,
},
],
)
if Markup._broadcast_overlay_fn is not None:
Markup._broadcast_overlay_fn(
Markup._current_node_id,
{
"kind": "markup",
"section_title": "Markup",
"image": encode_preview(datafield_to_uint8(field, field.colormap)),
"shape": str(shape),
"stroke_color": _normalize_markup_color(stroke_color),
"stroke_width": max(1, int(stroke_width)),
},
)
return (out,)

View File

@@ -1,437 +0,0 @@
"""
Mask operation nodes — creation, morphology, and boolean combination.
Gwyddion equivalents:
ThresholdMask → threshold.c / otsu_threshold.c
MaskMorphology → mask_morph.c (erode, dilate, open, close)
MaskInvert → (bitwise NOT on mask)
MaskCombine → (boolean ops between two masks)
"""
from __future__ import annotations
from functools import lru_cache
import json
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
"""Render greyscale base image with red shadow on masked (255) pixels.
Returns (H, W, 3) uint8 array.
"""
grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8
mask_bool = mask > 127
if not np.any(mask_bool):
return grey
overlay = grey.copy()
red = overlay[..., 0]
green = overlay[..., 1]
blue = overlay[..., 2]
# Integer alpha blend equivalent to a 45% red overlay, without float64 work.
red_vals = red[mask_bool].astype(np.uint16)
green_vals = green[mask_bool].astype(np.uint16)
blue_vals = blue[mask_bool].astype(np.uint16)
red[mask_bool] = ((red_vals * 55) + (255 * 45) + 50) // 100
green[mask_bool] = ((green_vals * 55) + 50) // 100
blue[mask_bool] = ((blue_vals * 55) + 50) // 100
return overlay
@lru_cache(maxsize=128)
def _mask_structure(radius: int, shape: str) -> np.ndarray:
radius = max(1, int(radius))
if shape == "disk":
y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1]
struct = (x * x + y * y) <= radius * radius
else:
size = 2 * radius + 1
struct = np.ones((size, size), dtype=bool)
struct.setflags(write=False)
return struct
def _clamp_fraction(value) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, numeric))
def _parse_mask_strokes(mask_paths) -> list[dict]:
if isinstance(mask_paths, list):
raw_strokes = mask_paths
elif isinstance(mask_paths, str) and mask_paths.strip():
try:
parsed = json.loads(mask_paths)
except json.JSONDecodeError:
return []
raw_strokes = parsed if isinstance(parsed, list) else []
else:
return []
strokes = []
for stroke in raw_strokes:
if not isinstance(stroke, dict):
continue
raw_points = stroke.get("points")
if not isinstance(raw_points, list):
continue
points = []
for point in raw_points:
if not isinstance(point, dict):
continue
if "x" not in point or "y" not in point:
continue
points.append({
"x": _clamp_fraction(point.get("x")),
"y": _clamp_fraction(point.get("y")),
})
if not points:
continue
try:
size = max(1, int(round(float(stroke.get("size", 1)))))
except (TypeError, ValueError):
size = 1
strokes.append({
"size": size,
"points": points,
})
return strokes
def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray:
from PIL import Image, ImageDraw
width = max(1, int(width))
height = max(1, int(height))
default_pen_size = max(1, int(default_pen_size))
mask_image = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
for stroke in strokes:
points = stroke.get("points") or []
if not points:
continue
size = stroke.get("size", default_pen_size)
try:
size = max(1, int(round(float(size))))
except (TypeError, ValueError):
size = default_pen_size
pixel_points = []
for point in points:
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
pixel_points.append((px, py))
radius = max(0.5, size / 2.0)
if len(pixel_points) == 1:
x, y = pixel_points[0]
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
continue
draw.line(pixel_points, fill=255, width=size)
for x, y in pixel_points:
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
return np.asarray(mask_image, dtype=np.uint8)
# ---------------------------------------------------------------------------
# DrawMask
# ---------------------------------------------------------------------------
@register_node(display_name="Draw Mask")
class DrawMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}),
"invert": ("BOOLEAN", {"default": False}),
"clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}),
"mask_paths": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Paint a binary mask directly over an image preview. "
"Pen size controls newly drawn strokes, the overlay lets you clear the mask, "
"and invert flips the final binary output."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple:
strokes = _parse_mask_strokes(mask_paths)
mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size)
if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None:
DrawMask._broadcast_overlay_fn(
DrawMask._current_node_id,
{
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
},
)
return (mask,)
# ---------------------------------------------------------------------------
# ThresholdMask
# ---------------------------------------------------------------------------
@register_node(display_name="Threshold Mask")
class ThresholdMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["otsu", "absolute", "relative"],),
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
"direction": (["above", "below"],),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Create a binary mask by thresholding data. "
"Otsu automatically finds the optimal threshold. "
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
data = field.data
if method == "otsu":
from skimage.filters import threshold_otsu
t = threshold_otsu(data)
elif method == "absolute":
t = float(threshold)
elif method == "relative":
# threshold is a fraction [0, 1] of the data range
dmin, dmax = data.min(), data.max()
t = dmin + float(threshold) * (dmax - dmin)
else:
raise ValueError(f"Unknown threshold method: {method}")
if direction == "above":
mask = (data >= t).astype(np.uint8) * 255
else:
mask = (data < t).astype(np.uint8) * 255
if ThresholdMask._broadcast_fn is not None:
overlay = _mask_overlay(field, mask)
ThresholdMask._broadcast_fn(
ThresholdMask._current_node_id, encode_preview(overlay),
)
return (mask,)
# ---------------------------------------------------------------------------
# MaskMorphology
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Morphology")
class MaskMorphology:
"""Morphological operations on binary masks.
Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close).
"""
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
"operation": (["dilate", "erode", "open", "close"],),
"radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}),
"shape": (["disk", "square"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Apply morphological operations to a binary mask. "
"Dilate expands regions, erode shrinks them, "
"open (erode then dilate) removes small spots, "
"close (dilate then erode) fills small holes. "
"Equivalent to Gwyddion mask_morph."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, operation: str, radius: int, shape: str,
field: DataField | None = None) -> tuple:
from scipy.ndimage import binary_closing, binary_dilation, binary_erosion, binary_opening
binary = mask > 127
struct = _mask_structure(radius, shape)
if operation == "dilate":
result = binary_dilation(binary, structure=struct)
elif operation == "erode":
result = binary_erosion(binary, structure=struct)
elif operation == "open":
result = binary_opening(binary, structure=struct)
elif operation == "close":
result = binary_closing(binary, structure=struct)
else:
raise ValueError(f"Unknown morphological operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskMorphology._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskMorphology._broadcast_fn(
MaskMorphology._current_node_id, encode_preview(overlay),
)
return (out,)
# ---------------------------------------------------------------------------
# MaskInvert
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Invert")
class MaskInvert:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions."
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
out = np.where(mask > 127, np.uint8(0), np.uint8(255))
if field is not None and MaskInvert._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskInvert._broadcast_fn(
MaskInvert._current_node_id, encode_preview(overlay),
)
return (out,)
# ---------------------------------------------------------------------------
# MaskCombine
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Combine")
class MaskCombine:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask_a": ("IMAGE",),
"mask_b": ("IMAGE",),
"operation": (["and", "or", "xor", "subtract"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Combine two binary masks with a boolean operation. "
"AND keeps overlap, OR merges, XOR keeps non-overlapping regions, "
"subtract removes mask_b from mask_a."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str,
field: DataField | None = None) -> tuple:
a = mask_a > 127
b = mask_b > 127
if operation == "and":
result = a & b
elif operation == "or":
result = a | b
elif operation == "xor":
result = a ^ b
elif operation == "subtract":
result = a & ~b
else:
raise ValueError(f"Unknown mask operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskCombine._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskCombine._broadcast_fn(
MaskCombine._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@register_node(display_name="Mask Combine")
class MaskCombine:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask_a": ("IMAGE",),
"mask_b": ("IMAGE",),
"operation": (["and", "or", "xor", "subtract"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Combine two binary masks with a boolean operation. "
"AND keeps overlap, OR merges, XOR keeps non-overlapping regions, "
"subtract removes mask_b from mask_a."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str,
field: DataField | None = None) -> tuple:
a = mask_a > 127
b = mask_b > 127
if operation == "and":
result = a & b
elif operation == "or":
result = a | b
elif operation == "xor":
result = a ^ b
elif operation == "subtract":
result = a & ~b
else:
raise ValueError(f"Unknown mask operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskCombine._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskCombine._broadcast_fn(
MaskCombine._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@register_node(display_name="Mask Invert")
class MaskInvert:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions."
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
out = np.where(mask > 127, np.uint8(0), np.uint8(255))
if field is not None and MaskInvert._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskInvert._broadcast_fn(
MaskInvert._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay, _mask_structure
@register_node(display_name="Mask Morphology")
class MaskMorphology:
"""Morphological operations on binary masks.
Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close).
"""
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
"operation": (["dilate", "erode", "open", "close"],),
"radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}),
"shape": (["disk", "square"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Apply morphological operations to a binary mask. "
"Dilate expands regions, erode shrinks them, "
"open (erode then dilate) removes small spots, "
"close (dilate then erode) fills small holes. "
"Equivalent to Gwyddion mask_morph."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, operation: str, radius: int, shape: str,
field: DataField | None = None) -> tuple:
from scipy.ndimage import binary_closing, binary_dilation, binary_erosion, binary_opening
binary = mask > 127
struct = _mask_structure(radius, shape)
if operation == "dilate":
result = binary_dilation(binary, structure=struct)
elif operation == "erode":
result = binary_erosion(binary, structure=struct)
elif operation == "open":
result = binary_opening(binary, structure=struct)
elif operation == "close":
result = binary_closing(binary, structure=struct)
else:
raise ValueError(f"Unknown morphological operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskMorphology._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskMorphology._broadcast_fn(
MaskMorphology._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Median Filter")
class MedianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
def process(self, field: DataField, size: int) -> tuple:
from scipy.ndimage import median_filter
size = max(1, int(size))
data = median_filter(field.data, size=size)
return (field.replace(data=data),)

28
backend/nodes/number.py Normal file
View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from backend.node_registry import register_node
@register_node(display_name="Number")
class Number:
"""Provide a fixed scalar value that can feed FLOAT or INT widget sockets."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "step": 0.01}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Output a fixed numeric value. "
"When connected to FLOAT inputs the exact value is used; "
"INT inputs round to the nearest integer at execution time."
)
def process(self, value: float) -> tuple:
return (float(value),)

View File

@@ -1,20 +1,9 @@
"""
Particle detection nodes.
Gwyddion equivalents:
ParticleAnalysis gwy_data_field_particles_get_values (particles-values.c)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, RecordTable
# ---------------------------------------------------------------------------
# ParticleAnalysis
# ---------------------------------------------------------------------------
@register_node(display_name="Particle Analysis")
class ParticleAnalysis:
@classmethod
@@ -43,7 +32,7 @@ class ParticleAnalysis:
binary = (mask > 127).astype(np.int32)
labeled, n_particles = label(binary)
pixel_area = field.dx * field.dy # m^2 per pixel
pixel_area = field.dx * field.dy
rows = RecordTable()
for pid in range(1, n_particles + 1):
@@ -59,7 +48,6 @@ class ParticleAnalysis:
mean_h = float(heights.mean())
max_h = float(heights.max())
# Bounding box
ys, xs = np.where(particle_pixels)
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Plane Level")
class PlaneLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("leveled",)
FUNCTION = "process"
DESCRIPTION = (
"Fit and subtract a least-squares plane from the data. "
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
)
def process(self, field: DataField) -> tuple:
data = field.data.copy()
yres, xres = data.shape
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
A = np.column_stack([
np.ones(xres * yres),
xx.ravel(),
yy.ravel(),
])
z = data.ravel()
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
pa, pbx, pby = coeffs
plane = (pa + pbx * xx + pby * yy)
return (field.replace(data=data - plane),)

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Polynomial Level")
class PolyLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
RETURN_NAMES = ("leveled", "background")
FUNCTION = "process"
DESCRIPTION = (
"Fit and subtract a polynomial background of given degree in x and y. "
"Equivalent to gwy_data_field_fit_polynom."
)
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
data = field.data.copy()
yres, xres = data.shape
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
cols = []
for i in range(degree_x + 1):
for j in range(degree_y + 1):
cols.append((xx ** i * yy ** j).ravel())
A = np.column_stack(cols)
z = data.ravel()
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
background = (A @ coeffs).reshape(yres, xres)
leveled = data - background
return (field.replace(data=leveled), field.replace(data=background))

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import (
COLORMAPS,
colormap_to_uint8,
encode_preview,
image_to_uint8,
render_datafield_preview,
resolve_colormap_input,
)
@register_node(display_name="Preview")
class PreviewImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"image": ("IMAGE",),
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ()
FUNCTION = "preview"
OUTPUT_NODE = True
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
_broadcast_fn = None
_current_node_id: str = ""
def preview(
self,
colormap: str,
image: np.ndarray | None = None,
field=None,
colormap_map=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap if field is not None else None,
default="gray",
)
if field is not None:
arr_u8 = render_datafield_preview(field, resolved_colormap)
elif image is not None:
arr_u8 = image_to_uint8(image)
if arr_u8.ndim == 2:
if image.dtype == np.uint8:
normalized = arr_u8.astype(np.float64) / 255.0
else:
imin, imax = image.min(), image.max()
if imax > imin:
normalized = (image - imin) / (imax - imin)
else:
normalized = np.zeros_like(image, dtype=np.float64)
arr_u8 = colormap_to_uint8(normalized, resolved_colormap)
else:
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
data_uri = encode_preview(arr_u8)
if PreviewImage._broadcast_fn is not None:
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
return ()

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from backend.node_registry import register_node
@register_node(display_name="Print Table")
class PrintTable:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"table": ("ANY_TABLE",),
}
}
RETURN_TYPES = ()
FUNCTION = "print_table"
OUTPUT_NODE = True
DESCRIPTION = "Send a measurement or record table to the browser as a WebSocket message for display."
_broadcast_table_fn = None
_current_node_id: str = ""
def print_table(self, table: list) -> tuple:
if PrintTable._broadcast_table_fn is not None:
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
return ()

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
@register_node(display_name="Float Slider")
class RangeSlider:
"""Interactive float control node with min/max bounds and a slider value."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"min_value": ("FLOAT", {"default": 0.0, "step": 0.01}),
"max_value": ("FLOAT", {"default": 1.0, "step": 0.01}),
"value": ("FLOAT", {
"default": 0.5,
"step": 0.01,
"slider": True,
"min_widget": "min_value",
"max_widget": "max_value",
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Interactive float slider. Set min and max bounds, then drag the slider to output a FLOAT value."
)
def process(self, min_value: float, max_value: float, value: float) -> tuple:
lo = min(float(min_value), float(max_value))
hi = max(float(min_value), float(max_value))
if hi == lo:
return (lo,)
return (float(np.clip(float(value), lo, hi)),)

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Rotate")
class RotateField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"angle": ("FLOAT", {"default": 90.0, "min": -360.0, "max": 360.0, "step": 1.0}),
"interpolation": (["bilinear", "nearest", "bicubic"],),
"expand_canvas": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
DESCRIPTION = (
"Rotate a DATA_FIELD counterclockwise by an angle in degrees. "
"Optionally expand the canvas to keep the full rotated field while preserving the field center."
)
_broadcast_warning_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
angle: float,
interpolation: str,
expand_canvas: bool,
) -> tuple:
if field.overlays:
self._send_warning("Rotate clears annotation/markup overlays!")
angle = float(angle)
order_map = {
"nearest": 0,
"bilinear": 1,
"bicubic": 3,
}
if interpolation not in order_map:
raise ValueError(f"Unknown interpolation mode: {interpolation}")
normalized_angle = angle % 360.0
snapped_quarters = int(round(normalized_angle / 90.0)) % 4
snapped_angle = snapped_quarters * 90.0
is_right_angle = abs(normalized_angle - snapped_angle) < 1e-9
if is_right_angle and expand_canvas:
rotated = np.rot90(field.data, k=snapped_quarters).copy()
elif abs(normalized_angle) < 1e-9:
rotated = field.data.copy()
else:
from scipy.ndimage import rotate as nd_rotate
rotated = nd_rotate(
field.data,
angle=angle,
reshape=bool(expand_canvas),
order=order_map[interpolation],
mode="nearest",
prefilter=order_map[interpolation] > 1,
)
new_xreal, new_yreal = self._rotated_extents(field, angle, expand_canvas)
center_x = field.xoff + field.xreal / 2.0
center_y = field.yoff + field.yreal / 2.0
result = field.replace(
data=np.asarray(rotated, dtype=np.float64),
xreal=new_xreal,
yreal=new_yreal,
xoff=center_x - new_xreal / 2.0,
yoff=center_y - new_yreal / 2.0,
overlays=[],
)
return (result,)
def _send_warning(self, message: str):
fn = RotateField._broadcast_warning_fn
nid = RotateField._current_node_id
if fn and nid:
fn(nid, message)
@staticmethod
def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]:
if not expand_canvas:
return (field.xreal, field.yreal)
theta = np.deg2rad(angle)
cos_t = abs(float(np.cos(theta)))
sin_t = abs(float(np.sin(theta)))
new_xreal = field.xreal * cos_t + field.yreal * sin_t
new_yreal = field.xreal * sin_t + field.yreal * cos_t
return (new_xreal, new_yreal)

182
backend/nodes/save_image.py Normal file
View File

@@ -0,0 +1,182 @@
from __future__ import annotations
import re
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, image_to_uint8
from backend.nodes.helpers import _MAX_SAVE_FIELDS
@register_node(display_name="Save Layers")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
optional = {
"directory": ("DIRECTORY", {"label": "directory"}),
}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"})
optional[f"layer_name_{i}"] = ("STRING", {
"default": "",
"placeholder": "name",
"show_when_input_visible": f"field_{i}",
"inline_with_input": f"field_{i}",
"hide_label": True,
})
return {
"required": {
"filename": ("STRING", {
"default": "",
"placeholder": "filename",
"placement": "top",
}),
"directory_path": ("FOLDER_PICKER", {
"default": "",
"label": "directory",
"placement": "top",
"hide_when_input_connected": "directory",
"top_socket_input": "directory",
}),
"format": (["TIFF", "NPZ"],),
},
"optional": optional,
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more layers to a single file. "
"Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. "
"Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. "
"A new slot appears as each one is filled, with a matching per-layer name field. "
"TIFF writes multi-page data and stores layer names as page descriptions; "
"NPZ writes named arrays using those layer names as keys. "
"Click Save to write (does not auto-run)."
)
_broadcast_warning_fn = None
_current_node_id = None
def save(
self,
filename: str,
directory_path: str = "",
format: str = "TIFF",
directory: str | None = None,
**kwargs,
):
layers = []
layer_names = []
for i in range(_MAX_SAVE_FIELDS):
layer = kwargs.get(f"field_{i}")
if layer is not None:
layers.append(layer)
layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i))
if not layers:
raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.")
path = self._resolve_save_path(filename, format, directory, directory_path)
if format == "TIFF":
self._save_tiff(path, layers, layer_names)
else:
self._save_npz(path, layers, layer_names)
self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}")
return ()
def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
import tifffile
with tifffile.TiffWriter(str(path)) as tif:
for layer, layer_name in zip(layers, layer_names):
tif.write(self._layer_array_for_tiff(layer), description=layer_name)
def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
arrays = {}
used_keys = set()
for i, (layer, layer_name) in enumerate(zip(layers, layer_names)):
arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer)
np.savez(str(path), **arrays)
def _resolve_layer_name(self, raw_name: object, index: int) -> str:
text = str(raw_name).strip() if raw_name is not None else ""
return text or f"layer_{index}"
def _resolve_save_path(
self,
filename: str,
format: str,
directory: str | None,
directory_path: str = "",
) -> Path:
ext = ".tiff" if format == "TIFF" else ".npz"
raw_filename = str(filename).strip() if filename is not None else ""
raw_directory = str(directory).strip() if directory is not None else ""
if not raw_directory:
raw_directory = str(directory_path).strip() if directory_path is not None else ""
if raw_directory:
dir_path = Path(raw_directory).expanduser()
if dir_path.exists() and not dir_path.is_dir():
raise ValueError("Directory input expects a folder path, not a file path.")
if not dir_path.exists():
if dir_path.suffix:
raise ValueError("Directory input expects a folder path, not a file path.")
dir_path.mkdir(parents=True, exist_ok=True)
filename_part = Path(raw_filename).name if raw_filename else ""
if not filename_part:
raise ValueError("No output filename selected — enter a file name when using a directory input.")
path = dir_path / filename_part
else:
if not raw_filename:
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(raw_filename).expanduser()
path.parent.mkdir(parents=True, exist_ok=True)
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
return path
def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str:
key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_")
if not key:
key = f"layer_{index}"
if key[0].isdigit():
key = f"layer_{key}"
candidate = key
suffix = 2
while candidate in used_keys:
candidate = f"{key}_{suffix}"
suffix += 1
used_keys.add(candidate)
return candidate
def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data, dtype=np.float32)
if isinstance(layer, np.ndarray):
return image_to_uint8(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data)
if isinstance(layer, np.ndarray):
return np.asarray(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return ()

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, MeasureTable
@register_node(display_name="Statistics")
class Statistics:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("MEASURE_TABLE",)
RETURN_NAMES = ("stats",)
FUNCTION = "process"
DESCRIPTION = (
"Compute basic surface statistics: min, max, mean, RMS roughness, median, "
"and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms."
)
def process(self, field: DataField) -> tuple:
d = field.data
mean = float(d.mean())
rms = float(np.sqrt(np.mean((d - mean) ** 2)))
skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0
kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0
table = MeasureTable([
{"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z},
{"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z},
{"quantity": "mean", "value": mean, "unit": field.si_unit_z},
{"quantity": "RMS", "value": rms, "unit": field.si_unit_z},
{"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z},
{"quantity": "skewness", "value": skewness, "unit": ""},
{"quantity": "kurtosis", "value": kurtosis, "unit": ""},
{"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z},
])
return (table,)

130
backend/nodes/stats.py Normal file
View File

@@ -0,0 +1,130 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, LineData, MeasureTable
from backend.nodes.helpers import (
LINE_OPS,
TABLE_OPS,
ARRAY_OPS,
_scalar_payload,
_apply_scalar_unit,
_common_table_unit,
extract_numeric_table_values,
resolve_table_column_name,
)
@register_node(display_name="Stats")
class Stats:
"""Polymorphic scalar stats node for LINE, RECORD_TABLE, DATA_FIELD, or IMAGE inputs."""
_broadcast_value_fn = None
_current_node_id: str = ""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input": ("STATS_SOURCE",),
"column": ("STRING", {
"default": "value",
"choices_from_table_input": "input",
"show_when_source_type": {
"input": ["RECORD_TABLE"],
},
}),
"operation": ("STRING", {
"default": "mean",
"choices_by_source_type": {
"LINE": list(LINE_OPS.keys()),
"RECORD_TABLE": list(TABLE_OPS.keys()),
"DATA_FIELD": list(ARRAY_OPS.keys()),
"IMAGE": list(ARRAY_OPS.keys()),
},
"source_type_input": "input",
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Compute a contextual scalar statistic from a LINE, record table, DATA_FIELD, or IMAGE. "
"The available operations adapt to the connected input type."
)
def process(self, input, operation: str, column: str = "value") -> tuple:
source_type, values, resolved_column = self._resolve_input_values(input, column)
if source_type == "RECORD_TABLE":
ops = TABLE_OPS
elif source_type == "LINE":
ops = LINE_OPS
else:
ops = ARRAY_OPS
if operation not in ops:
raise ValueError(f"Operation '{operation}' is not valid for {source_type} input.")
op_entry = ops[operation]
fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry
result = fn(values)
if Stats._broadcast_value_fn is not None:
Stats._broadcast_value_fn(
Stats._current_node_id,
_scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)),
)
return (result,)
def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str:
if source_type == "DATA_FIELD" and isinstance(input_value, DataField):
return _apply_scalar_unit(input_value.si_unit_z, operation)
if source_type == "LINE":
line_entry = LINE_OPS.get(operation)
explicit_unit = line_entry[1] if isinstance(line_entry, tuple) and len(line_entry) > 1 else ""
if explicit_unit:
return _apply_scalar_unit(explicit_unit, operation)
if isinstance(input_value, LineData):
return _apply_scalar_unit(input_value.y_unit, operation)
return ""
if source_type == "RECORD_TABLE" and isinstance(input_value, list) and column:
return _apply_scalar_unit(_common_table_unit(input_value, column), operation)
return ""
def _resolve_input_values(self, input_value, column: str) -> tuple[str, np.ndarray, str | None]:
if isinstance(input_value, DataField):
values = np.asarray(input_value.data, dtype=np.float64)
return ("DATA_FIELD", values.ravel(), None)
if isinstance(input_value, MeasureTable):
raise ValueError("Stats only accepts record tables, not measurement tables.")
if isinstance(input_value, list):
if not input_value:
raise ValueError("Stats requires a non-empty record table input.")
column_name = resolve_table_column_name(input_value, column)
values = extract_numeric_table_values(input_value, column_name)
if not values:
raise ValueError(f"Column '{column_name}' has no numeric values.")
return ("RECORD_TABLE", np.asarray(values, dtype=np.float64), column_name)
if isinstance(input_value, LineData):
values = np.asarray(input_value.data, dtype=np.float64)
if values.size == 0:
raise ValueError("Stats requires a non-empty input.")
return ("LINE", values.ravel(), None)
if isinstance(input_value, np.ndarray):
values = np.asarray(input_value, dtype=np.float64)
if values.size == 0:
raise ValueError("Stats requires a non-empty input.")
if values.ndim == 1:
return ("LINE", values.ravel(), None)
return ("IMAGE", values.ravel(), None)
raise ValueError(f"Unsupported Stats input type: {type(input_value).__name__}")

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@register_node(display_name="Threshold Mask")
class ThresholdMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["otsu", "absolute", "relative"],),
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
"direction": (["above", "below"],),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Create a binary mask by thresholding data. "
"Otsu automatically finds the optimal threshold. "
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
data = field.data
if method == "otsu":
from skimage.filters import threshold_otsu
t = threshold_otsu(data)
elif method == "absolute":
t = float(threshold)
elif method == "relative":
dmin, dmax = data.min(), data.max()
t = dmin + float(threshold) * (dmax - dmin)
else:
raise ValueError(f"Unknown threshold method: {method}")
if direction == "above":
mask = (data >= t).astype(np.uint8) * 255
else:
mask = (data < t).astype(np.uint8) * 255
if ThresholdMask._broadcast_fn is not None:
overlay = _mask_overlay(field, mask)
ThresholdMask._broadcast_fn(
ThresholdMask._current_node_id, encode_preview(overlay),
)
return (mask,)

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import MeasureTable
from backend.nodes.helpers import _measurement_entry, _measurement_value, _scalar_payload
@register_node(display_name="Value Display")
class ValueDisplay:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("VALUE_SOURCE",),
"measurement": ("STRING", {
"default": "",
"choices_from_measure_input": "value",
"show_when_source_type": {
"value": ["MEASURE_TABLE"],
},
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "display_value"
DESCRIPTION = "Display a FLOAT, or a selected numeric row from a measurement table, and pass the value through unchanged."
_broadcast_value_fn = None
_current_node_id: str = ""
def display_value(self, value, measurement: str = "") -> tuple:
unit = ""
if isinstance(value, MeasureTable):
row = _measurement_entry(value, measurement)
numeric = _measurement_value(value, measurement)
unit = row.get("unit", "") if isinstance(row.get("unit"), str) else ""
else:
numeric = float(value)
if ValueDisplay._broadcast_value_fn is not None:
ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit))
return (numeric,)

90
backend/nodes/view_3d.py Normal file
View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import (
COLORMAPS,
DataField,
colormap_to_uint8,
normalize_for_colormap,
resolve_colormap_input,
)
@register_node(display_name="3D View")
class View3D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ()
FUNCTION = "render"
OUTPUT_NODE = True
DESCRIPTION = (
"Interactive 3D surface view of a DATA_FIELD. "
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
)
_broadcast_mesh_fn = None
_current_node_id: str = ""
def render(
self, field: DataField,
colormap: str, z_scale: float, resolution: int, colormap_map=None,
) -> tuple:
import base64
data = field.data
yres, xres = data.shape
step_y = max(1, yres // resolution)
step_x = max(1, xres // resolution)
z = data[::step_y, ::step_x].astype(np.float32)
ny, nx = z.shape
zmin, zmax = float(z.min()), float(z.max())
z_norm = normalize_for_colormap(
z,
offset=field.display_offset,
scale=field.display_scale,
data_min=float(field.data.min()),
data_max=float(field.data.max()),
)
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
colors_u8 = colormap_to_uint8(z_norm, resolved_colormap)
z_b64 = base64.b64encode(z.tobytes()).decode()
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
mesh_data = {
"width": nx,
"height": ny,
"z_data": z_b64,
"colors": colors_b64,
"z_min": zmin,
"z_max": zmax,
"z_scale": float(z_scale * 0.1),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}
if View3D._broadcast_mesh_fn is not None:
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
return ()

View File

@@ -217,7 +217,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
async def get_folder_files(request: web.Request) -> web.Response:
folder_path = request.query.get("folder", "")
from backend.nodes.io import list_folder_paths
from backend.nodes.helpers import list_folder_paths
loop = asyncio.get_running_loop()
entries = await loop.run_in_executor(None, list_folder_paths, folder_path)
return web.Response(text=_dumps(entries), content_type="application/json")
@@ -267,7 +267,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
async def get_channels(request: web.Request) -> web.Response:
"""Return available channels for a given file path."""
from backend.nodes.io import list_channels
from backend.nodes.helpers import list_channels
filepath = request.query.get("file", "")
if not filepath:
return web.Response(

View File

@@ -9,7 +9,8 @@ import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField
from backend.nodes.analysis import FFT2D, InverseFFT2D
from backend.nodes.fft_2d import FFT2D
from backend.nodes.inverse_fft_2d import InverseFFT2D
def make_field(data, xreal=1e-6, yreal=1e-6):

View File

@@ -10,7 +10,7 @@ import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.analysis import FFT2D
from backend.nodes.fft_2d import FFT2D
OUT_DIR = os.path.join(os.path.dirname(__file__), "output")
os.makedirs(OUT_DIR, exist_ok=True)

View File

@@ -28,7 +28,7 @@ def make_field(data, xreal=1e-6, yreal=1e-6):
def test_threshold_otsu_bimodal():
"""Otsu on a clean bimodal image should separate the two populations."""
print("=== Test: Otsu on bimodal image ===")
from backend.nodes.particle import ThresholdMask
from backend.nodes.threshold_mask import ThresholdMask
node = ThresholdMask()
data = np.zeros((128, 128))
@@ -50,7 +50,7 @@ def test_threshold_otsu_bimodal():
def test_threshold_relative_range():
"""Relative threshold at 0.5 should be the midpoint of [min, max]."""
print("=== Test: Relative threshold at midpoint ===")
from backend.nodes.particle import ThresholdMask
from backend.nodes.threshold_mask import ThresholdMask
node = ThresholdMask()
data = np.full((64, 64), 2.0)
@@ -68,7 +68,7 @@ def test_threshold_relative_range():
def test_threshold_empty_mask():
"""Very high absolute threshold on low data should produce an empty mask."""
print("=== Test: Empty mask from high threshold ===")
from backend.nodes.particle import ThresholdMask
from backend.nodes.threshold_mask import ThresholdMask
node = ThresholdMask()
data = np.ones((64, 64))
@@ -82,7 +82,7 @@ def test_threshold_empty_mask():
def test_threshold_full_mask():
"""Very low absolute threshold should produce an all-white mask."""
print("=== Test: Full mask from low threshold ===")
from backend.nodes.particle import ThresholdMask
from backend.nodes.threshold_mask import ThresholdMask
node = ThresholdMask()
data = np.ones((64, 64)) * 5.0
@@ -100,7 +100,7 @@ def test_threshold_full_mask():
def test_single_circle_area():
"""A single filled circle — verify pixel count and physical area."""
print("=== Test: Single circle area ===")
from backend.nodes.particle import GrainAnalysis
from backend.nodes.particle_analysis import GrainAnalysis
node = GrainAnalysis()
N = 200
@@ -146,7 +146,7 @@ def test_single_circle_area():
def test_multiple_particles_separation():
"""Three well-separated particles of different sizes — check each is reported."""
print("=== Test: Multiple particles separation ===")
from backend.nodes.particle import GrainAnalysis
from backend.nodes.particle_analysis import GrainAnalysis
node = GrainAnalysis()
N = 128
@@ -184,7 +184,7 @@ def test_multiple_particles_separation():
def test_min_size_filtering():
"""min_size should exclude particles smaller than the threshold."""
print("=== Test: min_size filtering ===")
from backend.nodes.particle import GrainAnalysis
from backend.nodes.particle_analysis import GrainAnalysis
node = GrainAnalysis()
N = 64
@@ -227,7 +227,7 @@ def test_min_size_filtering():
def test_particles_bounding_box():
"""Bounding box should match the particles extents."""
print("=== Test: Grain bounding box ===")
from backend.nodes.particle import GrainAnalysis
from backend.nodes.particle_analysis import GrainAnalysis
node = GrainAnalysis()
N = 64
@@ -250,7 +250,7 @@ def test_particles_bounding_box():
def test_empty_mask_produces_no_particles():
"""An all-zero mask should yield zero particles."""
print("=== Test: Empty mask → no particles ===")
from backend.nodes.particle import GrainAnalysis
from backend.nodes.particle_analysis import GrainAnalysis
node = GrainAnalysis()
field = make_field(np.ones((64, 64)))
@@ -264,7 +264,7 @@ def test_empty_mask_produces_no_particles():
def test_particles_at_image_edge():
"""A particles touching the image border should still be detected."""
print("=== Test: Grain at image edge ===")
from backend.nodes.particle import GrainAnalysis
from backend.nodes.particle_analysis import GrainAnalysis
node = GrainAnalysis()
N = 64
@@ -286,7 +286,7 @@ def test_adjacent_particles_connectivity():
"""Two diagonally-touching blocks should be separate particles
(scipy.ndimage.label uses 4-connectivity by default)."""
print("=== Test: Diagonal adjacency → separate particles ===")
from backend.nodes.particle import GrainAnalysis
from backend.nodes.particle_analysis import GrainAnalysis
node = GrainAnalysis()
N = 32
@@ -316,7 +316,8 @@ def test_adjacent_particles_connectivity():
def test_pipeline_synthetic():
"""Full pipeline on a synthetic image with known geometry."""
print("=== Test: Full pipeline on synthetic particles ===")
from backend.nodes.particle import ThresholdMask, GrainAnalysis
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.particle_analysis import GrainAnalysis
N = 200
XREAL = 10e-6 # 10 µm
@@ -371,7 +372,8 @@ def test_pipeline_demo_image():
"""Run the full pipeline on the bundled demo nanoparticles image."""
print("=== Test: Full pipeline on demo nanoparticles.npy ===")
from pathlib import Path
from backend.nodes.particle import ThresholdMask, GrainAnalysis
from backend.nodes.threshold_mask import ThresholdMask
from backend.nodes.particle_analysis import GrainAnalysis
from backend.runtime_paths import demo_dir
npy_path = demo_dir() / "nanoparticles.npy"

View File

@@ -28,7 +28,7 @@ def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
def test_gaussian_filter():
print("=== Test: GaussianFilter ===")
from backend.nodes.filters import GaussianFilter
from backend.nodes.gaussian_filter import GaussianFilter
node = GaussianFilter()
field = make_field()
@@ -46,7 +46,7 @@ def test_gaussian_filter():
def test_median_filter():
print("=== Test: MedianFilter ===")
from backend.nodes.filters import MedianFilter
from backend.nodes.median_filter import MedianFilter
node = MedianFilter()
# Median filter should remove salt-and-pepper noise
@@ -68,7 +68,7 @@ def test_median_filter():
def test_crop_resize_field():
print("=== Test: CropResizeField ===")
from backend.nodes.modify import CropResizeField
from backend.nodes.crop_resize_field import CropResizeField
node = CropResizeField()
data = np.arange(32, dtype=np.float64).reshape(4, 8)
@@ -167,7 +167,7 @@ def test_crop_resize_field():
def test_rotate_field():
print("=== Test: RotateField ===")
from backend.nodes.modify import RotateField
from backend.nodes.rotate_field import RotateField
node = RotateField()
data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
@@ -230,7 +230,7 @@ def test_rotate_field():
def test_rotate_field_overlay_warning():
print("=== Test: RotateField overlay warning ===")
from backend.nodes.modify import RotateField
from backend.nodes.rotate_field import RotateField
node = RotateField()
warnings = []
@@ -258,7 +258,7 @@ def test_rotate_field_overlay_warning():
def test_colormap_adjust():
print("=== Test: ColormapAdjust ===")
from backend.nodes.modify import ColormapAdjust
from backend.nodes.colormap_adjust import ColormapAdjust
node = ColormapAdjust()
field = DataField(
@@ -299,7 +299,7 @@ def test_colormap_adjust():
def test_edge_detect():
print("=== Test: EdgeDetect ===")
from backend.nodes.filters import EdgeDetect
from backend.nodes.edge_detect import EdgeDetect
node = EdgeDetect()
# Create an image with a sharp vertical edge
@@ -320,7 +320,7 @@ def test_edge_detect():
def test_fft_filter_1d():
print("=== Test: FFTFilter1D ===")
from backend.nodes.filters import FFTFilter1D
from backend.nodes.fft_filter_1d import FFTFilter1D
node = FFTFilter1D()
# Signal: low-frequency sine + high-frequency sine
@@ -364,7 +364,7 @@ def test_fft_filter_1d():
def test_fft_filter_2d():
print("=== Test: FFTFilter2D ===")
from backend.nodes.filters import FFTFilter2D
from backend.nodes.fft_filter_2d import FFTFilter2D
node = FFTFilter2D()
N = 128
@@ -406,7 +406,7 @@ def test_fft_filter_2d():
def test_plane_level():
print("=== Test: PlaneLevelField ===")
from backend.nodes.level import PlaneLevelField
from backend.nodes.plane_level_field import PlaneLevelField
node = PlaneLevelField()
# Create a tilted plane + small signal
@@ -428,7 +428,7 @@ def test_plane_level():
def test_poly_level():
print("=== Test: PolyLevelField ===")
from backend.nodes.level import PolyLevelField
from backend.nodes.poly_level_field import PolyLevelField
node = PolyLevelField()
N = 64
@@ -455,7 +455,7 @@ def test_poly_level():
def test_fix_zero():
print("=== Test: FixZero ===")
from backend.nodes.level import FixZero
from backend.nodes.fix_zero import FixZero
node = FixZero()
field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64))
@@ -477,7 +477,7 @@ def test_fix_zero():
def test_statistics():
print("=== Test: Statistics ===")
from backend.nodes.analysis import Statistics
from backend.nodes.statistics_node import Statistics
node = Statistics()
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
@@ -507,7 +507,7 @@ def test_statistics():
def test_height_histogram():
print("=== Test: Histogram ===")
from backend.nodes.analysis import Histogram
from backend.nodes.histogram import Histogram
node = Histogram()
# Uniform data should give a roughly flat histogram
@@ -556,7 +556,7 @@ def test_height_histogram():
def test_cross_section():
print("=== Test: CrossSection ===")
from backend.nodes.analysis import CrossSection
from backend.nodes.cross_section import CrossSection
node = CrossSection()
# Create a field with a known horizontal gradient
@@ -604,7 +604,8 @@ def test_cross_section():
)
assert len(profile_diag) == 50
from backend.nodes.analysis import Cursors, Stats
from backend.nodes.cursors import Cursors
from backend.nodes.stats import Stats
cursors = Cursors()
table, _ = cursors.process(profile, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
@@ -630,7 +631,7 @@ def test_cross_section():
def test_threshold_mask():
print("=== Test: ThresholdMask ===")
from backend.nodes.mask import ThresholdMask
from backend.nodes.threshold_mask import ThresholdMask
node = ThresholdMask()
# Clear bimodal data: left half = 0, right half = 1
@@ -673,7 +674,7 @@ def test_threshold_mask():
def test_mask_morphology():
print("=== Test: MaskMorphology ===")
from backend.nodes.mask import MaskMorphology
from backend.nodes.mask_morphology import MaskMorphology
node = MaskMorphology()
# Small square blob in the centre
@@ -710,7 +711,7 @@ def test_mask_morphology():
def test_mask_invert():
print("=== Test: MaskInvert ===")
from backend.nodes.mask import MaskInvert
from backend.nodes.mask_invert import MaskInvert
node = MaskInvert()
mask = np.zeros((64, 64), dtype=np.uint8)
@@ -729,7 +730,7 @@ def test_mask_invert():
def test_mask_combine():
print("=== Test: MaskCombine ===")
from backend.nodes.mask import MaskCombine
from backend.nodes.mask_combine import MaskCombine
node = MaskCombine()
# Two overlapping squares
@@ -768,7 +769,7 @@ def test_mask_combine():
def test_draw_mask():
print("=== Test: DrawMask ===")
from backend.nodes.mask import DrawMask
from backend.nodes.draw_mask import DrawMask
node = DrawMask()
field = make_field(data=np.zeros((32, 32), dtype=np.float64))
@@ -815,7 +816,7 @@ def test_draw_mask():
def test_particle_analysis():
print("=== Test: ParticleAnalysis ===")
from backend.nodes.particless import ParticleAnalysis
from backend.nodes.particle_analysis import ParticleAnalysis
node = ParticleAnalysis()
# Create a field with two distinct particles
@@ -855,7 +856,7 @@ def test_particle_analysis():
def test_load_file():
print("=== Test: Image ===")
from backend.nodes.io import Image as ImageNode
from backend.nodes.image import Image as ImageNode
from PIL import Image as PILImage
node = ImageNode()
@@ -912,7 +913,7 @@ def test_load_file():
def test_save_image():
print("=== Test: SaveImage (Save Layers) ===")
from backend.nodes.io import SaveImage
from backend.nodes.save_image import SaveImage
import tifffile
node = SaveImage()
@@ -1012,7 +1013,7 @@ def test_save_image():
def test_color_map_node():
print("=== Test: ColorMap ===")
from backend.nodes.display import ColorMap
from backend.nodes.color_map import ColorMap
node = ColorMap()
@@ -1038,7 +1039,7 @@ def test_color_map_node():
def test_font_node():
print("=== Test: Font ===")
from backend.nodes.display import Font
from backend.nodes.font_node import Font
from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT
node = Font()
@@ -1056,7 +1057,7 @@ def test_font_node():
def test_preview_image():
print("=== Test: PreviewImage ===")
from backend.nodes.display import PreviewImage
from backend.nodes.preview_image import PreviewImage
node = PreviewImage()
# Set up a capture for the broadcast
@@ -1104,7 +1105,8 @@ def test_preview_image():
def test_annotations():
print("=== Test: Annotations ===")
from backend.nodes.display import Annotations, Font
from backend.nodes.annotations import Annotations
from backend.nodes.font_node import Font
node = Annotations()
font_node = Font()
@@ -1175,7 +1177,7 @@ def test_annotations():
def test_markup():
print("=== Test: Markup ===")
from backend.nodes.display import Markup
from backend.nodes.markup import Markup
from backend.data_types import _preview_markup_stroke_width
node = Markup()
@@ -1226,7 +1228,7 @@ def test_markup():
def test_print_table():
print("=== Test: PrintTable ===")
from backend.nodes.display import PrintTable
from backend.nodes.print_table import PrintTable
node = PrintTable()
captured = []
@@ -1244,7 +1246,7 @@ def test_print_table():
def test_value_display():
print("=== Test: ValueDisplay ===")
from backend.nodes.display import ValueDisplay
from backend.nodes.value_display import ValueDisplay
node = ValueDisplay()
captured = []
@@ -1273,7 +1275,7 @@ def test_value_display():
def test_load_file_ibw():
print("=== Test: Image IBW multi-channel ===")
from backend.nodes.io import Image
from backend.nodes.image import Image
node = Image()
ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")
@@ -1309,7 +1311,7 @@ def test_load_file_ibw():
def test_load_file_npz():
print("=== Test: Image .npz ===")
from backend.nodes.io import Image
from backend.nodes.image import Image
node = Image()
with tempfile.TemporaryDirectory() as tmpdir:
@@ -1326,7 +1328,7 @@ def test_load_file_npz():
def test_load_file_not_found():
print("=== Test: Image not found ===")
from backend.nodes.io import Image
from backend.nodes.image import Image
node = Image()
try:
@@ -1340,7 +1342,7 @@ def test_load_file_not_found():
def test_load_file_unsupported():
print("=== Test: Image unsupported format ===")
from backend.nodes.io import Image
from backend.nodes.image import Image
node = Image()
with tempfile.TemporaryDirectory() as tmpdir:
@@ -1358,7 +1360,7 @@ def test_load_file_unsupported():
def test_load_file_warning():
print("=== Test: Image warning for uncalibrated data ===")
from backend.nodes.io import Image as ImageNode
from backend.nodes.image import Image as ImageNode
from PIL import Image as PILImage
node = ImageNode()
@@ -1387,7 +1389,8 @@ def test_load_file_warning():
def test_list_channels():
print("=== Test: list_channels ===")
from backend.nodes.io import list_channels, list_folder_paths, Folder
from backend.nodes.helpers import list_channels, list_folder_paths
from backend.nodes.folder import Folder
from PIL import Image
# Non-existent file → default
@@ -1458,7 +1461,7 @@ def test_list_channels():
def test_load_demo():
print("=== Test: ImageDemo ===")
from backend.nodes.io import ImageDemo
from backend.nodes.image_demo import ImageDemo
node = ImageDemo()
@@ -1519,7 +1522,7 @@ def test_load_demo_multi_layer_preview_payload():
def test_coordinate():
print("=== Test: Coordinate ===")
from backend.nodes.io import Coordinate
from backend.nodes.coordinate import Coordinate
node = Coordinate()
@@ -1543,7 +1546,7 @@ def test_coordinate():
def test_number():
print("=== Test: Number ===")
from backend.nodes.io import Number
from backend.nodes.number import Number
node = Number()
@@ -1558,7 +1561,7 @@ def test_number():
def test_range_slider():
print("=== Test: RangeSlider ===")
from backend.nodes.io import RangeSlider
from backend.nodes.range_slider import RangeSlider
node = RangeSlider()
@@ -1642,7 +1645,7 @@ def test_execution_engine_numeric_socket_coercion():
def test_line_cursors():
print("=== Test: Cursors ===")
from backend.nodes.analysis import Cursors
from backend.nodes.cursors import Cursors
node = Cursors()
@@ -1718,7 +1721,7 @@ def test_line_cursors():
def test_fft2d():
print("=== Test: FFT2D ===")
from backend.nodes.analysis import FFT2D
from backend.nodes.fft_2d import FFT2D
node = FFT2D()
@@ -1777,7 +1780,7 @@ def test_fft2d():
def test_stats():
print("=== Test: Stats ===")
from backend.nodes.analysis import Stats
from backend.nodes.stats import Stats
node = Stats()
captured = []
@@ -1845,7 +1848,7 @@ def test_stats():
def test_view3d():
print("=== Test: View3D ===")
from backend.nodes.display import View3D
from backend.nodes.view_3d import View3D
node = View3D()
field = make_field()
@@ -1863,7 +1866,7 @@ def test_view3d():
assert "height" in mesh
assert "z_data" in mesh
assert "colors" in mesh
assert mesh["z_scale"] == 2.0
assert mesh["z_scale"] == 0.2
assert mesh["width"] <= 64
assert mesh["height"] <= 64
# z_min < z_max for non-constant data