refactor nodes into standalone file

This commit is contained in:
2026-03-26 19:50:03 -07:00
parent 711d7995b3
commit de0b49acc5
54 changed files with 3615 additions and 3710 deletions

View File

@@ -1,7 +1,54 @@
# Import all node modules to trigger @register_node decorators.
from . import io, filters, modify, level, analysis, mask, display
from backend.nodes import (
# IO
image,
image_demo,
folder,
coordinate,
coordinate_pair,
number,
range_slider,
save_image,
# Filters
gaussian_filter,
median_filter,
edge_detect,
fft_filter_1d,
fft_filter_2d,
# Modify
colormap_adjust,
crop_resize_field,
rotate_field,
# Level
plane_level_field,
poly_level_field,
fix_zero,
# Mask
draw_mask,
threshold_mask,
mask_morphology,
mask_invert,
mask_combine,
# Display
color_map,
font_node,
annotations,
markup,
preview_image,
view_3d,
print_table,
value_display,
# Analysis
statistics_node,
histogram,
cursors,
fft_2d,
inverse_fft_2d,
cross_section,
stats,
)
try:
from . import particle
from backend.nodes import particle_analysis
except ImportError:
from . import particless
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import COLORMAPS, DataField, normalize_font_spec, resolve_colormap_input
@register_node(display_name="Annotations")
class Annotations:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"show_scale_bar": ("BOOLEAN", {"default": True}),
"show_color_map": ("BOOLEAN", {"default": True}),
"text_size": ("FLOAT", {
"default": 14.0,
"min": 6.0,
"max": 96.0,
"step": 1.0,
}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"font": ("FONT",),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "render"
DESCRIPTION = (
"Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. "
"The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values."
)
def render(
self,
field: DataField,
colormap: str,
show_scale_bar: bool,
show_color_map: bool,
text_size: float = 1.0,
colormap_map=None,
font=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0
out = field.replace(
colormap=resolved_colormap,
overlays=[
*field.overlays,
{
"kind": "annotation",
"show_scale_bar": bool(show_scale_bar),
"show_color_map": bool(show_color_map),
"text_size": text_size,
"font": normalize_font_spec(font),
},
],
)
return (out,)

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import json
from backend.node_registry import register_node
from backend.data_types import COLORMAPS, DEFAULT_CUSTOM_COLORMAP_STOPS, normalize_colormap_spec
@register_node(display_name="Color Map")
class ColorMap:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["preset", "custom"], {"default": "preset"}),
"preset": (list(COLORMAPS), {
"default": "viridis",
"show_when_widget_value": {"mode": ["preset"]},
}),
"stops": ("STRING", {
"default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)),
"colormap_stops": True,
"show_when_widget_value": {"mode": ["custom"]},
}),
}
}
RETURN_TYPES = ("COLORMAP",)
RETURN_NAMES = ("colormap",)
FUNCTION = "build"
DESCRIPTION = (
"Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours "
"and any number of intermediate stops."
)
def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple:
if mode == "preset":
return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},)
try:
raw_stops = stops if stops is not None else stops_json
stops_data = json.loads(raw_stops or "[]")
except json.JSONDecodeError as exc:
raise ValueError("Custom colormap stops must be valid JSON.") from exc
spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None)
if not (isinstance(spec, dict) and spec.get("mode") == "custom"):
raise ValueError("Custom colormap must include at least min and max colours.")
return (spec,)

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Colormap Adjust")
class ColormapAdjust:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"offset": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}),
"scale": ("FLOAT", {"default": 1.0, "min": 0.05, "max": 4.0, "step": 0.01}),
"auto": ("BUTTON", {"label": "Auto", "set_widgets": {"offset": 0.0, "scale": 1.0}}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
DESCRIPTION = (
"Adjust how a DATA_FIELD maps into its colormap without changing the underlying data. "
"offset and scale operate in normalized display coordinates; Auto resets to the full data range."
)
def process(self, field: DataField, offset: float, scale: float) -> tuple:
scale = float(scale)
if not np.isfinite(scale) or scale <= 0.0:
raise ValueError("Scale must be a positive number.")
return (field.replace(display_offset=float(offset), display_scale=scale),)

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from backend.node_registry import register_node
@register_node(display_name="Coordinate")
class Coordinate:
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("COORD",)
RETURN_NAMES = ("point",)
FUNCTION = "process"
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
def process(self, x: float, y: float) -> tuple:
return ((float(x), float(y)),)

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from backend.node_registry import register_node
@register_node(display_name="Coordinate Pair")
class CoordinatePair:
"""Provide a pair of Coordinates, for drawing lines between markers, etc."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("COORD",),
"b": ("COORD",),
}
}
RETURN_TYPES = ("COORDPAIR",)
RETURN_NAMES = ("coord pair",)
FUNCTION = "process"
DESCRIPTION = "Output a pair of coordinates."
def process(self, a: tuple, b: tuple) -> tuple:
return ((a, b),)

View File

@@ -1,52 +1,9 @@
"""
Modify nodes geometric transforms for DATA_FIELDs.
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
# ---------------------------------------------------------------------------
# ColormapAdjust
# ---------------------------------------------------------------------------
@register_node(display_name="Colormap Adjust")
class ColormapAdjust:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"offset": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}),
"scale": ("FLOAT", {"default": 1.0, "min": 0.05, "max": 4.0, "step": 0.01}),
"auto": ("BUTTON", {"label": "Auto", "set_widgets": {"offset": 0.0, "scale": 1.0}}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
DESCRIPTION = (
"Adjust how a DATA_FIELD maps into its colormap without changing the underlying data. "
"offset and scale operate in normalized display coordinates; Auto resets to the full data range."
)
def process(self, field: DataField, offset: float, scale: float) -> tuple:
scale = float(scale)
if not np.isfinite(scale) or scale <= 0.0:
raise ValueError("Scale must be a positive number.")
return (field.replace(display_offset=float(offset), display_scale=scale),)
# ---------------------------------------------------------------------------
# CropResizeField
# ---------------------------------------------------------------------------
@register_node(display_name="Crop / Resize")
class CropResizeField:
@classmethod
@@ -190,105 +147,3 @@ class CropResizeField:
target_height = max(1, int(round(height * (target_width / width))))
return (max(1, target_width), max(1, target_height))
# ---------------------------------------------------------------------------
# RotateField
# ---------------------------------------------------------------------------
@register_node(display_name="Rotate")
class RotateField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"angle": ("FLOAT", {"default": 90.0, "min": -360.0, "max": 360.0, "step": 1.0}),
"interpolation": (["bilinear", "nearest", "bicubic"],),
"expand_canvas": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
DESCRIPTION = (
"Rotate a DATA_FIELD counterclockwise by an angle in degrees. "
"Optionally expand the canvas to keep the full rotated field while preserving the field center."
)
_broadcast_warning_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
angle: float,
interpolation: str,
expand_canvas: bool,
) -> tuple:
if field.overlays:
self._send_warning("Rotate clears annotation/markup overlays!")
angle = float(angle)
order_map = {
"nearest": 0,
"bilinear": 1,
"bicubic": 3,
}
if interpolation not in order_map:
raise ValueError(f"Unknown interpolation mode: {interpolation}")
normalized_angle = angle % 360.0
snapped_quarters = int(round(normalized_angle / 90.0)) % 4
snapped_angle = snapped_quarters * 90.0
is_right_angle = abs(normalized_angle - snapped_angle) < 1e-9
if is_right_angle and expand_canvas:
rotated = np.rot90(field.data, k=snapped_quarters).copy()
elif abs(normalized_angle) < 1e-9:
rotated = field.data.copy()
else:
from scipy.ndimage import rotate as nd_rotate
rotated = nd_rotate(
field.data,
angle=angle,
reshape=bool(expand_canvas),
order=order_map[interpolation],
mode="nearest",
prefilter=order_map[interpolation] > 1,
)
new_xreal, new_yreal = self._rotated_extents(field, angle, expand_canvas)
center_x = field.xoff + field.xreal / 2.0
center_y = field.yoff + field.yreal / 2.0
result = field.replace(
data=np.asarray(rotated, dtype=np.float64),
xreal=new_xreal,
yreal=new_yreal,
xoff=center_x - new_xreal / 2.0,
yoff=center_y - new_yreal / 2.0,
overlays=[],
)
return (result,)
def _send_warning(self, message: str):
fn = RotateField._broadcast_warning_fn
nid = RotateField._current_node_id
if fn and nid:
fn(nid, message)
@staticmethod
def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]:
if not expand_canvas:
return (field.xreal, field.yreal)
theta = np.deg2rad(angle)
cos_t = abs(float(np.cos(theta)))
sin_t = abs(float(np.sin(theta)))
new_xreal = field.xreal * cos_t + field.yreal * sin_t
new_yreal = field.xreal * sin_t + field.yreal * cos_t
return (new_xreal, new_yreal)

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _extend_to_edges
@register_node(display_name="Cross Section")
class CrossSection:
"""Extract a 1-D height profile along an arbitrary line across the image."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"x1": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"extend": (["none", "to_edges"],),
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
},
"optional": {
"marker_pair": ("COORDPAIR", {"label": "marker pair"}),
},
}
RETURN_TYPES = ("LINE", "COORDPAIR",)
RETURN_NAMES = ("profile", "marker pair",)
FUNCTION = "process"
DESCRIPTION = (
"Extract a cross-section profile along a line between two points. "
"Drag the markers on the image to set the line endpoints. "
"Equivalent to gwy_data_field_get_profile."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self, field: DataField,
x1: float, y1: float, x2: float, y2: float,
extend: str, n_samples: int,
marker_pair=None,
) -> tuple:
from scipy.ndimage import map_coordinates
if marker_pair is not None:
(x1, y1), (x2, y2) = marker_pair
marker_x1, marker_y1 = float(x1), float(y1)
marker_x2, marker_y2 = float(x2), float(y2)
xres, yres = field.xres, field.yres
if extend == "to_edges":
x1, y1, x2, y2 = _extend_to_edges(
float(x1), float(y1), float(x2), float(y2),
)
px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1)
px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1)
line_len_px = np.hypot(px2 - px1, py2 - py1)
if n_samples <= 0:
n_samples = max(2, int(np.ceil(line_len_px)))
t = np.linspace(0, 1, n_samples)
coords_y = py1 + t * (py2 - py1)
coords_x = px1 + t * (px2 - px1)
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
if CrossSection._broadcast_overlay_fn is not None:
image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
CrossSection._broadcast_overlay_fn(
CrossSection._current_node_id,
{
"image": image_uri,
"x1": marker_x1, "y1": marker_y1,
"x2": marker_x2, "y2": marker_y2,
"a_locked": marker_pair is not None,
"b_locked": marker_pair is not None,
},
)
dx_real = (x2 - x1) * field.xreal
dy_real = (y2 - y1) * field.yreal
distance_axis = np.linspace(0.0, float(np.hypot(dx_real, dy_real)), n_samples, dtype=np.float64)
return (
LineData(
data=profile.astype(np.float64),
x_axis=distance_axis,
x_unit=field.si_unit_xy,
y_unit=field.si_unit_z,
),
((marker_x1, marker_y1), (marker_x2, marker_y2)),
)

173
backend/nodes/cursors.py Normal file
View File

@@ -0,0 +1,173 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, LineData, MeasureTable, encode_preview, render_datafield_preview
@register_node(display_name="Cursors")
class Cursors:
"""Place two draggable cursors on a line plot or field to measure deltas."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("CURSOR_SOURCE", {"label": "input"}),
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
},
"optional": {
"coord_pair": ("COORDPAIR", {"label": "coord pair"}),
},
}
RETURN_TYPES = ("MEASURE_TABLE", "COORDPAIR",)
RETURN_NAMES = ("measurement", "coord pair",)
FUNCTION = "process"
DESCRIPTION = (
"Place two cursors on a line plot or 2D field. "
"On lines it reports x/y positions and dx/dy. "
"On fields it reports x/y/z at both markers plus dx/dy/dz."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self, line, x1: float, y1: float, x2: float, y2: float,
coord_pair=None,
) -> tuple:
if coord_pair is not None:
(x1, y1), (x2, y2) = coord_pair
locked = coord_pair is not None
if isinstance(line, DataField):
return self._process_field(line, x1=x1, y1=y1, x2=x2, y2=y2, locked=locked)
return self._process_line(line, x1=x1, y1=y1, x2=x2, y2=y2, locked=locked)
def _process_line(
self,
line,
x1: float,
y1: float,
x2: float,
y2: float,
locked: bool = False,
) -> tuple:
y = np.asarray(line, dtype=np.float64).ravel()
x_unit = line.x_unit if isinstance(line, LineData) else ""
y_unit = line.y_unit if isinstance(line, LineData) else ""
n = len(y)
if isinstance(line, LineData) and line.x_axis is not None:
x = np.asarray(line.x_axis, dtype=np.float64).ravel()[:n]
else:
x = np.arange(n, dtype=np.float64)
x1 = float(np.clip(x1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 1.0))
xmin = float(np.min(x)) if len(x) else 0.0
xmax = float(np.max(x)) if len(x) else 1.0
def x_frac_to_idx(frac):
if n <= 1:
return 0
if xmax == xmin:
return 0
target_x = xmin + frac * (xmax - xmin)
return int(np.argmin(np.abs(x - target_x)))
idx_a = x_frac_to_idx(x1)
idx_b = x_frac_to_idx(x2)
xa, ya = float(x[idx_a]), float(y[idx_a])
xb, yb = float(x[idx_b]), float(y[idx_b])
if Cursors._broadcast_overlay_fn is not None:
Cursors._broadcast_overlay_fn(
Cursors._current_node_id,
{
"kind": "line_plot",
"section_title": "Cursors",
"line": y.tolist(),
"x_axis": x.tolist(),
"x1": x1,
"x2": x2,
"y1": float(y1),
"y2": float(y2),
"a_locked": locked,
"b_locked": locked,
},
)
table = MeasureTable([
{"quantity": "A x", "value": xa, "unit": x_unit},
{"quantity": "A y", "value": ya, "unit": y_unit},
{"quantity": "B x", "value": xb, "unit": x_unit},
{"quantity": "B y", "value": yb, "unit": y_unit},
{"quantity": "dx", "value": xb - xa, "unit": x_unit},
{"quantity": "dy", "value": yb - ya, "unit": y_unit},
])
return (table, ((x1, y1), (x2, y2)))
def _process_field(
self,
field: DataField,
x1: float,
y1: float,
x2: float,
y2: float,
locked: bool = False,
) -> tuple:
from scipy.ndimage import map_coordinates
x1 = float(np.clip(x1, 0.0, 1.0))
y1 = float(np.clip(y1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 1.0))
y2 = float(np.clip(y2, 0.0, 1.0))
px1 = x1 * max(field.xres - 1, 0)
py1 = y1 * max(field.yres - 1, 0)
px2 = x2 * max(field.xres - 1, 0)
py2 = y2 * max(field.yres - 1, 0)
z1 = float(map_coordinates(field.data, [[py1], [px1]], order=1, mode="nearest")[0])
z2 = float(map_coordinates(field.data, [[py2], [px2]], order=1, mode="nearest")[0])
ax = float(field.xoff + x1 * field.xreal)
ay = float(field.yoff + y1 * field.yreal)
bx = float(field.xoff + x2 * field.xreal)
by = float(field.yoff + y2 * field.yreal)
if Cursors._broadcast_overlay_fn is not None:
Cursors._broadcast_overlay_fn(
Cursors._current_node_id,
{
"kind": "cursor_points",
"section_title": "Cursors",
"image": encode_preview(render_datafield_preview(field, field.colormap)),
"x1": x1,
"y1": y1,
"x2": x2,
"y2": y2,
"a_locked": locked,
"b_locked": locked,
},
)
table = MeasureTable([
{"quantity": "A x", "value": ax, "unit": field.si_unit_xy},
{"quantity": "A y", "value": ay, "unit": field.si_unit_xy},
{"quantity": "A z", "value": z1, "unit": field.si_unit_z},
{"quantity": "B x", "value": bx, "unit": field.si_unit_xy},
{"quantity": "B y", "value": by, "unit": field.si_unit_xy},
{"quantity": "B z", "value": z2, "unit": field.si_unit_z},
{"quantity": "dx", "value": bx - ax, "unit": field.si_unit_xy},
{"quantity": "dy", "value": by - ay, "unit": field.si_unit_xy},
{"quantity": "dz", "value": z2 - z1, "unit": field.si_unit_z},
])
return (table, ((x1, y1), (x2, y2)))

View File

@@ -1,743 +0,0 @@
"""
Display / output nodes.
Preview accepts both DATA_FIELD and IMAGE via optional inputs —
connect whichever type you have. The server injects _broadcast_fn
before execution begins.
"""
from __future__ import annotations
import json
import numpy as np
from backend.node_registry import register_node
from backend.data_types import (
DataField,
MeasureTable,
COLORMAPS,
CUSTOM_FILE_FONT,
DEFAULT_CUSTOM_COLORMAP_STOPS,
SYSTEM_DEFAULT_FONT,
colormap_to_uint8,
datafield_to_uint8,
encode_preview,
image_to_uint8,
list_overlay_font_choices,
normalize_colormap_spec,
normalize_font_spec,
normalize_for_colormap,
render_datafield_preview,
resolve_colormap_input,
)
def _measurement_names(table: list) -> list[str]:
names = []
for row in table:
if not isinstance(row, dict):
continue
quantity = row.get("quantity")
if isinstance(quantity, str) and quantity and quantity not in names:
names.append(quantity)
return names
def _measurement_entry(table: list, selection: str) -> dict:
names = _measurement_names(table)
if not names:
raise ValueError("Measurement table has no selectable rows.")
target = selection if selection in names else names[0]
for row in table:
if isinstance(row, dict) and row.get("quantity") == target:
return row
raise ValueError(f"Measurement '{target}' was not found.")
def _measurement_value(table: list, selection: str) -> float:
row = _measurement_entry(table, selection)
value = row.get("value")
if isinstance(value, bool):
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
try:
numeric = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") from exc
if np.isfinite(numeric):
return numeric
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
def _scalar_payload(value: float, unit: str = "") -> dict:
payload = {"value": float(value)}
if isinstance(unit, str) and unit.strip():
payload["unit"] = unit.strip()
return payload
_SI_PREFIXES = [
(1e24, "Y"),
(1e21, "Z"),
(1e18, "E"),
(1e15, "P"),
(1e12, "T"),
(1e9, "G"),
(1e6, "M"),
(1e3, "k"),
(1.0, ""),
(1e-3, "m"),
(1e-6, "u"),
(1e-9, "n"),
(1e-12, "p"),
(1e-15, "f"),
(1e-18, "a"),
(1e-21, "z"),
(1e-24, "y"),
]
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "Ω"}
def _format_numeric(value: float) -> str:
if not np.isfinite(value):
return str(value)
abs_value = abs(value)
if abs_value == 0:
return "0"
if abs_value >= 1e4 or abs_value < 1e-3:
return f"{value:.3e}"
return f"{value:.4g}"
def _format_with_unit(value: float, unit: str) -> str:
unit = (unit or "").strip()
if not unit:
return _format_numeric(value)
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
abs_value = abs(value)
for scale, prefix in _SI_PREFIXES:
scaled = abs_value / scale
if 1 <= scaled < 1000:
signed = value / scale
return f"{_format_numeric(signed)} {prefix}{unit}"
return f"{_format_numeric(value)} {unit}"
def _nice_length(target: float) -> float:
if not np.isfinite(target) or target <= 0:
return 0.0
exponent = np.floor(np.log10(target))
base = 10.0 ** exponent
for step in (5.0, 2.0, 1.0):
candidate = step * base
if candidate <= target:
return candidate
return base
def _display_value_range(field: DataField) -> tuple[float, float, float]:
data = np.asarray(field.data, dtype=np.float64)
dmin = float(data.min())
dmax = float(data.max())
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
return dmin, dmin, dmin
offset = float(field.display_offset)
scale = float(field.display_scale)
if not np.isfinite(offset):
offset = 0.0
if not np.isfinite(scale) or scale <= 0.0:
scale = 1.0
low_norm = float(np.clip(offset, 0.0, 1.0))
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
if high_norm < low_norm:
low_norm, high_norm = high_norm, low_norm
mid_norm = 0.5 * (low_norm + high_norm)
span = dmax - dmin
return (
dmin + low_norm * span,
dmin + mid_norm * span,
dmin + high_norm * span,
)
def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]):
from PIL import Image, ImageDraw, ImageFont
size_px = max(8, int(round(size_px)))
try:
font = ImageFont.truetype("DejaVuSans.ttf", size_px)
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
text_draw = ImageDraw.Draw(text_image)
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
return text_image
except Exception:
font = ImageFont.load_default()
probe = Image.new("L", (1, 1), 0)
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
mask = Image.new("L", (width, height), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
scale = max(1.0, size_px / max(1, height))
scaled_width = max(1, int(round(width * scale)))
scaled_height = max(1, int(round(height * scale)))
resampling = getattr(Image, "Resampling", Image)
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR)
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
text_image.putalpha(scaled_mask)
return text_image
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
if isinstance(color, str):
text = color.strip()
if len(text) == 4 and text.startswith("#"):
text = "#" + "".join(ch * 2 for ch in text[1:])
if len(text) == 7 and text.startswith("#"):
try:
int(text[1:], 16)
return text.lower()
except ValueError:
pass
return default
def _parse_markup_shapes(raw_shapes: str | list | None) -> list[dict[str, object]]:
if isinstance(raw_shapes, str):
try:
raw_shapes = json.loads(raw_shapes or "[]")
except json.JSONDecodeError:
raw_shapes = []
if not isinstance(raw_shapes, list):
return []
parsed: list[dict[str, object]] = []
for shape in raw_shapes:
if not isinstance(shape, dict):
continue
kind = str(shape.get("kind", "")).strip().lower()
if kind not in {"line", "rectangle", "circle", "arrow"}:
continue
try:
x1 = float(shape.get("x1"))
y1 = float(shape.get("y1"))
x2 = float(shape.get("x2"))
y2 = float(shape.get("y2"))
width = int(round(float(shape.get("width", 3))))
except (TypeError, ValueError):
continue
coords = [x1, y1, x2, y2]
if not all(np.isfinite(value) for value in coords):
continue
parsed.append({
"kind": kind,
"x1": float(np.clip(x1, 0.0, 1.0)),
"y1": float(np.clip(y1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y2": float(np.clip(y2, 0.0, 1.0)),
"width": max(1, min(128, width)),
"color": _normalize_markup_color(shape.get("color")),
})
return parsed
def _draw_arrow(draw, start: tuple[float, float], end: tuple[float, float], color: str, width: int):
dx = end[0] - start[0]
dy = end[1] - start[1]
length = float(np.hypot(dx, dy))
if length <= 1e-6:
radius = max(1.0, width / 2.0)
draw.ellipse(
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
fill=color,
)
return
ux = dx / length
uy = dy / length
head_length = max(10.0, width * 4.0)
head_width = max(8.0, width * 3.0)
shaft_end = (
end[0] - ux * head_length,
end[1] - uy * head_length,
)
draw.line((start, shaft_end), fill=color, width=width)
px = -uy
py = ux
left = (
shaft_end[0] + px * head_width / 2.0,
shaft_end[1] + py * head_width / 2.0,
)
right = (
shaft_end[0] - px * head_width / 2.0,
shaft_end[1] - py * head_width / 2.0,
)
draw.polygon([end, left, right], fill=color)
def _render_markup_image(image: np.ndarray, shapes: list[dict[str, object]]) -> np.ndarray:
from PIL import Image, ImageDraw
base = image_to_uint8(image)
if base.ndim == 2:
base = np.repeat(base[:, :, np.newaxis], 3, axis=2)
canvas = Image.fromarray(base.copy())
draw = ImageDraw.Draw(canvas)
height, width = base.shape[:2]
for shape in shapes:
x1 = float(shape["x1"]) * width
y1 = float(shape["y1"]) * height
x2 = float(shape["x2"]) * width
y2 = float(shape["y2"]) * height
color = str(shape["color"])
stroke_width = int(shape["width"])
kind = str(shape["kind"])
if kind == "line":
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
elif kind == "rectangle":
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "circle":
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "arrow":
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
return np.asarray(canvas, dtype=np.uint8)
@register_node(display_name="Color Map")
class ColorMap:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (["preset", "custom"], {"default": "preset"}),
"preset": (list(COLORMAPS), {
"default": "viridis",
"show_when_widget_value": {"mode": ["preset"]},
}),
"stops": ("STRING", {
"default": json.dumps(list(DEFAULT_CUSTOM_COLORMAP_STOPS)),
"colormap_stops": True,
"show_when_widget_value": {"mode": ["custom"]},
}),
}
}
RETURN_TYPES = ("COLORMAP",)
RETURN_NAMES = ("colormap",)
FUNCTION = "build"
DESCRIPTION = (
"Build a reusable colormap. Choose a preset, or create a custom gradient with min/max colours "
"and any number of intermediate stops."
)
def build(self, mode: str, preset: str, stops: str | None = None, stops_json: str | None = None) -> tuple:
if mode == "preset":
return ({"mode": "preset", "preset": normalize_colormap_spec(preset)},)
try:
raw_stops = stops if stops is not None else stops_json
stops_data = json.loads(raw_stops or "[]")
except json.JSONDecodeError as exc:
raise ValueError("Custom colormap stops must be valid JSON.") from exc
spec = normalize_colormap_spec({"mode": "custom", "stops": stops_data}, fallback=None)
if not (isinstance(spec, dict) and spec.get("mode") == "custom"):
raise ValueError("Custom colormap must include at least min and max colours.")
return (spec,)
@register_node(display_name="Font")
class Font:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], {
"default": SYSTEM_DEFAULT_FONT,
}),
"font_file": ("FILE_PICKER", {
"default": "",
"show_when_widget_value": {"family": [CUSTOM_FILE_FONT]},
}),
}
}
RETURN_TYPES = ("FONT",)
RETURN_NAMES = ("font",)
FUNCTION = "build"
DESCRIPTION = (
"Build a reusable font spec for annotation overlays. Choose a discovered system font, "
"use the default fallback stack, or point to a custom font file."
)
def build(self, family: str, font_file: str = "") -> tuple:
if family == SYSTEM_DEFAULT_FONT:
return (None,)
if family == CUSTOM_FILE_FONT:
return (normalize_font_spec({"path": font_file}),)
return (normalize_font_spec({"family": family}),)
@register_node(display_name="Annotations")
class Annotations:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"show_scale_bar": ("BOOLEAN", {"default": True}),
"show_color_map": ("BOOLEAN", {"default": True}),
"text_size": ("FLOAT", {
"default": 14.0,
"min": 6.0,
"max": 96.0,
"step": 1.0,
}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"font": ("FONT",),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "render"
DESCRIPTION = (
"Attach optional publication-style annotations to a DATA_FIELD without flattening the raw data. "
"The preview shows a scale bar and/or side colour legend, while downstream field operations keep the underlying AFM values."
)
def render(
self,
field: DataField,
colormap: str,
show_scale_bar: bool,
show_color_map: bool,
text_size: float = 1.0,
colormap_map=None,
font=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
text_size = float(np.clip(text_size, 6.0, 96.0)) if np.isfinite(text_size) else 14.0
out = field.replace(
colormap=resolved_colormap,
overlays=[
*field.overlays,
{
"kind": "annotation",
"show_scale_bar": bool(show_scale_bar),
"show_color_map": bool(show_color_map),
"text_size": text_size,
"font": normalize_font_spec(font),
},
],
)
return (out,)
@register_node(display_name="Markup")
class Markup:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}),
"stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}),
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),
"clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}),
"markup_shapes": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "process"
DESCRIPTION = (
"Draw simple vector markup over a DATA_FIELD without flattening the underlying data. "
"Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
shape: str,
stroke_color: str,
stroke_width: int,
markup_shapes: str,
) -> tuple:
shapes = _parse_markup_shapes(markup_shapes)
out = field.replace(
overlays=[
*field.overlays,
{
"kind": "markup",
"shapes": shapes,
},
],
)
if Markup._broadcast_overlay_fn is not None:
Markup._broadcast_overlay_fn(
Markup._current_node_id,
{
"kind": "markup",
"section_title": "Markup",
"image": encode_preview(datafield_to_uint8(field, field.colormap)),
"shape": str(shape),
"stroke_color": _normalize_markup_color(stroke_color),
"stroke_width": max(1, int(stroke_width)),
},
)
return (out,)
@register_node(display_name="Preview")
class PreviewImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"image": ("IMAGE",),
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ()
FUNCTION = "preview"
OUTPUT_NODE = True
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
_broadcast_fn = None
_current_node_id: str = ""
def preview(
self,
colormap: str,
image: np.ndarray | None = None,
field=None,
colormap_map=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap if field is not None else None,
default="gray",
)
# Prefer field if both are connected; accept whichever is provided
if field is not None:
arr_u8 = render_datafield_preview(field, resolved_colormap)
elif image is not None:
arr_u8 = image_to_uint8(image)
if arr_u8.ndim == 2:
if image.dtype == np.uint8:
normalized = arr_u8.astype(np.float64) / 255.0
else:
imin, imax = image.min(), image.max()
if imax > imin:
normalized = (image - imin) / (imax - imin)
else:
normalized = np.zeros_like(image, dtype=np.float64)
arr_u8 = colormap_to_uint8(normalized, resolved_colormap)
else:
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
data_uri = encode_preview(arr_u8)
if PreviewImage._broadcast_fn is not None:
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
return ()
@register_node(display_name="3D View")
class View3D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ()
FUNCTION = "render"
OUTPUT_NODE = True
DESCRIPTION = (
"Interactive 3D surface view of a DATA_FIELD. "
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
)
_broadcast_mesh_fn = None
_current_node_id: str = ""
def render(
self, field: DataField,
colormap: str, z_scale: float, resolution: int, colormap_map=None,
) -> tuple:
import base64
data = field.data
yres, xres = data.shape
# Downsample if larger than resolution
step_y = max(1, yres // resolution)
step_x = max(1, xres // resolution)
z = data[::step_y, ::step_x].astype(np.float32)
ny, nx = z.shape
# Normalize for colormap
zmin, zmax = float(z.min()), float(z.max())
z_norm = normalize_for_colormap(
z,
offset=field.display_offset,
scale=field.display_scale,
data_min=float(field.data.min()),
data_max=float(field.data.max()),
)
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
colors_u8 = colormap_to_uint8(z_norm, resolved_colormap)
# Base64-encode arrays for efficient WS transport
z_b64 = base64.b64encode(z.tobytes()).decode()
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
mesh_data = {
"width": nx,
"height": ny,
"z_data": z_b64,
"colors": colors_b64,
"z_min": zmin,
"z_max": zmax,
"z_scale": float(z_scale * 0.1),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}
if View3D._broadcast_mesh_fn is not None:
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
return ()
@register_node(display_name="Print Table")
class PrintTable:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"table": ("ANY_TABLE",),
}
}
RETURN_TYPES = ()
FUNCTION = "print_table"
OUTPUT_NODE = True
DESCRIPTION = "Send a measurement or record table to the browser as a WebSocket message for display."
_broadcast_table_fn = None
_current_node_id: str = ""
def print_table(self, table: list) -> tuple:
if PrintTable._broadcast_table_fn is not None:
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
return ()
@register_node(display_name="Value Display")
class ValueDisplay:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("VALUE_SOURCE",),
"measurement": ("STRING", {
"default": "",
"choices_from_measure_input": "value",
"show_when_source_type": {
"value": ["MEASURE_TABLE"],
},
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "display_value"
DESCRIPTION = "Display a FLOAT, or a selected numeric row from a measurement table, and pass the value through unchanged."
_broadcast_value_fn = None
_current_node_id: str = ""
def display_value(self, value, measurement: str = "") -> tuple:
unit = ""
if isinstance(value, MeasureTable):
row = _measurement_entry(value, measurement)
numeric = _measurement_value(value, measurement)
unit = row.get("unit", "") if isinstance(row.get("unit"), str) else ""
else:
numeric = float(value)
if ValueDisplay._broadcast_value_fn is not None:
ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit))
return (numeric,)

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _parse_mask_strokes, _rasterize_mask
@register_node(display_name="Draw Mask")
class DrawMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}),
"invert": ("BOOLEAN", {"default": False}),
"clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}),
"mask_paths": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Paint a binary mask directly over an image preview. "
"Pen size controls newly drawn strokes, the overlay lets you clear the mask, "
"and invert flips the final binary output."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple:
strokes = _parse_mask_strokes(mask_paths)
mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size)
if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None:
DrawMask._broadcast_overlay_fn(
DrawMask._current_node_id,
{
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
},
)
return (mask,)

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Edge Detect")
class EdgeDetect:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["sobel", "prewitt", "laplacian", "log"],),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("edges",)
FUNCTION = "process"
DESCRIPTION = (
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
)
def process(self, field: DataField, method: str, sigma: float) -> tuple:
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
data = field.data
if method == "sobel":
sx = sobel(data, axis=1)
sy = sobel(data, axis=0)
result = np.hypot(sx, sy)
elif method == "prewitt":
px = prewitt(data, axis=1)
py = prewitt(data, axis=0)
result = np.hypot(px, py)
elif method == "laplacian":
result = laplace(data)
elif method == "log":
result = gaussian_laplace(data, sigma=float(sigma))
else:
raise ValueError(f"Unknown edge detection method: {method}")
return (field.replace(data=result),)

115
backend/nodes/fft_2d.py Normal file
View File

@@ -0,0 +1,115 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="2D FFT")
class FFT2D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"windowing": (["hann", "hamming", "blackman", "none"],),
"level": (["mean", "plane", "none"],),
}
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "DATA_FIELD", "DATA_FIELD")
RETURN_NAMES = ("log_magnitude", "magnitude", "phase", "psdf")
FUNCTION = "process"
DESCRIPTION = (
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
"Outputs log magnitude, magnitude, phase, and PSDF as separate channels. "
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
)
def process(self, field: DataField, windowing: str, level: str) -> tuple:
data = field.data.copy()
yres, xres = data.shape
if level == "mean":
data -= data.mean()
elif level == "plane":
yy, xx = np.mgrid[0:yres, 0:xres]
xx_f = xx.ravel().astype(np.float64)
yy_f = yy.ravel().astype(np.float64)
zz_f = data.ravel()
A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f])
coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None)
plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy)
data -= plane
if windowing != "none":
t_y = (np.arange(yres) + 0.5) / yres
t_x = (np.arange(xres) + 0.5) / xres
if windowing == "hann":
wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y)
wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x)
elif windowing == "hamming":
wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y)
wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x)
elif windowing == "blackman":
wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y)
wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x)
else:
wy = np.ones(yres)
wx = np.ones(xres)
data *= np.outer(wy, wx)
F = np.fft.fftshift(np.fft.fft2(data))
n = xres * yres
magnitude = np.abs(F)
log_magnitude = np.log1p(magnitude)
phase = np.angle(F)
dx = field.xreal / xres
dy = field.yreal / yres
psdf = (magnitude ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
spatial_freq_xreal = xres / field.xreal
spatial_freq_yreal = yres / field.yreal
angular_freq_xreal = 2.0 * np.pi * xres / field.xreal
angular_freq_yreal = 2.0 * np.pi * yres / field.yreal
return (
DataField(
data=log_magnitude,
xreal=spatial_freq_xreal,
yreal=spatial_freq_yreal,
si_unit_xy="1/m",
si_unit_z=field.si_unit_z,
domain="frequency",
colormap=field.colormap,
),
DataField(
data=magnitude,
xreal=spatial_freq_xreal,
yreal=spatial_freq_yreal,
si_unit_xy="1/m",
si_unit_z=field.si_unit_z,
domain="frequency",
colormap=field.colormap,
),
DataField(
data=phase,
xreal=spatial_freq_xreal,
yreal=spatial_freq_yreal,
si_unit_xy="1/m",
si_unit_z=field.si_unit_z,
domain="frequency",
colormap=field.colormap,
),
DataField(
data=psdf,
xreal=angular_freq_xreal,
yreal=angular_freq_yreal,
si_unit_xy="1/m",
si_unit_z=f"({field.si_unit_z})^2 m^2",
domain="frequency",
colormap=field.colormap,
),
)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import LineData
from backend.nodes.helpers import _cached_1d_transfer
@register_node(display_name="1D FFT Filter")
class FFTFilter1D:
"""Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
transfer function with configurable order for a smooth roll-off.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("LINE",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("LINE",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = (
"Frequency-domain filtering of a 1-D line profile. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
"Equivalent to Gwyddion fft_filter_1d."
)
def process(self, line, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
z = np.asarray(line, dtype=np.float64).ravel()
n = len(z)
Z = np.fft.rfft(z)
H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order))
Z *= H
filtered = np.fft.irfft(Z, n=n)
if isinstance(line, LineData):
return (
LineData(
data=filtered,
x_axis=line.x_axis.copy() if line.x_axis is not None else None,
x_unit=line.x_unit,
y_unit=line.y_unit,
),
)
return (filtered,)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
from backend.nodes.helpers import _cached_2d_transfer
@register_node(display_name="2D FFT Filter")
class FFTFilter2D:
"""Frequency-domain filtering of 2-D data fields (images).
Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
Butterworth transfer function in the frequency domain to remove or
isolate periodic features.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = (
"Frequency-domain filtering of a 2-D data field. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a radial Butterworth roll-off. Cutoffs are fractions of the "
"Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
"bandpass/notch to isolate or remove periodic noise. "
"Equivalent to Gwyddion fft_filter_2d."
)
def process(self, field: DataField, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
data = field.data
yres, xres = data.shape
mean_val = float(data.mean())
centered = data - mean_val
spectrum = np.fft.rfft2(centered)
transfer = _cached_2d_transfer(
yres, xres, filter_type,
float(cutoff), float(cutoff_high), int(order),
)
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
result += mean_val
return (field.replace(data=result),)

View File

@@ -1,332 +0,0 @@
"""
Filter nodes — Gwyddion-equivalent image filters.
Gwyddion equivalents:
GaussianFilter → gwy_data_field_filter_gaussian
MedianFilter → gwy_data_field_filter_median
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
FFTFilter1D → fft_filter_1d.c (bandpass/lowpass/highpass on LINE profiles)
FFTFilter2D → fft_filter_2d.c (frequency-domain filtering of DATA_FIELDs)
"""
from __future__ import annotations
from functools import lru_cache
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, LineData
# ---------------------------------------------------------------------------
# GaussianFilter
# ---------------------------------------------------------------------------
@register_node(display_name="Gaussian Filter")
class GaussianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
def process(self, field: DataField, sigma: float) -> tuple:
from scipy.ndimage import gaussian_filter
data = gaussian_filter(field.data, sigma=float(sigma))
return (field.replace(data=data),)
# ---------------------------------------------------------------------------
# MedianFilter
# ---------------------------------------------------------------------------
@register_node(display_name="Median Filter")
class MedianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
def process(self, field: DataField, size: int) -> tuple:
from scipy.ndimage import median_filter
size = max(1, int(size))
data = median_filter(field.data, size=size)
return (field.replace(data=data),)
# ---------------------------------------------------------------------------
# EdgeDetect
# ---------------------------------------------------------------------------
@register_node(display_name="Edge Detect")
class EdgeDetect:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["sobel", "prewitt", "laplacian", "log"],),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("edges",)
FUNCTION = "process"
DESCRIPTION = (
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
)
def process(self, field: DataField, method: str, sigma: float) -> tuple:
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
data = field.data
if method == "sobel":
sx = sobel(data, axis=1)
sy = sobel(data, axis=0)
result = np.hypot(sx, sy)
elif method == "prewitt":
px = prewitt(data, axis=1)
py = prewitt(data, axis=0)
result = np.hypot(px, py)
elif method == "laplacian":
result = laplace(data)
elif method == "log":
result = gaussian_laplace(data, sigma=float(sigma))
else:
raise ValueError(f"Unknown edge detection method: {method}")
return (field.replace(data=result),)
# ---------------------------------------------------------------------------
# Butterworth transfer function helpers
# ---------------------------------------------------------------------------
def _butterworth_lp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
"""Butterworth lowpass: H = 1 / (1 + (f/fc)^(2n))."""
with np.errstate(divide="ignore", over="ignore"):
return 1.0 / (1.0 + (freq / cutoff) ** (2 * order))
def _butterworth_hp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
"""Butterworth highpass: H = 1 / (1 + (fc/f)^(2n))."""
with np.errstate(divide="ignore", invalid="ignore"):
h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order))
h = np.where(np.isfinite(h), h, 0.0)
return h
def _build_1d_transfer(n: int, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> np.ndarray:
"""Build a 1-D transfer function for an FFT of length *n*.
Frequencies are normalised so that 1.0 = Nyquist (fs/2).
The returned array has the same layout as np.fft.rfft output (length n//2+1).
"""
freq = np.linspace(0, 1, n // 2 + 1)
if filter_type == "lowpass":
H = _butterworth_lp(freq, cutoff, order)
elif filter_type == "highpass":
H = _butterworth_hp(freq, cutoff, order)
elif filter_type == "bandpass":
H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
elif filter_type == "notch":
bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
H = 1.0 - bp
else:
H = np.ones_like(freq)
return H
@lru_cache(maxsize=64)
def _cached_1d_transfer(n: int, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> np.ndarray:
transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
transfer.setflags(write=False)
return transfer
@lru_cache(maxsize=32)
def _fft_radius_grid(yres: int, xres: int) -> np.ndarray:
fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0
fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0
radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0)
np.clip(radius, 0.0, 1.0, out=radius)
radius.setflags(write=False)
return radius
@lru_cache(maxsize=128)
def _cached_2d_transfer(yres: int, xres: int, filter_type: str,
cutoff: float, cutoff_high: float, order: int) -> np.ndarray:
radius = _fft_radius_grid(yres, xres)
if filter_type == "lowpass":
transfer = _butterworth_lp(radius, cutoff, order)
elif filter_type == "highpass":
transfer = _butterworth_hp(radius, cutoff, order)
elif filter_type == "bandpass":
transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
elif filter_type == "notch":
band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
transfer = 1.0 - band
else:
transfer = np.ones_like(radius)
transfer.setflags(write=False)
return transfer
# ---------------------------------------------------------------------------
# FFTFilter1D — frequency-domain filtering of LINE profiles
# ---------------------------------------------------------------------------
@register_node(display_name="1D FFT Filter")
class FFTFilter1D:
"""Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
transfer function with configurable order for a smooth roll-off.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("LINE",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("LINE",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = (
"Frequency-domain filtering of a 1-D line profile. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
"Equivalent to Gwyddion fft_filter_1d."
)
def process(self, line, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
z = np.asarray(line, dtype=np.float64).ravel()
n = len(z)
# Forward FFT (real-valued)
Z = np.fft.rfft(z)
# Build and apply transfer function
H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order))
Z *= H
# Inverse FFT
filtered = np.fft.irfft(Z, n=n)
if isinstance(line, LineData):
return (
LineData(
data=filtered,
x_axis=line.x_axis.copy() if line.x_axis is not None else None,
x_unit=line.x_unit,
y_unit=line.y_unit,
),
)
return (filtered,)
# ---------------------------------------------------------------------------
# FFTFilter2D — frequency-domain filtering of DATA_FIELDs
# ---------------------------------------------------------------------------
@register_node(display_name="2D FFT Filter")
class FFTFilter2D:
"""Frequency-domain filtering of 2-D data fields (images).
Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
Butterworth transfer function in the frequency domain to remove or
isolate periodic features.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = (
"Frequency-domain filtering of a 2-D data field. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a radial Butterworth roll-off. Cutoffs are fractions of the "
"Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
"bandpass/notch to isolate or remove periodic noise. "
"Equivalent to Gwyddion fft_filter_2d."
)
def process(self, field: DataField, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
data = field.data
yres, xres = data.shape
# Subtract mean to avoid DC leakage artefacts.
mean_val = float(data.mean())
centered = data - mean_val
# Real-valued FFT keeps only the unique half-plane and avoids shift copies.
spectrum = np.fft.rfft2(centered)
transfer = _cached_2d_transfer(
yres,
xres,
filter_type,
float(cutoff),
float(cutoff_high),
int(order),
)
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
# Restore DC
result += mean_val
return (field.replace(data=result),)

37
backend/nodes/fix_zero.py Normal file
View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Fix Zero")
class FixZero:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["min", "mean", "median"],),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("zeroed",)
FUNCTION = "process"
DESCRIPTION = (
"Shift data so that the minimum (or mean/median) is zero. "
"Equivalent to fix_zero in Gwyddion's level.c."
)
def process(self, field: DataField, method: str) -> tuple:
data = field.data.copy()
if method == "min":
data -= data.min()
elif method == "mean":
data -= data.mean()
elif method == "median":
data -= np.median(data)
else:
raise ValueError(f"Unknown method: {method}")
return (field.replace(data=data),)

29
backend/nodes/folder.py Normal file
View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.nodes.helpers import list_folder_paths
@register_node(display_name="Folder")
class Folder:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}),
}
}
RETURN_TYPES = ("DIRECTORY",)
RETURN_NAMES = ("directory",)
FUNCTION = "list_files"
DESCRIPTION = (
"Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. "
"Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans."
)
def list_files(self, folder: str) -> tuple:
entries = list_folder_paths(folder)
if not entries:
return tuple()
return tuple(item["path"] for item in entries)

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT, list_overlay_font_choices, normalize_font_spec
@register_node(display_name="Font")
class Font:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"family": ([SYSTEM_DEFAULT_FONT, *list_overlay_font_choices(), CUSTOM_FILE_FONT], {
"default": SYSTEM_DEFAULT_FONT,
}),
"font_file": ("FILE_PICKER", {
"default": "",
"show_when_widget_value": {"family": [CUSTOM_FILE_FONT]},
}),
}
}
RETURN_TYPES = ("FONT",)
RETURN_NAMES = ("font",)
FUNCTION = "build"
DESCRIPTION = (
"Build a reusable font spec for annotation overlays. Choose a discovered system font, "
"use the default fallback stack, or point to a custom font file."
)
def build(self, family: str, font_file: str = "") -> tuple:
if family == SYSTEM_DEFAULT_FONT:
return (None,)
if family == CUSTOM_FILE_FONT:
return (normalize_font_spec({"path": font_file}),)
return (normalize_font_spec({"family": family}),)

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Gaussian Filter")
class GaussianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
def process(self, field: DataField, sigma: float) -> tuple:
from scipy.ndimage import gaussian_filter
data = gaussian_filter(field.data, sigma=float(sigma))
return (field.replace(data=data),)

873
backend/nodes/helpers.py Normal file
View File

@@ -0,0 +1,873 @@
"""
Shared helper functions for argonode nodes.
"""
from __future__ import annotations
import json
from functools import lru_cache
from pathlib import Path
from typing import Callable
import numpy as np
from backend.runtime_paths import demo_dir, input_dir, output_dir
# ---------------------------------------------------------------------------
# Scalar payload helpers (from display.py)
# ---------------------------------------------------------------------------
def _scalar_payload(value: float, unit: str = "") -> dict:
payload = {"value": float(value)}
if isinstance(unit, str) and unit.strip():
payload["unit"] = unit.strip()
return payload
# ---------------------------------------------------------------------------
# Measurement helpers (from display.py — used by ValueDisplay)
# ---------------------------------------------------------------------------
def _measurement_names(table: list) -> list[str]:
names = []
for row in table:
if not isinstance(row, dict):
continue
quantity = row.get("quantity")
if isinstance(quantity, str) and quantity and quantity not in names:
names.append(quantity)
return names
def _measurement_entry(table: list, selection: str) -> dict:
names = _measurement_names(table)
if not names:
raise ValueError("Measurement table has no selectable rows.")
target = selection if selection in names else names[0]
for row in table:
if isinstance(row, dict) and row.get("quantity") == target:
return row
raise ValueError(f"Measurement '{target}' was not found.")
def _measurement_value(table: list, selection: str) -> float:
row = _measurement_entry(table, selection)
value = row.get("value")
if isinstance(value, bool):
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
try:
numeric = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") from exc
if np.isfinite(numeric):
return numeric
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
# ---------------------------------------------------------------------------
# SI formatting helpers (from display.py — used by Annotations)
# ---------------------------------------------------------------------------
_SI_PREFIXES = [
(1e24, "Y"), (1e21, "Z"), (1e18, "E"), (1e15, "P"), (1e12, "T"),
(1e9, "G"), (1e6, "M"), (1e3, "k"), (1.0, ""), (1e-3, "m"),
(1e-6, "u"), (1e-9, "n"), (1e-12, "p"), (1e-15, "f"),
(1e-18, "a"), (1e-21, "z"), (1e-24, "y"),
]
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "\u03a9"}
def _format_numeric(value: float) -> str:
if not np.isfinite(value):
return str(value)
abs_value = abs(value)
if abs_value == 0:
return "0"
if abs_value >= 1e4 or abs_value < 1e-3:
return f"{value:.3e}"
return f"{value:.4g}"
def _format_with_unit(value: float, unit: str) -> str:
unit = (unit or "").strip()
if not unit:
return _format_numeric(value)
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
abs_value = abs(value)
for scale, prefix in _SI_PREFIXES:
scaled = abs_value / scale
if 1 <= scaled < 1000:
signed = value / scale
return f"{_format_numeric(signed)} {prefix}{unit}"
return f"{_format_numeric(value)} {unit}"
def _nice_length(target: float) -> float:
if not np.isfinite(target) or target <= 0:
return 0.0
exponent = np.floor(np.log10(target))
base = 10.0 ** exponent
for step in (5.0, 2.0, 1.0):
candidate = step * base
if candidate <= target:
return candidate
return base
def _display_value_range(field) -> tuple[float, float, float]:
data = np.asarray(field.data, dtype=np.float64)
dmin = float(data.min())
dmax = float(data.max())
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
return dmin, dmin, dmin
offset = float(field.display_offset)
scale = float(field.display_scale)
if not np.isfinite(offset):
offset = 0.0
if not np.isfinite(scale) or scale <= 0.0:
scale = 1.0
low_norm = float(np.clip(offset, 0.0, 1.0))
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
if high_norm < low_norm:
low_norm, high_norm = high_norm, low_norm
mid_norm = 0.5 * (low_norm + high_norm)
span = dmax - dmin
return (
dmin + low_norm * span,
dmin + mid_norm * span,
dmin + high_norm * span,
)
def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]):
from PIL import Image, ImageDraw, ImageFont
size_px = max(8, int(round(size_px)))
try:
font = ImageFont.truetype("DejaVuSans.ttf", size_px)
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
text_draw = ImageDraw.Draw(text_image)
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
return text_image
except Exception:
font = ImageFont.load_default()
probe = Image.new("L", (1, 1), 0)
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
mask = Image.new("L", (width, height), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
scale = max(1.0, size_px / max(1, height))
scaled_width = max(1, int(round(width * scale)))
scaled_height = max(1, int(round(height * scale)))
resampling = getattr(Image, "Resampling", Image)
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR)
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
text_image.putalpha(scaled_mask)
return text_image
# ---------------------------------------------------------------------------
# Markup helpers (from display.py — used by Markup)
# ---------------------------------------------------------------------------
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
if isinstance(color, str):
text = color.strip()
if len(text) == 4 and text.startswith("#"):
text = "#" + "".join(ch * 2 for ch in text[1:])
if len(text) == 7 and text.startswith("#"):
try:
int(text[1:], 16)
return text.lower()
except ValueError:
pass
return default
def _parse_markup_shapes(raw_shapes) -> list[dict]:
if isinstance(raw_shapes, str):
try:
raw_shapes = json.loads(raw_shapes or "[]")
except json.JSONDecodeError:
raw_shapes = []
if not isinstance(raw_shapes, list):
return []
parsed = []
for shape in raw_shapes:
if not isinstance(shape, dict):
continue
kind = str(shape.get("kind", "")).strip().lower()
if kind not in {"line", "rectangle", "circle", "arrow"}:
continue
try:
x1 = float(shape.get("x1"))
y1 = float(shape.get("y1"))
x2 = float(shape.get("x2"))
y2 = float(shape.get("y2"))
width = int(round(float(shape.get("width", 3))))
except (TypeError, ValueError):
continue
coords = [x1, y1, x2, y2]
if not all(np.isfinite(value) for value in coords):
continue
parsed.append({
"kind": kind,
"x1": float(np.clip(x1, 0.0, 1.0)),
"y1": float(np.clip(y1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y2": float(np.clip(y2, 0.0, 1.0)),
"width": max(1, min(128, width)),
"color": _normalize_markup_color(shape.get("color")),
})
return parsed
def _draw_arrow(draw, start, end, color, width):
dx = end[0] - start[0]
dy = end[1] - start[1]
length = float(np.hypot(dx, dy))
if length <= 1e-6:
radius = max(1.0, width / 2.0)
draw.ellipse(
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
fill=color,
)
return
ux = dx / length
uy = dy / length
head_length = max(10.0, width * 4.0)
head_width = max(8.0, width * 3.0)
shaft_end = (
end[0] - ux * head_length,
end[1] - uy * head_length,
)
draw.line((start, shaft_end), fill=color, width=width)
px = -uy
py = ux
left = (
shaft_end[0] + px * head_width / 2.0,
shaft_end[1] + py * head_width / 2.0,
)
right = (
shaft_end[0] - px * head_width / 2.0,
shaft_end[1] - py * head_width / 2.0,
)
draw.polygon([end, left, right], fill=color)
def _render_markup_image(image, shapes):
from PIL import Image as PILImage, ImageDraw
from backend.data_types import image_to_uint8
base = image_to_uint8(image)
if base.ndim == 2:
base = np.repeat(base[:, :, np.newaxis], 3, axis=2)
canvas = PILImage.fromarray(base.copy())
draw = ImageDraw.Draw(canvas)
height, width = base.shape[:2]
for shape in shapes:
x1 = float(shape["x1"]) * width
y1 = float(shape["y1"]) * height
x2 = float(shape["x2"]) * width
y2 = float(shape["y2"]) * height
color = str(shape["color"])
stroke_width = int(shape["width"])
kind = str(shape["kind"])
if kind == "line":
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
elif kind == "rectangle":
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "circle":
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "arrow":
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
return np.asarray(canvas, dtype=np.uint8)
# ---------------------------------------------------------------------------
# Mask helpers (from mask.py — used by multiple mask nodes)
# ---------------------------------------------------------------------------
def _mask_overlay(field, mask):
from backend.data_types import datafield_to_uint8
grey = datafield_to_uint8(field, "gray")
mask_bool = mask > 127
if not np.any(mask_bool):
return grey
overlay = grey.copy()
red = overlay[..., 0]
green = overlay[..., 1]
blue = overlay[..., 2]
red_vals = red[mask_bool].astype(np.uint16)
green_vals = green[mask_bool].astype(np.uint16)
blue_vals = blue[mask_bool].astype(np.uint16)
red[mask_bool] = ((red_vals * 55) + (255 * 45) + 50) // 100
green[mask_bool] = ((green_vals * 55) + 50) // 100
blue[mask_bool] = ((blue_vals * 55) + 50) // 100
return overlay
@lru_cache(maxsize=128)
def _mask_structure(radius: int, shape: str):
radius = max(1, int(radius))
if shape == "disk":
y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1]
struct = (x * x + y * y) <= radius * radius
else:
size = 2 * radius + 1
struct = np.ones((size, size), dtype=bool)
struct.setflags(write=False)
return struct
def _clamp_fraction(value) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, numeric))
def _parse_mask_strokes(mask_paths) -> list[dict]:
if isinstance(mask_paths, list):
raw_strokes = mask_paths
elif isinstance(mask_paths, str) and mask_paths.strip():
try:
parsed = json.loads(mask_paths)
except json.JSONDecodeError:
return []
raw_strokes = parsed if isinstance(parsed, list) else []
else:
return []
strokes = []
for stroke in raw_strokes:
if not isinstance(stroke, dict):
continue
raw_points = stroke.get("points")
if not isinstance(raw_points, list):
continue
points = []
for point in raw_points:
if not isinstance(point, dict):
continue
if "x" not in point or "y" not in point:
continue
points.append({
"x": _clamp_fraction(point.get("x")),
"y": _clamp_fraction(point.get("y")),
})
if not points:
continue
try:
size = max(1, int(round(float(stroke.get("size", 1)))))
except (TypeError, ValueError):
size = 1
strokes.append({
"size": size,
"points": points,
})
return strokes
def _rasterize_mask(width, height, strokes, default_pen_size):
from PIL import Image as PILImage, ImageDraw
width = max(1, int(width))
height = max(1, int(height))
default_pen_size = max(1, int(default_pen_size))
mask_image = PILImage.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
for stroke in strokes:
points = stroke.get("points") or []
if not points:
continue
size = stroke.get("size", default_pen_size)
try:
size = max(1, int(round(float(size))))
except (TypeError, ValueError):
size = default_pen_size
pixel_points = []
for point in points:
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
pixel_points.append((px, py))
radius = max(0.5, size / 2.0)
if len(pixel_points) == 1:
x, y = pixel_points[0]
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
continue
draw.line(pixel_points, fill=255, width=size)
for x, y in pixel_points:
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
return np.asarray(mask_image, dtype=np.uint8)
# ---------------------------------------------------------------------------
# Path / directory helpers (from io.py)
# ---------------------------------------------------------------------------
DEMO_DIR = demo_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
_MAX_SAVE_FIELDS = 8
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS
def _resolve_path(filepath: str):
path = Path(filepath)
if path.is_absolute():
return path
candidate = INPUT_DIR / filepath
if candidate.exists():
return candidate
candidate = DEMO_DIR / filepath
if candidate.exists():
return candidate
return INPUT_DIR / filepath
def list_channels(filepath: str) -> list[dict]:
path = _resolve_path(filepath)
if not path.exists():
return [{"name": "field", "type": "DATA_FIELD"}]
ext = path.suffix.lower()
if ext == ".gwy":
try:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".sxm":
try:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
if sxm.signals:
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".ibw":
try:
from igor.binarywave import load as load_ibw
wave = load_ibw(str(path))
raw = wave["wave"]["wData"]
labels = wave["wave"].get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
if raw.ndim >= 3 and raw.shape[2] > 1:
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
return [{"name": "field", "type": "DATA_FIELD"}]
def list_folder_paths(folderpath: str) -> list[dict]:
path = _resolve_path(folderpath)
if not path.exists() or not path.is_dir():
return []
resolved_dir = str(path.resolve())
results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}]
for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()):
if not entry.is_file() or entry.name.startswith("."):
continue
if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS:
continue
results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())})
return results
def _list_demo_files() -> list[str]:
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
)
# ---------------------------------------------------------------------------
# Butterworth / FFT helpers (from filters.py — used by FFTFilter1D, FFTFilter2D)
# ---------------------------------------------------------------------------
def _butterworth_lp(freq, cutoff, order):
with np.errstate(divide="ignore", over="ignore"):
return 1.0 / (1.0 + (freq / cutoff) ** (2 * order))
def _butterworth_hp(freq, cutoff, order):
with np.errstate(divide="ignore", invalid="ignore"):
h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order))
h = np.where(np.isfinite(h), h, 0.0)
return h
def _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order):
freq = np.linspace(0, 1, n // 2 + 1)
if filter_type == "lowpass":
H = _butterworth_lp(freq, cutoff, order)
elif filter_type == "highpass":
H = _butterworth_hp(freq, cutoff, order)
elif filter_type == "bandpass":
H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
elif filter_type == "notch":
bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
H = 1.0 - bp
else:
H = np.ones_like(freq)
return H
@lru_cache(maxsize=64)
def _cached_1d_transfer(n, filter_type, cutoff, cutoff_high, order):
transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
transfer.setflags(write=False)
return transfer
@lru_cache(maxsize=32)
def _fft_radius_grid(yres, xres):
fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0
fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0
radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0)
np.clip(radius, 0.0, 1.0, out=radius)
radius.setflags(write=False)
return radius
@lru_cache(maxsize=128)
def _cached_2d_transfer(yres, xres, filter_type, cutoff, cutoff_high, order):
radius = _fft_radius_grid(yres, xres)
if filter_type == "lowpass":
transfer = _butterworth_lp(radius, cutoff, order)
elif filter_type == "highpass":
transfer = _butterworth_hp(radius, cutoff, order)
elif filter_type == "bandpass":
transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
elif filter_type == "notch":
band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
transfer = 1.0 - band
else:
transfer = np.ones_like(radius)
transfer.setflags(write=False)
return transfer
# ---------------------------------------------------------------------------
# Cross-section and stats helpers (from analysis.py)
# ---------------------------------------------------------------------------
def _extend_to_edges(x1, y1, x2, y2):
dx = x2 - x1
dy = y2 - y1
t_candidates = []
if abs(dx) > 1e-12:
for bx in (0.0, 1.0):
t = (bx - x1) / dx
y_at_t = y1 + t * dy
if -1e-9 <= y_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if abs(dy) > 1e-12:
for by in (0.0, 1.0):
t = (by - y1) / dy
x_at_t = x1 + t * dx
if -1e-9 <= x_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if len(t_candidates) < 2:
return x1, y1, x2, y2
t_min = min(t_candidates)
t_max = max(t_candidates)
return (
np.clip(x1 + t_min * dx, 0, 1),
np.clip(y1 + t_min * dy, 0, 1),
np.clip(x1 + t_max * dx, 0, 1),
np.clip(y1 + t_max * dy, 0, 1),
)
def _safe_rq(d):
return float(np.sqrt(np.mean(d * d)))
LINE_OPS: dict[str, tuple] = {}
def _line_op(name, unit=""):
def decorator(fn):
LINE_OPS[name] = (fn, unit)
return fn
return decorator
@_line_op("min")
def _op_min(z):
return float(z.min())
@_line_op("max")
def _op_max(z):
return float(z.max())
@_line_op("mean")
def _op_mean(z):
return float(z.mean())
@_line_op("median")
def _op_median(z):
return float(np.median(z))
@_line_op("sum")
def _op_sum(z):
return float(z.sum())
@_line_op("range")
def _op_range(z):
return float(z.max() - z.min())
@_line_op("length", unit="pts")
def _op_length(z):
return float(len(z))
@_line_op("rms")
def _op_rms(z):
return float(np.sqrt(np.mean(z * z)))
@_line_op("Ra")
def _op_ra(z):
return float(np.mean(np.abs(z - z.mean())))
@_line_op("Rq")
def _op_rq(z):
d = z - z.mean()
return _safe_rq(d)
@_line_op("Rsk")
def _op_rsk(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0
@_line_op("Rku")
def _op_rku(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0
@_line_op("Rp")
def _op_rp(z):
return float((z - z.mean()).max())
@_line_op("Rv")
def _op_rv(z):
return float(-(z - z.mean()).min())
@_line_op("Rt")
def _op_rt(z):
d = z - z.mean()
return float(d.max() - d.min())
@_line_op("Dq")
def _op_dq(z):
dz = np.diff(z)
return float(np.sqrt(np.mean(dz * dz)))
@_line_op("Da")
def _op_da(z):
return float(np.mean(np.abs(np.diff(z))))
TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = {
"min": lambda values: float(np.min(values)),
"max": lambda values: float(np.max(values)),
"avg": lambda values: float(np.mean(values)),
"mean": lambda values: float(np.mean(values)),
"median": lambda values: float(np.median(values)),
"sum": lambda values: float(np.sum(values)),
"range": lambda values: float(np.max(values) - np.min(values)),
"std": lambda values: float(np.std(values)),
"variance": lambda values: float(np.var(values)),
"count": lambda values: float(len(values)),
}
ARRAY_OPS: dict[str, Callable[[np.ndarray], float]] = {
"min": lambda values: float(np.min(values)),
"max": lambda values: float(np.max(values)),
"avg": lambda values: float(np.mean(values)),
"mean": lambda values: float(np.mean(values)),
"median": lambda values: float(np.median(values)),
"sum": lambda values: float(np.sum(values)),
"range": lambda values: float(np.max(values) - np.min(values)),
"std": lambda values: float(np.std(values)),
"variance": lambda values: float(np.var(values)),
"rms": lambda values: float(np.sqrt(np.mean(values * values))),
"count": lambda values: float(values.size),
}
def _square_unit(unit: str) -> str:
unit = str(unit or "").strip()
if not unit:
return ""
if any(token in unit for token in ("^", "(", ")", "/", "*", " ")):
return f"({unit})^2"
return f"{unit}^2"
def _apply_scalar_unit(base_unit: str, operation: str) -> str:
unit = str(base_unit or "").strip()
if operation == "count":
return "count"
if not unit:
return ""
if operation == "variance":
return _square_unit(unit)
return unit
def _common_table_unit(table: list, column: str) -> str:
candidates = []
seen = set()
unit_key = f"{column}_unit"
for row in table:
if not isinstance(row, dict):
continue
unit = None
if unit_key in row and isinstance(row.get(unit_key), str):
unit = row.get(unit_key)
elif column == "value" and isinstance(row.get("unit"), str):
unit = row.get("unit")
if unit is None:
continue
unit = unit.strip()
if not unit or unit in seen:
continue
seen.add(unit)
candidates.append(unit)
if len(candidates) == 1:
return candidates[0]
return ""
def extract_numeric_table_values(table: list, column: str) -> list[float]:
values = []
for row in table:
if not isinstance(row, dict) or column not in row:
continue
value = row[column]
if isinstance(value, bool):
continue
try:
numeric = float(value)
except (TypeError, ValueError):
continue
if np.isfinite(numeric):
values.append(numeric)
return values
def resolve_table_column_name(table: list, column: str) -> str:
requested = str(column or "").strip()
if requested:
return requested
if extract_numeric_table_values(table, "value"):
return "value"
numeric_columns = []
seen = set()
for row in table:
if not isinstance(row, dict):
continue
for key in row.keys():
if key in seen:
continue
seen.add(key)
if extract_numeric_table_values(table, key):
numeric_columns.append(key)
if len(numeric_columns) == 1:
return numeric_columns[0]
if not numeric_columns:
raise ValueError("Stats could not find any numeric columns in the input table.")
raise ValueError(
"Stats found multiple numeric columns; set the column name explicitly."
)

100
backend/nodes/histogram.py Normal file
View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, MeasureTable
@register_node(display_name="Histogram")
class Histogram:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
"y_scale": (["linear", "log"],),
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
}
}
RETURN_TYPES = ("MEASURE_TABLE", "COORDPAIR",)
RETURN_NAMES = ("measurements", "marker pair",)
FUNCTION = "process"
DESCRIPTION = (
"Compute the height distribution histogram (DH). "
"Use log scale to reveal small peaks next to a dominant background. "
"Outputs marker measurements while showing the histogram interactively in-node. "
"Equivalent to gwy_data_field_dh."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
n_bins: int,
y_scale: str = "linear",
x1: float = 0.25,
y1: float = 0.5,
x2: float = 0.75,
y2: float = 0.5,
) -> tuple:
raw_counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
counts = raw_counts.astype(np.float64)
if y_scale == "log":
counts = np.log10(1.0 + counts)
x1 = float(np.clip(x1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 0.0 + 1.0))
xmin = float(np.min(bin_centers)) if len(bin_centers) else 0.0
xmax = float(np.max(bin_centers)) if len(bin_centers) else 1.0
def x_frac_to_idx(frac):
if len(bin_centers) <= 1:
return 0
if xmax == xmin:
return 0
target_x = xmin + frac * (xmax - xmin)
return int(np.argmin(np.abs(bin_centers - target_x)))
idx_a = x_frac_to_idx(x1)
idx_b = x_frac_to_idx(x2)
xa = float(bin_centers[idx_a]) if len(bin_centers) else 0.0
xb = float(bin_centers[idx_b]) if len(bin_centers) else 0.0
ya = float(counts[idx_a]) if len(counts) else 0.0
yb = float(counts[idx_b]) if len(counts) else 0.0
count_unit = "count" if y_scale == "linear" else "log10(1+count)"
if Histogram._broadcast_overlay_fn is not None:
Histogram._broadcast_overlay_fn(
Histogram._current_node_id,
{
"kind": "line_plot",
"section_title": "Histogram",
"line": counts.tolist(),
"x_axis": bin_centers.astype(np.float64).tolist(),
"x1": float(np.clip(x1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y1": float(y1),
"y2": float(y2),
"a_locked": False,
"b_locked": False,
},
)
table = MeasureTable([
{"quantity": "A position", "value": xa, "unit": field.si_unit_z},
{"quantity": "A count", "value": ya, "unit": count_unit},
{"quantity": "B position", "value": xb, "unit": field.si_unit_z},
{"quantity": "B count", "value": yb, "unit": count_unit},
{"quantity": "delta X", "value": xb - xa, "unit": field.si_unit_z},
{"quantity": "delta Y", "value": yb - ya, "unit": count_unit},
])
return (table, ((x1, y1), (x2, y2)))

215
backend/nodes/image.py Normal file
View File

@@ -0,0 +1,215 @@
from __future__ import annotations
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import COLORMAPS, DataField, resolve_colormap_input
from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS
@register_node(display_name="Image")
class Image:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"path": ("FILE_PATH", {"label": "path"}),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
DESCRIPTION = (
"Load any supported file. "
"SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; "
"each channel gets its own output. "
"Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields."
)
_broadcast_warning_fn = None
_current_node_id = None
def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None):
selected_path = str(path).strip() if path is not None else str(filename).strip()
if not selected_path:
raise ValueError("No file selected — use Browse to pick a file.")
path_obj = _resolve_path(selected_path)
if not path_obj.exists():
raise FileNotFoundError(f"File not found: {path_obj}")
if path_obj.is_dir():
raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}")
ext = path_obj.suffix.lower()
resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis")
if ext in _SPM_EXTENSIONS:
fields = self._load_spm_all(path_obj, ext)
for f in fields:
f.colormap = resolved_colormap
return tuple(fields)
field = self._load_image_or_array(path_obj, ext)
field.colormap = resolved_colormap
self._send_warning("Uncalibrated data — no physical dimensions.")
return (field,)
def _send_warning(self, message: str):
fn = Image._broadcast_warning_fn
nid = Image._current_node_id
if fn and nid:
fn(nid, message)
def _load_spm_all(self, path: Path, ext: str) -> list[DataField]:
if ext == ".gwy":
return self._load_gwy_all(path)
elif ext == ".sxm":
return self._load_sxm_all(path)
elif ext == ".ibw":
return self._load_ibw_all(path)
else:
raise ValueError(f"Unsupported SPM format: {ext}")
def _load_gwy_all(self, path: Path) -> list[DataField]:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
fields = []
for ch in channels.values():
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
fields.append(DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
))
return fields
def _load_sxm_all(self, path: Path) -> list[DataField]:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
if not signals:
raise ValueError(f"No signals found in {path.name}")
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
fields = []
for sig in signals.values():
data = sig.get("forward", list(sig.values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
fields.append(DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
))
return fields
def _load_ibw_all(self, path: Path) -> list[DataField]:
try:
from igor.binarywave import load as load_ibw
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
wave = load_ibw(str(path))
wdata = wave["wave"]
header = wdata["wave_header"]
raw = wdata["wData"]
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
sfA = header.get("sfA", None)
def _decode_unit(raw_unit):
if raw_unit is None:
return "m"
if isinstance(raw_unit, bytes):
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
if isinstance(raw_unit, np.ndarray):
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
return str(raw_unit).strip() or "m"
dim_units_raw = header.get("dimUnits", None)
data_units_raw = header.get("dataUnits", None)
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
si_unit_xy = _decode_unit(dim_units_raw[0])
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
si_unit_xy = _decode_unit(dim_units_raw[0])
else:
si_unit_xy = _decode_unit(dim_units_raw)
si_unit_z = _decode_unit(data_units_raw)
fields = []
for ch_idx in range(n_channels):
if raw.ndim >= 3:
ch_data = raw[:, :, ch_idx]
elif raw.ndim == 1:
ch_data = raw.reshape(-1, 1)
else:
ch_data = raw
data = np.flipud(ch_data.T).astype(np.float64)
yres, xres = data.shape
if sfA is not None and len(sfA) >= 2:
xreal = abs(float(sfA[0]) * xres) or 1e-6
yreal = abs(float(sfA[1]) * yres) or 1e-6
else:
hsA = header.get("hsA", 0.0)
xreal = abs(float(hsA) * xres) or 1e-6
yreal = xreal * (yres / xres) if xres else 1e-6
fields.append(DataField(
data=data, xreal=xreal, yreal=yreal,
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
))
return fields
def _load_image_or_array(self, path: Path, ext: str) -> DataField:
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image as PILImage
img = PILImage.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
return DataField(data=gray)

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import COLORMAPS
from backend.nodes.helpers import DEMO_DIR, _list_demo_files
@register_node(display_name="Image (Demo)")
class ImageDemo:
@classmethod
def INPUT_TYPES(cls):
choices = _list_demo_files() or ["(no demo files found)"]
return {
"required": {
"name": (choices,),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
_broadcast_warning_fn = None
_current_node_id = None
def load(self, name: str = "", colormap: str = "viridis", colormap_map=None):
from backend.nodes.image import Image
loader = Image()
demo_path = DEMO_DIR / name
if not demo_path.exists():
raise FileNotFoundError(f"Demo file not found: {name}")
return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Inverse 2D FFT")
class InverseFFT2D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"spectrum": ("DATA_FIELD",),
"representation": (["magnitude", "log_magnitude", "psdf"],),
},
"optional": {
"phase": ("DATA_FIELD",),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
DESCRIPTION = (
"Reconstruct a spatial-domain image from a 2D frequency spectrum. "
"For exact reconstruction, connect magnitude/phase (or log magnitude/phase, "
"or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed."
)
def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple:
if spectrum.domain != "frequency":
raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.")
if phase is not None:
if phase.data.shape != spectrum.data.shape:
raise ValueError("Phase input must have the same shape as the spectrum.")
if phase.domain != "frequency":
raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.")
amplitude = self._resolve_amplitude(spectrum, representation)
phase_data = phase.data if phase is not None else np.zeros_like(amplitude)
F = amplitude * np.exp(1j * phase_data)
spatial = np.fft.ifft2(np.fft.ifftshift(F)).real
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
z_unit = self._recover_z_unit(spectrum, representation, phase)
out_field = DataField(
data=spatial,
xreal=xreal,
yreal=yreal,
si_unit_xy="m",
si_unit_z=z_unit,
domain="spatial",
colormap=spectrum.colormap,
)
return (out_field,)
def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray:
data = np.asarray(spectrum.data, dtype=np.float64)
if representation == "magnitude":
return np.clip(data, 0.0, None)
if representation == "log_magnitude":
return np.expm1(data)
if representation == "psdf":
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
n = spectrum.xres * spectrum.yres
dx = xreal / spectrum.xres
dy = yreal / spectrum.yres
scale = n * 4.0 * np.pi ** 2 / (dx * dy)
return np.sqrt(np.clip(data, 0.0, None) * scale)
raise ValueError(f"Unsupported spectrum representation: {representation}")
def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]:
if representation == "psdf":
xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal
yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal
else:
xreal = spectrum.xres / spectrum.xreal
yreal = spectrum.yres / spectrum.yreal
return float(xreal), float(yreal)
def _recover_z_unit(
self,
spectrum: DataField,
representation: str,
phase: DataField | None,
) -> str:
if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip():
return phase.si_unit_z
if representation != "psdf":
return spectrum.si_unit_z
unit = str(spectrum.si_unit_z or "").strip()
if unit.startswith("(") and ")^2 m^2" in unit:
return unit.split(")^2 m^2", 1)[0][1:]
if unit.endswith("^2 m^2"):
return unit[:-6].removesuffix("^2").strip()
return ""

View File

@@ -1,721 +0,0 @@
"""
I/O nodes: load and save images and SPM data.
"""
from __future__ import annotations
import os
import re
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import COLORMAPS, DataField, encode_preview, image_to_uint8, resolve_colormap_input
from backend.runtime_paths import demo_dir, input_dir, output_dir
# Resolved at server startup so nodes know where to look
DEMO_DIR = demo_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS
# ---------------------------------------------------------------------------
# Channel listing helper (used by the /channels endpoint)
# ---------------------------------------------------------------------------
def _resolve_path(filepath: str) -> Path:
path = Path(filepath)
if path.is_absolute():
return path
# Try input dir first, then demo dir
candidate = INPUT_DIR / filepath
if candidate.exists():
return candidate
candidate = DEMO_DIR / filepath
if candidate.exists():
return candidate
# Fall back to input dir (will trigger FileNotFoundError later)
return INPUT_DIR / filepath
def list_channels(filepath: str) -> list[dict]:
"""Return available channel info for a file.
Returns a list of {"name": str, "type": "DATA_FIELD"} dicts.
For SPM formats this inspects the file header.
For images / arrays, returns a single unnamed channel.
"""
path = _resolve_path(filepath)
if not path.exists():
return [{"name": "field", "type": "DATA_FIELD"}]
ext = path.suffix.lower()
if ext == ".gwy":
try:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".sxm":
try:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
if sxm.signals:
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".ibw":
try:
from igor.binarywave import load as load_ibw
wave = load_ibw(str(path))
raw = wave["wave"]["wData"]
labels = wave["wave"].get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
# Multi-channel without labels — use numeric names
if raw.ndim >= 3 and raw.shape[2] > 1:
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
# Image or array — single channel
return [{"name": "field", "type": "DATA_FIELD"}]
def list_folder_paths(folderpath: str) -> list[dict]:
"""Return a folder DIRECTORY plus compatible image/array/SPM FILE_PATH outputs."""
path = _resolve_path(folderpath)
if not path.exists() or not path.is_dir():
return []
resolved_dir = str(path.resolve())
results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}]
for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()):
if not entry.is_file() or entry.name.startswith("."):
continue
if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS:
continue
results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())})
return results
# ---------------------------------------------------------------------------
# Image (unified loader — replaces LoadImage + LoadSPM)
# ---------------------------------------------------------------------------
@register_node(display_name="Image")
class Image:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": "", "hide_when_input_connected": "path"}),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"path": ("FILE_PATH", {"label": "path"}),
},
}
# Default outputs — overridden dynamically by the frontend for multi-channel files
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
DESCRIPTION = (
"Load any supported file. "
"SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; "
"each channel gets its own output. "
"Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields."
)
# Set by execution engine for warning broadcast
_broadcast_warning_fn = None
_current_node_id = None
def load(self, filename: str = "", colormap: str = "viridis", colormap_map=None, path: str | None = None):
selected_path = str(path).strip() if path is not None else str(filename).strip()
if not selected_path:
raise ValueError("No file selected — use Browse to pick a file.")
path_obj = _resolve_path(selected_path)
if not path_obj.exists():
raise FileNotFoundError(f"File not found: {path_obj}")
if path_obj.is_dir():
raise IsADirectoryError(f"Expected a file, got a directory: {path_obj}")
ext = path_obj.suffix.lower()
resolved_colormap = resolve_colormap_input(colormap, colormap_input=colormap_map, default="viridis")
if ext in _SPM_EXTENSIONS:
fields = self._load_spm_all(path_obj, ext)
for f in fields:
f.colormap = resolved_colormap
return tuple(fields)
# Image or array — uncalibrated, single output
field = self._load_image_or_array(path_obj, ext)
field.colormap = resolved_colormap
self._send_warning("Uncalibrated data — no physical dimensions.")
return (field,)
def _send_warning(self, message: str):
fn = Image._broadcast_warning_fn
nid = Image._current_node_id
if fn and nid:
fn(nid, message)
# -- SPM: load all channels ---------------------------------------------
def _load_spm_all(self, path: Path, ext: str) -> list[DataField]:
if ext == ".gwy":
return self._load_gwy_all(path)
elif ext == ".sxm":
return self._load_sxm_all(path)
elif ext == ".ibw":
return self._load_ibw_all(path)
else:
raise ValueError(f"Unsupported SPM format: {ext}")
# -- GWY ----------------------------------------------------------------
def _load_gwy_all(self, path: Path) -> list[DataField]:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
fields = []
for ch in channels.values():
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
fields.append(DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
))
return fields
# -- SXM ----------------------------------------------------------------
def _load_sxm_all(self, path: Path) -> list[DataField]:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
if not signals:
raise ValueError(f"No signals found in {path.name}")
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
fields = []
for sig in signals.values():
data = sig.get("forward", list(sig.values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
fields.append(DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
))
return fields
# -- IBW ----------------------------------------------------------------
def _load_ibw_all(self, path: Path) -> list[DataField]:
try:
from igor.binarywave import load as load_ibw
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
wave = load_ibw(str(path))
wdata = wave["wave"]
header = wdata["wave_header"]
raw = wdata["wData"]
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
# Physical scaling
sfA = header.get("sfA", None)
def _decode_unit(raw_unit):
if raw_unit is None:
return "m"
if isinstance(raw_unit, bytes):
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
if isinstance(raw_unit, np.ndarray):
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
return str(raw_unit).strip() or "m"
dim_units_raw = header.get("dimUnits", None)
data_units_raw = header.get("dataUnits", None)
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
si_unit_xy = _decode_unit(dim_units_raw[0])
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
si_unit_xy = _decode_unit(dim_units_raw[0])
else:
si_unit_xy = _decode_unit(dim_units_raw)
si_unit_z = _decode_unit(data_units_raw)
fields = []
for ch_idx in range(n_channels):
if raw.ndim >= 3:
ch_data = raw[:, :, ch_idx]
elif raw.ndim == 1:
ch_data = raw.reshape(-1, 1)
else:
ch_data = raw
# Transpose from (xres, yres) Igor order to (yres, xres) DataField order,
# then flip vertically to match gwyddion
data = np.flipud(ch_data.T).astype(np.float64)
yres, xres = data.shape
if sfA is not None and len(sfA) >= 2:
xreal = abs(float(sfA[0]) * xres) or 1e-6
yreal = abs(float(sfA[1]) * yres) or 1e-6
else:
hsA = header.get("hsA", 0.0)
xreal = abs(float(hsA) * xres) or 1e-6
yreal = xreal * (yres / xres) if xres else 1e-6
fields.append(DataField(
data=data, xreal=xreal, yreal=yreal,
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
))
return fields
# -- Image / array (uncalibrated) --------------------------------------
def _load_image_or_array(self, path: Path, ext: str) -> DataField:
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image
img = Image.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
return DataField(data=gray)
# ---------------------------------------------------------------------------
# ImageDemo
# ---------------------------------------------------------------------------
def _list_demo_files() -> list[str]:
"""Return sorted list of demo filenames available in the demo/ directory."""
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
)
@register_node(display_name="Image (Demo)")
class ImageDemo:
@classmethod
def INPUT_TYPES(cls):
choices = _list_demo_files() or ["(no demo files found)"]
return {
"required": {
"name": (choices,),
"colormap": (list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
def load(self, name: str = "", colormap: str = "viridis", colormap_map=None):
loader = Image()
demo_path = DEMO_DIR / name
if not demo_path.exists():
raise FileNotFoundError(f"Demo file not found: {name}")
return loader.load(filename=str(demo_path), colormap=colormap, colormap_map=colormap_map)
@register_node(display_name="Folder")
class Folder:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("FOLDER_PICKER", {"default": "", "placement": "top"}),
}
}
RETURN_TYPES = ("DIRECTORY",)
RETURN_NAMES = ("directory",)
FUNCTION = "list_files"
DESCRIPTION = (
"Pick a folder and output its directory path plus one file socket per compatible image, array, or SPM file inside it. "
"Supported files include common images, .npy/.npz arrays, and .gwy/.sxm/.ibw scans."
)
def list_files(self, folder: str) -> tuple:
entries = list_folder_paths(folder)
if not entries:
return tuple()
return tuple(item["path"] for item in entries)
# ---------------------------------------------------------------------------
# Coordinate
# ---------------------------------------------------------------------------
@register_node(display_name="Coordinate")
class Coordinate:
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("COORD",)
RETURN_NAMES = ("point",)
FUNCTION = "process"
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
def process(self, x: float, y: float) -> tuple:
return ((float(x), float(y)),)
@register_node(display_name="Coordinate Pair")
class CoordinatePair:
"""Provide a pair of Coordinates, for drawing lines between markers, etc."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("COORD",),
"b": ("COORD",),
}
}
RETURN_TYPES = ("COORDPAIR",)
RETURN_NAMES = ("coord pair",)
FUNCTION = "process"
DESCRIPTION = "Output a pair of coordinates."
def process(self, a: tuple, b: tuple) -> tuple:
return ((a, b),)
# ---------------------------------------------------------------------------
# Number
# ---------------------------------------------------------------------------
@register_node(display_name="Number")
class Number:
"""Provide a fixed scalar value that can feed FLOAT or INT widget sockets."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "step": 0.01}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Output a fixed numeric value. "
"When connected to FLOAT inputs the exact value is used; "
"INT inputs round to the nearest integer at execution time."
)
def process(self, value: float) -> tuple:
return (float(value),)
# ---------------------------------------------------------------------------
# RangeSlider
# ---------------------------------------------------------------------------
@register_node(display_name="Float Slider")
class RangeSlider:
"""Interactive float control node with min/max bounds and a slider value."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"min_value": ("FLOAT", {"default": 0.0, "step": 0.01}),
"max_value": ("FLOAT", {"default": 1.0, "step": 0.01}),
"value": ("FLOAT", {
"default": 0.5,
"step": 0.01,
"slider": True,
"min_widget": "min_value",
"max_widget": "max_value",
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Interactive float slider. Set min and max bounds, then drag the slider to output a FLOAT value."
)
def process(self, min_value: float, max_value: float, value: float) -> tuple:
lo = min(float(min_value), float(max_value))
hi = max(float(min_value), float(max_value))
if hi == lo:
return (lo,)
return (float(np.clip(float(value), lo, hi)),)
# ---------------------------------------------------------------------------
# SaveImage
# ---------------------------------------------------------------------------
_MAX_SAVE_FIELDS = 8
@register_node(display_name="Save Layers")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
optional = {
"directory": ("DIRECTORY", {"label": "directory"}),
}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"})
optional[f"layer_name_{i}"] = ("STRING", {
"default": "",
"placeholder": "name",
"show_when_input_visible": f"field_{i}",
"inline_with_input": f"field_{i}",
"hide_label": True,
})
return {
"required": {
"filename": ("STRING", {
"default": "",
"placeholder": "filename",
"placement": "top",
}),
"directory_path": ("FOLDER_PICKER", {
"default": "",
"label": "directory",
"placement": "top",
"hide_when_input_connected": "directory",
"top_socket_input": "directory",
}),
"format": (["TIFF", "NPZ"],),
},
"optional": optional,
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more layers to a single file. "
"Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. "
"Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. "
"A new slot appears as each one is filled, with a matching per-layer name field. "
"TIFF writes multi-page data and stores layer names as page descriptions; "
"NPZ writes named arrays using those layer names as keys. "
"Click Save to write (does not auto-run)."
)
_broadcast_warning_fn = None
_current_node_id = None
def save(
self,
filename: str,
directory_path: str = "",
format: str = "TIFF",
directory: str | None = None,
**kwargs,
):
layers = []
layer_names = []
for i in range(_MAX_SAVE_FIELDS):
layer = kwargs.get(f"field_{i}")
if layer is not None:
layers.append(layer)
layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i))
if not layers:
raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.")
path = self._resolve_save_path(filename, format, directory, directory_path)
if format == "TIFF":
self._save_tiff(path, layers, layer_names)
else:
self._save_npz(path, layers, layer_names)
self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}")
return ()
def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
import tifffile
with tifffile.TiffWriter(str(path)) as tif:
for layer, layer_name in zip(layers, layer_names):
tif.write(self._layer_array_for_tiff(layer), description=layer_name)
def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
arrays = {}
used_keys = set()
for i, (layer, layer_name) in enumerate(zip(layers, layer_names)):
arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer)
np.savez(str(path), **arrays)
def _resolve_layer_name(self, raw_name: object, index: int) -> str:
text = str(raw_name).strip() if raw_name is not None else ""
return text or f"layer_{index}"
def _resolve_save_path(
self,
filename: str,
format: str,
directory: str | None,
directory_path: str = "",
) -> Path:
ext = ".tiff" if format == "TIFF" else ".npz"
raw_filename = str(filename).strip() if filename is not None else ""
raw_directory = str(directory).strip() if directory is not None else ""
if not raw_directory:
raw_directory = str(directory_path).strip() if directory_path is not None else ""
if raw_directory:
dir_path = Path(raw_directory).expanduser()
if dir_path.exists() and not dir_path.is_dir():
raise ValueError("Directory input expects a folder path, not a file path.")
if not dir_path.exists():
if dir_path.suffix:
raise ValueError("Directory input expects a folder path, not a file path.")
dir_path.mkdir(parents=True, exist_ok=True)
filename_part = Path(raw_filename).name if raw_filename else ""
if not filename_part:
raise ValueError("No output filename selected — enter a file name when using a directory input.")
path = dir_path / filename_part
else:
if not raw_filename:
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(raw_filename).expanduser()
path.parent.mkdir(parents=True, exist_ok=True)
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
return path
def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str:
key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_")
if not key:
key = f"layer_{index}"
if key[0].isdigit():
key = f"layer_{key}"
candidate = key
suffix = 2
while candidate in used_keys:
candidate = f"{key}_{suffix}"
suffix += 1
used_keys.add(candidate)
return candidate
def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data, dtype=np.float32)
if isinstance(layer, np.ndarray):
return image_to_uint8(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data)
if isinstance(layer, np.ndarray):
return np.asarray(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return ()

View File

@@ -1,150 +0,0 @@
"""
Leveling nodes — background removal and zero correction.
Gwyddion equivalents:
PlaneLevelField → gwy_data_field_fit_plane + gwy_data_field_plane_level
PolyLevelField → gwy_data_field_fit_polynom (via level.c polylevel module)
FixZero → fix_zero in level.c
Plane-fit algorithm follows Gwyddion's level.h definition:
z_fit = pa + pbx * x + pby * y (least-squares over all pixels)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# PlaneLevelField
# ---------------------------------------------------------------------------
@register_node(display_name="Plane Level")
class PlaneLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("leveled",)
FUNCTION = "process"
DESCRIPTION = (
"Fit and subtract a least-squares plane from the data. "
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
)
def process(self, field: DataField) -> tuple:
data = field.data.copy()
yres, xres = data.shape
# Normalised coordinate grids in [0, 1]
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
# Design matrix: [1, x, y] shape (N, 3)
A = np.column_stack([
np.ones(xres * yres),
xx.ravel(),
yy.ravel(),
])
z = data.ravel()
# Least-squares: solve A @ [pa, pbx, pby] = z
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
pa, pbx, pby = coeffs
plane = (pa + pbx * xx + pby * yy)
return (field.replace(data=data - plane),)
# ---------------------------------------------------------------------------
# PolyLevelField
# ---------------------------------------------------------------------------
@register_node(display_name="Polynomial Level")
class PolyLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
RETURN_NAMES = ("leveled", "background")
FUNCTION = "process"
DESCRIPTION = (
"Fit and subtract a polynomial background of given degree in x and y. "
"Equivalent to gwy_data_field_fit_polynom."
)
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
data = field.data.copy()
yres, xres = data.shape
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
# Build Vandermonde-style design matrix with all monomials x^i * y^j
cols = []
for i in range(degree_x + 1):
for j in range(degree_y + 1):
cols.append((xx ** i * yy ** j).ravel())
A = np.column_stack(cols)
z = data.ravel()
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
background = (A @ coeffs).reshape(yres, xres)
leveled = data - background
return (field.replace(data=leveled), field.replace(data=background))
# ---------------------------------------------------------------------------
# FixZero
# ---------------------------------------------------------------------------
@register_node(display_name="Fix Zero")
class FixZero:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["min", "mean", "median"],),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("zeroed",)
FUNCTION = "process"
DESCRIPTION = (
"Shift data so that the minimum (or mean/median) is zero. "
"Equivalent to fix_zero in Gwyddion's level.c."
)
def process(self, field: DataField, method: str) -> tuple:
data = field.data.copy()
if method == "min":
data -= data.min()
elif method == "mean":
data -= data.mean()
elif method == "median":
data -= np.median(data)
else:
raise ValueError(f"Unknown method: {method}")
return (field.replace(data=data),)

68
backend/nodes/markup.py Normal file
View File

@@ -0,0 +1,68 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.helpers import _parse_markup_shapes, _normalize_markup_color
@register_node(display_name="Markup")
class Markup:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "line"}),
"stroke_color": ("STRING", {"default": "#ffd54f", "color_picker": True}),
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),
"clear_shapes": ("BUTTON", {"label": "Clear Shapes", "set_widgets": {"markup_shapes": "[]"}}),
"markup_shapes": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("annotated",)
FUNCTION = "process"
DESCRIPTION = (
"Draw simple vector markup over a DATA_FIELD without flattening the underlying data. "
"Choose a shape mode, colour, and stroke width, then drag directly on the preview to place lines, rectangles, circles, or arrows."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
shape: str,
stroke_color: str,
stroke_width: int,
markup_shapes: str,
) -> tuple:
shapes = _parse_markup_shapes(markup_shapes)
out = field.replace(
overlays=[
*field.overlays,
{
"kind": "markup",
"shapes": shapes,
},
],
)
if Markup._broadcast_overlay_fn is not None:
Markup._broadcast_overlay_fn(
Markup._current_node_id,
{
"kind": "markup",
"section_title": "Markup",
"image": encode_preview(datafield_to_uint8(field, field.colormap)),
"shape": str(shape),
"stroke_color": _normalize_markup_color(stroke_color),
"stroke_width": max(1, int(stroke_width)),
},
)
return (out,)

View File

@@ -1,437 +0,0 @@
"""
Mask operation nodes — creation, morphology, and boolean combination.
Gwyddion equivalents:
ThresholdMask → threshold.c / otsu_threshold.c
MaskMorphology → mask_morph.c (erode, dilate, open, close)
MaskInvert → (bitwise NOT on mask)
MaskCombine → (boolean ops between two masks)
"""
from __future__ import annotations
from functools import lru_cache
import json
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
"""Render greyscale base image with red shadow on masked (255) pixels.
Returns (H, W, 3) uint8 array.
"""
grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8
mask_bool = mask > 127
if not np.any(mask_bool):
return grey
overlay = grey.copy()
red = overlay[..., 0]
green = overlay[..., 1]
blue = overlay[..., 2]
# Integer alpha blend equivalent to a 45% red overlay, without float64 work.
red_vals = red[mask_bool].astype(np.uint16)
green_vals = green[mask_bool].astype(np.uint16)
blue_vals = blue[mask_bool].astype(np.uint16)
red[mask_bool] = ((red_vals * 55) + (255 * 45) + 50) // 100
green[mask_bool] = ((green_vals * 55) + 50) // 100
blue[mask_bool] = ((blue_vals * 55) + 50) // 100
return overlay
@lru_cache(maxsize=128)
def _mask_structure(radius: int, shape: str) -> np.ndarray:
radius = max(1, int(radius))
if shape == "disk":
y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1]
struct = (x * x + y * y) <= radius * radius
else:
size = 2 * radius + 1
struct = np.ones((size, size), dtype=bool)
struct.setflags(write=False)
return struct
def _clamp_fraction(value) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, numeric))
def _parse_mask_strokes(mask_paths) -> list[dict]:
if isinstance(mask_paths, list):
raw_strokes = mask_paths
elif isinstance(mask_paths, str) and mask_paths.strip():
try:
parsed = json.loads(mask_paths)
except json.JSONDecodeError:
return []
raw_strokes = parsed if isinstance(parsed, list) else []
else:
return []
strokes = []
for stroke in raw_strokes:
if not isinstance(stroke, dict):
continue
raw_points = stroke.get("points")
if not isinstance(raw_points, list):
continue
points = []
for point in raw_points:
if not isinstance(point, dict):
continue
if "x" not in point or "y" not in point:
continue
points.append({
"x": _clamp_fraction(point.get("x")),
"y": _clamp_fraction(point.get("y")),
})
if not points:
continue
try:
size = max(1, int(round(float(stroke.get("size", 1)))))
except (TypeError, ValueError):
size = 1
strokes.append({
"size": size,
"points": points,
})
return strokes
def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray:
from PIL import Image, ImageDraw
width = max(1, int(width))
height = max(1, int(height))
default_pen_size = max(1, int(default_pen_size))
mask_image = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
for stroke in strokes:
points = stroke.get("points") or []
if not points:
continue
size = stroke.get("size", default_pen_size)
try:
size = max(1, int(round(float(size))))
except (TypeError, ValueError):
size = default_pen_size
pixel_points = []
for point in points:
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
pixel_points.append((px, py))
radius = max(0.5, size / 2.0)
if len(pixel_points) == 1:
x, y = pixel_points[0]
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
continue
draw.line(pixel_points, fill=255, width=size)
for x, y in pixel_points:
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
return np.asarray(mask_image, dtype=np.uint8)
# ---------------------------------------------------------------------------
# DrawMask
# ---------------------------------------------------------------------------
@register_node(display_name="Draw Mask")
class DrawMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}),
"invert": ("BOOLEAN", {"default": False}),
"clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}),
"mask_paths": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Paint a binary mask directly over an image preview. "
"Pen size controls newly drawn strokes, the overlay lets you clear the mask, "
"and invert flips the final binary output."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple:
strokes = _parse_mask_strokes(mask_paths)
mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size)
if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None:
DrawMask._broadcast_overlay_fn(
DrawMask._current_node_id,
{
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
},
)
return (mask,)
# ---------------------------------------------------------------------------
# ThresholdMask
# ---------------------------------------------------------------------------
@register_node(display_name="Threshold Mask")
class ThresholdMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["otsu", "absolute", "relative"],),
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
"direction": (["above", "below"],),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Create a binary mask by thresholding data. "
"Otsu automatically finds the optimal threshold. "
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
data = field.data
if method == "otsu":
from skimage.filters import threshold_otsu
t = threshold_otsu(data)
elif method == "absolute":
t = float(threshold)
elif method == "relative":
# threshold is a fraction [0, 1] of the data range
dmin, dmax = data.min(), data.max()
t = dmin + float(threshold) * (dmax - dmin)
else:
raise ValueError(f"Unknown threshold method: {method}")
if direction == "above":
mask = (data >= t).astype(np.uint8) * 255
else:
mask = (data < t).astype(np.uint8) * 255
if ThresholdMask._broadcast_fn is not None:
overlay = _mask_overlay(field, mask)
ThresholdMask._broadcast_fn(
ThresholdMask._current_node_id, encode_preview(overlay),
)
return (mask,)
# ---------------------------------------------------------------------------
# MaskMorphology
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Morphology")
class MaskMorphology:
"""Morphological operations on binary masks.
Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close).
"""
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
"operation": (["dilate", "erode", "open", "close"],),
"radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}),
"shape": (["disk", "square"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Apply morphological operations to a binary mask. "
"Dilate expands regions, erode shrinks them, "
"open (erode then dilate) removes small spots, "
"close (dilate then erode) fills small holes. "
"Equivalent to Gwyddion mask_morph."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, operation: str, radius: int, shape: str,
field: DataField | None = None) -> tuple:
from scipy.ndimage import binary_closing, binary_dilation, binary_erosion, binary_opening
binary = mask > 127
struct = _mask_structure(radius, shape)
if operation == "dilate":
result = binary_dilation(binary, structure=struct)
elif operation == "erode":
result = binary_erosion(binary, structure=struct)
elif operation == "open":
result = binary_opening(binary, structure=struct)
elif operation == "close":
result = binary_closing(binary, structure=struct)
else:
raise ValueError(f"Unknown morphological operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskMorphology._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskMorphology._broadcast_fn(
MaskMorphology._current_node_id, encode_preview(overlay),
)
return (out,)
# ---------------------------------------------------------------------------
# MaskInvert
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Invert")
class MaskInvert:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions."
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
out = np.where(mask > 127, np.uint8(0), np.uint8(255))
if field is not None and MaskInvert._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskInvert._broadcast_fn(
MaskInvert._current_node_id, encode_preview(overlay),
)
return (out,)
# ---------------------------------------------------------------------------
# MaskCombine
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Combine")
class MaskCombine:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask_a": ("IMAGE",),
"mask_b": ("IMAGE",),
"operation": (["and", "or", "xor", "subtract"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Combine two binary masks with a boolean operation. "
"AND keeps overlap, OR merges, XOR keeps non-overlapping regions, "
"subtract removes mask_b from mask_a."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str,
field: DataField | None = None) -> tuple:
a = mask_a > 127
b = mask_b > 127
if operation == "and":
result = a & b
elif operation == "or":
result = a | b
elif operation == "xor":
result = a ^ b
elif operation == "subtract":
result = a & ~b
else:
raise ValueError(f"Unknown mask operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskCombine._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskCombine._broadcast_fn(
MaskCombine._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@register_node(display_name="Mask Combine")
class MaskCombine:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask_a": ("IMAGE",),
"mask_b": ("IMAGE",),
"operation": (["and", "or", "xor", "subtract"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Combine two binary masks with a boolean operation. "
"AND keeps overlap, OR merges, XOR keeps non-overlapping regions, "
"subtract removes mask_b from mask_a."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str,
field: DataField | None = None) -> tuple:
a = mask_a > 127
b = mask_b > 127
if operation == "and":
result = a & b
elif operation == "or":
result = a | b
elif operation == "xor":
result = a ^ b
elif operation == "subtract":
result = a & ~b
else:
raise ValueError(f"Unknown mask operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskCombine._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskCombine._broadcast_fn(
MaskCombine._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@register_node(display_name="Mask Invert")
class MaskInvert:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions."
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
out = np.where(mask > 127, np.uint8(0), np.uint8(255))
if field is not None and MaskInvert._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskInvert._broadcast_fn(
MaskInvert._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay, _mask_structure
@register_node(display_name="Mask Morphology")
class MaskMorphology:
"""Morphological operations on binary masks.
Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close).
"""
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
"operation": (["dilate", "erode", "open", "close"],),
"radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}),
"shape": (["disk", "square"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Apply morphological operations to a binary mask. "
"Dilate expands regions, erode shrinks them, "
"open (erode then dilate) removes small spots, "
"close (dilate then erode) fills small holes. "
"Equivalent to Gwyddion mask_morph."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, operation: str, radius: int, shape: str,
field: DataField | None = None) -> tuple:
from scipy.ndimage import binary_closing, binary_dilation, binary_erosion, binary_opening
binary = mask > 127
struct = _mask_structure(radius, shape)
if operation == "dilate":
result = binary_dilation(binary, structure=struct)
elif operation == "erode":
result = binary_erosion(binary, structure=struct)
elif operation == "open":
result = binary_opening(binary, structure=struct)
elif operation == "close":
result = binary_closing(binary, structure=struct)
else:
raise ValueError(f"Unknown morphological operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskMorphology._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskMorphology._broadcast_fn(
MaskMorphology._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Median Filter")
class MedianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
def process(self, field: DataField, size: int) -> tuple:
from scipy.ndimage import median_filter
size = max(1, int(size))
data = median_filter(field.data, size=size)
return (field.replace(data=data),)

28
backend/nodes/number.py Normal file
View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from backend.node_registry import register_node
@register_node(display_name="Number")
class Number:
"""Provide a fixed scalar value that can feed FLOAT or INT widget sockets."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "step": 0.01}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Output a fixed numeric value. "
"When connected to FLOAT inputs the exact value is used; "
"INT inputs round to the nearest integer at execution time."
)
def process(self, value: float) -> tuple:
return (float(value),)

View File

@@ -1,20 +1,9 @@
"""
Particle detection nodes.
Gwyddion equivalents:
ParticleAnalysis gwy_data_field_particles_get_values (particles-values.c)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, RecordTable
# ---------------------------------------------------------------------------
# ParticleAnalysis
# ---------------------------------------------------------------------------
@register_node(display_name="Particle Analysis")
class ParticleAnalysis:
@classmethod
@@ -43,7 +32,7 @@ class ParticleAnalysis:
binary = (mask > 127).astype(np.int32)
labeled, n_particles = label(binary)
pixel_area = field.dx * field.dy # m^2 per pixel
pixel_area = field.dx * field.dy
rows = RecordTable()
for pid in range(1, n_particles + 1):
@@ -59,7 +48,6 @@ class ParticleAnalysis:
mean_h = float(heights.mean())
max_h = float(heights.max())
# Bounding box
ys, xs = np.where(particle_pixels)
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Plane Level")
class PlaneLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("leveled",)
FUNCTION = "process"
DESCRIPTION = (
"Fit and subtract a least-squares plane from the data. "
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
)
def process(self, field: DataField) -> tuple:
data = field.data.copy()
yres, xres = data.shape
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
A = np.column_stack([
np.ones(xres * yres),
xx.ravel(),
yy.ravel(),
])
z = data.ravel()
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
pa, pbx, pby = coeffs
plane = (pa + pbx * xx + pby * yy)
return (field.replace(data=data - plane),)

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Polynomial Level")
class PolyLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
RETURN_NAMES = ("leveled", "background")
FUNCTION = "process"
DESCRIPTION = (
"Fit and subtract a polynomial background of given degree in x and y. "
"Equivalent to gwy_data_field_fit_polynom."
)
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
data = field.data.copy()
yres, xres = data.shape
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
cols = []
for i in range(degree_x + 1):
for j in range(degree_y + 1):
cols.append((xx ** i * yy ** j).ravel())
A = np.column_stack(cols)
z = data.ravel()
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
background = (A @ coeffs).reshape(yres, xres)
leveled = data - background
return (field.replace(data=leveled), field.replace(data=background))

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import (
COLORMAPS,
colormap_to_uint8,
encode_preview,
image_to_uint8,
render_datafield_preview,
resolve_colormap_input,
)
@register_node(display_name="Preview")
class PreviewImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
"image": ("IMAGE",),
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ()
FUNCTION = "preview"
OUTPUT_NODE = True
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
_broadcast_fn = None
_current_node_id: str = ""
def preview(
self,
colormap: str,
image: np.ndarray | None = None,
field=None,
colormap_map=None,
) -> tuple:
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap if field is not None else None,
default="gray",
)
if field is not None:
arr_u8 = render_datafield_preview(field, resolved_colormap)
elif image is not None:
arr_u8 = image_to_uint8(image)
if arr_u8.ndim == 2:
if image.dtype == np.uint8:
normalized = arr_u8.astype(np.float64) / 255.0
else:
imin, imax = image.min(), image.max()
if imax > imin:
normalized = (image - imin) / (imax - imin)
else:
normalized = np.zeros_like(image, dtype=np.float64)
arr_u8 = colormap_to_uint8(normalized, resolved_colormap)
else:
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
data_uri = encode_preview(arr_u8)
if PreviewImage._broadcast_fn is not None:
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
return ()

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from backend.node_registry import register_node
@register_node(display_name="Print Table")
class PrintTable:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"table": ("ANY_TABLE",),
}
}
RETURN_TYPES = ()
FUNCTION = "print_table"
OUTPUT_NODE = True
DESCRIPTION = "Send a measurement or record table to the browser as a WebSocket message for display."
_broadcast_table_fn = None
_current_node_id: str = ""
def print_table(self, table: list) -> tuple:
if PrintTable._broadcast_table_fn is not None:
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
return ()

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
@register_node(display_name="Float Slider")
class RangeSlider:
"""Interactive float control node with min/max bounds and a slider value."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"min_value": ("FLOAT", {"default": 0.0, "step": 0.01}),
"max_value": ("FLOAT", {"default": 1.0, "step": 0.01}),
"value": ("FLOAT", {
"default": 0.5,
"step": 0.01,
"slider": True,
"min_widget": "min_value",
"max_widget": "max_value",
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Interactive float slider. Set min and max bounds, then drag the slider to output a FLOAT value."
)
def process(self, min_value: float, max_value: float, value: float) -> tuple:
lo = min(float(min_value), float(max_value))
hi = max(float(min_value), float(max_value))
if hi == lo:
return (lo,)
return (float(np.clip(float(value), lo, hi)),)

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Rotate")
class RotateField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"angle": ("FLOAT", {"default": 90.0, "min": -360.0, "max": 360.0, "step": 1.0}),
"interpolation": (["bilinear", "nearest", "bicubic"],),
"expand_canvas": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
DESCRIPTION = (
"Rotate a DATA_FIELD counterclockwise by an angle in degrees. "
"Optionally expand the canvas to keep the full rotated field while preserving the field center."
)
_broadcast_warning_fn = None
_current_node_id: str = ""
def process(
self,
field: DataField,
angle: float,
interpolation: str,
expand_canvas: bool,
) -> tuple:
if field.overlays:
self._send_warning("Rotate clears annotation/markup overlays!")
angle = float(angle)
order_map = {
"nearest": 0,
"bilinear": 1,
"bicubic": 3,
}
if interpolation not in order_map:
raise ValueError(f"Unknown interpolation mode: {interpolation}")
normalized_angle = angle % 360.0
snapped_quarters = int(round(normalized_angle / 90.0)) % 4
snapped_angle = snapped_quarters * 90.0
is_right_angle = abs(normalized_angle - snapped_angle) < 1e-9
if is_right_angle and expand_canvas:
rotated = np.rot90(field.data, k=snapped_quarters).copy()
elif abs(normalized_angle) < 1e-9:
rotated = field.data.copy()
else:
from scipy.ndimage import rotate as nd_rotate
rotated = nd_rotate(
field.data,
angle=angle,
reshape=bool(expand_canvas),
order=order_map[interpolation],
mode="nearest",
prefilter=order_map[interpolation] > 1,
)
new_xreal, new_yreal = self._rotated_extents(field, angle, expand_canvas)
center_x = field.xoff + field.xreal / 2.0
center_y = field.yoff + field.yreal / 2.0
result = field.replace(
data=np.asarray(rotated, dtype=np.float64),
xreal=new_xreal,
yreal=new_yreal,
xoff=center_x - new_xreal / 2.0,
yoff=center_y - new_yreal / 2.0,
overlays=[],
)
return (result,)
def _send_warning(self, message: str):
fn = RotateField._broadcast_warning_fn
nid = RotateField._current_node_id
if fn and nid:
fn(nid, message)
@staticmethod
def _rotated_extents(field: DataField, angle: float, expand_canvas: bool) -> tuple[float, float]:
if not expand_canvas:
return (field.xreal, field.yreal)
theta = np.deg2rad(angle)
cos_t = abs(float(np.cos(theta)))
sin_t = abs(float(np.sin(theta)))
new_xreal = field.xreal * cos_t + field.yreal * sin_t
new_yreal = field.xreal * sin_t + field.yreal * cos_t
return (new_xreal, new_yreal)

182
backend/nodes/save_image.py Normal file
View File

@@ -0,0 +1,182 @@
from __future__ import annotations
import re
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, image_to_uint8
from backend.nodes.helpers import _MAX_SAVE_FIELDS
@register_node(display_name="Save Layers")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
optional = {
"directory": ("DIRECTORY", {"label": "directory"}),
}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"})
optional[f"layer_name_{i}"] = ("STRING", {
"default": "",
"placeholder": "name",
"show_when_input_visible": f"field_{i}",
"inline_with_input": f"field_{i}",
"hide_label": True,
})
return {
"required": {
"filename": ("STRING", {
"default": "",
"placeholder": "filename",
"placement": "top",
}),
"directory_path": ("FOLDER_PICKER", {
"default": "",
"label": "directory",
"placement": "top",
"hide_when_input_connected": "directory",
"top_socket_input": "directory",
}),
"format": (["TIFF", "NPZ"],),
},
"optional": optional,
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more layers to a single file. "
"Each layer input accepts either a DATA_FIELD or an IMAGE, including annotated images. "
"Optionally drive the output directory from a folder/path node, while keeping the filename widget for the file name. "
"A new slot appears as each one is filled, with a matching per-layer name field. "
"TIFF writes multi-page data and stores layer names as page descriptions; "
"NPZ writes named arrays using those layer names as keys. "
"Click Save to write (does not auto-run)."
)
_broadcast_warning_fn = None
_current_node_id = None
def save(
self,
filename: str,
directory_path: str = "",
format: str = "TIFF",
directory: str | None = None,
**kwargs,
):
layers = []
layer_names = []
for i in range(_MAX_SAVE_FIELDS):
layer = kwargs.get(f"field_{i}")
if layer is not None:
layers.append(layer)
layer_names.append(self._resolve_layer_name(kwargs.get(f"layer_name_{i}"), i))
if not layers:
raise ValueError("No layers connected — connect at least one DATA_FIELD or IMAGE input.")
path = self._resolve_save_path(filename, format, directory, directory_path)
if format == "TIFF":
self._save_tiff(path, layers, layer_names)
else:
self._save_npz(path, layers, layer_names)
self._send_warning(f"Saved {len(layers)} layer(s) to {path.name}")
return ()
def _save_tiff(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
import tifffile
with tifffile.TiffWriter(str(path)) as tif:
for layer, layer_name in zip(layers, layer_names):
tif.write(self._layer_array_for_tiff(layer), description=layer_name)
def _save_npz(self, path: Path, layers: list[DataField | np.ndarray], layer_names: list[str]):
arrays = {}
used_keys = set()
for i, (layer, layer_name) in enumerate(zip(layers, layer_names)):
arrays[self._unique_npz_key(layer_name, used_keys, i)] = self._layer_array_for_npz(layer)
np.savez(str(path), **arrays)
def _resolve_layer_name(self, raw_name: object, index: int) -> str:
text = str(raw_name).strip() if raw_name is not None else ""
return text or f"layer_{index}"
def _resolve_save_path(
self,
filename: str,
format: str,
directory: str | None,
directory_path: str = "",
) -> Path:
ext = ".tiff" if format == "TIFF" else ".npz"
raw_filename = str(filename).strip() if filename is not None else ""
raw_directory = str(directory).strip() if directory is not None else ""
if not raw_directory:
raw_directory = str(directory_path).strip() if directory_path is not None else ""
if raw_directory:
dir_path = Path(raw_directory).expanduser()
if dir_path.exists() and not dir_path.is_dir():
raise ValueError("Directory input expects a folder path, not a file path.")
if not dir_path.exists():
if dir_path.suffix:
raise ValueError("Directory input expects a folder path, not a file path.")
dir_path.mkdir(parents=True, exist_ok=True)
filename_part = Path(raw_filename).name if raw_filename else ""
if not filename_part:
raise ValueError("No output filename selected — enter a file name when using a directory input.")
path = dir_path / filename_part
else:
if not raw_filename:
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(raw_filename).expanduser()
path.parent.mkdir(parents=True, exist_ok=True)
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
return path
def _unique_npz_key(self, raw_name: str, used_keys: set[str], index: int) -> str:
key = re.sub(r"[^0-9A-Za-z_]+", "_", str(raw_name).strip()).strip("_")
if not key:
key = f"layer_{index}"
if key[0].isdigit():
key = f"layer_{key}"
candidate = key
suffix = 2
while candidate in used_keys:
candidate = f"{key}_{suffix}"
suffix += 1
used_keys.add(candidate)
return candidate
def _layer_array_for_tiff(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data, dtype=np.float32)
if isinstance(layer, np.ndarray):
return image_to_uint8(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _layer_array_for_npz(self, layer: DataField | np.ndarray) -> np.ndarray:
if isinstance(layer, DataField):
return np.asarray(layer.data)
if isinstance(layer, np.ndarray):
return np.asarray(layer)
raise ValueError(f"Unsupported save layer type: {type(layer).__name__}")
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return ()

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, MeasureTable
@register_node(display_name="Statistics")
class Statistics:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("MEASURE_TABLE",)
RETURN_NAMES = ("stats",)
FUNCTION = "process"
DESCRIPTION = (
"Compute basic surface statistics: min, max, mean, RMS roughness, median, "
"and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms."
)
def process(self, field: DataField) -> tuple:
d = field.data
mean = float(d.mean())
rms = float(np.sqrt(np.mean((d - mean) ** 2)))
skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0
kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0
table = MeasureTable([
{"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z},
{"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z},
{"quantity": "mean", "value": mean, "unit": field.si_unit_z},
{"quantity": "RMS", "value": rms, "unit": field.si_unit_z},
{"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z},
{"quantity": "skewness", "value": skewness, "unit": ""},
{"quantity": "kurtosis", "value": kurtosis, "unit": ""},
{"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z},
])
return (table,)

130
backend/nodes/stats.py Normal file
View File

@@ -0,0 +1,130 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, LineData, MeasureTable
from backend.nodes.helpers import (
LINE_OPS,
TABLE_OPS,
ARRAY_OPS,
_scalar_payload,
_apply_scalar_unit,
_common_table_unit,
extract_numeric_table_values,
resolve_table_column_name,
)
@register_node(display_name="Stats")
class Stats:
"""Polymorphic scalar stats node for LINE, RECORD_TABLE, DATA_FIELD, or IMAGE inputs."""
_broadcast_value_fn = None
_current_node_id: str = ""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input": ("STATS_SOURCE",),
"column": ("STRING", {
"default": "value",
"choices_from_table_input": "input",
"show_when_source_type": {
"input": ["RECORD_TABLE"],
},
}),
"operation": ("STRING", {
"default": "mean",
"choices_by_source_type": {
"LINE": list(LINE_OPS.keys()),
"RECORD_TABLE": list(TABLE_OPS.keys()),
"DATA_FIELD": list(ARRAY_OPS.keys()),
"IMAGE": list(ARRAY_OPS.keys()),
},
"source_type_input": "input",
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
DESCRIPTION = (
"Compute a contextual scalar statistic from a LINE, record table, DATA_FIELD, or IMAGE. "
"The available operations adapt to the connected input type."
)
def process(self, input, operation: str, column: str = "value") -> tuple:
source_type, values, resolved_column = self._resolve_input_values(input, column)
if source_type == "RECORD_TABLE":
ops = TABLE_OPS
elif source_type == "LINE":
ops = LINE_OPS
else:
ops = ARRAY_OPS
if operation not in ops:
raise ValueError(f"Operation '{operation}' is not valid for {source_type} input.")
op_entry = ops[operation]
fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry
result = fn(values)
if Stats._broadcast_value_fn is not None:
Stats._broadcast_value_fn(
Stats._current_node_id,
_scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)),
)
return (result,)
def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str:
if source_type == "DATA_FIELD" and isinstance(input_value, DataField):
return _apply_scalar_unit(input_value.si_unit_z, operation)
if source_type == "LINE":
line_entry = LINE_OPS.get(operation)
explicit_unit = line_entry[1] if isinstance(line_entry, tuple) and len(line_entry) > 1 else ""
if explicit_unit:
return _apply_scalar_unit(explicit_unit, operation)
if isinstance(input_value, LineData):
return _apply_scalar_unit(input_value.y_unit, operation)
return ""
if source_type == "RECORD_TABLE" and isinstance(input_value, list) and column:
return _apply_scalar_unit(_common_table_unit(input_value, column), operation)
return ""
def _resolve_input_values(self, input_value, column: str) -> tuple[str, np.ndarray, str | None]:
if isinstance(input_value, DataField):
values = np.asarray(input_value.data, dtype=np.float64)
return ("DATA_FIELD", values.ravel(), None)
if isinstance(input_value, MeasureTable):
raise ValueError("Stats only accepts record tables, not measurement tables.")
if isinstance(input_value, list):
if not input_value:
raise ValueError("Stats requires a non-empty record table input.")
column_name = resolve_table_column_name(input_value, column)
values = extract_numeric_table_values(input_value, column_name)
if not values:
raise ValueError(f"Column '{column_name}' has no numeric values.")
return ("RECORD_TABLE", np.asarray(values, dtype=np.float64), column_name)
if isinstance(input_value, LineData):
values = np.asarray(input_value.data, dtype=np.float64)
if values.size == 0:
raise ValueError("Stats requires a non-empty input.")
return ("LINE", values.ravel(), None)
if isinstance(input_value, np.ndarray):
values = np.asarray(input_value, dtype=np.float64)
if values.size == 0:
raise ValueError("Stats requires a non-empty input.")
if values.ndim == 1:
return ("LINE", values.ravel(), None)
return ("IMAGE", values.ravel(), None)
raise ValueError(f"Unsupported Stats input type: {type(input_value).__name__}")

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview
from backend.nodes.helpers import _mask_overlay
@register_node(display_name="Threshold Mask")
class ThresholdMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["otsu", "absolute", "relative"],),
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
"direction": (["above", "below"],),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
DESCRIPTION = (
"Create a binary mask by thresholding data. "
"Otsu automatically finds the optimal threshold. "
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
data = field.data
if method == "otsu":
from skimage.filters import threshold_otsu
t = threshold_otsu(data)
elif method == "absolute":
t = float(threshold)
elif method == "relative":
dmin, dmax = data.min(), data.max()
t = dmin + float(threshold) * (dmax - dmin)
else:
raise ValueError(f"Unknown threshold method: {method}")
if direction == "above":
mask = (data >= t).astype(np.uint8) * 255
else:
mask = (data < t).astype(np.uint8) * 255
if ThresholdMask._broadcast_fn is not None:
overlay = _mask_overlay(field, mask)
ThresholdMask._broadcast_fn(
ThresholdMask._current_node_id, encode_preview(overlay),
)
return (mask,)

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from backend.node_registry import register_node
from backend.data_types import MeasureTable
from backend.nodes.helpers import _measurement_entry, _measurement_value, _scalar_payload
@register_node(display_name="Value Display")
class ValueDisplay:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("VALUE_SOURCE",),
"measurement": ("STRING", {
"default": "",
"choices_from_measure_input": "value",
"show_when_source_type": {
"value": ["MEASURE_TABLE"],
},
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "display_value"
DESCRIPTION = "Display a FLOAT, or a selected numeric row from a measurement table, and pass the value through unchanged."
_broadcast_value_fn = None
_current_node_id: str = ""
def display_value(self, value, measurement: str = "") -> tuple:
unit = ""
if isinstance(value, MeasureTable):
row = _measurement_entry(value, measurement)
numeric = _measurement_value(value, measurement)
unit = row.get("unit", "") if isinstance(row.get("unit"), str) else ""
else:
numeric = float(value)
if ValueDisplay._broadcast_value_fn is not None:
ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit))
return (numeric,)

90
backend/nodes/view_3d.py Normal file
View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import (
COLORMAPS,
DataField,
colormap_to_uint8,
normalize_for_colormap,
resolve_colormap_input,
)
@register_node(display_name="3D View")
class View3D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
},
"optional": {
"colormap_map": ("COLORMAP", {"label": "colormap"}),
},
}
RETURN_TYPES = ()
FUNCTION = "render"
OUTPUT_NODE = True
DESCRIPTION = (
"Interactive 3D surface view of a DATA_FIELD. "
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
)
_broadcast_mesh_fn = None
_current_node_id: str = ""
def render(
self, field: DataField,
colormap: str, z_scale: float, resolution: int, colormap_map=None,
) -> tuple:
import base64
data = field.data
yres, xres = data.shape
step_y = max(1, yres // resolution)
step_x = max(1, xres // resolution)
z = data[::step_y, ::step_x].astype(np.float32)
ny, nx = z.shape
zmin, zmax = float(z.min()), float(z.max())
z_norm = normalize_for_colormap(
z,
offset=field.display_offset,
scale=field.display_scale,
data_min=float(field.data.min()),
data_max=float(field.data.max()),
)
resolved_colormap = resolve_colormap_input(
colormap,
colormap_input=colormap_map,
inherited=field.colormap,
default="gray",
)
colors_u8 = colormap_to_uint8(z_norm, resolved_colormap)
z_b64 = base64.b64encode(z.tobytes()).decode()
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
mesh_data = {
"width": nx,
"height": ny,
"z_data": z_b64,
"colors": colors_b64,
"z_min": zmin,
"z_max": zmax,
"z_scale": float(z_scale * 0.1),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}
if View3D._broadcast_mesh_fn is not None:
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
return ()