add snapshot tool, masks, and build for mac

This commit is contained in:
2026-03-23 21:52:17 -07:00
parent 080eefbef6
commit a34b1c980d
29 changed files with 2016 additions and 170 deletions

View File

@@ -177,13 +177,19 @@ class ExecutionEngine:
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import SaveImage
PreviewImage._broadcast_fn = on_preview
ThresholdMask._broadcast_fn = on_preview
MaskMorphology._broadcast_fn = on_preview
MaskInvert._broadcast_fn = on_preview
MaskCombine._broadcast_fn = on_preview
View3D._broadcast_mesh_fn = on_mesh
PrintTable._broadcast_table_fn = on_table
CrossSection._broadcast_overlay_fn = on_overlay
LineCursors._broadcast_overlay_fn = on_overlay
SaveImage._broadcast_preview = (
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
)
@@ -191,8 +197,10 @@ class ExecutionEngine:
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection
if cls in (PreviewImage, PrintTable, View3D, CrossSection):
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine):
cls._current_node_id = node_id
def _auto_preview(
@@ -206,12 +214,16 @@ class ExecutionEngine:
"""
After every node executes, inspect its outputs and broadcast
a preview for the first DATA_FIELD, IMAGE, or TABLE found.
Skip nodes that broadcast their own custom preview.
"""
import numpy as np
from backend.data_types import (
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
)
if getattr(cls, "_CUSTOM_PREVIEW", False):
return
return_types = getattr(cls, "RETURN_TYPES", ())
for slot, type_name in enumerate(return_types):

View File

@@ -36,7 +36,7 @@ def main() -> None:
app = create_app(loop)
log.info("=" * 60)
log.info(" Argonode — Node-based image analysis")
log.info(" argonode — Node-based image analysis")
log.info(" Open your browser at http://%s:%d", HOST, PORT)
log.info("=" * 60)

View File

@@ -1,2 +1,2 @@
# Import all node modules to trigger @register_node decorators.
from . import io, filters, level, analysis, grains, display
from . import io, filters, level, analysis, grains, mask, display

View File

@@ -69,6 +69,7 @@ class HeightHistogram:
"required": {
"field": ("DATA_FIELD",),
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
"y_scale": (["linear", "log"],),
}
}
@@ -78,13 +79,150 @@ class HeightHistogram:
CATEGORY = "analysis"
DESCRIPTION = (
"Compute the height distribution histogram (DH). "
"Use log scale to reveal small peaks next to a dominant background. "
"Equivalent to gwy_data_field_dh."
)
def process(self, field: DataField, n_bins: int) -> tuple:
def process(self, field: DataField, n_bins: int, y_scale: str = "linear") -> tuple:
counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
return (counts.astype(np.float64), bin_centers)
counts = counts.astype(np.float64)
if y_scale == "log":
counts = np.log10(1.0 + counts)
return (counts, bin_centers)
# ---------------------------------------------------------------------------
# LineCursors — interactive measurement cursors on any LINE plot
# ---------------------------------------------------------------------------
@register_node(display_name="Line Cursors")
class LineCursors:
"""Place two draggable cursors on any LINE plot to measure values and deltas."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("LINE",),
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
},
"optional": {
"x_axis": ("LINE",),
},
}
RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("measurement",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Place two cursors on any line plot (histogram, cross section, profile) "
"to measure positions, values, and deltas. Drag the markers to reposition."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self, line, x1: float, y1: float, x2: float, y2: float,
x_axis=None,
) -> tuple:
import io as _io
import base64
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
y = np.asarray(line, dtype=np.float64).ravel()
n = len(y)
if x_axis is not None:
x = np.asarray(x_axis, dtype=np.float64).ravel()[:n]
else:
x = np.arange(n, dtype=np.float64)
# --- Render the base plot first to determine axes bounds ---
fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100)
fig.patch.set_facecolor("#1e293b")
ax.set_facecolor("#0f172a")
ax.plot(x, y, color="#ff9800", linewidth=1.2)
ax.tick_params(colors="#94a3b8", labelsize=7)
for spine in ax.spines.values():
spine.set_color("#334155")
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
fig.tight_layout(pad=0.4)
# Force a draw so transforms are valid
fig.canvas.draw()
# Get axes position in figure-fraction coordinates
ax_pos = ax.get_position()
ax_l, ax_b = ax_pos.x0, ax_pos.y0
ax_w, ax_h = ax_pos.width, ax_pos.height
# x1/y1 arrive as image-fraction from the frontend drag.
# Convert image-fraction x → axes-fraction → nearest data index.
def img_x_to_idx(ix):
axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1)
return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1))
idx_a = img_x_to_idx(x1)
idx_b = img_x_to_idx(x2)
xa, ya = float(x[idx_a]), float(y[idx_a])
xb, yb = float(x[idx_b]), float(y[idx_b])
# --- Draw cursor lines and markers on the plot ---
ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5)
ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5)
ax.annotate(
"", xy=(xb, yb), xytext=(xa, ya),
arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5),
)
# --- Broadcast overlay ---
if LineCursors._broadcast_overlay_fn is not None:
# Convert data-space positions back to image-fraction for markers
fig.canvas.draw()
inv = fig.transFigure.inverted()
fig_a = inv.transform(ax.transData.transform([xa, ya]))
fig_b = inv.transform(ax.transData.transform([xb, yb]))
buf = _io.BytesIO()
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
LineCursors._broadcast_overlay_fn(
LineCursors._current_node_id,
{
"image": image_uri,
"x1": float(fig_a[0]),
"y1": float(1.0 - fig_a[1]), # flip: image y=0 is top
"x2": float(fig_b[0]),
"y2": float(1.0 - fig_b[1]),
"a_locked": False,
"b_locked": False,
},
)
plt.close(fig)
# --- Output table ---
table = [
{"quantity": "A position", "value": xa, "unit": ""},
{"quantity": "A value", "value": ya, "unit": ""},
{"quantity": "B position", "value": xb, "unit": ""},
{"quantity": "B value", "value": yb, "unit": ""},
{"quantity": "delta X", "value": xb - xa, "unit": ""},
{"quantity": "delta Y", "value": yb - ya, "unit": ""},
]
return (table,)
# ---------------------------------------------------------------------------
@@ -242,9 +380,9 @@ class CrossSection:
return {
"required": {
"field": ("DATA_FIELD",),
"x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x1": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"extend": (["none", "to_edges"],),
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),

View File

@@ -5,6 +5,8 @@ Gwyddion equivalents:
GaussianFilter → gwy_data_field_filter_gaussian
MedianFilter → gwy_data_field_filter_median
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
FFTFilter1D → fft_filter_1d.c (bandpass/lowpass/highpass on LINE profiles)
FFTFilter2D → fft_filter_2d.c (frequency-domain filtering of DATA_FIELDs)
"""
from __future__ import annotations
@@ -113,3 +115,190 @@ class EdgeDetect:
raise ValueError(f"Unknown edge detection method: {method}")
return (field.replace(data=result),)
# ---------------------------------------------------------------------------
# Butterworth transfer function helpers
# ---------------------------------------------------------------------------
def _butterworth_lp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
"""Butterworth lowpass: H = 1 / (1 + (f/fc)^(2n))."""
with np.errstate(divide="ignore", over="ignore"):
return 1.0 / (1.0 + (freq / cutoff) ** (2 * order))
def _butterworth_hp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
"""Butterworth highpass: H = 1 / (1 + (fc/f)^(2n))."""
with np.errstate(divide="ignore", invalid="ignore"):
h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order))
h = np.where(np.isfinite(h), h, 0.0)
return h
def _build_1d_transfer(n: int, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> np.ndarray:
"""Build a 1-D transfer function for an FFT of length *n*.
Frequencies are normalised so that 1.0 = Nyquist (fs/2).
The returned array has the same layout as np.fft.rfft output (length n//2+1).
"""
freq = np.linspace(0, 1, n // 2 + 1)
if filter_type == "lowpass":
H = _butterworth_lp(freq, cutoff, order)
elif filter_type == "highpass":
H = _butterworth_hp(freq, cutoff, order)
elif filter_type == "bandpass":
H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
elif filter_type == "notch":
bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
H = 1.0 - bp
else:
H = np.ones_like(freq)
return H
# ---------------------------------------------------------------------------
# FFTFilter1D — frequency-domain filtering of LINE profiles
# ---------------------------------------------------------------------------
@register_node(display_name="1D FFT Filter")
class FFTFilter1D:
"""Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
transfer function with configurable order for a smooth roll-off.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("LINE",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("LINE",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
CATEGORY = "filters"
DESCRIPTION = (
"Frequency-domain filtering of a 1-D line profile. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
"Equivalent to Gwyddion fft_filter_1d."
)
def process(self, line, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
z = np.asarray(line, dtype=np.float64).ravel()
n = len(z)
# Forward FFT (real-valued)
Z = np.fft.rfft(z)
# Build and apply transfer function
H = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
Z *= H
# Inverse FFT
filtered = np.fft.irfft(Z, n=n)
return (filtered,)
# ---------------------------------------------------------------------------
# FFTFilter2D — frequency-domain filtering of DATA_FIELDs
# ---------------------------------------------------------------------------
@register_node(display_name="2D FFT Filter")
class FFTFilter2D:
"""Frequency-domain filtering of 2-D data fields (images).
Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
Butterworth transfer function in the frequency domain to remove or
isolate periodic features.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
"cutoff": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"cutoff_high": ("FLOAT", {
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
}),
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
CATEGORY = "filters"
DESCRIPTION = (
"Frequency-domain filtering of a 2-D data field. "
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
"with a radial Butterworth roll-off. Cutoffs are fractions of the "
"Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
"bandpass/notch to isolate or remove periodic noise. "
"Equivalent to Gwyddion fft_filter_2d."
)
def process(self, field: DataField, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
data = field.data.copy()
yres, xres = data.shape
# Subtract mean to avoid DC leakage artefacts
mean_val = data.mean()
data -= mean_val
# Forward 2D FFT
F = np.fft.fft2(data)
F = np.fft.fftshift(F)
# Build radial frequency grid normalised to [0, 1] (1 = Nyquist)
fy = np.fft.fftshift(np.fft.fftfreq(yres)) # range [-0.5, 0.5)
fx = np.fft.fftshift(np.fft.fftfreq(xres))
FX, FY = np.meshgrid(fx, fy)
# Normalise so that corner = 1 in each axis independently,
# then take Euclidean norm; max radial value = 1.0 at Nyquist.
R = np.sqrt((FX / 0.5) ** 2 + (FY / 0.5) ** 2)
R = np.clip(R / R.max(), 0, 1) if R.max() > 0 else R
# Build transfer function
if filter_type == "lowpass":
H = _butterworth_lp(R, cutoff, order)
elif filter_type == "highpass":
H = _butterworth_hp(R, cutoff, order)
elif filter_type == "bandpass":
H = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order)
elif filter_type == "notch":
bp = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order)
H = 1.0 - bp
else:
H = np.ones_like(R)
# Apply filter
F *= H
# Inverse FFT
F = np.fft.ifftshift(F)
result = np.fft.ifft2(F).real
# Restore DC
result += mean_val
return (field.replace(data=result),)

View File

@@ -1,9 +1,8 @@
"""
Grain/feature detection nodes.
Particle detection nodes.
Gwyddion equivalents:
ThresholdMask → threshold.c / otsu_threshold.c
GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c)
ParticleAnalysis → gwy_data_field_grains_get_values (grains-values.c)
"""
from __future__ import annotations
@@ -13,61 +12,11 @@ from backend.data_types import DataField
# ---------------------------------------------------------------------------
# ThresholdMask
# ParticleAnalysis
# ---------------------------------------------------------------------------
@register_node(display_name="Threshold Mask")
class ThresholdMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["otsu", "absolute", "relative"],),
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
"direction": (["above", "below"],),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "grains"
DESCRIPTION = (
"Create a binary mask by thresholding data. "
"Otsu automatically finds the optimal threshold. "
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
)
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
data = field.data
if method == "otsu":
from skimage.filters import threshold_otsu
t = threshold_otsu(data)
elif method == "absolute":
t = float(threshold)
elif method == "relative":
# threshold is a fraction [0, 1] of the data range
dmin, dmax = data.min(), data.max()
t = dmin + float(threshold) * (dmax - dmin)
else:
raise ValueError(f"Unknown threshold method: {method}")
if direction == "above":
mask = (data >= t).astype(np.uint8) * 255
else:
mask = (data < t).astype(np.uint8) * 255
return (mask,)
# ---------------------------------------------------------------------------
# GrainAnalysis
# ---------------------------------------------------------------------------
@register_node(display_name="Grain Analysis")
class GrainAnalysis:
@register_node(display_name="Particle Analysis")
class ParticleAnalysis:
@classmethod
def INPUT_TYPES(cls):
return {
@@ -79,43 +28,43 @@ class GrainAnalysis:
}
RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("grain_stats",)
RETURN_NAMES = ("particle_stats",)
FUNCTION = "process"
CATEGORY = "grains"
DESCRIPTION = (
"Label connected grain regions in a binary mask and compute per-grain statistics: "
"area, equivalent diameter, mean/max height, bounding box. "
"Label connected particle regions in a binary mask and compute per-particle "
"statistics: area, equivalent diameter, mean/max height, bounding box. "
"Equivalent to gwy_data_field_grains_get_values."
)
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:
from scipy.ndimage import label, find_objects
from scipy.ndimage import label
binary = (mask > 127).astype(np.int32)
labeled, n_grains = label(binary)
labeled, n_particles = label(binary)
pixel_area = field.dx * field.dy # m^2 per pixel
rows = []
for grain_id in range(1, n_grains + 1):
grain_pixels = labeled == grain_id
area_px = int(grain_pixels.sum())
for pid in range(1, n_particles + 1):
particle_pixels = labeled == pid
area_px = int(particle_pixels.sum())
if area_px < min_size:
continue
area_m2 = area_px * pixel_area
equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi))
heights = field.data[grain_pixels]
heights = field.data[particle_pixels]
mean_h = float(heights.mean())
max_h = float(heights.max())
# Bounding box
ys, xs = np.where(grain_pixels)
ys, xs = np.where(particle_pixels)
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"
rows.append({
"grain_id": grain_id,
"particle_id": pid,
"area_px": area_px,
"area_m2": area_m2,
"equiv_diam_m": equiv_diam,

View File

@@ -9,12 +9,16 @@ from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview, image_to_uint8
from backend.runtime_paths import input_dir, output_dir
from backend.runtime_paths import demo_dir, input_dir, output_dir
# Resolved at server startup so nodes know where to look
DEMO_DIR = demo_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
# ---------------------------------------------------------------------------
# LoadImage
@@ -68,6 +72,81 @@ class LoadImage:
return (arr, field)
# ---------------------------------------------------------------------------
# LoadDemo
# ---------------------------------------------------------------------------
def _list_demo_files() -> list[str]:
"""Return sorted list of demo filenames available in the demo/ directory."""
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
)
@register_node(display_name="Load Demo Image")
class LoadDemo:
@classmethod
def INPUT_TYPES(cls):
choices = _list_demo_files() or ["(no demo images found)"]
return {
"required": {
"name": (choices,),
}
}
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
RETURN_NAMES = ("image", "field")
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load a bundled demo image so you can try the app without providing your own data."
def load(self, name: str):
path = DEMO_DIR / name
if not path.exists():
raise FileNotFoundError(f"Demo image not found: {name}")
ext = path.suffix.lower()
# SPM formats → delegate to LoadSPM-style loading, return as IMAGE + DATA_FIELD
if ext == ".gwy":
field = LoadSPM()._load_gwy(path, "Z")
arr = field.data
return (arr, field)
elif ext == ".sxm":
field = LoadSPM()._load_sxm(path, "Z")
arr = field.data
return (arr, field)
elif ext == ".ibw":
field = LoadSPM()._load_ibw(path)
arr = field.data
return (arr, field)
# npy / npz
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image
img = Image.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
field = DataField(data=gray)
return (arr, field)
# ---------------------------------------------------------------------------
# LoadSPM
# ---------------------------------------------------------------------------

273
backend/nodes/mask.py Normal file
View File

@@ -0,0 +1,273 @@
"""
Mask operation nodes — creation, morphology, and boolean combination.
Gwyddion equivalents:
ThresholdMask → threshold.c / otsu_threshold.c
MaskMorphology → mask_morph.c (erode, dilate, open, close)
MaskInvert → (bitwise NOT on mask)
MaskCombine → (boolean ops between two masks)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
"""Render greyscale base image with red shadow on masked (255) pixels.
Returns (H, W, 3) uint8 array.
"""
grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8
overlay = grey.astype(np.float64)
mask_bool = mask == 255
alpha = 0.45
overlay[mask_bool, 0] = overlay[mask_bool, 0] * (1 - alpha) + 255 * alpha
overlay[mask_bool, 1] = overlay[mask_bool, 1] * (1 - alpha)
overlay[mask_bool, 2] = overlay[mask_bool, 2] * (1 - alpha)
return np.clip(overlay, 0, 255).astype(np.uint8)
# ---------------------------------------------------------------------------
# ThresholdMask
# ---------------------------------------------------------------------------
@register_node(display_name="Threshold Mask")
class ThresholdMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["otsu", "absolute", "relative"],),
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
"direction": (["above", "below"],),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "mask"
DESCRIPTION = (
"Create a binary mask by thresholding data. "
"Otsu automatically finds the optimal threshold. "
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
data = field.data
if method == "otsu":
from skimage.filters import threshold_otsu
t = threshold_otsu(data)
elif method == "absolute":
t = float(threshold)
elif method == "relative":
# threshold is a fraction [0, 1] of the data range
dmin, dmax = data.min(), data.max()
t = dmin + float(threshold) * (dmax - dmin)
else:
raise ValueError(f"Unknown threshold method: {method}")
if direction == "above":
mask = (data >= t).astype(np.uint8) * 255
else:
mask = (data < t).astype(np.uint8) * 255
if ThresholdMask._broadcast_fn is not None:
overlay = _mask_overlay(field, mask)
ThresholdMask._broadcast_fn(
ThresholdMask._current_node_id, encode_preview(overlay),
)
return (mask,)
# ---------------------------------------------------------------------------
# MaskMorphology
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Morphology")
class MaskMorphology:
"""Morphological operations on binary masks.
Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close).
"""
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
"operation": (["dilate", "erode", "open", "close"],),
"radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}),
"shape": (["disk", "square"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "mask"
DESCRIPTION = (
"Apply morphological operations to a binary mask. "
"Dilate expands regions, erode shrinks them, "
"open (erode then dilate) removes small spots, "
"close (dilate then erode) fills small holes. "
"Equivalent to Gwyddion mask_morph."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, operation: str, radius: int, shape: str,
field: DataField | None = None) -> tuple:
from scipy.ndimage import binary_dilation, binary_erosion
binary = mask > 127
if shape == "disk":
y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1]
struct = (x * x + y * y) <= radius * radius
else:
size = 2 * radius + 1
struct = np.ones((size, size), dtype=bool)
if operation == "dilate":
result = binary_dilation(binary, structure=struct)
elif operation == "erode":
result = binary_erosion(binary, structure=struct)
elif operation == "open":
result = binary_dilation(
binary_erosion(binary, structure=struct),
structure=struct,
)
elif operation == "close":
result = binary_erosion(
binary_dilation(binary, structure=struct),
structure=struct,
)
else:
raise ValueError(f"Unknown morphological operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskMorphology._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskMorphology._broadcast_fn(
MaskMorphology._current_node_id, encode_preview(overlay),
)
return (out,)
# ---------------------------------------------------------------------------
# MaskInvert
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Invert")
class MaskInvert:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("IMAGE",),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "mask"
DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions."
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
out = np.where(mask > 127, np.uint8(0), np.uint8(255))
if field is not None and MaskInvert._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskInvert._broadcast_fn(
MaskInvert._current_node_id, encode_preview(overlay),
)
return (out,)
# ---------------------------------------------------------------------------
# MaskCombine
# ---------------------------------------------------------------------------
@register_node(display_name="Mask Combine")
class MaskCombine:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask_a": ("IMAGE",),
"mask_b": ("IMAGE",),
"operation": (["and", "or", "xor", "subtract"],),
},
"optional": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "mask"
DESCRIPTION = (
"Combine two binary masks with a boolean operation. "
"AND keeps overlap, OR merges, XOR keeps non-overlapping regions, "
"subtract removes mask_b from mask_a."
)
_broadcast_fn = None
_current_node_id: str = ""
def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str,
field: DataField | None = None) -> tuple:
a = mask_a > 127
b = mask_b > 127
if operation == "and":
result = a & b
elif operation == "or":
result = a | b
elif operation == "xor":
result = a ^ b
elif operation == "subtract":
result = a & ~b
else:
raise ValueError(f"Unknown mask operation: {operation}")
out = result.astype(np.uint8) * 255
if field is not None and MaskCombine._broadcast_fn is not None:
overlay = _mask_overlay(field, out)
MaskCombine._broadcast_fn(
MaskCombine._current_node_id, encode_preview(overlay),
)
return (out,)

View File

@@ -4,7 +4,7 @@ import os
import sys
from pathlib import Path
APP_NAME = "Argonode"
APP_NAME = "argonode"
def project_root() -> Path:
@@ -34,13 +34,26 @@ def app_data_dir() -> Path:
return Path(override).expanduser().resolve()
if getattr(sys, "frozen", False):
local_appdata = os.getenv("LOCALAPPDATA")
base_dir = Path(local_appdata) if local_appdata else Path.home() / "AppData" / "Local"
if sys.platform == "darwin":
base_dir = Path.home() / "Library" / "Application Support"
elif sys.platform == "linux":
xdg = os.getenv("XDG_DATA_HOME")
base_dir = Path(xdg) if xdg else Path.home() / ".local" / "share"
else:
local_appdata = os.getenv("LOCALAPPDATA")
base_dir = Path(local_appdata) if local_appdata else Path.home() / "AppData" / "Local"
return (base_dir / APP_NAME).resolve()
return project_root()
def demo_dir() -> Path:
bundled = resource_root() / "demo"
if bundled.exists():
return bundled
return project_root() / "demo"
def input_dir() -> Path:
return app_data_dir() / "input"