initial commit
This commit is contained in:
2
backend/nodes/__init__.py
Normal file
2
backend/nodes/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Import all node modules to trigger @register_node decorators.
|
||||
from . import io, filters, level, analysis, grains, display
|
||||
471
backend/nodes/analysis.py
Normal file
471
backend/nodes/analysis.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Analysis nodes — statistics, histograms, FFT, cross sections.
|
||||
|
||||
Gwyddion equivalents:
|
||||
StatisticsNode → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h)
|
||||
HeightHistogram → DH (height distribution), gwy_data_field_dh
|
||||
FFT2D → gwy_data_field_2dfft + gwy_data_field_2dpsdf
|
||||
CrossSection → gwy_data_field_get_profile (libprocess/datafield.c)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StatisticsNode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Statistics")
|
||||
class StatisticsNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TABLE",)
|
||||
RETURN_NAMES = ("stats",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute basic surface statistics: min, max, mean, RMS roughness, median, "
|
||||
"and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms."
|
||||
)
|
||||
|
||||
def process(self, field: DataField) -> tuple:
|
||||
d = field.data
|
||||
mean = float(d.mean())
|
||||
rms = float(np.sqrt(np.mean((d - mean) ** 2)))
|
||||
skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0
|
||||
kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0
|
||||
|
||||
table = [
|
||||
{"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z},
|
||||
{"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z},
|
||||
{"quantity": "mean", "value": mean, "unit": field.si_unit_z},
|
||||
{"quantity": "RMS", "value": rms, "unit": field.si_unit_z},
|
||||
{"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z},
|
||||
{"quantity": "skewness", "value": skewness, "unit": ""},
|
||||
{"quantity": "kurtosis", "value": kurtosis, "unit": ""},
|
||||
{"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z},
|
||||
]
|
||||
return (table,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HeightHistogram
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Height Histogram")
|
||||
class HeightHistogram:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LINE", "LINE")
|
||||
RETURN_NAMES = ("counts", "bin_centers")
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute the height distribution histogram (DH). "
|
||||
"Equivalent to gwy_data_field_dh."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, n_bins: int) -> tuple:
|
||||
counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
|
||||
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
|
||||
return (counts.astype(np.float64), bin_centers)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FFT2D
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="2D FFT")
|
||||
class FFT2D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"windowing": (["hann", "hamming", "blackman", "none"],),
|
||||
"level": (["mean", "plane", "none"],),
|
||||
"output": (["log_magnitude", "magnitude", "phase", "psdf"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("spectrum",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
|
||||
"Output can be log magnitude, magnitude, phase, or PSDF. "
|
||||
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, windowing: str, level: str, output: str) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
# Level subtraction (Gwyddion-style, before windowing)
|
||||
if level == "mean":
|
||||
data -= data.mean()
|
||||
elif level == "plane":
|
||||
# Fit and subtract a plane: z = a + b*x + c*y
|
||||
yy, xx = np.mgrid[0:yres, 0:xres]
|
||||
xx_f = xx.ravel().astype(np.float64)
|
||||
yy_f = yy.ravel().astype(np.float64)
|
||||
zz_f = data.ravel()
|
||||
A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f])
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None)
|
||||
plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy)
|
||||
data -= plane
|
||||
|
||||
# Windowing (Gwyddion uses (i+0.5)/n centred formulation)
|
||||
if windowing != "none":
|
||||
t_y = (np.arange(yres) + 0.5) / yres
|
||||
t_x = (np.arange(xres) + 0.5) / xres
|
||||
if windowing == "hann":
|
||||
wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y)
|
||||
wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x)
|
||||
elif windowing == "hamming":
|
||||
wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y)
|
||||
wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x)
|
||||
elif windowing == "blackman":
|
||||
wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y)
|
||||
wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x)
|
||||
else:
|
||||
wy = np.ones(yres)
|
||||
wx = np.ones(xres)
|
||||
data *= np.outer(wy, wx)
|
||||
|
||||
# 2D FFT, shifted so DC is at centre
|
||||
F = np.fft.fftshift(np.fft.fft2(data))
|
||||
n = xres * yres
|
||||
|
||||
if output == "log_magnitude":
|
||||
mag = np.abs(F)
|
||||
# Log scale with floor to avoid log(0)
|
||||
result = np.log1p(mag)
|
||||
elif output == "magnitude":
|
||||
result = np.abs(F)
|
||||
elif output == "phase":
|
||||
result = np.angle(F)
|
||||
elif output == "psdf":
|
||||
# Gwyddion-equivalent PSDF: |F|^2 * dx * dy / (n * 4π²)
|
||||
dx = field.xreal / xres
|
||||
dy = field.yreal / yres
|
||||
result = (np.abs(F) ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
|
||||
else:
|
||||
result = np.abs(F)
|
||||
|
||||
# Calibrate the output field in spatial-frequency units
|
||||
if output == "psdf":
|
||||
# Gwyddion uses angular frequency: 2π/dx, 2π/dy
|
||||
freq_xreal = 2.0 * np.pi * xres / field.xreal
|
||||
freq_yreal = 2.0 * np.pi * yres / field.yreal
|
||||
z_unit = f"({field.si_unit_z})^2 m^2"
|
||||
else:
|
||||
freq_xreal = xres / field.xreal
|
||||
freq_yreal = yres / field.yreal
|
||||
z_unit = field.si_unit_z
|
||||
|
||||
out_field = DataField(
|
||||
data=result,
|
||||
xreal=freq_xreal,
|
||||
yreal=freq_yreal,
|
||||
si_unit_xy="1/m",
|
||||
si_unit_z=z_unit,
|
||||
domain="frequency",
|
||||
)
|
||||
return (out_field,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CrossSection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extend_to_edges(x1, y1, x2, y2):
|
||||
"""
|
||||
Extend the line through (x1,y1)-(x2,y2) to the boundaries of [0,1]x[0,1].
|
||||
Returns the two intersection points (clipped to the unit square).
|
||||
"""
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Collect parametric t values where line hits each boundary
|
||||
t_candidates = []
|
||||
if abs(dx) > 1e-12:
|
||||
for bx in (0.0, 1.0):
|
||||
t = (bx - x1) / dx
|
||||
y_at_t = y1 + t * dy
|
||||
if -1e-9 <= y_at_t <= 1.0 + 1e-9:
|
||||
t_candidates.append(t)
|
||||
if abs(dy) > 1e-12:
|
||||
for by in (0.0, 1.0):
|
||||
t = (by - y1) / dy
|
||||
x_at_t = x1 + t * dx
|
||||
if -1e-9 <= x_at_t <= 1.0 + 1e-9:
|
||||
t_candidates.append(t)
|
||||
|
||||
if len(t_candidates) < 2:
|
||||
return x1, y1, x2, y2
|
||||
|
||||
t_min = min(t_candidates)
|
||||
t_max = max(t_candidates)
|
||||
|
||||
return (
|
||||
np.clip(x1 + t_min * dx, 0, 1),
|
||||
np.clip(y1 + t_min * dy, 0, 1),
|
||||
np.clip(x1 + t_max * dx, 0, 1),
|
||||
np.clip(y1 + t_max * dy, 0, 1),
|
||||
)
|
||||
|
||||
|
||||
@register_node(display_name="Cross Section")
|
||||
class CrossSection:
|
||||
"""Extract a 1-D height profile along an arbitrary line across the image."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"extend": (["none", "to_edges"],),
|
||||
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"point_a": ("COORD",),
|
||||
"point_b": ("COORD",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LINE",)
|
||||
RETURN_NAMES = ("profile",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Extract a cross-section profile along a line between two points. "
|
||||
"Drag the markers on the image to set the line endpoints. "
|
||||
"Equivalent to gwy_data_field_get_profile."
|
||||
)
|
||||
|
||||
_broadcast_overlay_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(
|
||||
self, field: DataField,
|
||||
x1: float, y1: float, x2: float, y2: float,
|
||||
extend: str, n_samples: int,
|
||||
point_a=None, point_b=None,
|
||||
) -> tuple:
|
||||
from scipy.ndimage import map_coordinates
|
||||
import io, base64
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
# COORD inputs override widget values
|
||||
if point_a is not None:
|
||||
x1, y1 = float(point_a[0]), float(point_a[1])
|
||||
if point_b is not None:
|
||||
x2, y2 = float(point_b[0]), float(point_b[1])
|
||||
|
||||
# Remember marker positions (before extend)
|
||||
marker_x1, marker_y1 = float(x1), float(y1)
|
||||
marker_x2, marker_y2 = float(x2), float(y2)
|
||||
|
||||
xres, yres = field.xres, field.yres
|
||||
|
||||
if extend == "to_edges":
|
||||
x1, y1, x2, y2 = _extend_to_edges(
|
||||
float(x1), float(y1), float(x2), float(y2),
|
||||
)
|
||||
|
||||
# Convert fractional [0,1] to pixel indices [0, res-1]
|
||||
px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1)
|
||||
px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1)
|
||||
|
||||
# Number of sample points
|
||||
line_len_px = np.hypot(px2 - px1, py2 - py1)
|
||||
if n_samples <= 0:
|
||||
n_samples = max(2, int(np.ceil(line_len_px)))
|
||||
|
||||
# Sample coordinates along the line
|
||||
t = np.linspace(0, 1, n_samples)
|
||||
coords_y = py1 + t * (py2 - py1)
|
||||
coords_x = px1 + t * (px2 - px1)
|
||||
|
||||
# Interpolate values along the line (cubic spline)
|
||||
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
|
||||
|
||||
# Broadcast overlay image with marker positions
|
||||
if CrossSection._broadcast_overlay_fn is not None:
|
||||
fig = Figure(figsize=(3, 3), dpi=100)
|
||||
ax = fig.add_axes([0, 0, 1, 1])
|
||||
ax.imshow(field.data, cmap="viridis", aspect="auto")
|
||||
ax.axis("off")
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
||||
buf.seek(0)
|
||||
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
|
||||
|
||||
CrossSection._broadcast_overlay_fn(
|
||||
CrossSection._current_node_id,
|
||||
{
|
||||
"image": image_uri,
|
||||
"x1": marker_x1, "y1": marker_y1,
|
||||
"x2": marker_x2, "y2": marker_y2,
|
||||
"a_locked": point_a is not None,
|
||||
"b_locked": point_b is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return (profile.astype(np.float64),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LineMath — single scalar measurement from a LINE profile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_rq(d):
|
||||
"""RMS of deviations from mean."""
|
||||
return float(np.sqrt(np.mean(d * d)))
|
||||
|
||||
# Registry: name → (function(z) → float, unit_label)
|
||||
# All functions receive the raw 1-D profile as float64.
|
||||
LINE_OPS: dict[str, tuple] = {}
|
||||
|
||||
|
||||
def _line_op(name, unit=""):
|
||||
"""Decorator to register a LINE operation."""
|
||||
def decorator(fn):
|
||||
LINE_OPS[name] = (fn, unit)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
|
||||
# ── Basic statistics ──────────────────────────────────────────────────────
|
||||
|
||||
@_line_op("min")
|
||||
def _op_min(z):
|
||||
return float(z.min())
|
||||
|
||||
@_line_op("max")
|
||||
def _op_max(z):
|
||||
return float(z.max())
|
||||
|
||||
@_line_op("mean")
|
||||
def _op_mean(z):
|
||||
return float(z.mean())
|
||||
|
||||
@_line_op("median")
|
||||
def _op_median(z):
|
||||
return float(np.median(z))
|
||||
|
||||
@_line_op("sum")
|
||||
def _op_sum(z):
|
||||
return float(z.sum())
|
||||
|
||||
@_line_op("range")
|
||||
def _op_range(z):
|
||||
return float(z.max() - z.min())
|
||||
|
||||
@_line_op("length", unit="pts")
|
||||
def _op_length(z):
|
||||
return float(len(z))
|
||||
|
||||
@_line_op("rms")
|
||||
def _op_rms(z):
|
||||
return float(np.sqrt(np.mean(z * z)))
|
||||
|
||||
|
||||
# ── Roughness parameters ──────────────────────────
|
||||
|
||||
@_line_op("Ra")
|
||||
def _op_ra(z):
|
||||
return float(np.mean(np.abs(z - z.mean())))
|
||||
|
||||
@_line_op("Rq")
|
||||
def _op_rq(z):
|
||||
d = z - z.mean()
|
||||
return _safe_rq(d)
|
||||
|
||||
@_line_op("Rsk")
|
||||
def _op_rsk(z):
|
||||
d = z - z.mean()
|
||||
rq = _safe_rq(d)
|
||||
return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0
|
||||
|
||||
@_line_op("Rku")
|
||||
def _op_rku(z):
|
||||
d = z - z.mean()
|
||||
rq = _safe_rq(d)
|
||||
return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0
|
||||
|
||||
@_line_op("Rp")
|
||||
def _op_rp(z):
|
||||
return float((z - z.mean()).max())
|
||||
|
||||
@_line_op("Rv")
|
||||
def _op_rv(z):
|
||||
return float(-(z - z.mean()).min())
|
||||
|
||||
@_line_op("Rt")
|
||||
def _op_rt(z):
|
||||
d = z - z.mean()
|
||||
return float(d.max() - d.min())
|
||||
|
||||
@_line_op("Dq")
|
||||
def _op_dq(z):
|
||||
"""RMS slope (first derivative RMS)."""
|
||||
dz = np.diff(z)
|
||||
return float(np.sqrt(np.mean(dz * dz)))
|
||||
|
||||
@_line_op("Da")
|
||||
def _op_da(z):
|
||||
"""Mean absolute slope."""
|
||||
return float(np.mean(np.abs(np.diff(z))))
|
||||
|
||||
|
||||
@register_node(display_name="Line Math")
|
||||
class LineMath:
|
||||
"""Compute a single scalar value from a LINE profile."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"line": ("LINE",),
|
||||
"operation": (list(LINE_OPS.keys()),),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TABLE",)
|
||||
RETURN_NAMES = ("result",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute a single scalar measurement from a LINE profile. "
|
||||
"Includes basic stats and Gwyddion-convention roughness parameters."
|
||||
)
|
||||
|
||||
def process(self, line, operation: str) -> tuple:
|
||||
z = np.asarray(line, dtype=np.float64).ravel()
|
||||
fn, unit = LINE_OPS[operation]
|
||||
value = fn(z)
|
||||
table = [{"quantity": operation, "value": value, "unit": unit}]
|
||||
return (table,)
|
||||
165
backend/nodes/display.py
Normal file
165
backend/nodes/display.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Display / output nodes.
|
||||
|
||||
Preview accepts both DATA_FIELD and IMAGE via optional inputs —
|
||||
connect whichever type you have. The server injects _broadcast_fn
|
||||
before execution begins.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview
|
||||
|
||||
|
||||
@register_node(display_name="Preview")
|
||||
class PreviewImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],),
|
||||
},
|
||||
"optional": {
|
||||
"image": ("IMAGE",),
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "preview"
|
||||
CATEGORY = "display"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
|
||||
|
||||
_broadcast_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
|
||||
# Prefer field if both are connected; accept whichever is provided
|
||||
if field is not None:
|
||||
arr_u8 = datafield_to_uint8(field, colormap)
|
||||
elif image is not None:
|
||||
if image.dtype != np.uint8:
|
||||
imin, imax = image.min(), image.max()
|
||||
if imax > imin:
|
||||
norm = (image - imin) / (imax - imin)
|
||||
else:
|
||||
norm = np.zeros_like(image)
|
||||
arr_u8 = (norm * 255).astype(np.uint8)
|
||||
else:
|
||||
arr_u8 = image
|
||||
|
||||
if arr_u8.ndim == 2 and colormap != "gray":
|
||||
import matplotlib.cm as cm
|
||||
cmap = cm.get_cmap(colormap)
|
||||
rgba = cmap(arr_u8.astype(np.float32) / 255.0)
|
||||
arr_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
else:
|
||||
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
|
||||
|
||||
data_uri = encode_preview(arr_u8)
|
||||
|
||||
if PreviewImage._broadcast_fn is not None:
|
||||
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
|
||||
|
||||
return ()
|
||||
|
||||
|
||||
@register_node(display_name="3D View")
|
||||
class View3D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],),
|
||||
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
|
||||
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "render"
|
||||
CATEGORY = "display"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = (
|
||||
"Interactive 3D surface view of a DATA_FIELD. "
|
||||
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
|
||||
)
|
||||
|
||||
_broadcast_mesh_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def render(
|
||||
self, field: DataField,
|
||||
colormap: str, z_scale: float, resolution: int,
|
||||
) -> tuple:
|
||||
import matplotlib.cm as cm
|
||||
import base64
|
||||
|
||||
data = field.data
|
||||
yres, xres = data.shape
|
||||
|
||||
# Downsample if larger than resolution
|
||||
step_y = max(1, yres // resolution)
|
||||
step_x = max(1, xres // resolution)
|
||||
z = data[::step_y, ::step_x].astype(np.float32)
|
||||
ny, nx = z.shape
|
||||
|
||||
# Normalize for colormap
|
||||
zmin, zmax = float(z.min()), float(z.max())
|
||||
if zmax > zmin:
|
||||
z_norm = (z - zmin) / (zmax - zmin)
|
||||
else:
|
||||
z_norm = np.zeros_like(z)
|
||||
|
||||
cmap = cm.get_cmap(colormap)
|
||||
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
|
||||
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
|
||||
# Base64-encode arrays for efficient WS transport
|
||||
z_b64 = base64.b64encode(z.tobytes()).decode()
|
||||
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
|
||||
|
||||
mesh_data = {
|
||||
"width": nx,
|
||||
"height": ny,
|
||||
"z_data": z_b64,
|
||||
"colors": colors_b64,
|
||||
"z_min": zmin,
|
||||
"z_max": zmax,
|
||||
"z_scale": float(z_scale),
|
||||
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
|
||||
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
|
||||
}
|
||||
|
||||
if View3D._broadcast_mesh_fn is not None:
|
||||
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
|
||||
|
||||
return ()
|
||||
|
||||
|
||||
@register_node(display_name="Print Table")
|
||||
class PrintTable:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"table": ("TABLE",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "print_table"
|
||||
CATEGORY = "display"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Send a TABLE to the browser as a WebSocket message for display."
|
||||
|
||||
_broadcast_table_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def print_table(self, table: list) -> tuple:
|
||||
if PrintTable._broadcast_table_fn is not None:
|
||||
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
|
||||
return ()
|
||||
115
backend/nodes/filters.py
Normal file
115
backend/nodes/filters.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Filter nodes — Gwyddion-equivalent image filters.
|
||||
|
||||
Gwyddion equivalents:
|
||||
GaussianFilter → gwy_data_field_filter_gaussian
|
||||
MedianFilter → gwy_data_field_filter_median
|
||||
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GaussianFilter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Gaussian Filter")
|
||||
class GaussianFilter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("filtered",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "filters"
|
||||
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
|
||||
|
||||
def process(self, field: DataField, sigma: float) -> tuple:
|
||||
from scipy.ndimage import gaussian_filter
|
||||
data = gaussian_filter(field.data.copy(), sigma=float(sigma))
|
||||
return (field.replace(data=data),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MedianFilter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Median Filter")
|
||||
class MedianFilter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("filtered",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "filters"
|
||||
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
|
||||
|
||||
def process(self, field: DataField, size: int) -> tuple:
|
||||
from scipy.ndimage import median_filter
|
||||
size = max(1, int(size))
|
||||
data = median_filter(field.data.copy(), size=size)
|
||||
return (field.replace(data=data),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EdgeDetect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Edge Detect")
|
||||
class EdgeDetect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["sobel", "prewitt", "laplacian", "log"],),
|
||||
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("edges",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "filters"
|
||||
DESCRIPTION = (
|
||||
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
|
||||
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, method: str, sigma: float) -> tuple:
|
||||
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
|
||||
data = field.data.copy()
|
||||
|
||||
if method == "sobel":
|
||||
sx = sobel(data, axis=1)
|
||||
sy = sobel(data, axis=0)
|
||||
result = np.hypot(sx, sy)
|
||||
elif method == "prewitt":
|
||||
px = prewitt(data, axis=1)
|
||||
py = prewitt(data, axis=0)
|
||||
result = np.hypot(px, py)
|
||||
elif method == "laplacian":
|
||||
result = laplace(data)
|
||||
elif method == "log":
|
||||
result = gaussian_laplace(data, sigma=float(sigma))
|
||||
else:
|
||||
raise ValueError(f"Unknown edge detection method: {method}")
|
||||
|
||||
return (field.replace(data=result),)
|
||||
127
backend/nodes/grains.py
Normal file
127
backend/nodes/grains.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Grain/feature detection nodes.
|
||||
|
||||
Gwyddion equivalents:
|
||||
ThresholdMask → threshold.c / otsu_threshold.c
|
||||
GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ThresholdMask
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Threshold Mask")
|
||||
class ThresholdMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["otsu", "absolute", "relative"],),
|
||||
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
|
||||
"direction": (["above", "below"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "grains"
|
||||
DESCRIPTION = (
|
||||
"Create a binary mask by thresholding data. "
|
||||
"Otsu automatically finds the optimal threshold. "
|
||||
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
|
||||
data = field.data
|
||||
|
||||
if method == "otsu":
|
||||
from skimage.filters import threshold_otsu
|
||||
t = threshold_otsu(data)
|
||||
elif method == "absolute":
|
||||
t = float(threshold)
|
||||
elif method == "relative":
|
||||
# threshold is a fraction [0, 1] of the data range
|
||||
dmin, dmax = data.min(), data.max()
|
||||
t = dmin + float(threshold) * (dmax - dmin)
|
||||
else:
|
||||
raise ValueError(f"Unknown threshold method: {method}")
|
||||
|
||||
if direction == "above":
|
||||
mask = (data >= t).astype(np.uint8) * 255
|
||||
else:
|
||||
mask = (data < t).astype(np.uint8) * 255
|
||||
|
||||
return (mask,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GrainAnalysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Grain Analysis")
|
||||
class GrainAnalysis:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"mask": ("IMAGE",),
|
||||
"min_size": ("INT", {"default": 10, "min": 1, "max": 100000, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TABLE",)
|
||||
RETURN_NAMES = ("grain_stats",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "grains"
|
||||
DESCRIPTION = (
|
||||
"Label connected grain regions in a binary mask and compute per-grain statistics: "
|
||||
"area, equivalent diameter, mean/max height, bounding box. "
|
||||
"Equivalent to gwy_data_field_grains_get_values."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:
|
||||
from scipy.ndimage import label, find_objects
|
||||
|
||||
binary = (mask > 127).astype(np.int32)
|
||||
labeled, n_grains = label(binary)
|
||||
|
||||
pixel_area = field.dx * field.dy # m^2 per pixel
|
||||
|
||||
rows = []
|
||||
for grain_id in range(1, n_grains + 1):
|
||||
grain_pixels = labeled == grain_id
|
||||
area_px = int(grain_pixels.sum())
|
||||
if area_px < min_size:
|
||||
continue
|
||||
|
||||
area_m2 = area_px * pixel_area
|
||||
equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi))
|
||||
|
||||
heights = field.data[grain_pixels]
|
||||
mean_h = float(heights.mean())
|
||||
max_h = float(heights.max())
|
||||
|
||||
# Bounding box
|
||||
ys, xs = np.where(grain_pixels)
|
||||
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"
|
||||
|
||||
rows.append({
|
||||
"grain_id": grain_id,
|
||||
"area_px": area_px,
|
||||
"area_m2": area_m2,
|
||||
"equiv_diam_m": equiv_diam,
|
||||
"mean_height": mean_h,
|
||||
"max_height": max_h,
|
||||
"bbox": bbox,
|
||||
})
|
||||
|
||||
return (rows,)
|
||||
277
backend/nodes/io.py
Normal file
277
backend/nodes/io.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
I/O nodes: load and save images and SPM data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, encode_preview, image_to_uint8
|
||||
|
||||
# Resolved at server startup so nodes know where to look
|
||||
INPUT_DIR = Path(__file__).parent.parent.parent / "input"
|
||||
OUTPUT_DIR = Path(__file__).parent.parent.parent / "output"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadImage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Load Image")
|
||||
class LoadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("FILE_PICKER", {"default": ""}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
|
||||
RETURN_NAMES = ("image", "field")
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD."
|
||||
|
||||
def load(self, filename: str):
|
||||
# Accept absolute paths or filenames relative to input/
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = INPUT_DIR / filename
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
if ext in (".npy",):
|
||||
arr = np.load(str(path)).astype(np.float64)
|
||||
elif ext in (".npz",):
|
||||
npz = np.load(str(path))
|
||||
key = list(npz.files)[0]
|
||||
arr = npz[key].astype(np.float64)
|
||||
else:
|
||||
from PIL import Image
|
||||
img = Image.open(str(path))
|
||||
arr = np.array(img)
|
||||
if arr.dtype != np.uint8:
|
||||
arr = arr.astype(np.float64)
|
||||
|
||||
# Convert to float64 grayscale for the DATA_FIELD output
|
||||
if arr.ndim == 3:
|
||||
gray = np.mean(arr.astype(np.float64), axis=2)
|
||||
else:
|
||||
gray = arr.astype(np.float64)
|
||||
|
||||
field = DataField(data=gray)
|
||||
return (arr, field)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadSPM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Load SPM File")
|
||||
class LoadSPM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("FILE_PICKER", {"default": ""}),
|
||||
"channel": ("STRING", {"default": "Z"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("field",)
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField."
|
||||
|
||||
def load(self, filename: str, channel: str = "Z"):
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = INPUT_DIR / filename
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
|
||||
if ext == ".gwy":
|
||||
return (self._load_gwy(path, channel),)
|
||||
elif ext == ".sxm":
|
||||
return (self._load_sxm(path, channel),)
|
||||
elif ext in (".ibw",):
|
||||
return (self._load_ibw(path),)
|
||||
elif ext in (".npy",):
|
||||
data = np.load(str(path)).astype(np.float64)
|
||||
return (DataField(data=data),)
|
||||
elif ext in (".npz",):
|
||||
npz = np.load(str(path))
|
||||
key = list(npz.files)[0]
|
||||
return (DataField(data=npz[key].astype(np.float64)),)
|
||||
else:
|
||||
raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz")
|
||||
|
||||
def _load_gwy(self, path: Path, channel: str) -> DataField:
|
||||
try:
|
||||
import gwyfile
|
||||
except ImportError:
|
||||
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
|
||||
|
||||
obj = gwyfile.load(str(path))
|
||||
channels = gwyfile.util.get_datafields(obj)
|
||||
if not channels:
|
||||
raise ValueError(f"No data channels found in {path.name}")
|
||||
|
||||
# Try requested channel name, fall back to first available
|
||||
ch = None
|
||||
for key, df in channels.items():
|
||||
if channel.lower() in key.lower():
|
||||
ch = df
|
||||
break
|
||||
if ch is None:
|
||||
ch = next(iter(channels.values()))
|
||||
|
||||
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
|
||||
return DataField(
|
||||
data=data,
|
||||
xreal=float(ch.xreal),
|
||||
yreal=float(ch.yreal),
|
||||
xoff=float(getattr(ch, "xoff", 0.0)),
|
||||
yoff=float(getattr(ch, "yoff", 0.0)),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
)
|
||||
|
||||
def _load_sxm(self, path: Path, channel: str) -> DataField:
|
||||
try:
|
||||
import nanonispy as nap
|
||||
except ImportError:
|
||||
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
|
||||
|
||||
sxm = nap.read.Scan(str(path))
|
||||
signals = sxm.signals
|
||||
|
||||
# Pick channel
|
||||
ch_key = None
|
||||
for key in signals:
|
||||
if channel.upper() in key.upper():
|
||||
ch_key = key
|
||||
break
|
||||
if ch_key is None:
|
||||
ch_key = next(iter(signals))
|
||||
|
||||
data = signals[ch_key].get("forward", list(signals[ch_key].values())[0])
|
||||
data = np.asarray(data, dtype=np.float64)
|
||||
if data.ndim != 2:
|
||||
data = data.reshape(data.shape[-2], data.shape[-1])
|
||||
|
||||
header = sxm.header
|
||||
scan_range = header.get("scan_range", [1e-6, 1e-6])
|
||||
return DataField(
|
||||
data=data,
|
||||
xreal=float(scan_range[0]),
|
||||
yreal=float(scan_range[1]),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
)
|
||||
|
||||
def _load_ibw(self, path: Path) -> DataField:
|
||||
try:
|
||||
import igor.igorpy as igorpy
|
||||
wave = igorpy.load(str(path))
|
||||
data = wave.wave["wData"].squeeze().astype(np.float64)
|
||||
except ImportError:
|
||||
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
|
||||
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
elif data.ndim != 2:
|
||||
data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1)
|
||||
|
||||
return DataField(data=data, si_unit_xy="m", si_unit_z="m")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coordinate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Coordinate")
|
||||
class Coordinate:
|
||||
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("COORD",)
|
||||
RETURN_NAMES = ("point",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
|
||||
|
||||
def process(self, x: float, y: float) -> tuple:
|
||||
return ((float(x), float(y)),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SaveImage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Save Image")
|
||||
class SaveImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"filename_prefix": ("STRING", {"default": "output"}),
|
||||
"format": (["PNG", "TIFF", "NPY"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
CATEGORY = "io"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Save an image or array to the output folder."
|
||||
|
||||
# Injected by server.py before execution begins
|
||||
_broadcast_preview = None
|
||||
|
||||
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"):
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Find next available filename
|
||||
idx = 1
|
||||
while True:
|
||||
name = f"{filename_prefix}_{idx:04d}"
|
||||
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
|
||||
if not candidate.exists():
|
||||
break
|
||||
idx += 1
|
||||
|
||||
if format == "NPY":
|
||||
np.save(str(OUTPUT_DIR / f"{name}.npy"), image)
|
||||
else:
|
||||
from PIL import Image
|
||||
arr = image_to_uint8(image)
|
||||
if arr.ndim == 2:
|
||||
pil_img = Image.fromarray(arr, mode="L")
|
||||
else:
|
||||
pil_img = Image.fromarray(arr, mode="RGB")
|
||||
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
|
||||
|
||||
# Emit preview over WebSocket if callback is set
|
||||
if SaveImage._broadcast_preview is not None:
|
||||
arr_u8 = image_to_uint8(image)
|
||||
data_uri = encode_preview(arr_u8)
|
||||
SaveImage._broadcast_preview(data_uri)
|
||||
|
||||
return ()
|
||||
150
backend/nodes/level.py
Normal file
150
backend/nodes/level.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Leveling nodes — background removal and zero correction.
|
||||
|
||||
Gwyddion equivalents:
|
||||
PlaneLevelField → gwy_data_field_fit_plane + gwy_data_field_plane_level
|
||||
PolyLevelField → gwy_data_field_fit_polynom (via level.c polylevel module)
|
||||
FixZero → fix_zero in level.c
|
||||
|
||||
Plane-fit algorithm follows Gwyddion's level.h definition:
|
||||
z_fit = pa + pbx * x + pby * y (least-squares over all pixels)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PlaneLevelField
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Plane Level")
|
||||
class PlaneLevelField:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("leveled",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "level"
|
||||
DESCRIPTION = (
|
||||
"Fit and subtract a least-squares plane from the data. "
|
||||
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
|
||||
)
|
||||
|
||||
def process(self, field: DataField) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
# Normalised coordinate grids in [0, 1]
|
||||
x = np.linspace(0.0, 1.0, xres)
|
||||
y = np.linspace(0.0, 1.0, yres)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
# Design matrix: [1, x, y] shape (N, 3)
|
||||
A = np.column_stack([
|
||||
np.ones(xres * yres),
|
||||
xx.ravel(),
|
||||
yy.ravel(),
|
||||
])
|
||||
z = data.ravel()
|
||||
|
||||
# Least-squares: solve A @ [pa, pbx, pby] = z
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
|
||||
pa, pbx, pby = coeffs
|
||||
|
||||
plane = (pa + pbx * xx + pby * yy)
|
||||
return (field.replace(data=data - plane),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PolyLevelField
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Polynomial Level")
|
||||
class PolyLevelField:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
|
||||
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
|
||||
RETURN_NAMES = ("leveled", "background")
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "level"
|
||||
DESCRIPTION = (
|
||||
"Fit and subtract a polynomial background of given degree in x and y. "
|
||||
"Equivalent to gwy_data_field_fit_polynom."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
x = np.linspace(0.0, 1.0, xres)
|
||||
y = np.linspace(0.0, 1.0, yres)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
# Build Vandermonde-style design matrix with all monomials x^i * y^j
|
||||
cols = []
|
||||
for i in range(degree_x + 1):
|
||||
for j in range(degree_y + 1):
|
||||
cols.append((xx ** i * yy ** j).ravel())
|
||||
A = np.column_stack(cols)
|
||||
z = data.ravel()
|
||||
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
|
||||
|
||||
background = (A @ coeffs).reshape(yres, xres)
|
||||
leveled = data - background
|
||||
|
||||
return (field.replace(data=leveled), field.replace(data=background))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FixZero
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Fix Zero")
|
||||
class FixZero:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["min", "mean", "median"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("zeroed",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "level"
|
||||
DESCRIPTION = (
|
||||
"Shift data so that the minimum (or mean/median) is zero. "
|
||||
"Equivalent to fix_zero in Gwyddion's level.c."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, method: str) -> tuple:
|
||||
data = field.data.copy()
|
||||
if method == "min":
|
||||
data -= data.min()
|
||||
elif method == "mean":
|
||||
data -= data.mean()
|
||||
elif method == "median":
|
||||
data -= np.median(data)
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
return (field.replace(data=data),)
|
||||
Reference in New Issue
Block a user