diff --git a/GWYDDION_FEATURE_GAP.md b/GWYDDION_FEATURE_GAP.md
index b393f82..762e41a 100644
--- a/GWYDDION_FEATURE_GAP.md
+++ b/GWYDDION_FEATURE_GAP.md
@@ -1,4 +1,4 @@
-# Gwyddion Features Not Yet in Argonode
+# Gwyddion Features Not Yet in argonode
Reference for future implementation. Grouped by value to typical SPM workflows.
@@ -11,9 +11,9 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
| 1 | Line Correction | linecorrect.c, linematch.c | Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts. |
| 2 | Scar Removal | scars.c | Detect and interpolate scan-line defects (horizontal streaks). |
| 3 | Facet Leveling | facet-level.c | Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces. |
-| 4 | Morphological Mask Ops | mask_morph.c | Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks. |
-| 5 | 1D FFT Filter | fft_filter_1d.c | Bandpass/lowpass/highpass filtering of LINE profiles. |
-| 6 | 2D FFT Filter | fft_filter_2d.c | Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.). |
+| ~~4~~ | ~~Morphological Mask Ops~~ | ~~mask_morph.c~~ | ~~Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks.~~ **DONE** |
+| ~~5~~ | ~~1D FFT Filter~~ | ~~fft_filter_1d.c~~ | ~~Bandpass/lowpass/highpass filtering of LINE profiles.~~ **DONE** |
+| ~~6~~ | ~~2D FFT Filter~~ | ~~fft_filter_2d.c~~ | ~~Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.).~~ **DONE** |
| 7 | Autocorrelation (ACF) | acf2d.c | 2D autocorrelation function. Reveals periodic structures and correlation lengths. |
| 8 | PSDF | psdf2d.c | Radial/2D power spectral density function. Complementary to ACF for roughness characterization. |
| 9 | Fractal Dimension | fractal.c | Multiple methods: partitioning, cube counting, triangulation, PSDF, HHCF. Quantifies surface complexity. |
@@ -61,11 +61,11 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
---
-## Already Implemented in Argonode
+## Already Implemented in argonode
For reference, these Gwyddion equivalents are already covered:
-| Argonode Node | Category | Gwyddion Equivalent |
+| argonode Node | Category | Gwyddion Equivalent |
|--------------|----------|-------------------|
| Load Image / Load SPM File | io | File import (gwy, sxm, ibw) |
| Save Image | io | File export |
@@ -76,12 +76,17 @@ For reference, these Gwyddion equivalents are already covered:
| Gaussian Filter | filters | filters.c (gaussian) |
| Median Filter | filters | filters.c (median) |
| Edge Detect | filters | edge.c (sobel, prewitt, laplacian, LoG) |
+| 1D FFT Filter | filters | fft_filter_1d.c (lowpass, highpass, bandpass, notch) |
+| 2D FFT Filter | filters | fft_filter_2d.c (lowpass, highpass, bandpass, notch) |
| Statistics | analysis | stats.c |
| Height Histogram | analysis | linestats.c (dh) |
| 2D FFT | analysis | fft.c |
| Cross Section | analysis | profile tool |
| Profile Roughness | analysis | roughness.c (Ra, Rq, Rsk, Rku, Rp, Rv, Rt) |
| Line Math | analysis | linestats.c |
-| Threshold Mask | grains | threshold.c, otsu_threshold.c |
-| Grain Analysis | grains | grain_stat.c |
+| Threshold Mask | mask | threshold.c, otsu_threshold.c |
+| Mask Morphology | mask | mask_morph.c (erode, dilate, open, close) |
+| Mask Invert | mask | — |
+| Mask Combine | mask | — (boolean AND, OR, XOR, subtract) |
+| Particle Analysis | grains | grain_stat.c |
| Preview / 3D View / Print Table | display | Presentation, 3D view |
diff --git a/README.md b/README.md
index f3902a5..999d091 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
-# Argonode
+# argonode
-Argonode is a node-based image analysis application with:
+argonode is a node-based image analysis application with:
- a Python backend built on `aiohttp`
- a React + Vite frontend
@@ -135,13 +135,13 @@ powershell -ExecutionPolicy Bypass -File scripts\build-desktop.ps1
The packaged app is written to:
```text
-desktop-dist/Argonode/
+desktop-dist/argonode/
```
Main executable:
```text
-desktop-dist/Argonode/Argonode.exe
+desktop-dist/argonode/argonode.exe
```
### One-File Build
@@ -161,14 +161,14 @@ During normal source-based development, input/output folders live under the repo
In the packaged desktop app, writable data is stored under:
```text
-%LOCALAPPDATA%\Argonode\
+%LOCALAPPDATA%\argonode\
```
Specifically:
```text
-%LOCALAPPDATA%\Argonode\input
-%LOCALAPPDATA%\Argonode\output
+%LOCALAPPDATA%\argonode\input
+%LOCALAPPDATA%\argonode\output
```
You can override the packaged app data directory with:
diff --git a/backend/execution.py b/backend/execution.py
index cb11b62..b9a52af 100644
--- a/backend/execution.py
+++ b/backend/execution.py
@@ -177,13 +177,19 @@ class ExecutionEngine:
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
- from backend.nodes.analysis import CrossSection
+ from backend.nodes.analysis import CrossSection, LineCursors
+ from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import SaveImage
PreviewImage._broadcast_fn = on_preview
+ ThresholdMask._broadcast_fn = on_preview
+ MaskMorphology._broadcast_fn = on_preview
+ MaskInvert._broadcast_fn = on_preview
+ MaskCombine._broadcast_fn = on_preview
View3D._broadcast_mesh_fn = on_mesh
PrintTable._broadcast_table_fn = on_table
CrossSection._broadcast_overlay_fn = on_overlay
+ LineCursors._broadcast_overlay_fn = on_overlay
SaveImage._broadcast_preview = (
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
)
@@ -191,8 +197,10 @@ class ExecutionEngine:
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
- from backend.nodes.analysis import CrossSection
- if cls in (PreviewImage, PrintTable, View3D, CrossSection):
+ from backend.nodes.analysis import CrossSection, LineCursors
+ from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
+ if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
+ ThresholdMask, MaskMorphology, MaskInvert, MaskCombine):
cls._current_node_id = node_id
def _auto_preview(
@@ -206,12 +214,16 @@ class ExecutionEngine:
"""
After every node executes, inspect its outputs and broadcast
a preview for the first DATA_FIELD, IMAGE, or TABLE found.
+ Skip nodes that broadcast their own custom preview.
"""
import numpy as np
from backend.data_types import (
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
)
+ if getattr(cls, "_CUSTOM_PREVIEW", False):
+ return
+
return_types = getattr(cls, "RETURN_TYPES", ())
for slot, type_name in enumerate(return_types):
diff --git a/backend/main.py b/backend/main.py
index 657bd14..ad8bc28 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -36,7 +36,7 @@ def main() -> None:
app = create_app(loop)
log.info("=" * 60)
- log.info(" Argonode — Node-based image analysis")
+ log.info(" argonode — Node-based image analysis")
log.info(" Open your browser at http://%s:%d", HOST, PORT)
log.info("=" * 60)
diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py
index 588daa7..cae9bb6 100644
--- a/backend/nodes/__init__.py
+++ b/backend/nodes/__init__.py
@@ -1,2 +1,2 @@
# Import all node modules to trigger @register_node decorators.
-from . import io, filters, level, analysis, grains, display
+from . import io, filters, level, analysis, grains, mask, display
diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py
index d112a95..60c6b30 100644
--- a/backend/nodes/analysis.py
+++ b/backend/nodes/analysis.py
@@ -69,6 +69,7 @@ class HeightHistogram:
"required": {
"field": ("DATA_FIELD",),
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
+ "y_scale": (["linear", "log"],),
}
}
@@ -78,13 +79,150 @@ class HeightHistogram:
CATEGORY = "analysis"
DESCRIPTION = (
"Compute the height distribution histogram (DH). "
+ "Use log scale to reveal small peaks next to a dominant background. "
"Equivalent to gwy_data_field_dh."
)
- def process(self, field: DataField, n_bins: int) -> tuple:
+ def process(self, field: DataField, n_bins: int, y_scale: str = "linear") -> tuple:
counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
- return (counts.astype(np.float64), bin_centers)
+ counts = counts.astype(np.float64)
+ if y_scale == "log":
+ counts = np.log10(1.0 + counts)
+ return (counts, bin_centers)
+
+
+# ---------------------------------------------------------------------------
+# LineCursors — interactive measurement cursors on any LINE plot
+# ---------------------------------------------------------------------------
+
+@register_node(display_name="Line Cursors")
+class LineCursors:
+ """Place two draggable cursors on any LINE plot to measure values and deltas."""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "line": ("LINE",),
+ "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
+ "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
+ "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
+ "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
+ },
+ "optional": {
+ "x_axis": ("LINE",),
+ },
+ }
+
+ RETURN_TYPES = ("TABLE",)
+ RETURN_NAMES = ("measurement",)
+ FUNCTION = "process"
+ CATEGORY = "analysis"
+ DESCRIPTION = (
+ "Place two cursors on any line plot (histogram, cross section, profile) "
+ "to measure positions, values, and deltas. Drag the markers to reposition."
+ )
+
+ _broadcast_overlay_fn = None
+ _current_node_id: str = ""
+
+ def process(
+ self, line, x1: float, y1: float, x2: float, y2: float,
+ x_axis=None,
+ ) -> tuple:
+ import io as _io
+ import base64
+ import matplotlib
+ matplotlib.use("Agg")
+ import matplotlib.pyplot as plt
+
+ y = np.asarray(line, dtype=np.float64).ravel()
+ n = len(y)
+ if x_axis is not None:
+ x = np.asarray(x_axis, dtype=np.float64).ravel()[:n]
+ else:
+ x = np.arange(n, dtype=np.float64)
+
+ # --- Render the base plot first to determine axes bounds ---
+ fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100)
+ fig.patch.set_facecolor("#1e293b")
+ ax.set_facecolor("#0f172a")
+ ax.plot(x, y, color="#ff9800", linewidth=1.2)
+ ax.tick_params(colors="#94a3b8", labelsize=7)
+ for spine in ax.spines.values():
+ spine.set_color("#334155")
+ ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
+ fig.tight_layout(pad=0.4)
+
+ # Force a draw so transforms are valid
+ fig.canvas.draw()
+
+ # Get axes position in figure-fraction coordinates
+ ax_pos = ax.get_position()
+ ax_l, ax_b = ax_pos.x0, ax_pos.y0
+ ax_w, ax_h = ax_pos.width, ax_pos.height
+
+ # x1/y1 arrive as image-fraction from the frontend drag.
+ # Convert image-fraction x → axes-fraction → nearest data index.
+ def img_x_to_idx(ix):
+ axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1)
+ return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1))
+
+ idx_a = img_x_to_idx(x1)
+ idx_b = img_x_to_idx(x2)
+
+ xa, ya = float(x[idx_a]), float(y[idx_a])
+ xb, yb = float(x[idx_b]), float(y[idx_b])
+
+ # --- Draw cursor lines and markers on the plot ---
+ ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
+ ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
+ ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5)
+ ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5)
+ ax.annotate(
+ "", xy=(xb, yb), xytext=(xa, ya),
+ arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5),
+ )
+
+ # --- Broadcast overlay ---
+ if LineCursors._broadcast_overlay_fn is not None:
+ # Convert data-space positions back to image-fraction for markers
+ fig.canvas.draw()
+ inv = fig.transFigure.inverted()
+ fig_a = inv.transform(ax.transData.transform([xa, ya]))
+ fig_b = inv.transform(ax.transData.transform([xb, yb]))
+
+ buf = _io.BytesIO()
+ fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
+ buf.seek(0)
+ image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
+
+ LineCursors._broadcast_overlay_fn(
+ LineCursors._current_node_id,
+ {
+ "image": image_uri,
+ "x1": float(fig_a[0]),
+ "y1": float(1.0 - fig_a[1]), # flip: image y=0 is top
+ "x2": float(fig_b[0]),
+ "y2": float(1.0 - fig_b[1]),
+ "a_locked": False,
+ "b_locked": False,
+ },
+ )
+
+ plt.close(fig)
+
+ # --- Output table ---
+ table = [
+ {"quantity": "A position", "value": xa, "unit": ""},
+ {"quantity": "A value", "value": ya, "unit": ""},
+ {"quantity": "B position", "value": xb, "unit": ""},
+ {"quantity": "B value", "value": yb, "unit": ""},
+ {"quantity": "delta X", "value": xb - xa, "unit": ""},
+ {"quantity": "delta Y", "value": yb - ya, "unit": ""},
+ ]
+ return (table,)
# ---------------------------------------------------------------------------
@@ -242,9 +380,9 @@ class CrossSection:
return {
"required": {
"field": ("DATA_FIELD",),
- "x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
+ "x1": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
- "x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
+ "x2": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"extend": (["none", "to_edges"],),
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
diff --git a/backend/nodes/filters.py b/backend/nodes/filters.py
index 783a04b..8a5f872 100644
--- a/backend/nodes/filters.py
+++ b/backend/nodes/filters.py
@@ -5,6 +5,8 @@ Gwyddion equivalents:
GaussianFilter → gwy_data_field_filter_gaussian
MedianFilter → gwy_data_field_filter_median
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
+ FFTFilter1D → fft_filter_1d.c (bandpass/lowpass/highpass on LINE profiles)
+ FFTFilter2D → fft_filter_2d.c (frequency-domain filtering of DATA_FIELDs)
"""
from __future__ import annotations
@@ -113,3 +115,190 @@ class EdgeDetect:
raise ValueError(f"Unknown edge detection method: {method}")
return (field.replace(data=result),)
+
+
+# ---------------------------------------------------------------------------
+# Butterworth transfer function helpers
+# ---------------------------------------------------------------------------
+
+def _butterworth_lp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
+ """Butterworth lowpass: H = 1 / (1 + (f/fc)^(2n))."""
+ with np.errstate(divide="ignore", over="ignore"):
+ return 1.0 / (1.0 + (freq / cutoff) ** (2 * order))
+
+
+def _butterworth_hp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
+ """Butterworth highpass: H = 1 / (1 + (fc/f)^(2n))."""
+ with np.errstate(divide="ignore", invalid="ignore"):
+ h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order))
+ h = np.where(np.isfinite(h), h, 0.0)
+ return h
+
+
+def _build_1d_transfer(n: int, filter_type: str, cutoff: float,
+ cutoff_high: float, order: int) -> np.ndarray:
+ """Build a 1-D transfer function for an FFT of length *n*.
+
+ Frequencies are normalised so that 1.0 = Nyquist (fs/2).
+ The returned array has the same layout as np.fft.rfft output (length n//2+1).
+ """
+ freq = np.linspace(0, 1, n // 2 + 1)
+
+ if filter_type == "lowpass":
+ H = _butterworth_lp(freq, cutoff, order)
+ elif filter_type == "highpass":
+ H = _butterworth_hp(freq, cutoff, order)
+ elif filter_type == "bandpass":
+ H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
+ elif filter_type == "notch":
+ bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
+ H = 1.0 - bp
+ else:
+ H = np.ones_like(freq)
+ return H
+
+
+# ---------------------------------------------------------------------------
+# FFTFilter1D — frequency-domain filtering of LINE profiles
+# ---------------------------------------------------------------------------
+
+@register_node(display_name="1D FFT Filter")
+class FFTFilter1D:
+ """Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
+
+ Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
+ transfer function with configurable order for a smooth roll-off.
+ """
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "line": ("LINE",),
+ "filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
+ "cutoff": ("FLOAT", {
+ "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
+ }),
+ "cutoff_high": ("FLOAT", {
+ "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
+ }),
+ "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
+ }
+ }
+
+ RETURN_TYPES = ("LINE",)
+ RETURN_NAMES = ("filtered",)
+ FUNCTION = "process"
+ CATEGORY = "filters"
+ DESCRIPTION = (
+ "Frequency-domain filtering of a 1-D line profile. "
+ "Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
+ "with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
+ "Equivalent to Gwyddion fft_filter_1d."
+ )
+
+ def process(self, line, filter_type: str, cutoff: float,
+ cutoff_high: float, order: int) -> tuple:
+ z = np.asarray(line, dtype=np.float64).ravel()
+ n = len(z)
+
+ # Forward FFT (real-valued)
+ Z = np.fft.rfft(z)
+
+ # Build and apply transfer function
+ H = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
+ Z *= H
+
+ # Inverse FFT
+ filtered = np.fft.irfft(Z, n=n)
+ return (filtered,)
+
+
+# ---------------------------------------------------------------------------
+# FFTFilter2D — frequency-domain filtering of DATA_FIELDs
+# ---------------------------------------------------------------------------
+
+@register_node(display_name="2D FFT Filter")
+class FFTFilter2D:
+ """Frequency-domain filtering of 2-D data fields (images).
+
+ Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
+ Butterworth transfer function in the frequency domain to remove or
+ isolate periodic features.
+ """
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "field": ("DATA_FIELD",),
+ "filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
+ "cutoff": ("FLOAT", {
+ "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
+ }),
+ "cutoff_high": ("FLOAT", {
+ "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
+ }),
+ "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
+ }
+ }
+
+ RETURN_TYPES = ("DATA_FIELD",)
+ RETURN_NAMES = ("filtered",)
+ FUNCTION = "process"
+ CATEGORY = "filters"
+ DESCRIPTION = (
+ "Frequency-domain filtering of a 2-D data field. "
+ "Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
+ "with a radial Butterworth roll-off. Cutoffs are fractions of the "
+ "Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
+ "bandpass/notch to isolate or remove periodic noise. "
+ "Equivalent to Gwyddion fft_filter_2d."
+ )
+
+ def process(self, field: DataField, filter_type: str, cutoff: float,
+ cutoff_high: float, order: int) -> tuple:
+ data = field.data.copy()
+ yres, xres = data.shape
+
+ # Subtract mean to avoid DC leakage artefacts
+ mean_val = data.mean()
+ data -= mean_val
+
+ # Forward 2D FFT
+ F = np.fft.fft2(data)
+ F = np.fft.fftshift(F)
+
+ # Build radial frequency grid normalised to [0, 1] (1 = Nyquist)
+ fy = np.fft.fftshift(np.fft.fftfreq(yres)) # range [-0.5, 0.5)
+ fx = np.fft.fftshift(np.fft.fftfreq(xres))
+ FX, FY = np.meshgrid(fx, fy)
+ # Normalise so that corner = 1 in each axis independently,
+ # then take Euclidean norm; max radial value = 1.0 at Nyquist.
+ R = np.sqrt((FX / 0.5) ** 2 + (FY / 0.5) ** 2)
+ R = np.clip(R / R.max(), 0, 1) if R.max() > 0 else R
+
+ # Build transfer function
+ if filter_type == "lowpass":
+ H = _butterworth_lp(R, cutoff, order)
+ elif filter_type == "highpass":
+ H = _butterworth_hp(R, cutoff, order)
+ elif filter_type == "bandpass":
+ H = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order)
+ elif filter_type == "notch":
+ bp = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order)
+ H = 1.0 - bp
+ else:
+ H = np.ones_like(R)
+
+ # Apply filter
+ F *= H
+
+ # Inverse FFT
+ F = np.fft.ifftshift(F)
+ result = np.fft.ifft2(F).real
+
+ # Restore DC
+ result += mean_val
+
+ return (field.replace(data=result),)
diff --git a/backend/nodes/grains.py b/backend/nodes/grains.py
index 6228955..7ef5219 100644
--- a/backend/nodes/grains.py
+++ b/backend/nodes/grains.py
@@ -1,9 +1,8 @@
"""
-Grain/feature detection nodes.
+Particle detection nodes.
Gwyddion equivalents:
- ThresholdMask → threshold.c / otsu_threshold.c
- GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c)
+ ParticleAnalysis → gwy_data_field_grains_get_values (grains-values.c)
"""
from __future__ import annotations
@@ -13,61 +12,11 @@ from backend.data_types import DataField
# ---------------------------------------------------------------------------
-# ThresholdMask
+# ParticleAnalysis
# ---------------------------------------------------------------------------
-@register_node(display_name="Threshold Mask")
-class ThresholdMask:
- @classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "field": ("DATA_FIELD",),
- "method": (["otsu", "absolute", "relative"],),
- "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
- "direction": (["above", "below"],),
- }
- }
-
- RETURN_TYPES = ("IMAGE",)
- RETURN_NAMES = ("mask",)
- FUNCTION = "process"
- CATEGORY = "grains"
- DESCRIPTION = (
- "Create a binary mask by thresholding data. "
- "Otsu automatically finds the optimal threshold. "
- "Equivalent to Gwyddion's threshold and otsu_threshold modules."
- )
-
- def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
- data = field.data
-
- if method == "otsu":
- from skimage.filters import threshold_otsu
- t = threshold_otsu(data)
- elif method == "absolute":
- t = float(threshold)
- elif method == "relative":
- # threshold is a fraction [0, 1] of the data range
- dmin, dmax = data.min(), data.max()
- t = dmin + float(threshold) * (dmax - dmin)
- else:
- raise ValueError(f"Unknown threshold method: {method}")
-
- if direction == "above":
- mask = (data >= t).astype(np.uint8) * 255
- else:
- mask = (data < t).astype(np.uint8) * 255
-
- return (mask,)
-
-
-# ---------------------------------------------------------------------------
-# GrainAnalysis
-# ---------------------------------------------------------------------------
-
-@register_node(display_name="Grain Analysis")
-class GrainAnalysis:
+@register_node(display_name="Particle Analysis")
+class ParticleAnalysis:
@classmethod
def INPUT_TYPES(cls):
return {
@@ -79,43 +28,43 @@ class GrainAnalysis:
}
RETURN_TYPES = ("TABLE",)
- RETURN_NAMES = ("grain_stats",)
+ RETURN_NAMES = ("particle_stats",)
FUNCTION = "process"
CATEGORY = "grains"
DESCRIPTION = (
- "Label connected grain regions in a binary mask and compute per-grain statistics: "
- "area, equivalent diameter, mean/max height, bounding box. "
+ "Label connected particle regions in a binary mask and compute per-particle "
+ "statistics: area, equivalent diameter, mean/max height, bounding box. "
"Equivalent to gwy_data_field_grains_get_values."
)
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:
- from scipy.ndimage import label, find_objects
+ from scipy.ndimage import label
binary = (mask > 127).astype(np.int32)
- labeled, n_grains = label(binary)
+ labeled, n_particles = label(binary)
pixel_area = field.dx * field.dy # m^2 per pixel
rows = []
- for grain_id in range(1, n_grains + 1):
- grain_pixels = labeled == grain_id
- area_px = int(grain_pixels.sum())
+ for pid in range(1, n_particles + 1):
+ particle_pixels = labeled == pid
+ area_px = int(particle_pixels.sum())
if area_px < min_size:
continue
area_m2 = area_px * pixel_area
equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi))
- heights = field.data[grain_pixels]
+ heights = field.data[particle_pixels]
mean_h = float(heights.mean())
max_h = float(heights.max())
# Bounding box
- ys, xs = np.where(grain_pixels)
+ ys, xs = np.where(particle_pixels)
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"
rows.append({
- "grain_id": grain_id,
+ "particle_id": pid,
"area_px": area_px,
"area_m2": area_m2,
"equiv_diam_m": equiv_diam,
diff --git a/backend/nodes/io.py b/backend/nodes/io.py
index 40cfae9..7a44f46 100644
--- a/backend/nodes/io.py
+++ b/backend/nodes/io.py
@@ -9,12 +9,16 @@ from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview, image_to_uint8
-from backend.runtime_paths import input_dir, output_dir
+from backend.runtime_paths import demo_dir, input_dir, output_dir
# Resolved at server startup so nodes know where to look
+DEMO_DIR = demo_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
+_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
+ ".gwy", ".sxm", ".ibw"}
+
# ---------------------------------------------------------------------------
# LoadImage
@@ -68,6 +72,81 @@ class LoadImage:
return (arr, field)
+# ---------------------------------------------------------------------------
+# LoadDemo
+# ---------------------------------------------------------------------------
+
+def _list_demo_files() -> list[str]:
+ """Return sorted list of demo filenames available in the demo/ directory."""
+ if not DEMO_DIR.exists():
+ return []
+ return sorted(
+ f.name for f in DEMO_DIR.iterdir()
+ if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
+ )
+
+
+@register_node(display_name="Load Demo Image")
+class LoadDemo:
+ @classmethod
+ def INPUT_TYPES(cls):
+ choices = _list_demo_files() or ["(no demo images found)"]
+ return {
+ "required": {
+ "name": (choices,),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE", "DATA_FIELD")
+ RETURN_NAMES = ("image", "field")
+ FUNCTION = "load"
+ CATEGORY = "io"
+ DESCRIPTION = "Load a bundled demo image so you can try the app without providing your own data."
+
+ def load(self, name: str):
+ path = DEMO_DIR / name
+ if not path.exists():
+ raise FileNotFoundError(f"Demo image not found: {name}")
+
+ ext = path.suffix.lower()
+
+ # SPM formats → delegate to LoadSPM-style loading, return as IMAGE + DATA_FIELD
+ if ext == ".gwy":
+ field = LoadSPM()._load_gwy(path, "Z")
+ arr = field.data
+ return (arr, field)
+ elif ext == ".sxm":
+ field = LoadSPM()._load_sxm(path, "Z")
+ arr = field.data
+ return (arr, field)
+ elif ext == ".ibw":
+ field = LoadSPM()._load_ibw(path)
+ arr = field.data
+ return (arr, field)
+
+ # npy / npz
+ if ext == ".npy":
+ arr = np.load(str(path)).astype(np.float64)
+ elif ext == ".npz":
+ npz = np.load(str(path))
+ key = list(npz.files)[0]
+ arr = npz[key].astype(np.float64)
+ else:
+ from PIL import Image
+ img = Image.open(str(path))
+ arr = np.array(img)
+ if arr.dtype != np.uint8:
+ arr = arr.astype(np.float64)
+
+ if arr.ndim == 3:
+ gray = np.mean(arr.astype(np.float64), axis=2)
+ else:
+ gray = arr.astype(np.float64)
+
+ field = DataField(data=gray)
+ return (arr, field)
+
+
# ---------------------------------------------------------------------------
# LoadSPM
# ---------------------------------------------------------------------------
diff --git a/backend/nodes/mask.py b/backend/nodes/mask.py
new file mode 100644
index 0000000..9203488
--- /dev/null
+++ b/backend/nodes/mask.py
@@ -0,0 +1,273 @@
+"""
+Mask operation nodes — creation, morphology, and boolean combination.
+
+Gwyddion equivalents:
+ ThresholdMask → threshold.c / otsu_threshold.c
+ MaskMorphology → mask_morph.c (erode, dilate, open, close)
+ MaskInvert → (bitwise NOT on mask)
+ MaskCombine → (boolean ops between two masks)
+"""
+
+from __future__ import annotations
+import numpy as np
+from backend.node_registry import register_node
+from backend.data_types import DataField, datafield_to_uint8, encode_preview
+
+
+def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
+ """Render greyscale base image with red shadow on masked (255) pixels.
+
+ Returns (H, W, 3) uint8 array.
+ """
+ grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8
+ overlay = grey.astype(np.float64)
+ mask_bool = mask == 255
+ alpha = 0.45
+ overlay[mask_bool, 0] = overlay[mask_bool, 0] * (1 - alpha) + 255 * alpha
+ overlay[mask_bool, 1] = overlay[mask_bool, 1] * (1 - alpha)
+ overlay[mask_bool, 2] = overlay[mask_bool, 2] * (1 - alpha)
+ return np.clip(overlay, 0, 255).astype(np.uint8)
+
+
+# ---------------------------------------------------------------------------
+# ThresholdMask
+# ---------------------------------------------------------------------------
+
+@register_node(display_name="Threshold Mask")
+class ThresholdMask:
+ _CUSTOM_PREVIEW = True
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "field": ("DATA_FIELD",),
+ "method": (["otsu", "absolute", "relative"],),
+ "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
+ "direction": (["above", "below"],),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("mask",)
+ FUNCTION = "process"
+ CATEGORY = "mask"
+ DESCRIPTION = (
+ "Create a binary mask by thresholding data. "
+ "Otsu automatically finds the optimal threshold. "
+ "Equivalent to Gwyddion's threshold and otsu_threshold modules."
+ )
+
+ _broadcast_fn = None
+ _current_node_id: str = ""
+
+ def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
+ data = field.data
+
+ if method == "otsu":
+ from skimage.filters import threshold_otsu
+ t = threshold_otsu(data)
+ elif method == "absolute":
+ t = float(threshold)
+ elif method == "relative":
+ # threshold is a fraction [0, 1] of the data range
+ dmin, dmax = data.min(), data.max()
+ t = dmin + float(threshold) * (dmax - dmin)
+ else:
+ raise ValueError(f"Unknown threshold method: {method}")
+
+ if direction == "above":
+ mask = (data >= t).astype(np.uint8) * 255
+ else:
+ mask = (data < t).astype(np.uint8) * 255
+
+ if ThresholdMask._broadcast_fn is not None:
+ overlay = _mask_overlay(field, mask)
+ ThresholdMask._broadcast_fn(
+ ThresholdMask._current_node_id, encode_preview(overlay),
+ )
+
+ return (mask,)
+
+
+# ---------------------------------------------------------------------------
+# MaskMorphology
+# ---------------------------------------------------------------------------
+
+@register_node(display_name="Mask Morphology")
+class MaskMorphology:
+ """Morphological operations on binary masks.
+
+ Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close).
+ """
+ _CUSTOM_PREVIEW = True
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "mask": ("IMAGE",),
+ "operation": (["dilate", "erode", "open", "close"],),
+ "radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}),
+ "shape": (["disk", "square"],),
+ },
+ "optional": {
+ "field": ("DATA_FIELD",),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("mask",)
+ FUNCTION = "process"
+ CATEGORY = "mask"
+ DESCRIPTION = (
+ "Apply morphological operations to a binary mask. "
+ "Dilate expands regions, erode shrinks them, "
+ "open (erode then dilate) removes small spots, "
+ "close (dilate then erode) fills small holes. "
+ "Equivalent to Gwyddion mask_morph."
+ )
+
+ _broadcast_fn = None
+ _current_node_id: str = ""
+
+ def process(self, mask: np.ndarray, operation: str, radius: int, shape: str,
+ field: DataField | None = None) -> tuple:
+ from scipy.ndimage import binary_dilation, binary_erosion
+
+ binary = mask > 127
+
+ if shape == "disk":
+ y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1]
+ struct = (x * x + y * y) <= radius * radius
+ else:
+ size = 2 * radius + 1
+ struct = np.ones((size, size), dtype=bool)
+
+ if operation == "dilate":
+ result = binary_dilation(binary, structure=struct)
+ elif operation == "erode":
+ result = binary_erosion(binary, structure=struct)
+ elif operation == "open":
+ result = binary_dilation(
+ binary_erosion(binary, structure=struct),
+ structure=struct,
+ )
+ elif operation == "close":
+ result = binary_erosion(
+ binary_dilation(binary, structure=struct),
+ structure=struct,
+ )
+ else:
+ raise ValueError(f"Unknown morphological operation: {operation}")
+
+ out = result.astype(np.uint8) * 255
+
+ if field is not None and MaskMorphology._broadcast_fn is not None:
+ overlay = _mask_overlay(field, out)
+ MaskMorphology._broadcast_fn(
+ MaskMorphology._current_node_id, encode_preview(overlay),
+ )
+
+ return (out,)
+
+
+# ---------------------------------------------------------------------------
+# MaskInvert
+# ---------------------------------------------------------------------------
+
+@register_node(display_name="Mask Invert")
+class MaskInvert:
+ _CUSTOM_PREVIEW = True
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "mask": ("IMAGE",),
+ },
+ "optional": {
+ "field": ("DATA_FIELD",),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("mask",)
+ FUNCTION = "process"
+ CATEGORY = "mask"
+ DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions."
+
+ _broadcast_fn = None
+ _current_node_id: str = ""
+
+ def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
+ out = np.where(mask > 127, np.uint8(0), np.uint8(255))
+
+ if field is not None and MaskInvert._broadcast_fn is not None:
+ overlay = _mask_overlay(field, out)
+ MaskInvert._broadcast_fn(
+ MaskInvert._current_node_id, encode_preview(overlay),
+ )
+
+ return (out,)
+
+
+# ---------------------------------------------------------------------------
+# MaskCombine
+# ---------------------------------------------------------------------------
+
+@register_node(display_name="Mask Combine")
+class MaskCombine:
+ _CUSTOM_PREVIEW = True
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "mask_a": ("IMAGE",),
+ "mask_b": ("IMAGE",),
+ "operation": (["and", "or", "xor", "subtract"],),
+ },
+ "optional": {
+ "field": ("DATA_FIELD",),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("mask",)
+ FUNCTION = "process"
+ CATEGORY = "mask"
+ DESCRIPTION = (
+ "Combine two binary masks with a boolean operation. "
+ "AND keeps overlap, OR merges, XOR keeps non-overlapping regions, "
+ "subtract removes mask_b from mask_a."
+ )
+
+ _broadcast_fn = None
+ _current_node_id: str = ""
+
+ def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str,
+ field: DataField | None = None) -> tuple:
+ a = mask_a > 127
+ b = mask_b > 127
+
+ if operation == "and":
+ result = a & b
+ elif operation == "or":
+ result = a | b
+ elif operation == "xor":
+ result = a ^ b
+ elif operation == "subtract":
+ result = a & ~b
+ else:
+ raise ValueError(f"Unknown mask operation: {operation}")
+
+ out = result.astype(np.uint8) * 255
+
+ if field is not None and MaskCombine._broadcast_fn is not None:
+ overlay = _mask_overlay(field, out)
+ MaskCombine._broadcast_fn(
+ MaskCombine._current_node_id, encode_preview(overlay),
+ )
+
+ return (out,)
diff --git a/backend/runtime_paths.py b/backend/runtime_paths.py
index bab42db..7ddbe72 100644
--- a/backend/runtime_paths.py
+++ b/backend/runtime_paths.py
@@ -4,7 +4,7 @@ import os
import sys
from pathlib import Path
-APP_NAME = "Argonode"
+APP_NAME = "argonode"
def project_root() -> Path:
@@ -34,13 +34,26 @@ def app_data_dir() -> Path:
return Path(override).expanduser().resolve()
if getattr(sys, "frozen", False):
- local_appdata = os.getenv("LOCALAPPDATA")
- base_dir = Path(local_appdata) if local_appdata else Path.home() / "AppData" / "Local"
+ if sys.platform == "darwin":
+ base_dir = Path.home() / "Library" / "Application Support"
+ elif sys.platform == "linux":
+ xdg = os.getenv("XDG_DATA_HOME")
+ base_dir = Path(xdg) if xdg else Path.home() / ".local" / "share"
+ else:
+ local_appdata = os.getenv("LOCALAPPDATA")
+ base_dir = Path(local_appdata) if local_appdata else Path.home() / "AppData" / "Local"
return (base_dir / APP_NAME).resolve()
return project_root()
+def demo_dir() -> Path:
+ bundled = resource_root() / "demo"
+ if bundled.exists():
+ return bundled
+ return project_root() / "demo"
+
+
def input_dir() -> Path:
return app_data_dir() / "input"
diff --git a/demo/.gitkeep b/demo/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/demo/nanoparticles.npy b/demo/nanoparticles.npy
new file mode 100644
index 0000000..d679672
Binary files /dev/null and b/demo/nanoparticles.npy differ
diff --git a/demo/nanoparticles.png b/demo/nanoparticles.png
new file mode 100644
index 0000000..dda6204
Binary files /dev/null and b/demo/nanoparticles.png differ
diff --git a/desktop.py b/desktop.py
index 4c70508..dd9145b 100644
--- a/desktop.py
+++ b/desktop.py
@@ -14,7 +14,34 @@ from backend.runtime_paths import app_data_dir, ensure_runtime_dirs
from backend.server import create_app
HOST = "127.0.0.1"
-WINDOW_TITLE = "Argonode"
+WINDOW_TITLE = "argonode"
+
+
+class _Api:
+ """Exposed to JavaScript as window.pywebview.api."""
+
+ def __init__(self, window_ref: list):
+ self._window_ref = window_ref
+
+ def open_file_dialog(self) -> str | None:
+ """Open a native file picker and return the selected path (or None)."""
+ win = self._window_ref[0]
+ if win is None:
+ return None
+ result = win.create_file_dialog(
+ webview.OPEN_DIALOG,
+ allow_multiple=False,
+ file_types=(
+ "All supported (*.png;*.jpg;*.jpeg;*.tiff;*.tif;*.npy;*.npz;*.gwy;*.sxm;*.ibw)",
+ "Images (*.png;*.jpg;*.jpeg;*.tiff;*.tif)",
+ "NumPy (*.npy;*.npz)",
+ "SPM (*.gwy;*.sxm;*.ibw)",
+ "All files (*.*)",
+ ),
+ )
+ if result and len(result) > 0:
+ return result[0]
+ return None
def _pick_free_port() -> int:
@@ -85,17 +112,22 @@ def main() -> None:
ready.wait(timeout=15.0)
if "error" in state:
- raise RuntimeError("Argonode server failed to start") from state["error"]
+ raise RuntimeError("argonode server failed to start") from state["error"]
_wait_for_server(f"{base_url}/nodes")
+ window_ref: list[webview.Window | None] = [None]
+ js_api = _Api(window_ref)
+
window = webview.create_window(
WINDOW_TITLE,
base_url,
width=1600,
height=1000,
min_size=(1100, 720),
+ js_api=js_api,
)
+ window_ref[0] = window
def _shutdown() -> None:
loop = state.get("loop")
diff --git a/frontend/index.html b/frontend/index.html
index 7418eec..146d241 100644
--- a/frontend/index.html
+++ b/frontend/index.html
@@ -3,7 +3,7 @@
- Argonode — Image Analysis
+ argonode — Image Analysis
diff --git a/frontend/package-lock.json b/frontend/package-lock.json
index 471f968..98af0fe 100644
--- a/frontend/package-lock.json
+++ b/frontend/package-lock.json
@@ -7,6 +7,7 @@
"name": "argonode-frontend",
"dependencies": {
"@xyflow/react": "^12.0.0",
+ "html-to-image": "^1.11.13",
"react": "^18.3.0",
"react-dom": "^18.3.0",
"three": "^0.183.2"
@@ -1539,6 +1540,12 @@
"node": ">=6.9.0"
}
},
+ "node_modules/html-to-image": {
+ "version": "1.11.13",
+ "resolved": "https://registry.npmjs.org/html-to-image/-/html-to-image-1.11.13.tgz",
+ "integrity": "sha512-cuOPoI7WApyhBElTTb9oqsawRvZ0rHhaHwghRLlTuffoD1B2aDemlCruLeZrUIIdvG7gs9xeELEPm6PhuASqrg==",
+ "license": "MIT"
+ },
"node_modules/js-tokens": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",
diff --git a/frontend/package.json b/frontend/package.json
index b489587..2082038 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -13,6 +13,7 @@
},
"dependencies": {
"@xyflow/react": "^12.0.0",
+ "html-to-image": "^1.11.13",
"react": "^18.3.0",
"react-dom": "^18.3.0",
"three": "^0.183.2"
diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx
index a75acc7..5a818a4 100644
--- a/frontend/src/App.jsx
+++ b/frontend/src/App.jsx
@@ -4,24 +4,26 @@ import React, {
import {
ReactFlow, Background, Controls, MiniMap,
useNodesState, useEdgesState, addEdge, useReactFlow,
- ReactFlowProvider,
+ ReactFlowProvider, getNodesBounds, getViewportForBounds,
} from '@xyflow/react';
import '@xyflow/react/dist/style.css';
import CustomNode, { NodeContext } from './CustomNode';
import FileBrowser from './FileBrowser';
import * as api from './api';
+import { toBlob } from 'html-to-image';
+import { embedWorkflow, extractWorkflow } from './pngMetadata';
// ── Constants ─────────────────────────────────────────────────────────
const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']);
const TYPE_COLORS = {
- DATA_FIELD: '#3a7abf',
- IMAGE: '#4caf50',
- LINE: '#ff9800',
- TABLE: '#fdd835',
- COORD: '#e91e63',
+ DATA_FIELD: '#ff002f',
+ IMAGE: '#00ff08a0',
+ LINE: '#ffbe5c',
+ TABLE: '#35e2fd',
+ COORD: '#e91ed1',
};
const NODE_TYPES = { custom: CustomNode };
@@ -272,6 +274,13 @@ function Flow() {
// ── File browser ────────────────────────────────────────────────────
const openFileBrowser = useCallback((callback) => {
+ // Use native file picker when running inside pywebview (desktop app)
+ if (window.pywebview?.api?.open_file_dialog) {
+ window.pywebview.api.open_file_dialog().then((path) => {
+ if (path) callback(path);
+ });
+ return;
+ }
setFileBrowserCb(() => callback);
}, []);
@@ -427,60 +436,162 @@ function Flow() {
setStatus({ text: 'Graph cleared.', level: 'info' });
}, [setNodes, setEdges]);
- const saveWorkflow = useCallback(() => {
- const currentNodes = reactFlow.getNodes().map((n) => ({
+ const applyWorkflowData = useCallback((data) => {
+ const loadedNodes = data.nodes || [];
+ const loadedEdges = data.edges || [];
+ const defs = nodeDefsRef.current;
+ const hydrated = loadedNodes.map((n) => ({
+ ...n,
+ data: {
+ ...n.data,
+ definition: defs[n.data.className] || n.data.definition,
+ previewImage: null, tableRows: null, meshData: null, overlay: null,
+ },
+ }));
+ setNodes(hydrated);
+ setEdges(loadedEdges);
+ const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
+ nextIdRef.current = maxId + 1;
+ }, [setNodes, setEdges]);
+
+ const getWorkflowBlob = useCallback(async () => {
+ const viewportEl = document.querySelector('.react-flow__viewport');
+ if (!viewportEl) throw new Error('Flow element not found');
+
+ const allNodes = reactFlow.getNodes();
+ if (allNodes.length === 0) throw new Error('No nodes to capture');
+
+ const bounds = getNodesBounds(allNodes);
+ const pad = 0.1; // 10% margin on each side
+ const imageWidth = Math.ceil(bounds.width * (1 + pad * 2));
+ const imageHeight = Math.ceil(bounds.height * (1 + pad * 2));
+ const vp = getViewportForBounds(bounds, imageWidth, imageHeight, 0.5, 1, pad);
+
+ const blob = await toBlob(viewportEl, {
+ backgroundColor: '#1a1a1a',
+ width: imageWidth,
+ height: imageHeight,
+ style: {
+ width: `${imageWidth}px`,
+ height: `${imageHeight}px`,
+ transform: `translate(${vp.x}px, ${vp.y}px) scale(${vp.zoom})`,
+ },
+ });
+ if (!blob) throw new Error('Capture returned empty');
+
+ const currentNodes = allNodes.map((n) => ({
...n,
data: { ...n.data, previewImage: null, tableRows: null, meshData: null, overlay: null },
}));
- const data = { version: 1, nodes: currentNodes, edges: reactFlow.getEdges() };
- const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' });
- const a = document.createElement('a');
- a.href = URL.createObjectURL(blob);
- a.download = 'workflow.json';
- a.click();
+ const workflow = { version: 1, nodes: currentNodes, edges: reactFlow.getEdges() };
+ return embedWorkflow(blob, workflow);
}, [reactFlow]);
+ const saveWorkflow = useCallback(async () => {
+ setStatus({ text: 'Saving…', level: 'info' });
+ try {
+ const finalBlob = await getWorkflowBlob();
+
+ if (window.showSaveFilePicker) {
+ const handle = await window.showSaveFilePicker({
+ suggestedName: 'workflow.png',
+ types: [{ description: 'PNG Image', accept: { 'image/png': ['.png'] } }],
+ });
+ const writable = await handle.createWritable();
+ await writable.write(finalBlob);
+ await writable.close();
+ } else {
+ // Fallback: programmatic download
+ const a = document.createElement('a');
+ a.href = URL.createObjectURL(finalBlob);
+ a.download = 'workflow.png';
+ document.body.appendChild(a);
+ a.click();
+ document.body.removeChild(a);
+ URL.revokeObjectURL(a.href);
+ }
+
+ setStatus({ text: 'Workflow saved.', level: 'info' });
+ } catch (err) {
+ if (err.name === 'AbortError') {
+ setStatus({ text: 'Save cancelled.', level: 'info' });
+ } else {
+ setStatus({ text: 'Save failed: ' + err.message, level: 'error' });
+ }
+ }
+ }, [getWorkflowBlob]);
+
+ const copySnapshot = useCallback(() => {
+ setStatus({ text: 'Copying snapshot…', level: 'info' });
+ // Pass a Promise to ClipboardItem so the clipboard.write() call
+ // happens synchronously within the user gesture, avoiding permission errors.
+ const blobPromise = getWorkflowBlob().catch((err) => {
+ setStatus({ text: 'Snapshot failed: ' + err.message, level: 'error' });
+ throw err;
+ });
+ navigator.clipboard.write([new ClipboardItem({ 'image/png': blobPromise })]).then(() => {
+ setStatus({ text: 'Snapshot copied to clipboard.', level: 'info' });
+ }).catch((err) => {
+ setStatus({ text: 'Copy failed: ' + err.message, level: 'error' });
+ });
+ }, [getWorkflowBlob]);
+
const loadWorkflow = useCallback(() => {
const input = document.createElement('input');
input.type = 'file';
- input.accept = '.json';
+ input.accept = '.json,.png';
input.onchange = async (e) => {
const file = e.target.files[0];
if (!file) return;
- const text = await file.text();
try {
- const data = JSON.parse(text);
- const loadedNodes = data.nodes || [];
- const loadedEdges = data.edges || [];
-
- // Re-populate definitions from current nodeDefs
- const defs = nodeDefsRef.current;
- const hydrated = loadedNodes.map((n) => ({
- ...n,
- data: {
- ...n.data,
- definition: defs[n.data.className] || n.data.definition,
- previewImage: null,
- tableRows: null,
- meshData: null,
- overlay: null,
- },
- }));
-
- setNodes(hydrated);
- setEdges(loadedEdges);
-
- // Update ID counter to avoid collisions
- const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
- nextIdRef.current = maxId + 1;
-
+ let data;
+ if (file.name.endsWith('.png') || file.type === 'image/png') {
+ data = await extractWorkflow(file);
+ if (!data) {
+ setStatus({ text: 'No workflow data found in image.', level: 'error' });
+ return;
+ }
+ } else {
+ data = JSON.parse(await file.text());
+ }
+ applyWorkflowData(data);
setStatus({ text: 'Workflow loaded.', level: 'info' });
} catch {
- setStatus({ text: 'Invalid workflow JSON.', level: 'error' });
+ setStatus({ text: 'Invalid workflow file.', level: 'error' });
}
};
input.click();
- }, [setNodes, setEdges]);
+ }, [applyWorkflowData]);
+
+ // ── Drag-and-drop workflow image loading ───────────────────────────
+
+ const onDropFile = useCallback(async (event) => {
+ const files = event.dataTransfer?.files;
+ if (!files || files.length === 0) return;
+ event.preventDefault();
+
+ const file = files[0];
+ if (file.type !== 'image/png') return;
+
+ try {
+ const data = await extractWorkflow(file);
+ if (!data) {
+ setStatus({ text: 'No workflow data in this image.', level: 'error' });
+ return;
+ }
+ applyWorkflowData(data);
+ setStatus({ text: 'Workflow loaded from image.', level: 'info' });
+ } catch (err) {
+ setStatus({ text: 'Failed to load: ' + err.message, level: 'error' });
+ }
+ }, [applyWorkflowData]);
+
+ const onDragOver = useCallback((event) => {
+ if (event.dataTransfer?.types?.includes('Files')) {
+ event.preventDefault();
+ event.dataTransfer.dropEffect = 'copy';
+ }
+ }, []);
// ── Keyboard shortcut ───────────────────────────────────────────────
@@ -509,7 +620,7 @@ function Flow() {
{/* Toolbar */}
-
Argonode
+
argonode
-
{status.text}
@@ -535,7 +649,7 @@ function Flow() {
{/* React Flow canvas */}
{
if (!e.target.closest('.context-menu')) setContextMenu(null);
- }}>
+ }} onDrop={onDropFile} onDragOver={onDragOver}>
>> 1)) : (c >>> 1);
+ }
+ crcTable[i] = c;
+}
+
+function crc32(bytes) {
+ let crc = 0xFFFFFFFF;
+ for (let i = 0; i < bytes.length; i++) {
+ crc = crcTable[(crc ^ bytes[i]) & 0xFF] ^ (crc >>> 8);
+ }
+ return (crc ^ 0xFFFFFFFF) >>> 0;
+}
+
+// ── Helpers ──────────────────────────────────────────────────────────
+
+const PNG_SIG = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
+
+function isPng(data) {
+ if (data.length < 8) return false;
+ for (let i = 0; i < 8; i++) {
+ if (data[i] !== PNG_SIG[i]) return false;
+ }
+ return true;
+}
+
+function chunkType(data, offset) {
+ return String.fromCharCode(
+ data[offset + 4], data[offset + 5], data[offset + 6], data[offset + 7],
+ );
+}
+
+// ── Public API ───────────────────────────────────────────────────────
+
+/**
+ * Embed a workflow object into a PNG blob as a tEXt chunk.
+ * Returns a new Blob with the metadata inserted before IEND.
+ */
+export async function embedWorkflow(pngBlob, workflow) {
+ const data = new Uint8Array(await pngBlob.arrayBuffer());
+ if (!isPng(data)) throw new Error('Not a valid PNG file');
+
+ const encoder = new TextEncoder();
+
+ // Build tEXt payload: keyword \0 text
+ const key = encoder.encode('workflow');
+ const val = encoder.encode(JSON.stringify(workflow));
+ const payload = new Uint8Array(key.length + 1 + val.length);
+ payload.set(key, 0);
+ // payload[key.length] is already 0 (null separator)
+ payload.set(val, key.length + 1);
+
+ // CRC covers type + payload
+ const typeBytes = encoder.encode('tEXt');
+ const forCrc = new Uint8Array(4 + payload.length);
+ forCrc.set(typeBytes, 0);
+ forCrc.set(payload, 4);
+
+ // Assemble chunk: length(4) + type(4) + payload + crc(4)
+ const chunk = new Uint8Array(12 + payload.length);
+ const view = new DataView(chunk.buffer);
+ view.setUint32(0, payload.length);
+ chunk.set(typeBytes, 4);
+ chunk.set(payload, 8);
+ view.setUint32(8 + payload.length, crc32(forCrc));
+
+ // Locate IEND
+ let pos = 8;
+ let iendPos = data.length;
+ while (pos < data.length) {
+ const len = new DataView(data.buffer, pos, 4).getUint32(0);
+ if (chunkType(data, pos) === 'IEND') { iendPos = pos; break; }
+ pos += 12 + len;
+ }
+
+ // Splice: [before IEND] + [tEXt chunk] + [IEND]
+ const result = new Uint8Array(data.length + chunk.length);
+ result.set(data.subarray(0, iendPos), 0);
+ result.set(chunk, iendPos);
+ result.set(data.subarray(iendPos), iendPos + chunk.length);
+
+ return new Blob([result], { type: 'image/png' });
+}
+
+/**
+ * Extract the workflow object from a PNG blob's tEXt chunks.
+ * Returns the parsed object, or null if no "workflow" key is found.
+ */
+export async function extractWorkflow(pngBlob) {
+ const data = new Uint8Array(await pngBlob.arrayBuffer());
+ if (!isPng(data)) return null;
+
+ const decoder = new TextDecoder();
+ let pos = 8;
+
+ while (pos + 8 <= data.length) {
+ const len = new DataView(data.buffer, pos, 4).getUint32(0);
+ const type = chunkType(data, pos);
+
+ if (type === 'tEXt' && pos + 8 + len <= data.length) {
+ const chunkData = data.subarray(pos + 8, pos + 8 + len);
+ const nullIdx = chunkData.indexOf(0);
+ if (nullIdx !== -1) {
+ const k = decoder.decode(chunkData.subarray(0, nullIdx));
+ if (k === 'workflow') {
+ return JSON.parse(decoder.decode(chunkData.subarray(nullIdx + 1)));
+ }
+ }
+ }
+
+ if (type === 'IEND') break;
+ pos += 12 + len;
+ }
+
+ return null;
+}
diff --git a/frontend/src/styles.css b/frontend/src/styles.css
index 27a0fab..5b84a67 100644
--- a/frontend/src/styles.css
+++ b/frontend/src/styles.css
@@ -21,8 +21,8 @@ html, body, #root {
/* ── Toolbar ───────────────────────────────────────────────────────── */
#toolbar {
height: 44px;
- background: #16213e;
- border-bottom: 1px solid #0f3460;
+ background: #242424;
+ border-bottom: 1px solid #000000;
display: flex;
align-items: center;
padding: 0 12px;
@@ -36,7 +36,7 @@ html, body, #root {
font-size: 15px;
font-weight: 700;
letter-spacing: 0.5px;
- color: #e94560;
+ color: #ffffff;
margin-right: 8px;
flex-shrink: 0;
}
@@ -129,8 +129,17 @@ html, body, #root {
cursor: grabbing;
}
-.custom-node.selected {
+/* Selected node — target via React Flow's wrapper class */
+.react-flow__node.selected .custom-node {
border-color: #90caf9;
+ box-shadow: 0 0 0 1px #90caf9, 0 0 12px rgba(144, 202, 249, 0.4);
+}
+
+/* Selected edge */
+.react-flow__edge.selected .react-flow__edge-path {
+ stroke: #90caf9 !important;
+ stroke-width: 3px !important;
+ filter: drop-shadow(0 0 4px rgba(144, 202, 249, 0.6));
}
.node-title {
diff --git a/package.json b/package.json
index b27a7b7..5fe014e 100644
--- a/package.json
+++ b/package.json
@@ -12,6 +12,8 @@
"preview": "npm --prefix frontend run preview",
"backend": "python -m backend.main",
"desktop": "python desktop.py",
- "build:desktop": "powershell -ExecutionPolicy Bypass -File scripts\\build-desktop.ps1"
+ "build:windows": "powershell -ExecutionPolicy Bypass -File scripts\\build-windows.ps1",
+ "build:mac": "bash scripts/build-mac.sh",
+ "build:linux": "bash scripts/build-linux.sh"
}
}
diff --git a/pyproject.toml b/pyproject.toml
index 538b571..e83b80c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,7 +10,10 @@ readme = "GWYDDION_FEATURE_GAP.md"
requires-python = ">=3.10"
dependencies = [
"aiohttp>=3.9,<4",
+ "gwyfile>=0.2",
+ "igor>=0.3",
"matplotlib>=3.8,<4",
+ "nanonispy>=1.1",
"numpy>=1.26,<3",
"pillow>=10,<12",
"scikit-image>=0.22,<1",
@@ -18,11 +21,6 @@ dependencies = [
]
[project.optional-dependencies]
-spm = [
- "gwyfile>=0.2",
- "igor>=0.3",
- "nanonispy>=1.1",
-]
dev = [
"pytest>=8,<9",
]
diff --git a/scripts/build-linux.sh b/scripts/build-linux.sh
new file mode 100755
index 0000000..df926c5
--- /dev/null
+++ b/scripts/build-linux.sh
@@ -0,0 +1,65 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ONE_FILE=false
+CREATE_TAR=true
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --onefile) ONE_FILE=true; shift ;;
+ --no-tar) CREATE_TAR=false; shift ;;
+ *) echo "Unknown option: $1"; exit 1 ;;
+ esac
+done
+
+REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
+cd "$REPO_ROOT"
+
+if [ -d ".venv/bin" ]; then
+ PYTHON=".venv/bin/python"
+else
+ PYTHON="python3"
+fi
+
+FRONTEND_DIST="$REPO_ROOT/frontend/dist"
+DEMO_DIR="$REPO_ROOT/demo"
+
+echo "Building frontend bundle..."
+npm run build
+
+echo "Installing desktop build dependencies..."
+uv pip install -e ".[desktop]"
+
+if $ONE_FILE; then
+ MODE="--onefile"
+else
+ MODE="--onedir"
+fi
+
+echo "Packaging desktop app with PyInstaller..."
+$PYTHON -m PyInstaller \
+ desktop.py \
+ --noconfirm \
+ --clean \
+ --name argonode \
+ --windowed \
+ $MODE \
+ --distpath desktop-dist \
+ --workpath desktop-build \
+ --specpath desktop-build \
+ --add-data "${FRONTEND_DIST}:frontend/dist" \
+ --add-data "${DEMO_DIR}:demo" \
+ --collect-all matplotlib \
+ --collect-all scipy \
+ --collect-all skimage \
+ --collect-all webview
+
+if $CREATE_TAR; then
+ TAR_PATH="desktop-dist/argonode-linux.tar.gz"
+ echo "Creating tarball..."
+ tar -czf "$TAR_PATH" -C desktop-dist argonode
+ echo "Tarball created: $TAR_PATH"
+fi
+
+echo "Desktop build complete."
+echo "Output: $REPO_ROOT/desktop-dist/"
diff --git a/scripts/build-mac.sh b/scripts/build-mac.sh
new file mode 100755
index 0000000..c6f92ff
--- /dev/null
+++ b/scripts/build-mac.sh
@@ -0,0 +1,110 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ONE_FILE=false
+CREATE_DMG=true
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --onefile) ONE_FILE=true; shift ;;
+ --no-dmg) CREATE_DMG=false; shift ;;
+ *) echo "Unknown option: $1"; exit 1 ;;
+ esac
+done
+
+REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
+cd "$REPO_ROOT"
+
+if [ -d ".venv/bin" ]; then
+ PYTHON=".venv/bin/python"
+else
+ PYTHON="python3"
+fi
+
+FRONTEND_DIST="$REPO_ROOT/frontend/dist"
+DEMO_DIR="$REPO_ROOT/demo"
+
+echo "Building frontend bundle..."
+npm run build
+
+echo "Installing desktop build dependencies..."
+uv pip install -e ".[desktop]"
+
+if $ONE_FILE; then
+ MODE="--onefile"
+else
+ MODE="--onedir"
+fi
+
+echo "Packaging desktop app with PyInstaller..."
+$PYTHON -m PyInstaller \
+ desktop.py \
+ --noconfirm \
+ --clean \
+ --name argonode \
+ --windowed \
+ $MODE \
+ --distpath desktop-dist \
+ --workpath desktop-build \
+ --specpath desktop-build \
+ --add-data "${FRONTEND_DIST}:frontend/dist" \
+ --add-data "${DEMO_DIR}:demo" \
+ --collect-all matplotlib \
+ --collect-all scipy \
+ --collect-all skimage \
+ --collect-all webview \
+ --icon resources/icon.icns 2>/dev/null || \
+$PYTHON -m PyInstaller \
+ desktop.py \
+ --noconfirm \
+ --clean \
+ --name argonode \
+ --windowed \
+ $MODE \
+ --distpath desktop-dist \
+ --workpath desktop-build \
+ --specpath desktop-build \
+ --add-data "${FRONTEND_DIST}:frontend/dist" \
+ --add-data "${DEMO_DIR}:demo" \
+ --collect-all matplotlib \
+ --collect-all scipy \
+ --collect-all skimage \
+ --collect-all webview
+
+APP_BUNDLE="desktop-dist/argonode.app"
+
+if [ ! -d "$APP_BUNDLE" ]; then
+ # --onedir puts it inside a folder
+ if [ -d "desktop-dist/argonode/argonode.app" ]; then
+ APP_BUNDLE="desktop-dist/argonode/argonode.app"
+ else
+ echo "Warning: .app bundle not found; skipping DMG creation."
+ CREATE_DMG=false
+ fi
+fi
+
+if $CREATE_DMG; then
+ DMG_PATH="desktop-dist/argonode.dmg"
+ echo "Creating DMG installer..."
+ rm -f "$DMG_PATH"
+
+ # Create a temporary directory for DMG contents
+ DMG_STAGING="desktop-build/dmg-staging"
+ rm -rf "$DMG_STAGING"
+ mkdir -p "$DMG_STAGING"
+ cp -R "$APP_BUNDLE" "$DMG_STAGING/"
+ ln -s /Applications "$DMG_STAGING/Applications"
+
+ hdiutil create \
+ -volname "argonode" \
+ -srcfolder "$DMG_STAGING" \
+ -ov \
+ -format UDZO \
+ "$DMG_PATH"
+
+ rm -rf "$DMG_STAGING"
+ echo "DMG created: $DMG_PATH"
+fi
+
+echo "Desktop build complete."
+echo "Output: $REPO_ROOT/desktop-dist/"
diff --git a/scripts/build-desktop.ps1 b/scripts/build-windows.ps1
similarity index 86%
rename from scripts/build-desktop.ps1
rename to scripts/build-windows.ps1
index ac4c302..b129820 100644
--- a/scripts/build-desktop.ps1
+++ b/scripts/build-windows.ps1
@@ -14,6 +14,7 @@ $pythonExe = if (Test-Path ".\.venv\Scripts\python.exe") {
"python"
}
$frontendDist = Join-Path $repoRoot "frontend\dist"
+$demoDir = Join-Path $repoRoot "demo"
Write-Host "Building frontend bundle..."
npm run build
@@ -28,13 +29,14 @@ $pyInstallerArgs = @(
"desktop.py",
"--noconfirm",
"--clean",
- "--name", "Argonode",
+ "--name", "argonode",
"--windowed",
$mode,
"--distpath", "desktop-dist",
"--workpath", "desktop-build",
"--specpath", "desktop-build",
"--add-data", "${frontendDist};frontend/dist",
+ "--add-data", "${demoDir};demo",
"--collect-all", "matplotlib",
"--collect-all", "scipy",
"--collect-all", "skimage",
@@ -45,4 +47,4 @@ Write-Host "Packaging desktop app..."
& $pythonExe @pyInstallerArgs
Write-Host "Desktop build complete."
-Write-Host "Output folder: $repoRoot\desktop-dist\Argonode"
+Write-Host "Output folder: $repoRoot\desktop-dist\argonode"
diff --git a/scripts/generate_demo_particles.py b/scripts/generate_demo_particles.py
new file mode 100644
index 0000000..7d8899e
--- /dev/null
+++ b/scripts/generate_demo_particles.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+"""Generate a synthetic nanoparticle image for the demo/ folder.
+
+The image simulates an AFM scan of particles on a flat substrate:
+ - Slightly noisy background
+ - ~20 hemisphere-shaped particles with varying radii and heights
+ - Saved as both .npy (calibrated float64) and .png (visual preview)
+
+Run from project root:
+ python scripts/generate_demo_particles.py
+"""
+
+import numpy as np
+from pathlib import Path
+
+DEMO_DIR = Path(__file__).resolve().parent.parent / "demo"
+DEMO_DIR.mkdir(exist_ok=True)
+
+RNG = np.random.default_rng(2024)
+
+# --- Image parameters ---
+N = 256 # pixels
+SCAN_SIZE = 5e-6 # 5 µm scan
+PIXEL_SIZE = SCAN_SIZE / N # metres per pixel
+BG_NOISE_RMS = 0.3e-9 # 0.3 nm background noise
+
+# --- Generate particles ---
+particles = []
+# Hand-placed cluster + random scatter to give a realistic spread
+fixed = [
+ # (cx_frac, cy_frac, radius_nm, height_nm)
+ (0.25, 0.30, 120, 30),
+ (0.28, 0.34, 80, 20),
+ (0.70, 0.25, 150, 45),
+ (0.50, 0.55, 100, 25),
+ (0.55, 0.60, 60, 15),
+ (0.15, 0.75, 200, 55),
+ (0.80, 0.80, 90, 22),
+]
+for cx_f, cy_f, r_nm, h_nm in fixed:
+ particles.append((cx_f * N, cy_f * N, r_nm * 1e-9, h_nm * 1e-9))
+
+# Random particles
+for _ in range(15):
+ cx = RNG.uniform(20, N - 20)
+ cy = RNG.uniform(20, N - 20)
+ radius = RNG.uniform(30, 180) * 1e-9 # 30–180 nm
+ height = RNG.uniform(8, 60) * 1e-9 # 8–60 nm
+ particles.append((cx, cy, radius, height))
+
+# --- Render height map ---
+image = RNG.normal(0, BG_NOISE_RMS, (N, N))
+
+yy, xx = np.mgrid[0:N, 0:N]
+
+for cx, cy, radius_m, height_m in particles:
+ radius_px = radius_m / PIXEL_SIZE
+ dist2 = (xx - cx) ** 2 + (yy - cy) ** 2
+ inside = dist2 < radius_px ** 2
+ # Hemisphere profile: z = h * sqrt(1 - (r/R)^2)
+ z = np.zeros_like(image)
+ z[inside] = height_m * np.sqrt(1.0 - dist2[inside] / radius_px ** 2)
+ image = np.maximum(image, z) # particles don't subtract from each other
+
+# --- Save .npy (float64 metres) ---
+npy_path = DEMO_DIR / "nanoparticles.npy"
+np.save(str(npy_path), image)
+print(f"Saved {npy_path} shape={image.shape} range=[{image.min():.2e}, {image.max():.2e}] m")
+
+# --- Save .png (8-bit grayscale for quick visual) ---
+from PIL import Image
+
+normed = (image - image.min()) / (image.max() - image.min())
+uint8 = (normed * 255).astype(np.uint8)
+png_path = DEMO_DIR / "nanoparticles.png"
+Image.fromarray(uint8, mode="L").save(str(png_path))
+print(f"Saved {png_path}")
+
+print(f"\n{len(particles)} particles generated on a {SCAN_SIZE*1e6:.0f} µm × {SCAN_SIZE*1e6:.0f} µm scan")
diff --git a/tests/test_grains.py b/tests/test_grains.py
new file mode 100644
index 0000000..7b94f85
--- /dev/null
+++ b/tests/test_grains.py
@@ -0,0 +1,445 @@
+"""
+Thorough tests for the grain/particle analysis pipeline:
+ ThresholdMask → GrainAnalysis
+
+Covers synthetic geometry (known answers), the demo nanoparticles image,
+edge cases, and physical-unit correctness.
+
+Run from project root:
+ .venv/bin/python -m tests.test_grains
+"""
+
+import sys
+import numpy as np
+
+sys.path.insert(0, ".")
+from backend.data_types import DataField
+
+
+def make_field(data, xreal=1e-6, yreal=1e-6):
+ return DataField(data=data.astype(np.float64), xreal=xreal, yreal=yreal,
+ si_unit_xy="m", si_unit_z="m")
+
+
+# =========================================================================
+# ThresholdMask tests
+# =========================================================================
+
+def test_threshold_otsu_bimodal():
+ """Otsu on a clean bimodal image should separate the two populations."""
+ print("=== Test: Otsu on bimodal image ===")
+ from backend.nodes.grains import ThresholdMask
+ node = ThresholdMask()
+
+ data = np.zeros((128, 128))
+ data[30:50, 30:50] = 10.0 # bright square
+ data[70:100, 80:110] = 10.0 # another bright region
+ field = make_field(data)
+
+ mask, = node.process(field, method="otsu", threshold=0.0, direction="above")
+ bright_pixels = (mask == 255)
+ # Should capture both bright regions
+ assert bright_pixels[40, 40], "Otsu missed bright region 1"
+ assert bright_pixels[85, 95], "Otsu missed bright region 2"
+ # Background should be dark
+ assert not bright_pixels[0, 0], "Otsu false positive in background"
+ assert not bright_pixels[60, 60], "Otsu false positive between regions"
+ print(" PASS\n")
+
+
+def test_threshold_relative_range():
+ """Relative threshold at 0.5 should be the midpoint of [min, max]."""
+ print("=== Test: Relative threshold at midpoint ===")
+ from backend.nodes.grains import ThresholdMask
+ node = ThresholdMask()
+
+ data = np.full((64, 64), 2.0)
+ data[10:20, 10:20] = 8.0 # bright patch, range = [2, 8], midpoint = 5
+ field = make_field(data)
+
+ mask, = node.process(field, method="relative", threshold=0.5, direction="above")
+ # Only the bright patch (value 8 >= 5) should be masked
+ assert np.all(mask[10:20, 10:20] == 255)
+ assert np.all(mask[0:10, :] == 0)
+ assert np.all(mask[20:, :] == 0)
+ print(" PASS\n")
+
+
+def test_threshold_empty_mask():
+ """Very high absolute threshold on low data should produce an empty mask."""
+ print("=== Test: Empty mask from high threshold ===")
+ from backend.nodes.grains import ThresholdMask
+ node = ThresholdMask()
+
+ data = np.ones((64, 64))
+ field = make_field(data)
+
+ mask, = node.process(field, method="absolute", threshold=999.0, direction="above")
+ assert mask.sum() == 0, "Mask should be completely empty"
+ print(" PASS\n")
+
+
+def test_threshold_full_mask():
+ """Very low absolute threshold should produce an all-white mask."""
+ print("=== Test: Full mask from low threshold ===")
+ from backend.nodes.grains import ThresholdMask
+ node = ThresholdMask()
+
+ data = np.ones((64, 64)) * 5.0
+ field = make_field(data)
+
+ mask, = node.process(field, method="absolute", threshold=-1.0, direction="above")
+ assert np.all(mask == 255), "Mask should be all white"
+ print(" PASS\n")
+
+
+# =========================================================================
+# GrainAnalysis tests
+# =========================================================================
+
+def test_single_circle_area():
+ """A single filled circle — verify pixel count and physical area."""
+ print("=== Test: Single circle area ===")
+ from backend.nodes.grains import GrainAnalysis
+ node = GrainAnalysis()
+
+ N = 200
+ XREAL = 2e-6 # 2 µm
+ data = np.zeros((N, N))
+ mask = np.zeros((N, N), dtype=np.uint8)
+
+ # Draw a filled circle, radius 30 px, centred at (100, 100)
+ yy, xx = np.mgrid[0:N, 0:N]
+ r = 30
+ circle = ((xx - 100) ** 2 + (yy - 100) ** 2) <= r ** 2
+ data[circle] = 5.0
+ mask[circle] = 255
+
+ field = make_field(data, xreal=XREAL, yreal=XREAL)
+ table, = node.process(field, mask=mask, min_size=1)
+
+ assert len(table) == 1, f"Expected 1 grain, got {len(table)}"
+ grain = table[0]
+
+ # Pixel area of a discrete circle: should be close to π r²
+ expected_px = np.pi * r ** 2
+ assert abs(grain["area_px"] - expected_px) / expected_px < 0.02, \
+ f"area_px={grain['area_px']}, expected≈{expected_px:.0f}"
+
+ # Physical area
+ pixel_area = (XREAL / N) ** 2
+ expected_m2 = grain["area_px"] * pixel_area
+ assert abs(grain["area_m2"] - expected_m2) < 1e-20, \
+ f"area_m2 mismatch: {grain['area_m2']} vs {expected_m2}"
+
+ # Equivalent diameter should be close to 2r in physical units
+ expected_diam = 2 * r * (XREAL / N)
+ assert abs(grain["equiv_diam_m"] - expected_diam) / expected_diam < 0.02, \
+ f"equiv_diam={grain['equiv_diam_m']:.3e}, expected≈{expected_diam:.3e}"
+
+ # Heights
+ assert abs(grain["mean_height"] - 5.0) < 1e-10
+ assert abs(grain["max_height"] - 5.0) < 1e-10
+ print(" PASS\n")
+
+
+def test_multiple_grains_separation():
+ """Three well-separated grains of different sizes — check each is reported."""
+ print("=== Test: Multiple grain separation ===")
+ from backend.nodes.grains import GrainAnalysis
+ node = GrainAnalysis()
+
+ N = 128
+ data = np.zeros((N, N))
+ mask = np.zeros((N, N), dtype=np.uint8)
+
+ # Grain A: 20×20 block, height 10
+ data[10:30, 10:30] = 10.0
+ mask[10:30, 10:30] = 255
+
+ # Grain B: 10×10 block, height 7
+ data[60:70, 60:70] = 7.0
+ mask[60:70, 60:70] = 255
+
+ # Grain C: 5×5 block, height 3
+ data[100:105, 100:105] = 3.0
+ mask[100:105, 100:105] = 255
+
+ field = make_field(data)
+ table, = node.process(field, mask=mask, min_size=1)
+
+ assert len(table) == 3, f"Expected 3 grains, got {len(table)}"
+
+ table.sort(key=lambda r: r["area_px"], reverse=True)
+ assert table[0]["area_px"] == 400 # 20×20
+ assert table[1]["area_px"] == 100 # 10×10
+ assert table[2]["area_px"] == 25 # 5×5
+
+ assert abs(table[0]["mean_height"] - 10.0) < 1e-10
+ assert abs(table[1]["mean_height"] - 7.0) < 1e-10
+ assert abs(table[2]["mean_height"] - 3.0) < 1e-10
+ print(" PASS\n")
+
+
+def test_min_size_filtering():
+ """min_size should exclude grains smaller than the threshold."""
+ print("=== Test: min_size filtering ===")
+ from backend.nodes.grains import GrainAnalysis
+ node = GrainAnalysis()
+
+ N = 64
+ data = np.zeros((N, N))
+ mask = np.zeros((N, N), dtype=np.uint8)
+
+ # Large grain: 15×15 = 225 px
+ data[5:20, 5:20] = 1.0
+ mask[5:20, 5:20] = 255
+
+ # Medium grain: 8×8 = 64 px
+ data[30:38, 30:38] = 1.0
+ mask[30:38, 30:38] = 255
+
+ # Tiny grain: 3×3 = 9 px
+ data[50:53, 50:53] = 1.0
+ mask[50:53, 50:53] = 255
+
+ field = make_field(data)
+
+ # min_size=1: all three
+ table, = node.process(field, mask=mask, min_size=1)
+ assert len(table) == 3
+
+ # min_size=10: drops the 3×3
+ table, = node.process(field, mask=mask, min_size=10)
+ assert len(table) == 2
+
+ # min_size=100: drops the 3×3 and 8×8
+ table, = node.process(field, mask=mask, min_size=100)
+ assert len(table) == 1
+ assert table[0]["area_px"] == 225
+
+ # min_size=300: drops everything
+ table, = node.process(field, mask=mask, min_size=300)
+ assert len(table) == 0
+ print(" PASS\n")
+
+
+def test_grain_bounding_box():
+ """Bounding box should match the grain extents."""
+ print("=== Test: Grain bounding box ===")
+ from backend.nodes.grains import GrainAnalysis
+ node = GrainAnalysis()
+
+ N = 64
+ data = np.zeros((N, N))
+ mask = np.zeros((N, N), dtype=np.uint8)
+ # Place a grain at rows 20:35, cols 10:45
+ data[20:35, 10:45] = 2.0
+ mask[20:35, 10:45] = 255
+
+ field = make_field(data)
+ table, = node.process(field, mask=mask, min_size=1)
+ assert len(table) == 1
+
+ bbox = table[0]["bbox"]
+ # Format: "(xmin,ymin)-(xmax,ymax)" = "(10,20)-(44,34)"
+ assert bbox == "(10,20)-(44,34)", f"bbox={bbox}, expected (10,20)-(44,34)"
+ print(" PASS\n")
+
+
+def test_empty_mask_produces_no_grains():
+ """An all-zero mask should yield zero grains."""
+ print("=== Test: Empty mask → no grains ===")
+ from backend.nodes.grains import GrainAnalysis
+ node = GrainAnalysis()
+
+ field = make_field(np.ones((64, 64)))
+ mask = np.zeros((64, 64), dtype=np.uint8)
+
+ table, = node.process(field, mask=mask, min_size=1)
+ assert len(table) == 0
+ print(" PASS\n")
+
+
+def test_grain_at_image_edge():
+ """A grain touching the image border should still be detected."""
+ print("=== Test: Grain at image edge ===")
+ from backend.nodes.grains import GrainAnalysis
+ node = GrainAnalysis()
+
+ N = 64
+ data = np.zeros((N, N))
+ mask = np.zeros((N, N), dtype=np.uint8)
+ # Grain touching top-left corner
+ data[0:10, 0:10] = 4.0
+ mask[0:10, 0:10] = 255
+
+ field = make_field(data)
+ table, = node.process(field, mask=mask, min_size=1)
+ assert len(table) == 1
+ assert table[0]["area_px"] == 100
+ assert table[0]["bbox"] == "(0,0)-(9,9)"
+ print(" PASS\n")
+
+
+def test_adjacent_grains_connectivity():
+ """Two diagonally-touching blocks should be separate grains
+ (scipy.ndimage.label uses 4-connectivity by default)."""
+ print("=== Test: Diagonal adjacency → separate grains ===")
+ from backend.nodes.grains import GrainAnalysis
+ node = GrainAnalysis()
+
+ N = 32
+ data = np.zeros((N, N))
+ mask = np.zeros((N, N), dtype=np.uint8)
+
+ # Block A
+ data[5:10, 5:10] = 1.0
+ mask[5:10, 5:10] = 255
+
+ # Block B diagonally adjacent (touching only at corner 10,10)
+ data[10:15, 10:15] = 1.0
+ mask[10:15, 10:15] = 255
+
+ field = make_field(data)
+ table, = node.process(field, mask=mask, min_size=1)
+ # Default label() uses structure that connects diagonals? Let's verify.
+ # scipy.ndimage.label default is cross-shaped (no diagonals) for 2D
+ assert len(table) == 2, f"Expected 2 separate grains, got {len(table)}"
+ print(" PASS\n")
+
+
+# =========================================================================
+# End-to-end pipeline: ThresholdMask → GrainAnalysis
+# =========================================================================
+
+def test_pipeline_synthetic():
+ """Full pipeline on a synthetic image with known geometry."""
+ print("=== Test: Full pipeline on synthetic particles ===")
+ from backend.nodes.grains import ThresholdMask, GrainAnalysis
+
+ N = 200
+ XREAL = 10e-6 # 10 µm
+ rng = np.random.default_rng(99)
+
+ # Background at 0 with small noise, particles as raised bumps
+ bg = rng.normal(0, 0.1, (N, N))
+ particles = np.zeros((N, N))
+
+ yy, xx = np.mgrid[0:N, 0:N]
+
+ specs = [
+ (50, 50, 15, 5.0), # (cx, cy, radius_px, height)
+ (150, 50, 20, 8.0),
+ (100, 100, 10, 3.0),
+ (50, 160, 25, 6.0),
+ (160, 160, 12, 4.0),
+ ]
+ for cx, cy, r, h in specs:
+ inside = ((xx - cx) ** 2 + (yy - cy) ** 2) <= r ** 2
+ particles[inside] = h
+
+ data = bg + particles
+ field = make_field(data, xreal=XREAL, yreal=XREAL)
+
+ # Step 1: threshold
+ thresh = ThresholdMask()
+ mask, = thresh.process(field, method="absolute", threshold=1.0, direction="above")
+
+ # Particles are well above noise, so mask should capture all 5
+ assert mask.max() == 255, "No particles detected"
+
+ # Step 2: grain analysis
+ ga = GrainAnalysis()
+ table, = ga.process(field, mask=mask, min_size=5)
+
+ assert len(table) == 5, f"Expected 5 grains, got {len(table)}"
+
+ # Verify that detected areas are in the right ballpark
+ table.sort(key=lambda r: r["area_px"], reverse=True)
+ expected_areas = sorted([np.pi * r ** 2 for _, _, r, _ in specs], reverse=True)
+
+ for grain, expected_px in zip(table, expected_areas):
+ ratio = grain["area_px"] / expected_px
+ assert 0.85 < ratio < 1.15, \
+ f"grain area_px={grain['area_px']}, expected≈{expected_px:.0f}, ratio={ratio:.2f}"
+
+ print(" PASS\n")
+
+
+def test_pipeline_demo_image():
+ """Run the full pipeline on the bundled demo nanoparticles image."""
+ print("=== Test: Full pipeline on demo nanoparticles.npy ===")
+ from pathlib import Path
+ from backend.nodes.grains import ThresholdMask, GrainAnalysis
+ from backend.runtime_paths import demo_dir
+
+ npy_path = demo_dir() / "nanoparticles.npy"
+ if not npy_path.exists():
+ print(" SKIP (demo image not found)\n")
+ return
+
+ data = np.load(str(npy_path)).astype(np.float64)
+ # The demo image is a 5 µm × 5 µm scan
+ field = make_field(data, xreal=5e-6, yreal=5e-6)
+
+ # Threshold to find particles (they are raised above background)
+ thresh = ThresholdMask()
+ mask, = thresh.process(field, method="otsu", threshold=0.0, direction="above")
+
+ # Should detect particles
+ assert mask.max() == 255, "No particles found in demo image"
+ particle_fraction = (mask == 255).sum() / mask.size
+ assert 0.01 < particle_fraction < 0.5, \
+ f"Suspicious particle fraction: {particle_fraction:.3f}"
+ print(f" Mask: {particle_fraction*100:.1f}% of pixels are particles")
+
+ # Grain analysis
+ ga = GrainAnalysis()
+ table, = ga.process(field, mask=mask, min_size=20)
+
+ assert len(table) > 0, "No grains detected"
+ print(f" Found {len(table)} grains (min_size=20)")
+
+ # Sanity checks on grain properties
+ for grain in table:
+ assert grain["area_px"] >= 20
+ assert grain["area_m2"] > 0
+ assert grain["equiv_diam_m"] > 0
+ assert grain["max_height"] >= grain["mean_height"]
+ assert grain["mean_height"] > 0
+
+ # Physical size sanity: equivalent diameters should be in the nm–µm range
+ diams_nm = [g["equiv_diam_m"] * 1e9 for g in table]
+ print(f" Diameters: min={min(diams_nm):.0f} nm, max={max(diams_nm):.0f} nm")
+ assert all(1 < d < 2000 for d in diams_nm), \
+ f"Grain diameters out of expected range: {diams_nm}"
+
+ print(" PASS\n")
+
+
+# =========================================================================
+# Run all tests
+# =========================================================================
+
+if __name__ == "__main__":
+ # ThresholdMask
+ test_threshold_otsu_bimodal()
+ test_threshold_relative_range()
+ test_threshold_empty_mask()
+ test_threshold_full_mask()
+
+ # GrainAnalysis
+ test_single_circle_area()
+ test_multiple_grains_separation()
+ test_min_size_filtering()
+ test_grain_bounding_box()
+ test_empty_mask_produces_no_grains()
+ test_grain_at_image_edge()
+ test_adjacent_grains_connectivity()
+
+ # End-to-end pipeline
+ test_pipeline_synthetic()
+ test_pipeline_demo_image()
+
+ print("All grain tests passed!")
diff --git a/tests/test_nodes.py b/tests/test_nodes.py
index 5f90f4d..541ca60 100644
--- a/tests/test_nodes.py
+++ b/tests/test_nodes.py
@@ -85,6 +85,88 @@ def test_edge_detect():
print(" PASS\n")
+def test_fft_filter_1d():
+ print("=== Test: FFTFilter1D ===")
+ from backend.nodes.filters import FFTFilter1D
+ node = FFTFilter1D()
+
+ # Signal: low-frequency sine + high-frequency sine
+ n = 256
+ t = np.arange(n, dtype=np.float64) / n
+ low = np.sin(2 * np.pi * 3 * t) # 3 cycles — low freq
+ high = np.sin(2 * np.pi * 80 * t) # 80 cycles — high freq
+ line = low + high
+
+ # Lowpass should keep low, suppress high
+ filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
+ assert len(filtered_lp) == n
+ corr_low = np.corrcoef(filtered_lp, low)[0, 1]
+ corr_high = np.corrcoef(filtered_lp, high)[0, 1]
+ assert corr_low > 0.95, f"Lowpass: correlation with low={corr_low}"
+ assert abs(corr_high) < 0.3, f"Lowpass: correlation with high={corr_high}"
+
+ # Highpass should keep high, suppress low
+ filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
+ corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1]
+ corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1]
+ assert abs(corr_low_hp) < 0.3, f"Highpass: correlation with low={corr_low_hp}"
+ assert corr_high_hp > 0.95, f"Highpass: correlation with high={corr_high_hp}"
+
+ # Bandpass centred on the high frequency
+ filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4)
+ corr_low_bp = np.corrcoef(filtered_bp, low)[0, 1]
+ corr_high_bp = np.corrcoef(filtered_bp, high)[0, 1]
+ assert abs(corr_low_bp) < 0.3, f"Bandpass: correlation with low={corr_low_bp}"
+ assert corr_high_bp > 0.9, f"Bandpass: correlation with high={corr_high_bp}"
+
+ # Notch (band-reject) centred on the high frequency — should remove it
+ filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4)
+ corr_low_notch = np.corrcoef(filtered_notch, low)[0, 1]
+ corr_high_notch = np.corrcoef(filtered_notch, high)[0, 1]
+ assert corr_low_notch > 0.95, f"Notch: correlation with low={corr_low_notch}"
+ assert abs(corr_high_notch) < 0.3, f"Notch: correlation with high={corr_high_notch}"
+
+ print(" PASS\n")
+
+
+def test_fft_filter_2d():
+ print("=== Test: FFTFilter2D ===")
+ from backend.nodes.filters import FFTFilter2D
+ node = FFTFilter2D()
+
+ N = 128
+ y, x = np.mgrid[0:N, 0:N] / N
+ # Low-frequency 2D pattern + high-frequency pattern
+ low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y)
+ high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y)
+ data = low_2d + high_2d
+ field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6)
+
+ # Lowpass — should preserve low, remove high
+ result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
+ assert result_lp.data.shape == (N, N)
+ assert result_lp.xreal == field.xreal
+ assert result_lp.si_unit_z == field.si_unit_z
+ corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1]
+ corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1]
+ assert corr_low > 0.9, f"2D lowpass: correlation with low={corr_low}"
+ assert abs(corr_high) < 0.3, f"2D lowpass: correlation with high={corr_high}"
+
+ # Highpass — should preserve high, remove low
+ result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
+ corr_low_hp = np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]
+ corr_high_hp = np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1]
+ assert abs(corr_low_hp) < 0.3, f"2D highpass: correlation with low={corr_low_hp}"
+ assert corr_high_hp > 0.9, f"2D highpass: correlation with high={corr_high_hp}"
+
+ # Constant field should be unchanged by lowpass (DC preservation)
+ const = make_field(data=np.ones((32, 32)) * 7.0)
+ result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2)
+ assert np.allclose(result_const.data, 7.0, atol=1e-10), "Lowpass should preserve constant field"
+
+ print(" PASS\n")
+
+
# =========================================================================
# Level
# =========================================================================
@@ -199,7 +281,7 @@ def test_height_histogram():
data = np.linspace(0, 1, 1000).reshape(25, 40)
field = make_field(data=data)
- counts, bin_centers = node.process(field, n_bins=10)
+ counts, bin_centers = node.process(field, n_bins=10, y_scale="linear")
assert len(counts) == 10
assert len(bin_centers) == 10
assert counts.dtype == np.float64
@@ -265,7 +347,7 @@ def test_cross_section():
def test_threshold_mask():
print("=== Test: ThresholdMask ===")
- from backend.nodes.grains import ThresholdMask
+ from backend.nodes.mask import ThresholdMask
node = ThresholdMask()
# Clear bimodal data: left half = 0, right half = 1
@@ -273,6 +355,11 @@ def test_threshold_mask():
data[:, 32:] = 1.0
field = make_field(data=data)
+ # Capture overlay preview
+ previews = []
+ ThresholdMask._broadcast_fn = lambda nid, uri: previews.append(uri)
+ ThresholdMask._current_node_id = "test"
+
# Absolute threshold at 0.5
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
assert mask.dtype == np.uint8
@@ -280,6 +367,10 @@ def test_threshold_mask():
assert np.all(mask[:, :32] == 0)
assert np.all(mask[:, 32:] == 255)
+ # Verify overlay preview was broadcast
+ assert len(previews) == 1
+ assert previews[0].startswith("data:image/png;base64,")
+
# Direction "below"
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
assert np.all(mask_below[:, :32] == 255)
@@ -292,20 +383,117 @@ def test_threshold_mask():
# Otsu should find the bimodal threshold
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
+
+ ThresholdMask._broadcast_fn = None
print(" PASS\n")
-def test_grain_analysis():
- print("=== Test: GrainAnalysis ===")
- from backend.nodes.grains import GrainAnalysis
- node = GrainAnalysis()
+def test_mask_morphology():
+ print("=== Test: MaskMorphology ===")
+ from backend.nodes.mask import MaskMorphology
+ node = MaskMorphology()
- # Create a field with two distinct "grains"
+ # Small square blob in the centre
+ mask = np.zeros((64, 64), dtype=np.uint8)
+ mask[28:36, 28:36] = 255 # 8x8 block
+ orig_count = np.count_nonzero(mask)
+
+ # Dilate should grow the region
+ dilated, = node.process(mask, operation="dilate", radius=1, shape="square")
+ assert dilated.dtype == np.uint8
+ assert np.count_nonzero(dilated) > orig_count
+
+ # Erode should shrink it
+ eroded, = node.process(mask, operation="erode", radius=1, shape="square")
+ assert np.count_nonzero(eroded) < orig_count
+
+ # Open on a clean block should give back roughly the same block
+ opened, = node.process(mask, operation="open", radius=1, shape="square")
+ assert np.count_nonzero(opened) <= orig_count
+
+ # Close on a mask with a 1-pixel hole should fill the hole
+ mask_hole = mask.copy()
+ mask_hole[32, 32] = 0 # poke a hole
+ assert np.count_nonzero(mask_hole) == orig_count - 1
+ closed, = node.process(mask_hole, operation="close", radius=1, shape="square")
+ assert closed[32, 32] == 255, "Close should fill the 1-pixel hole"
+
+ # Disk structuring element should also work
+ dilated_disk, = node.process(mask, operation="dilate", radius=2, shape="disk")
+ assert np.count_nonzero(dilated_disk) > orig_count
+
+ print(" PASS\n")
+
+
+def test_mask_invert():
+ print("=== Test: MaskInvert ===")
+ from backend.nodes.mask import MaskInvert
+ node = MaskInvert()
+
+ mask = np.zeros((64, 64), dtype=np.uint8)
+ mask[10:20, 10:20] = 255
+
+ inverted, = node.process(mask)
+ assert inverted.dtype == np.uint8
+ assert np.all(inverted[10:20, 10:20] == 0)
+ assert np.all(inverted[0:10, 0:10] == 255)
+ # Double-invert should return to original
+ double, = node.process(inverted)
+ assert np.array_equal(double, mask)
+
+ print(" PASS\n")
+
+
+def test_mask_combine():
+ print("=== Test: MaskCombine ===")
+ from backend.nodes.mask import MaskCombine
+ node = MaskCombine()
+
+ # Two overlapping squares
+ a = np.zeros((64, 64), dtype=np.uint8)
+ a[10:30, 10:30] = 255 # 20x20
+ b = np.zeros((64, 64), dtype=np.uint8)
+ b[20:40, 20:40] = 255 # 20x20, overlaps 10x10
+
+ # AND — only the overlap
+ result_and, = node.process(a, b, operation="and")
+ assert np.all(result_and[20:30, 20:30] == 255)
+ assert result_and[15, 15] == 0 # a-only region
+ assert result_and[35, 35] == 0 # b-only region
+
+ # OR — union
+ result_or, = node.process(a, b, operation="or")
+ assert result_or[15, 15] == 255
+ assert result_or[35, 35] == 255
+ assert result_or[25, 25] == 255
+ assert result_or[5, 5] == 0
+
+ # XOR — symmetric difference
+ result_xor, = node.process(a, b, operation="xor")
+ assert result_xor[15, 15] == 255 # a-only
+ assert result_xor[35, 35] == 255 # b-only
+ assert result_xor[25, 25] == 0 # overlap excluded
+
+ # Subtract — a minus b
+ result_sub, = node.process(a, b, operation="subtract")
+ assert result_sub[15, 15] == 255 # a-only kept
+ assert result_sub[25, 25] == 0 # overlap removed
+ assert result_sub[35, 35] == 0 # b-only not included
+
+ print(" PASS\n")
+
+
+def test_particle_analysis():
+ print("=== Test: ParticleAnalysis ===")
+ from backend.nodes.grains import ParticleAnalysis
+ node = ParticleAnalysis()
+
+ # Create a field with two distinct particles
N = 64
data = np.zeros((N, N))
- # Grain 1: 10x10 block at top-left with height 5
+ # Particle 1: 10x10 block at top-left with height 5
data[5:15, 5:15] = 5.0
- # Grain 2: 8x8 block at bottom-right with height 3
+ # Particle 2: 8x8 block at bottom-right with height 3
data[45:53, 45:53] = 3.0
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
@@ -315,7 +503,7 @@ def test_grain_analysis():
mask[45:53, 45:53] = 255
table, = node.process(field, mask=mask, min_size=10)
- assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
+ assert len(table) == 2, f"Expected 2 particles, got {len(table)}"
# Sort by area descending
table.sort(key=lambda r: r["area_px"], reverse=True)
@@ -324,7 +512,7 @@ def test_grain_analysis():
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
- # min_size filtering: only keep grains >= 80 px
+ # min_size filtering: only keep particles >= 80 px
table_filtered, = node.process(field, mask=mask, min_size=80)
assert len(table_filtered) == 1
assert table_filtered[0]["area_px"] == 100
@@ -462,6 +650,8 @@ if __name__ == "__main__":
test_gaussian_filter()
test_median_filter()
test_edge_detect()
+ test_fft_filter_1d()
+ test_fft_filter_2d()
# Level
test_plane_level()
@@ -473,9 +663,14 @@ if __name__ == "__main__":
test_height_histogram()
test_cross_section()
- # Grains
+ # Mask
test_threshold_mask()
- test_grain_analysis()
+ test_mask_morphology()
+ test_mask_invert()
+ test_mask_combine()
+
+ # Grains
+ test_particle_analysis()
# I/O
test_load_image()