add snapshot tool, masks, and build for mac
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Gwyddion Features Not Yet in Argonode
|
||||
# Gwyddion Features Not Yet in argonode
|
||||
|
||||
Reference for future implementation. Grouped by value to typical SPM workflows.
|
||||
|
||||
@@ -11,9 +11,9 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
|
||||
| 1 | Line Correction | linecorrect.c, linematch.c | Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts. |
|
||||
| 2 | Scar Removal | scars.c | Detect and interpolate scan-line defects (horizontal streaks). |
|
||||
| 3 | Facet Leveling | facet-level.c | Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces. |
|
||||
| 4 | Morphological Mask Ops | mask_morph.c | Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks. |
|
||||
| 5 | 1D FFT Filter | fft_filter_1d.c | Bandpass/lowpass/highpass filtering of LINE profiles. |
|
||||
| 6 | 2D FFT Filter | fft_filter_2d.c | Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.). |
|
||||
| ~~4~~ | ~~Morphological Mask Ops~~ | ~~mask_morph.c~~ | ~~Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks.~~ **DONE** |
|
||||
| ~~5~~ | ~~1D FFT Filter~~ | ~~fft_filter_1d.c~~ | ~~Bandpass/lowpass/highpass filtering of LINE profiles.~~ **DONE** |
|
||||
| ~~6~~ | ~~2D FFT Filter~~ | ~~fft_filter_2d.c~~ | ~~Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.).~~ **DONE** |
|
||||
| 7 | Autocorrelation (ACF) | acf2d.c | 2D autocorrelation function. Reveals periodic structures and correlation lengths. |
|
||||
| 8 | PSDF | psdf2d.c | Radial/2D power spectral density function. Complementary to ACF for roughness characterization. |
|
||||
| 9 | Fractal Dimension | fractal.c | Multiple methods: partitioning, cube counting, triangulation, PSDF, HHCF. Quantifies surface complexity. |
|
||||
@@ -61,11 +61,11 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
|
||||
|
||||
---
|
||||
|
||||
## Already Implemented in Argonode
|
||||
## Already Implemented in argonode
|
||||
|
||||
For reference, these Gwyddion equivalents are already covered:
|
||||
|
||||
| Argonode Node | Category | Gwyddion Equivalent |
|
||||
| argonode Node | Category | Gwyddion Equivalent |
|
||||
|--------------|----------|-------------------|
|
||||
| Load Image / Load SPM File | io | File import (gwy, sxm, ibw) |
|
||||
| Save Image | io | File export |
|
||||
@@ -76,12 +76,17 @@ For reference, these Gwyddion equivalents are already covered:
|
||||
| Gaussian Filter | filters | filters.c (gaussian) |
|
||||
| Median Filter | filters | filters.c (median) |
|
||||
| Edge Detect | filters | edge.c (sobel, prewitt, laplacian, LoG) |
|
||||
| 1D FFT Filter | filters | fft_filter_1d.c (lowpass, highpass, bandpass, notch) |
|
||||
| 2D FFT Filter | filters | fft_filter_2d.c (lowpass, highpass, bandpass, notch) |
|
||||
| Statistics | analysis | stats.c |
|
||||
| Height Histogram | analysis | linestats.c (dh) |
|
||||
| 2D FFT | analysis | fft.c |
|
||||
| Cross Section | analysis | profile tool |
|
||||
| Profile Roughness | analysis | roughness.c (Ra, Rq, Rsk, Rku, Rp, Rv, Rt) |
|
||||
| Line Math | analysis | linestats.c |
|
||||
| Threshold Mask | grains | threshold.c, otsu_threshold.c |
|
||||
| Grain Analysis | grains | grain_stat.c |
|
||||
| Threshold Mask | mask | threshold.c, otsu_threshold.c |
|
||||
| Mask Morphology | mask | mask_morph.c (erode, dilate, open, close) |
|
||||
| Mask Invert | mask | — |
|
||||
| Mask Combine | mask | — (boolean AND, OR, XOR, subtract) |
|
||||
| Particle Analysis | grains | grain_stat.c |
|
||||
| Preview / 3D View / Print Table | display | Presentation, 3D view |
|
||||
|
||||
14
README.md
14
README.md
@@ -1,6 +1,6 @@
|
||||
# Argonode
|
||||
# argonode
|
||||
|
||||
Argonode is a node-based image analysis application with:
|
||||
argonode is a node-based image analysis application with:
|
||||
|
||||
- a Python backend built on `aiohttp`
|
||||
- a React + Vite frontend
|
||||
@@ -135,13 +135,13 @@ powershell -ExecutionPolicy Bypass -File scripts\build-desktop.ps1
|
||||
The packaged app is written to:
|
||||
|
||||
```text
|
||||
desktop-dist/Argonode/
|
||||
desktop-dist/argonode/
|
||||
```
|
||||
|
||||
Main executable:
|
||||
|
||||
```text
|
||||
desktop-dist/Argonode/Argonode.exe
|
||||
desktop-dist/argonode/argonode.exe
|
||||
```
|
||||
|
||||
### One-File Build
|
||||
@@ -161,14 +161,14 @@ During normal source-based development, input/output folders live under the repo
|
||||
In the packaged desktop app, writable data is stored under:
|
||||
|
||||
```text
|
||||
%LOCALAPPDATA%\Argonode\
|
||||
%LOCALAPPDATA%\argonode\
|
||||
```
|
||||
|
||||
Specifically:
|
||||
|
||||
```text
|
||||
%LOCALAPPDATA%\Argonode\input
|
||||
%LOCALAPPDATA%\Argonode\output
|
||||
%LOCALAPPDATA%\argonode\input
|
||||
%LOCALAPPDATA%\argonode\output
|
||||
```
|
||||
|
||||
You can override the packaged app data directory with:
|
||||
|
||||
@@ -177,13 +177,19 @@ class ExecutionEngine:
|
||||
) -> None:
|
||||
"""Wire up broadcast callbacks on display node classes."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection
|
||||
from backend.nodes.analysis import CrossSection, LineCursors
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||
from backend.nodes.io import SaveImage
|
||||
|
||||
PreviewImage._broadcast_fn = on_preview
|
||||
ThresholdMask._broadcast_fn = on_preview
|
||||
MaskMorphology._broadcast_fn = on_preview
|
||||
MaskInvert._broadcast_fn = on_preview
|
||||
MaskCombine._broadcast_fn = on_preview
|
||||
View3D._broadcast_mesh_fn = on_mesh
|
||||
PrintTable._broadcast_table_fn = on_table
|
||||
CrossSection._broadcast_overlay_fn = on_overlay
|
||||
LineCursors._broadcast_overlay_fn = on_overlay
|
||||
SaveImage._broadcast_preview = (
|
||||
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
|
||||
)
|
||||
@@ -191,8 +197,10 @@ class ExecutionEngine:
|
||||
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||
"""Inform display nodes of their current node_id for WS tagging."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection
|
||||
if cls in (PreviewImage, PrintTable, View3D, CrossSection):
|
||||
from backend.nodes.analysis import CrossSection, LineCursors
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
def _auto_preview(
|
||||
@@ -206,12 +214,16 @@ class ExecutionEngine:
|
||||
"""
|
||||
After every node executes, inspect its outputs and broadcast
|
||||
a preview for the first DATA_FIELD, IMAGE, or TABLE found.
|
||||
Skip nodes that broadcast their own custom preview.
|
||||
"""
|
||||
import numpy as np
|
||||
from backend.data_types import (
|
||||
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
|
||||
)
|
||||
|
||||
if getattr(cls, "_CUSTOM_PREVIEW", False):
|
||||
return
|
||||
|
||||
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||
|
||||
for slot, type_name in enumerate(return_types):
|
||||
|
||||
@@ -36,7 +36,7 @@ def main() -> None:
|
||||
app = create_app(loop)
|
||||
|
||||
log.info("=" * 60)
|
||||
log.info(" Argonode — Node-based image analysis")
|
||||
log.info(" argonode — Node-based image analysis")
|
||||
log.info(" Open your browser at http://%s:%d", HOST, PORT)
|
||||
log.info("=" * 60)
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Import all node modules to trigger @register_node decorators.
|
||||
from . import io, filters, level, analysis, grains, display
|
||||
from . import io, filters, level, analysis, grains, mask, display
|
||||
|
||||
@@ -69,6 +69,7 @@ class HeightHistogram:
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
|
||||
"y_scale": (["linear", "log"],),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,13 +79,150 @@ class HeightHistogram:
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute the height distribution histogram (DH). "
|
||||
"Use log scale to reveal small peaks next to a dominant background. "
|
||||
"Equivalent to gwy_data_field_dh."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, n_bins: int) -> tuple:
|
||||
def process(self, field: DataField, n_bins: int, y_scale: str = "linear") -> tuple:
|
||||
counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
|
||||
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
|
||||
return (counts.astype(np.float64), bin_centers)
|
||||
counts = counts.astype(np.float64)
|
||||
if y_scale == "log":
|
||||
counts = np.log10(1.0 + counts)
|
||||
return (counts, bin_centers)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LineCursors — interactive measurement cursors on any LINE plot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Line Cursors")
|
||||
class LineCursors:
|
||||
"""Place two draggable cursors on any LINE plot to measure values and deltas."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"line": ("LINE",),
|
||||
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
},
|
||||
"optional": {
|
||||
"x_axis": ("LINE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TABLE",)
|
||||
RETURN_NAMES = ("measurement",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Place two cursors on any line plot (histogram, cross section, profile) "
|
||||
"to measure positions, values, and deltas. Drag the markers to reposition."
|
||||
)
|
||||
|
||||
_broadcast_overlay_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(
|
||||
self, line, x1: float, y1: float, x2: float, y2: float,
|
||||
x_axis=None,
|
||||
) -> tuple:
|
||||
import io as _io
|
||||
import base64
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
y = np.asarray(line, dtype=np.float64).ravel()
|
||||
n = len(y)
|
||||
if x_axis is not None:
|
||||
x = np.asarray(x_axis, dtype=np.float64).ravel()[:n]
|
||||
else:
|
||||
x = np.arange(n, dtype=np.float64)
|
||||
|
||||
# --- Render the base plot first to determine axes bounds ---
|
||||
fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100)
|
||||
fig.patch.set_facecolor("#1e293b")
|
||||
ax.set_facecolor("#0f172a")
|
||||
ax.plot(x, y, color="#ff9800", linewidth=1.2)
|
||||
ax.tick_params(colors="#94a3b8", labelsize=7)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#334155")
|
||||
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
|
||||
fig.tight_layout(pad=0.4)
|
||||
|
||||
# Force a draw so transforms are valid
|
||||
fig.canvas.draw()
|
||||
|
||||
# Get axes position in figure-fraction coordinates
|
||||
ax_pos = ax.get_position()
|
||||
ax_l, ax_b = ax_pos.x0, ax_pos.y0
|
||||
ax_w, ax_h = ax_pos.width, ax_pos.height
|
||||
|
||||
# x1/y1 arrive as image-fraction from the frontend drag.
|
||||
# Convert image-fraction x → axes-fraction → nearest data index.
|
||||
def img_x_to_idx(ix):
|
||||
axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1)
|
||||
return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1))
|
||||
|
||||
idx_a = img_x_to_idx(x1)
|
||||
idx_b = img_x_to_idx(x2)
|
||||
|
||||
xa, ya = float(x[idx_a]), float(y[idx_a])
|
||||
xb, yb = float(x[idx_b]), float(y[idx_b])
|
||||
|
||||
# --- Draw cursor lines and markers on the plot ---
|
||||
ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
|
||||
ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
|
||||
ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5)
|
||||
ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5)
|
||||
ax.annotate(
|
||||
"", xy=(xb, yb), xytext=(xa, ya),
|
||||
arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5),
|
||||
)
|
||||
|
||||
# --- Broadcast overlay ---
|
||||
if LineCursors._broadcast_overlay_fn is not None:
|
||||
# Convert data-space positions back to image-fraction for markers
|
||||
fig.canvas.draw()
|
||||
inv = fig.transFigure.inverted()
|
||||
fig_a = inv.transform(ax.transData.transform([xa, ya]))
|
||||
fig_b = inv.transform(ax.transData.transform([xb, yb]))
|
||||
|
||||
buf = _io.BytesIO()
|
||||
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
|
||||
buf.seek(0)
|
||||
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
|
||||
|
||||
LineCursors._broadcast_overlay_fn(
|
||||
LineCursors._current_node_id,
|
||||
{
|
||||
"image": image_uri,
|
||||
"x1": float(fig_a[0]),
|
||||
"y1": float(1.0 - fig_a[1]), # flip: image y=0 is top
|
||||
"x2": float(fig_b[0]),
|
||||
"y2": float(1.0 - fig_b[1]),
|
||||
"a_locked": False,
|
||||
"b_locked": False,
|
||||
},
|
||||
)
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
# --- Output table ---
|
||||
table = [
|
||||
{"quantity": "A position", "value": xa, "unit": ""},
|
||||
{"quantity": "A value", "value": ya, "unit": ""},
|
||||
{"quantity": "B position", "value": xb, "unit": ""},
|
||||
{"quantity": "B value", "value": yb, "unit": ""},
|
||||
{"quantity": "delta X", "value": xb - xa, "unit": ""},
|
||||
{"quantity": "delta Y", "value": yb - ya, "unit": ""},
|
||||
]
|
||||
return (table,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -242,9 +380,9 @@ class CrossSection:
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"x1": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"x2": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"extend": (["none", "to_edges"],),
|
||||
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
|
||||
|
||||
@@ -5,6 +5,8 @@ Gwyddion equivalents:
|
||||
GaussianFilter → gwy_data_field_filter_gaussian
|
||||
MedianFilter → gwy_data_field_filter_median
|
||||
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
|
||||
FFTFilter1D → fft_filter_1d.c (bandpass/lowpass/highpass on LINE profiles)
|
||||
FFTFilter2D → fft_filter_2d.c (frequency-domain filtering of DATA_FIELDs)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -113,3 +115,190 @@ class EdgeDetect:
|
||||
raise ValueError(f"Unknown edge detection method: {method}")
|
||||
|
||||
return (field.replace(data=result),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Butterworth transfer function helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _butterworth_lp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
|
||||
"""Butterworth lowpass: H = 1 / (1 + (f/fc)^(2n))."""
|
||||
with np.errstate(divide="ignore", over="ignore"):
|
||||
return 1.0 / (1.0 + (freq / cutoff) ** (2 * order))
|
||||
|
||||
|
||||
def _butterworth_hp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray:
|
||||
"""Butterworth highpass: H = 1 / (1 + (fc/f)^(2n))."""
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order))
|
||||
h = np.where(np.isfinite(h), h, 0.0)
|
||||
return h
|
||||
|
||||
|
||||
def _build_1d_transfer(n: int, filter_type: str, cutoff: float,
|
||||
cutoff_high: float, order: int) -> np.ndarray:
|
||||
"""Build a 1-D transfer function for an FFT of length *n*.
|
||||
|
||||
Frequencies are normalised so that 1.0 = Nyquist (fs/2).
|
||||
The returned array has the same layout as np.fft.rfft output (length n//2+1).
|
||||
"""
|
||||
freq = np.linspace(0, 1, n // 2 + 1)
|
||||
|
||||
if filter_type == "lowpass":
|
||||
H = _butterworth_lp(freq, cutoff, order)
|
||||
elif filter_type == "highpass":
|
||||
H = _butterworth_hp(freq, cutoff, order)
|
||||
elif filter_type == "bandpass":
|
||||
H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
|
||||
elif filter_type == "notch":
|
||||
bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
|
||||
H = 1.0 - bp
|
||||
else:
|
||||
H = np.ones_like(freq)
|
||||
return H
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FFTFilter1D — frequency-domain filtering of LINE profiles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="1D FFT Filter")
|
||||
class FFTFilter1D:
|
||||
"""Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
|
||||
|
||||
Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
|
||||
transfer function with configurable order for a smooth roll-off.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"line": ("LINE",),
|
||||
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
|
||||
"cutoff": ("FLOAT", {
|
||||
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"cutoff_high": ("FLOAT", {
|
||||
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LINE",)
|
||||
RETURN_NAMES = ("filtered",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "filters"
|
||||
DESCRIPTION = (
|
||||
"Frequency-domain filtering of a 1-D line profile. "
|
||||
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
|
||||
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
|
||||
"Equivalent to Gwyddion fft_filter_1d."
|
||||
)
|
||||
|
||||
def process(self, line, filter_type: str, cutoff: float,
|
||||
cutoff_high: float, order: int) -> tuple:
|
||||
z = np.asarray(line, dtype=np.float64).ravel()
|
||||
n = len(z)
|
||||
|
||||
# Forward FFT (real-valued)
|
||||
Z = np.fft.rfft(z)
|
||||
|
||||
# Build and apply transfer function
|
||||
H = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
|
||||
Z *= H
|
||||
|
||||
# Inverse FFT
|
||||
filtered = np.fft.irfft(Z, n=n)
|
||||
return (filtered,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FFTFilter2D — frequency-domain filtering of DATA_FIELDs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="2D FFT Filter")
|
||||
class FFTFilter2D:
|
||||
"""Frequency-domain filtering of 2-D data fields (images).
|
||||
|
||||
Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
|
||||
Butterworth transfer function in the frequency domain to remove or
|
||||
isolate periodic features.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
|
||||
"cutoff": ("FLOAT", {
|
||||
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"cutoff_high": ("FLOAT", {
|
||||
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("filtered",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "filters"
|
||||
DESCRIPTION = (
|
||||
"Frequency-domain filtering of a 2-D data field. "
|
||||
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
|
||||
"with a radial Butterworth roll-off. Cutoffs are fractions of the "
|
||||
"Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
|
||||
"bandpass/notch to isolate or remove periodic noise. "
|
||||
"Equivalent to Gwyddion fft_filter_2d."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, filter_type: str, cutoff: float,
|
||||
cutoff_high: float, order: int) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
# Subtract mean to avoid DC leakage artefacts
|
||||
mean_val = data.mean()
|
||||
data -= mean_val
|
||||
|
||||
# Forward 2D FFT
|
||||
F = np.fft.fft2(data)
|
||||
F = np.fft.fftshift(F)
|
||||
|
||||
# Build radial frequency grid normalised to [0, 1] (1 = Nyquist)
|
||||
fy = np.fft.fftshift(np.fft.fftfreq(yres)) # range [-0.5, 0.5)
|
||||
fx = np.fft.fftshift(np.fft.fftfreq(xres))
|
||||
FX, FY = np.meshgrid(fx, fy)
|
||||
# Normalise so that corner = 1 in each axis independently,
|
||||
# then take Euclidean norm; max radial value = 1.0 at Nyquist.
|
||||
R = np.sqrt((FX / 0.5) ** 2 + (FY / 0.5) ** 2)
|
||||
R = np.clip(R / R.max(), 0, 1) if R.max() > 0 else R
|
||||
|
||||
# Build transfer function
|
||||
if filter_type == "lowpass":
|
||||
H = _butterworth_lp(R, cutoff, order)
|
||||
elif filter_type == "highpass":
|
||||
H = _butterworth_hp(R, cutoff, order)
|
||||
elif filter_type == "bandpass":
|
||||
H = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order)
|
||||
elif filter_type == "notch":
|
||||
bp = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order)
|
||||
H = 1.0 - bp
|
||||
else:
|
||||
H = np.ones_like(R)
|
||||
|
||||
# Apply filter
|
||||
F *= H
|
||||
|
||||
# Inverse FFT
|
||||
F = np.fft.ifftshift(F)
|
||||
result = np.fft.ifft2(F).real
|
||||
|
||||
# Restore DC
|
||||
result += mean_val
|
||||
|
||||
return (field.replace(data=result),)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""
|
||||
Grain/feature detection nodes.
|
||||
Particle detection nodes.
|
||||
|
||||
Gwyddion equivalents:
|
||||
ThresholdMask → threshold.c / otsu_threshold.c
|
||||
GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c)
|
||||
ParticleAnalysis → gwy_data_field_grains_get_values (grains-values.c)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -13,61 +12,11 @@ from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ThresholdMask
|
||||
# ParticleAnalysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Threshold Mask")
|
||||
class ThresholdMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["otsu", "absolute", "relative"],),
|
||||
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
|
||||
"direction": (["above", "below"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "grains"
|
||||
DESCRIPTION = (
|
||||
"Create a binary mask by thresholding data. "
|
||||
"Otsu automatically finds the optimal threshold. "
|
||||
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
|
||||
data = field.data
|
||||
|
||||
if method == "otsu":
|
||||
from skimage.filters import threshold_otsu
|
||||
t = threshold_otsu(data)
|
||||
elif method == "absolute":
|
||||
t = float(threshold)
|
||||
elif method == "relative":
|
||||
# threshold is a fraction [0, 1] of the data range
|
||||
dmin, dmax = data.min(), data.max()
|
||||
t = dmin + float(threshold) * (dmax - dmin)
|
||||
else:
|
||||
raise ValueError(f"Unknown threshold method: {method}")
|
||||
|
||||
if direction == "above":
|
||||
mask = (data >= t).astype(np.uint8) * 255
|
||||
else:
|
||||
mask = (data < t).astype(np.uint8) * 255
|
||||
|
||||
return (mask,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GrainAnalysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Grain Analysis")
|
||||
class GrainAnalysis:
|
||||
@register_node(display_name="Particle Analysis")
|
||||
class ParticleAnalysis:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
@@ -79,43 +28,43 @@ class GrainAnalysis:
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TABLE",)
|
||||
RETURN_NAMES = ("grain_stats",)
|
||||
RETURN_NAMES = ("particle_stats",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "grains"
|
||||
DESCRIPTION = (
|
||||
"Label connected grain regions in a binary mask and compute per-grain statistics: "
|
||||
"area, equivalent diameter, mean/max height, bounding box. "
|
||||
"Label connected particle regions in a binary mask and compute per-particle "
|
||||
"statistics: area, equivalent diameter, mean/max height, bounding box. "
|
||||
"Equivalent to gwy_data_field_grains_get_values."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:
|
||||
from scipy.ndimage import label, find_objects
|
||||
from scipy.ndimage import label
|
||||
|
||||
binary = (mask > 127).astype(np.int32)
|
||||
labeled, n_grains = label(binary)
|
||||
labeled, n_particles = label(binary)
|
||||
|
||||
pixel_area = field.dx * field.dy # m^2 per pixel
|
||||
|
||||
rows = []
|
||||
for grain_id in range(1, n_grains + 1):
|
||||
grain_pixels = labeled == grain_id
|
||||
area_px = int(grain_pixels.sum())
|
||||
for pid in range(1, n_particles + 1):
|
||||
particle_pixels = labeled == pid
|
||||
area_px = int(particle_pixels.sum())
|
||||
if area_px < min_size:
|
||||
continue
|
||||
|
||||
area_m2 = area_px * pixel_area
|
||||
equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi))
|
||||
|
||||
heights = field.data[grain_pixels]
|
||||
heights = field.data[particle_pixels]
|
||||
mean_h = float(heights.mean())
|
||||
max_h = float(heights.max())
|
||||
|
||||
# Bounding box
|
||||
ys, xs = np.where(grain_pixels)
|
||||
ys, xs = np.where(particle_pixels)
|
||||
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"
|
||||
|
||||
rows.append({
|
||||
"grain_id": grain_id,
|
||||
"particle_id": pid,
|
||||
"area_px": area_px,
|
||||
"area_m2": area_m2,
|
||||
"equiv_diam_m": equiv_diam,
|
||||
|
||||
@@ -9,12 +9,16 @@ from pathlib import Path
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, encode_preview, image_to_uint8
|
||||
from backend.runtime_paths import input_dir, output_dir
|
||||
from backend.runtime_paths import demo_dir, input_dir, output_dir
|
||||
|
||||
# Resolved at server startup so nodes know where to look
|
||||
DEMO_DIR = demo_dir()
|
||||
INPUT_DIR = input_dir()
|
||||
OUTPUT_DIR = output_dir()
|
||||
|
||||
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
|
||||
".gwy", ".sxm", ".ibw"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadImage
|
||||
@@ -68,6 +72,81 @@ class LoadImage:
|
||||
return (arr, field)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadDemo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _list_demo_files() -> list[str]:
|
||||
"""Return sorted list of demo filenames available in the demo/ directory."""
|
||||
if not DEMO_DIR.exists():
|
||||
return []
|
||||
return sorted(
|
||||
f.name for f in DEMO_DIR.iterdir()
|
||||
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
@register_node(display_name="Load Demo Image")
|
||||
class LoadDemo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
choices = _list_demo_files() or ["(no demo images found)"]
|
||||
return {
|
||||
"required": {
|
||||
"name": (choices,),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
|
||||
RETURN_NAMES = ("image", "field")
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load a bundled demo image so you can try the app without providing your own data."
|
||||
|
||||
def load(self, name: str):
|
||||
path = DEMO_DIR / name
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Demo image not found: {name}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
|
||||
# SPM formats → delegate to LoadSPM-style loading, return as IMAGE + DATA_FIELD
|
||||
if ext == ".gwy":
|
||||
field = LoadSPM()._load_gwy(path, "Z")
|
||||
arr = field.data
|
||||
return (arr, field)
|
||||
elif ext == ".sxm":
|
||||
field = LoadSPM()._load_sxm(path, "Z")
|
||||
arr = field.data
|
||||
return (arr, field)
|
||||
elif ext == ".ibw":
|
||||
field = LoadSPM()._load_ibw(path)
|
||||
arr = field.data
|
||||
return (arr, field)
|
||||
|
||||
# npy / npz
|
||||
if ext == ".npy":
|
||||
arr = np.load(str(path)).astype(np.float64)
|
||||
elif ext == ".npz":
|
||||
npz = np.load(str(path))
|
||||
key = list(npz.files)[0]
|
||||
arr = npz[key].astype(np.float64)
|
||||
else:
|
||||
from PIL import Image
|
||||
img = Image.open(str(path))
|
||||
arr = np.array(img)
|
||||
if arr.dtype != np.uint8:
|
||||
arr = arr.astype(np.float64)
|
||||
|
||||
if arr.ndim == 3:
|
||||
gray = np.mean(arr.astype(np.float64), axis=2)
|
||||
else:
|
||||
gray = arr.astype(np.float64)
|
||||
|
||||
field = DataField(data=gray)
|
||||
return (arr, field)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadSPM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
273
backend/nodes/mask.py
Normal file
273
backend/nodes/mask.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Mask operation nodes — creation, morphology, and boolean combination.
|
||||
|
||||
Gwyddion equivalents:
|
||||
ThresholdMask → threshold.c / otsu_threshold.c
|
||||
MaskMorphology → mask_morph.c (erode, dilate, open, close)
|
||||
MaskInvert → (bitwise NOT on mask)
|
||||
MaskCombine → (boolean ops between two masks)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, datafield_to_uint8, encode_preview
|
||||
|
||||
|
||||
def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
|
||||
"""Render greyscale base image with red shadow on masked (255) pixels.
|
||||
|
||||
Returns (H, W, 3) uint8 array.
|
||||
"""
|
||||
grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8
|
||||
overlay = grey.astype(np.float64)
|
||||
mask_bool = mask == 255
|
||||
alpha = 0.45
|
||||
overlay[mask_bool, 0] = overlay[mask_bool, 0] * (1 - alpha) + 255 * alpha
|
||||
overlay[mask_bool, 1] = overlay[mask_bool, 1] * (1 - alpha)
|
||||
overlay[mask_bool, 2] = overlay[mask_bool, 2] * (1 - alpha)
|
||||
return np.clip(overlay, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ThresholdMask
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Threshold Mask")
|
||||
class ThresholdMask:
|
||||
_CUSTOM_PREVIEW = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["otsu", "absolute", "relative"],),
|
||||
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
|
||||
"direction": (["above", "below"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "mask"
|
||||
DESCRIPTION = (
|
||||
"Create a binary mask by thresholding data. "
|
||||
"Otsu automatically finds the optimal threshold. "
|
||||
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
|
||||
)
|
||||
|
||||
_broadcast_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
|
||||
data = field.data
|
||||
|
||||
if method == "otsu":
|
||||
from skimage.filters import threshold_otsu
|
||||
t = threshold_otsu(data)
|
||||
elif method == "absolute":
|
||||
t = float(threshold)
|
||||
elif method == "relative":
|
||||
# threshold is a fraction [0, 1] of the data range
|
||||
dmin, dmax = data.min(), data.max()
|
||||
t = dmin + float(threshold) * (dmax - dmin)
|
||||
else:
|
||||
raise ValueError(f"Unknown threshold method: {method}")
|
||||
|
||||
if direction == "above":
|
||||
mask = (data >= t).astype(np.uint8) * 255
|
||||
else:
|
||||
mask = (data < t).astype(np.uint8) * 255
|
||||
|
||||
if ThresholdMask._broadcast_fn is not None:
|
||||
overlay = _mask_overlay(field, mask)
|
||||
ThresholdMask._broadcast_fn(
|
||||
ThresholdMask._current_node_id, encode_preview(overlay),
|
||||
)
|
||||
|
||||
return (mask,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MaskMorphology
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Mask Morphology")
|
||||
class MaskMorphology:
|
||||
"""Morphological operations on binary masks.
|
||||
|
||||
Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close).
|
||||
"""
|
||||
_CUSTOM_PREVIEW = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("IMAGE",),
|
||||
"operation": (["dilate", "erode", "open", "close"],),
|
||||
"radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}),
|
||||
"shape": (["disk", "square"],),
|
||||
},
|
||||
"optional": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "mask"
|
||||
DESCRIPTION = (
|
||||
"Apply morphological operations to a binary mask. "
|
||||
"Dilate expands regions, erode shrinks them, "
|
||||
"open (erode then dilate) removes small spots, "
|
||||
"close (dilate then erode) fills small holes. "
|
||||
"Equivalent to Gwyddion mask_morph."
|
||||
)
|
||||
|
||||
_broadcast_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(self, mask: np.ndarray, operation: str, radius: int, shape: str,
|
||||
field: DataField | None = None) -> tuple:
|
||||
from scipy.ndimage import binary_dilation, binary_erosion
|
||||
|
||||
binary = mask > 127
|
||||
|
||||
if shape == "disk":
|
||||
y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1]
|
||||
struct = (x * x + y * y) <= radius * radius
|
||||
else:
|
||||
size = 2 * radius + 1
|
||||
struct = np.ones((size, size), dtype=bool)
|
||||
|
||||
if operation == "dilate":
|
||||
result = binary_dilation(binary, structure=struct)
|
||||
elif operation == "erode":
|
||||
result = binary_erosion(binary, structure=struct)
|
||||
elif operation == "open":
|
||||
result = binary_dilation(
|
||||
binary_erosion(binary, structure=struct),
|
||||
structure=struct,
|
||||
)
|
||||
elif operation == "close":
|
||||
result = binary_erosion(
|
||||
binary_dilation(binary, structure=struct),
|
||||
structure=struct,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown morphological operation: {operation}")
|
||||
|
||||
out = result.astype(np.uint8) * 255
|
||||
|
||||
if field is not None and MaskMorphology._broadcast_fn is not None:
|
||||
overlay = _mask_overlay(field, out)
|
||||
MaskMorphology._broadcast_fn(
|
||||
MaskMorphology._current_node_id, encode_preview(overlay),
|
||||
)
|
||||
|
||||
return (out,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MaskInvert
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Mask Invert")
|
||||
class MaskInvert:
|
||||
_CUSTOM_PREVIEW = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "mask"
|
||||
DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions."
|
||||
|
||||
_broadcast_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple:
|
||||
out = np.where(mask > 127, np.uint8(0), np.uint8(255))
|
||||
|
||||
if field is not None and MaskInvert._broadcast_fn is not None:
|
||||
overlay = _mask_overlay(field, out)
|
||||
MaskInvert._broadcast_fn(
|
||||
MaskInvert._current_node_id, encode_preview(overlay),
|
||||
)
|
||||
|
||||
return (out,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MaskCombine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Mask Combine")
|
||||
class MaskCombine:
|
||||
_CUSTOM_PREVIEW = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask_a": ("IMAGE",),
|
||||
"mask_b": ("IMAGE",),
|
||||
"operation": (["and", "or", "xor", "subtract"],),
|
||||
},
|
||||
"optional": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "mask"
|
||||
DESCRIPTION = (
|
||||
"Combine two binary masks with a boolean operation. "
|
||||
"AND keeps overlap, OR merges, XOR keeps non-overlapping regions, "
|
||||
"subtract removes mask_b from mask_a."
|
||||
)
|
||||
|
||||
_broadcast_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str,
|
||||
field: DataField | None = None) -> tuple:
|
||||
a = mask_a > 127
|
||||
b = mask_b > 127
|
||||
|
||||
if operation == "and":
|
||||
result = a & b
|
||||
elif operation == "or":
|
||||
result = a | b
|
||||
elif operation == "xor":
|
||||
result = a ^ b
|
||||
elif operation == "subtract":
|
||||
result = a & ~b
|
||||
else:
|
||||
raise ValueError(f"Unknown mask operation: {operation}")
|
||||
|
||||
out = result.astype(np.uint8) * 255
|
||||
|
||||
if field is not None and MaskCombine._broadcast_fn is not None:
|
||||
overlay = _mask_overlay(field, out)
|
||||
MaskCombine._broadcast_fn(
|
||||
MaskCombine._current_node_id, encode_preview(overlay),
|
||||
)
|
||||
|
||||
return (out,)
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
APP_NAME = "Argonode"
|
||||
APP_NAME = "argonode"
|
||||
|
||||
|
||||
def project_root() -> Path:
|
||||
@@ -34,13 +34,26 @@ def app_data_dir() -> Path:
|
||||
return Path(override).expanduser().resolve()
|
||||
|
||||
if getattr(sys, "frozen", False):
|
||||
local_appdata = os.getenv("LOCALAPPDATA")
|
||||
base_dir = Path(local_appdata) if local_appdata else Path.home() / "AppData" / "Local"
|
||||
if sys.platform == "darwin":
|
||||
base_dir = Path.home() / "Library" / "Application Support"
|
||||
elif sys.platform == "linux":
|
||||
xdg = os.getenv("XDG_DATA_HOME")
|
||||
base_dir = Path(xdg) if xdg else Path.home() / ".local" / "share"
|
||||
else:
|
||||
local_appdata = os.getenv("LOCALAPPDATA")
|
||||
base_dir = Path(local_appdata) if local_appdata else Path.home() / "AppData" / "Local"
|
||||
return (base_dir / APP_NAME).resolve()
|
||||
|
||||
return project_root()
|
||||
|
||||
|
||||
def demo_dir() -> Path:
|
||||
bundled = resource_root() / "demo"
|
||||
if bundled.exists():
|
||||
return bundled
|
||||
return project_root() / "demo"
|
||||
|
||||
|
||||
def input_dir() -> Path:
|
||||
return app_data_dir() / "input"
|
||||
|
||||
|
||||
0
demo/.gitkeep
Normal file
0
demo/.gitkeep
Normal file
BIN
demo/nanoparticles.npy
Normal file
BIN
demo/nanoparticles.npy
Normal file
Binary file not shown.
BIN
demo/nanoparticles.png
Normal file
BIN
demo/nanoparticles.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
36
desktop.py
36
desktop.py
@@ -14,7 +14,34 @@ from backend.runtime_paths import app_data_dir, ensure_runtime_dirs
|
||||
from backend.server import create_app
|
||||
|
||||
HOST = "127.0.0.1"
|
||||
WINDOW_TITLE = "Argonode"
|
||||
WINDOW_TITLE = "argonode"
|
||||
|
||||
|
||||
class _Api:
|
||||
"""Exposed to JavaScript as window.pywebview.api."""
|
||||
|
||||
def __init__(self, window_ref: list):
|
||||
self._window_ref = window_ref
|
||||
|
||||
def open_file_dialog(self) -> str | None:
|
||||
"""Open a native file picker and return the selected path (or None)."""
|
||||
win = self._window_ref[0]
|
||||
if win is None:
|
||||
return None
|
||||
result = win.create_file_dialog(
|
||||
webview.OPEN_DIALOG,
|
||||
allow_multiple=False,
|
||||
file_types=(
|
||||
"All supported (*.png;*.jpg;*.jpeg;*.tiff;*.tif;*.npy;*.npz;*.gwy;*.sxm;*.ibw)",
|
||||
"Images (*.png;*.jpg;*.jpeg;*.tiff;*.tif)",
|
||||
"NumPy (*.npy;*.npz)",
|
||||
"SPM (*.gwy;*.sxm;*.ibw)",
|
||||
"All files (*.*)",
|
||||
),
|
||||
)
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
|
||||
def _pick_free_port() -> int:
|
||||
@@ -85,17 +112,22 @@ def main() -> None:
|
||||
ready.wait(timeout=15.0)
|
||||
|
||||
if "error" in state:
|
||||
raise RuntimeError("Argonode server failed to start") from state["error"]
|
||||
raise RuntimeError("argonode server failed to start") from state["error"]
|
||||
|
||||
_wait_for_server(f"{base_url}/nodes")
|
||||
|
||||
window_ref: list[webview.Window | None] = [None]
|
||||
js_api = _Api(window_ref)
|
||||
|
||||
window = webview.create_window(
|
||||
WINDOW_TITLE,
|
||||
base_url,
|
||||
width=1600,
|
||||
height=1000,
|
||||
min_size=(1100, 720),
|
||||
js_api=js_api,
|
||||
)
|
||||
window_ref[0] = window
|
||||
|
||||
def _shutdown() -> None:
|
||||
loop = state.get("loop")
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Argonode — Image Analysis</title>
|
||||
<title>argonode — Image Analysis</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
||||
7
frontend/package-lock.json
generated
7
frontend/package-lock.json
generated
@@ -7,6 +7,7 @@
|
||||
"name": "argonode-frontend",
|
||||
"dependencies": {
|
||||
"@xyflow/react": "^12.0.0",
|
||||
"html-to-image": "^1.11.13",
|
||||
"react": "^18.3.0",
|
||||
"react-dom": "^18.3.0",
|
||||
"three": "^0.183.2"
|
||||
@@ -1539,6 +1540,12 @@
|
||||
"node": ">=6.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/html-to-image": {
|
||||
"version": "1.11.13",
|
||||
"resolved": "https://registry.npmjs.org/html-to-image/-/html-to-image-1.11.13.tgz",
|
||||
"integrity": "sha512-cuOPoI7WApyhBElTTb9oqsawRvZ0rHhaHwghRLlTuffoD1B2aDemlCruLeZrUIIdvG7gs9xeELEPm6PhuASqrg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/js-tokens": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz",
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@xyflow/react": "^12.0.0",
|
||||
"html-to-image": "^1.11.13",
|
||||
"react": "^18.3.0",
|
||||
"react-dom": "^18.3.0",
|
||||
"three": "^0.183.2"
|
||||
|
||||
@@ -4,24 +4,26 @@ import React, {
|
||||
import {
|
||||
ReactFlow, Background, Controls, MiniMap,
|
||||
useNodesState, useEdgesState, addEdge, useReactFlow,
|
||||
ReactFlowProvider,
|
||||
ReactFlowProvider, getNodesBounds, getViewportForBounds,
|
||||
} from '@xyflow/react';
|
||||
import '@xyflow/react/dist/style.css';
|
||||
|
||||
import CustomNode, { NodeContext } from './CustomNode';
|
||||
import FileBrowser from './FileBrowser';
|
||||
import * as api from './api';
|
||||
import { toBlob } from 'html-to-image';
|
||||
import { embedWorkflow, extractWorkflow } from './pngMetadata';
|
||||
|
||||
// ── Constants ─────────────────────────────────────────────────────────
|
||||
|
||||
const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']);
|
||||
|
||||
const TYPE_COLORS = {
|
||||
DATA_FIELD: '#3a7abf',
|
||||
IMAGE: '#4caf50',
|
||||
LINE: '#ff9800',
|
||||
TABLE: '#fdd835',
|
||||
COORD: '#e91e63',
|
||||
DATA_FIELD: '#ff002f',
|
||||
IMAGE: '#00ff08a0',
|
||||
LINE: '#ffbe5c',
|
||||
TABLE: '#35e2fd',
|
||||
COORD: '#e91ed1',
|
||||
};
|
||||
|
||||
const NODE_TYPES = { custom: CustomNode };
|
||||
@@ -272,6 +274,13 @@ function Flow() {
|
||||
// ── File browser ────────────────────────────────────────────────────
|
||||
|
||||
const openFileBrowser = useCallback((callback) => {
|
||||
// Use native file picker when running inside pywebview (desktop app)
|
||||
if (window.pywebview?.api?.open_file_dialog) {
|
||||
window.pywebview.api.open_file_dialog().then((path) => {
|
||||
if (path) callback(path);
|
||||
});
|
||||
return;
|
||||
}
|
||||
setFileBrowserCb(() => callback);
|
||||
}, []);
|
||||
|
||||
@@ -427,60 +436,162 @@ function Flow() {
|
||||
setStatus({ text: 'Graph cleared.', level: 'info' });
|
||||
}, [setNodes, setEdges]);
|
||||
|
||||
const saveWorkflow = useCallback(() => {
|
||||
const currentNodes = reactFlow.getNodes().map((n) => ({
|
||||
const applyWorkflowData = useCallback((data) => {
|
||||
const loadedNodes = data.nodes || [];
|
||||
const loadedEdges = data.edges || [];
|
||||
const defs = nodeDefsRef.current;
|
||||
const hydrated = loadedNodes.map((n) => ({
|
||||
...n,
|
||||
data: {
|
||||
...n.data,
|
||||
definition: defs[n.data.className] || n.data.definition,
|
||||
previewImage: null, tableRows: null, meshData: null, overlay: null,
|
||||
},
|
||||
}));
|
||||
setNodes(hydrated);
|
||||
setEdges(loadedEdges);
|
||||
const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
|
||||
nextIdRef.current = maxId + 1;
|
||||
}, [setNodes, setEdges]);
|
||||
|
||||
const getWorkflowBlob = useCallback(async () => {
|
||||
const viewportEl = document.querySelector('.react-flow__viewport');
|
||||
if (!viewportEl) throw new Error('Flow element not found');
|
||||
|
||||
const allNodes = reactFlow.getNodes();
|
||||
if (allNodes.length === 0) throw new Error('No nodes to capture');
|
||||
|
||||
const bounds = getNodesBounds(allNodes);
|
||||
const pad = 0.1; // 10% margin on each side
|
||||
const imageWidth = Math.ceil(bounds.width * (1 + pad * 2));
|
||||
const imageHeight = Math.ceil(bounds.height * (1 + pad * 2));
|
||||
const vp = getViewportForBounds(bounds, imageWidth, imageHeight, 0.5, 1, pad);
|
||||
|
||||
const blob = await toBlob(viewportEl, {
|
||||
backgroundColor: '#1a1a1a',
|
||||
width: imageWidth,
|
||||
height: imageHeight,
|
||||
style: {
|
||||
width: `${imageWidth}px`,
|
||||
height: `${imageHeight}px`,
|
||||
transform: `translate(${vp.x}px, ${vp.y}px) scale(${vp.zoom})`,
|
||||
},
|
||||
});
|
||||
if (!blob) throw new Error('Capture returned empty');
|
||||
|
||||
const currentNodes = allNodes.map((n) => ({
|
||||
...n,
|
||||
data: { ...n.data, previewImage: null, tableRows: null, meshData: null, overlay: null },
|
||||
}));
|
||||
const data = { version: 1, nodes: currentNodes, edges: reactFlow.getEdges() };
|
||||
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' });
|
||||
const a = document.createElement('a');
|
||||
a.href = URL.createObjectURL(blob);
|
||||
a.download = 'workflow.json';
|
||||
a.click();
|
||||
const workflow = { version: 1, nodes: currentNodes, edges: reactFlow.getEdges() };
|
||||
return embedWorkflow(blob, workflow);
|
||||
}, [reactFlow]);
|
||||
|
||||
const saveWorkflow = useCallback(async () => {
|
||||
setStatus({ text: 'Saving…', level: 'info' });
|
||||
try {
|
||||
const finalBlob = await getWorkflowBlob();
|
||||
|
||||
if (window.showSaveFilePicker) {
|
||||
const handle = await window.showSaveFilePicker({
|
||||
suggestedName: 'workflow.png',
|
||||
types: [{ description: 'PNG Image', accept: { 'image/png': ['.png'] } }],
|
||||
});
|
||||
const writable = await handle.createWritable();
|
||||
await writable.write(finalBlob);
|
||||
await writable.close();
|
||||
} else {
|
||||
// Fallback: programmatic download
|
||||
const a = document.createElement('a');
|
||||
a.href = URL.createObjectURL(finalBlob);
|
||||
a.download = 'workflow.png';
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(a.href);
|
||||
}
|
||||
|
||||
setStatus({ text: 'Workflow saved.', level: 'info' });
|
||||
} catch (err) {
|
||||
if (err.name === 'AbortError') {
|
||||
setStatus({ text: 'Save cancelled.', level: 'info' });
|
||||
} else {
|
||||
setStatus({ text: 'Save failed: ' + err.message, level: 'error' });
|
||||
}
|
||||
}
|
||||
}, [getWorkflowBlob]);
|
||||
|
||||
const copySnapshot = useCallback(() => {
|
||||
setStatus({ text: 'Copying snapshot…', level: 'info' });
|
||||
// Pass a Promise<Blob> to ClipboardItem so the clipboard.write() call
|
||||
// happens synchronously within the user gesture, avoiding permission errors.
|
||||
const blobPromise = getWorkflowBlob().catch((err) => {
|
||||
setStatus({ text: 'Snapshot failed: ' + err.message, level: 'error' });
|
||||
throw err;
|
||||
});
|
||||
navigator.clipboard.write([new ClipboardItem({ 'image/png': blobPromise })]).then(() => {
|
||||
setStatus({ text: 'Snapshot copied to clipboard.', level: 'info' });
|
||||
}).catch((err) => {
|
||||
setStatus({ text: 'Copy failed: ' + err.message, level: 'error' });
|
||||
});
|
||||
}, [getWorkflowBlob]);
|
||||
|
||||
const loadWorkflow = useCallback(() => {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'file';
|
||||
input.accept = '.json';
|
||||
input.accept = '.json,.png';
|
||||
input.onchange = async (e) => {
|
||||
const file = e.target.files[0];
|
||||
if (!file) return;
|
||||
const text = await file.text();
|
||||
try {
|
||||
const data = JSON.parse(text);
|
||||
const loadedNodes = data.nodes || [];
|
||||
const loadedEdges = data.edges || [];
|
||||
|
||||
// Re-populate definitions from current nodeDefs
|
||||
const defs = nodeDefsRef.current;
|
||||
const hydrated = loadedNodes.map((n) => ({
|
||||
...n,
|
||||
data: {
|
||||
...n.data,
|
||||
definition: defs[n.data.className] || n.data.definition,
|
||||
previewImage: null,
|
||||
tableRows: null,
|
||||
meshData: null,
|
||||
overlay: null,
|
||||
},
|
||||
}));
|
||||
|
||||
setNodes(hydrated);
|
||||
setEdges(loadedEdges);
|
||||
|
||||
// Update ID counter to avoid collisions
|
||||
const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
|
||||
nextIdRef.current = maxId + 1;
|
||||
|
||||
let data;
|
||||
if (file.name.endsWith('.png') || file.type === 'image/png') {
|
||||
data = await extractWorkflow(file);
|
||||
if (!data) {
|
||||
setStatus({ text: 'No workflow data found in image.', level: 'error' });
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
data = JSON.parse(await file.text());
|
||||
}
|
||||
applyWorkflowData(data);
|
||||
setStatus({ text: 'Workflow loaded.', level: 'info' });
|
||||
} catch {
|
||||
setStatus({ text: 'Invalid workflow JSON.', level: 'error' });
|
||||
setStatus({ text: 'Invalid workflow file.', level: 'error' });
|
||||
}
|
||||
};
|
||||
input.click();
|
||||
}, [setNodes, setEdges]);
|
||||
}, [applyWorkflowData]);
|
||||
|
||||
// ── Drag-and-drop workflow image loading ───────────────────────────
|
||||
|
||||
const onDropFile = useCallback(async (event) => {
|
||||
const files = event.dataTransfer?.files;
|
||||
if (!files || files.length === 0) return;
|
||||
event.preventDefault();
|
||||
|
||||
const file = files[0];
|
||||
if (file.type !== 'image/png') return;
|
||||
|
||||
try {
|
||||
const data = await extractWorkflow(file);
|
||||
if (!data) {
|
||||
setStatus({ text: 'No workflow data in this image.', level: 'error' });
|
||||
return;
|
||||
}
|
||||
applyWorkflowData(data);
|
||||
setStatus({ text: 'Workflow loaded from image.', level: 'info' });
|
||||
} catch (err) {
|
||||
setStatus({ text: 'Failed to load: ' + err.message, level: 'error' });
|
||||
}
|
||||
}, [applyWorkflowData]);
|
||||
|
||||
const onDragOver = useCallback((event) => {
|
||||
if (event.dataTransfer?.types?.includes('Files')) {
|
||||
event.preventDefault();
|
||||
event.dataTransfer.dropEffect = 'copy';
|
||||
}
|
||||
}, []);
|
||||
|
||||
// ── Keyboard shortcut ───────────────────────────────────────────────
|
||||
|
||||
@@ -509,7 +620,7 @@ function Flow() {
|
||||
<div className="app-container">
|
||||
{/* Toolbar */}
|
||||
<div id="toolbar">
|
||||
<span id="app-title">Argonode</span>
|
||||
<span id="app-title">argonode</span>
|
||||
|
||||
<div className="toolbar-group">
|
||||
<button className="btn btn-primary" onClick={runWorkflow} title="Run workflow (Ctrl+Enter)">
|
||||
@@ -521,12 +632,15 @@ function Flow() {
|
||||
</div>
|
||||
|
||||
<div className="toolbar-group">
|
||||
<button className="btn" onClick={saveWorkflow} title="Save workflow JSON">
|
||||
<button className="btn" onClick={saveWorkflow} title="Save workflow as PNG">
|
||||
⤓ Save
|
||||
</button>
|
||||
<button className="btn" onClick={loadWorkflow} title="Load workflow JSON">
|
||||
<button className="btn" onClick={loadWorkflow} title="Load workflow (JSON or PNG)">
|
||||
⤒ Load
|
||||
</button>
|
||||
<button className="btn" onClick={copySnapshot} title="Copy workflow screenshot to clipboard">
|
||||
⎘ Snapshot
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className={`status-bar ${status.level}`}>{status.text}</div>
|
||||
@@ -535,7 +649,7 @@ function Flow() {
|
||||
{/* React Flow canvas */}
|
||||
<div className="flow-container" onMouseDown={(e) => {
|
||||
if (!e.target.closest('.context-menu')) setContextMenu(null);
|
||||
}}>
|
||||
}} onDrop={onDropFile} onDragOver={onDragOver}>
|
||||
<ReactFlow
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
|
||||
129
frontend/src/pngMetadata.js
Normal file
129
frontend/src/pngMetadata.js
Normal file
@@ -0,0 +1,129 @@
|
||||
/**
|
||||
* PNG tEXt chunk utilities for embedding/extracting workflow metadata.
|
||||
*
|
||||
* PNG files are composed of chunks: [4-byte length][4-byte type][data][4-byte CRC].
|
||||
* We add a tEXt chunk with key "workflow" containing the JSON-serialised graph,
|
||||
* inserted just before the IEND chunk.
|
||||
*/
|
||||
|
||||
// ── CRC32 (PNG uses CRC-32/ISO 3309) ────────────────────────────────
|
||||
|
||||
const crcTable = new Uint32Array(256);
|
||||
for (let i = 0; i < 256; i++) {
|
||||
let c = i;
|
||||
for (let j = 0; j < 8; j++) {
|
||||
c = (c & 1) ? (0xEDB88320 ^ (c >>> 1)) : (c >>> 1);
|
||||
}
|
||||
crcTable[i] = c;
|
||||
}
|
||||
|
||||
function crc32(bytes) {
|
||||
let crc = 0xFFFFFFFF;
|
||||
for (let i = 0; i < bytes.length; i++) {
|
||||
crc = crcTable[(crc ^ bytes[i]) & 0xFF] ^ (crc >>> 8);
|
||||
}
|
||||
return (crc ^ 0xFFFFFFFF) >>> 0;
|
||||
}
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
const PNG_SIG = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
|
||||
|
||||
function isPng(data) {
|
||||
if (data.length < 8) return false;
|
||||
for (let i = 0; i < 8; i++) {
|
||||
if (data[i] !== PNG_SIG[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function chunkType(data, offset) {
|
||||
return String.fromCharCode(
|
||||
data[offset + 4], data[offset + 5], data[offset + 6], data[offset + 7],
|
||||
);
|
||||
}
|
||||
|
||||
// ── Public API ───────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Embed a workflow object into a PNG blob as a tEXt chunk.
|
||||
* Returns a new Blob with the metadata inserted before IEND.
|
||||
*/
|
||||
export async function embedWorkflow(pngBlob, workflow) {
|
||||
const data = new Uint8Array(await pngBlob.arrayBuffer());
|
||||
if (!isPng(data)) throw new Error('Not a valid PNG file');
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
|
||||
// Build tEXt payload: keyword \0 text
|
||||
const key = encoder.encode('workflow');
|
||||
const val = encoder.encode(JSON.stringify(workflow));
|
||||
const payload = new Uint8Array(key.length + 1 + val.length);
|
||||
payload.set(key, 0);
|
||||
// payload[key.length] is already 0 (null separator)
|
||||
payload.set(val, key.length + 1);
|
||||
|
||||
// CRC covers type + payload
|
||||
const typeBytes = encoder.encode('tEXt');
|
||||
const forCrc = new Uint8Array(4 + payload.length);
|
||||
forCrc.set(typeBytes, 0);
|
||||
forCrc.set(payload, 4);
|
||||
|
||||
// Assemble chunk: length(4) + type(4) + payload + crc(4)
|
||||
const chunk = new Uint8Array(12 + payload.length);
|
||||
const view = new DataView(chunk.buffer);
|
||||
view.setUint32(0, payload.length);
|
||||
chunk.set(typeBytes, 4);
|
||||
chunk.set(payload, 8);
|
||||
view.setUint32(8 + payload.length, crc32(forCrc));
|
||||
|
||||
// Locate IEND
|
||||
let pos = 8;
|
||||
let iendPos = data.length;
|
||||
while (pos < data.length) {
|
||||
const len = new DataView(data.buffer, pos, 4).getUint32(0);
|
||||
if (chunkType(data, pos) === 'IEND') { iendPos = pos; break; }
|
||||
pos += 12 + len;
|
||||
}
|
||||
|
||||
// Splice: [before IEND] + [tEXt chunk] + [IEND]
|
||||
const result = new Uint8Array(data.length + chunk.length);
|
||||
result.set(data.subarray(0, iendPos), 0);
|
||||
result.set(chunk, iendPos);
|
||||
result.set(data.subarray(iendPos), iendPos + chunk.length);
|
||||
|
||||
return new Blob([result], { type: 'image/png' });
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the workflow object from a PNG blob's tEXt chunks.
|
||||
* Returns the parsed object, or null if no "workflow" key is found.
|
||||
*/
|
||||
export async function extractWorkflow(pngBlob) {
|
||||
const data = new Uint8Array(await pngBlob.arrayBuffer());
|
||||
if (!isPng(data)) return null;
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let pos = 8;
|
||||
|
||||
while (pos + 8 <= data.length) {
|
||||
const len = new DataView(data.buffer, pos, 4).getUint32(0);
|
||||
const type = chunkType(data, pos);
|
||||
|
||||
if (type === 'tEXt' && pos + 8 + len <= data.length) {
|
||||
const chunkData = data.subarray(pos + 8, pos + 8 + len);
|
||||
const nullIdx = chunkData.indexOf(0);
|
||||
if (nullIdx !== -1) {
|
||||
const k = decoder.decode(chunkData.subarray(0, nullIdx));
|
||||
if (k === 'workflow') {
|
||||
return JSON.parse(decoder.decode(chunkData.subarray(nullIdx + 1)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (type === 'IEND') break;
|
||||
pos += 12 + len;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
@@ -21,8 +21,8 @@ html, body, #root {
|
||||
/* ── Toolbar ───────────────────────────────────────────────────────── */
|
||||
#toolbar {
|
||||
height: 44px;
|
||||
background: #16213e;
|
||||
border-bottom: 1px solid #0f3460;
|
||||
background: #242424;
|
||||
border-bottom: 1px solid #000000;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0 12px;
|
||||
@@ -36,7 +36,7 @@ html, body, #root {
|
||||
font-size: 15px;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0.5px;
|
||||
color: #e94560;
|
||||
color: #ffffff;
|
||||
margin-right: 8px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
@@ -129,8 +129,17 @@ html, body, #root {
|
||||
cursor: grabbing;
|
||||
}
|
||||
|
||||
.custom-node.selected {
|
||||
/* Selected node — target via React Flow's wrapper class */
|
||||
.react-flow__node.selected .custom-node {
|
||||
border-color: #90caf9;
|
||||
box-shadow: 0 0 0 1px #90caf9, 0 0 12px rgba(144, 202, 249, 0.4);
|
||||
}
|
||||
|
||||
/* Selected edge */
|
||||
.react-flow__edge.selected .react-flow__edge-path {
|
||||
stroke: #90caf9 !important;
|
||||
stroke-width: 3px !important;
|
||||
filter: drop-shadow(0 0 4px rgba(144, 202, 249, 0.6));
|
||||
}
|
||||
|
||||
.node-title {
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
"preview": "npm --prefix frontend run preview",
|
||||
"backend": "python -m backend.main",
|
||||
"desktop": "python desktop.py",
|
||||
"build:desktop": "powershell -ExecutionPolicy Bypass -File scripts\\build-desktop.ps1"
|
||||
"build:windows": "powershell -ExecutionPolicy Bypass -File scripts\\build-windows.ps1",
|
||||
"build:mac": "bash scripts/build-mac.sh",
|
||||
"build:linux": "bash scripts/build-linux.sh"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,10 @@ readme = "GWYDDION_FEATURE_GAP.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"aiohttp>=3.9,<4",
|
||||
"gwyfile>=0.2",
|
||||
"igor>=0.3",
|
||||
"matplotlib>=3.8,<4",
|
||||
"nanonispy>=1.1",
|
||||
"numpy>=1.26,<3",
|
||||
"pillow>=10,<12",
|
||||
"scikit-image>=0.22,<1",
|
||||
@@ -18,11 +21,6 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
spm = [
|
||||
"gwyfile>=0.2",
|
||||
"igor>=0.3",
|
||||
"nanonispy>=1.1",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8,<9",
|
||||
]
|
||||
|
||||
65
scripts/build-linux.sh
Executable file
65
scripts/build-linux.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ONE_FILE=false
|
||||
CREATE_TAR=true
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--onefile) ONE_FILE=true; shift ;;
|
||||
--no-tar) CREATE_TAR=false; shift ;;
|
||||
*) echo "Unknown option: $1"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
if [ -d ".venv/bin" ]; then
|
||||
PYTHON=".venv/bin/python"
|
||||
else
|
||||
PYTHON="python3"
|
||||
fi
|
||||
|
||||
FRONTEND_DIST="$REPO_ROOT/frontend/dist"
|
||||
DEMO_DIR="$REPO_ROOT/demo"
|
||||
|
||||
echo "Building frontend bundle..."
|
||||
npm run build
|
||||
|
||||
echo "Installing desktop build dependencies..."
|
||||
uv pip install -e ".[desktop]"
|
||||
|
||||
if $ONE_FILE; then
|
||||
MODE="--onefile"
|
||||
else
|
||||
MODE="--onedir"
|
||||
fi
|
||||
|
||||
echo "Packaging desktop app with PyInstaller..."
|
||||
$PYTHON -m PyInstaller \
|
||||
desktop.py \
|
||||
--noconfirm \
|
||||
--clean \
|
||||
--name argonode \
|
||||
--windowed \
|
||||
$MODE \
|
||||
--distpath desktop-dist \
|
||||
--workpath desktop-build \
|
||||
--specpath desktop-build \
|
||||
--add-data "${FRONTEND_DIST}:frontend/dist" \
|
||||
--add-data "${DEMO_DIR}:demo" \
|
||||
--collect-all matplotlib \
|
||||
--collect-all scipy \
|
||||
--collect-all skimage \
|
||||
--collect-all webview
|
||||
|
||||
if $CREATE_TAR; then
|
||||
TAR_PATH="desktop-dist/argonode-linux.tar.gz"
|
||||
echo "Creating tarball..."
|
||||
tar -czf "$TAR_PATH" -C desktop-dist argonode
|
||||
echo "Tarball created: $TAR_PATH"
|
||||
fi
|
||||
|
||||
echo "Desktop build complete."
|
||||
echo "Output: $REPO_ROOT/desktop-dist/"
|
||||
110
scripts/build-mac.sh
Executable file
110
scripts/build-mac.sh
Executable file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ONE_FILE=false
|
||||
CREATE_DMG=true
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--onefile) ONE_FILE=true; shift ;;
|
||||
--no-dmg) CREATE_DMG=false; shift ;;
|
||||
*) echo "Unknown option: $1"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
if [ -d ".venv/bin" ]; then
|
||||
PYTHON=".venv/bin/python"
|
||||
else
|
||||
PYTHON="python3"
|
||||
fi
|
||||
|
||||
FRONTEND_DIST="$REPO_ROOT/frontend/dist"
|
||||
DEMO_DIR="$REPO_ROOT/demo"
|
||||
|
||||
echo "Building frontend bundle..."
|
||||
npm run build
|
||||
|
||||
echo "Installing desktop build dependencies..."
|
||||
uv pip install -e ".[desktop]"
|
||||
|
||||
if $ONE_FILE; then
|
||||
MODE="--onefile"
|
||||
else
|
||||
MODE="--onedir"
|
||||
fi
|
||||
|
||||
echo "Packaging desktop app with PyInstaller..."
|
||||
$PYTHON -m PyInstaller \
|
||||
desktop.py \
|
||||
--noconfirm \
|
||||
--clean \
|
||||
--name argonode \
|
||||
--windowed \
|
||||
$MODE \
|
||||
--distpath desktop-dist \
|
||||
--workpath desktop-build \
|
||||
--specpath desktop-build \
|
||||
--add-data "${FRONTEND_DIST}:frontend/dist" \
|
||||
--add-data "${DEMO_DIR}:demo" \
|
||||
--collect-all matplotlib \
|
||||
--collect-all scipy \
|
||||
--collect-all skimage \
|
||||
--collect-all webview \
|
||||
--icon resources/icon.icns 2>/dev/null || \
|
||||
$PYTHON -m PyInstaller \
|
||||
desktop.py \
|
||||
--noconfirm \
|
||||
--clean \
|
||||
--name argonode \
|
||||
--windowed \
|
||||
$MODE \
|
||||
--distpath desktop-dist \
|
||||
--workpath desktop-build \
|
||||
--specpath desktop-build \
|
||||
--add-data "${FRONTEND_DIST}:frontend/dist" \
|
||||
--add-data "${DEMO_DIR}:demo" \
|
||||
--collect-all matplotlib \
|
||||
--collect-all scipy \
|
||||
--collect-all skimage \
|
||||
--collect-all webview
|
||||
|
||||
APP_BUNDLE="desktop-dist/argonode.app"
|
||||
|
||||
if [ ! -d "$APP_BUNDLE" ]; then
|
||||
# --onedir puts it inside a folder
|
||||
if [ -d "desktop-dist/argonode/argonode.app" ]; then
|
||||
APP_BUNDLE="desktop-dist/argonode/argonode.app"
|
||||
else
|
||||
echo "Warning: .app bundle not found; skipping DMG creation."
|
||||
CREATE_DMG=false
|
||||
fi
|
||||
fi
|
||||
|
||||
if $CREATE_DMG; then
|
||||
DMG_PATH="desktop-dist/argonode.dmg"
|
||||
echo "Creating DMG installer..."
|
||||
rm -f "$DMG_PATH"
|
||||
|
||||
# Create a temporary directory for DMG contents
|
||||
DMG_STAGING="desktop-build/dmg-staging"
|
||||
rm -rf "$DMG_STAGING"
|
||||
mkdir -p "$DMG_STAGING"
|
||||
cp -R "$APP_BUNDLE" "$DMG_STAGING/"
|
||||
ln -s /Applications "$DMG_STAGING/Applications"
|
||||
|
||||
hdiutil create \
|
||||
-volname "argonode" \
|
||||
-srcfolder "$DMG_STAGING" \
|
||||
-ov \
|
||||
-format UDZO \
|
||||
"$DMG_PATH"
|
||||
|
||||
rm -rf "$DMG_STAGING"
|
||||
echo "DMG created: $DMG_PATH"
|
||||
fi
|
||||
|
||||
echo "Desktop build complete."
|
||||
echo "Output: $REPO_ROOT/desktop-dist/"
|
||||
@@ -14,6 +14,7 @@ $pythonExe = if (Test-Path ".\.venv\Scripts\python.exe") {
|
||||
"python"
|
||||
}
|
||||
$frontendDist = Join-Path $repoRoot "frontend\dist"
|
||||
$demoDir = Join-Path $repoRoot "demo"
|
||||
|
||||
Write-Host "Building frontend bundle..."
|
||||
npm run build
|
||||
@@ -28,13 +29,14 @@ $pyInstallerArgs = @(
|
||||
"desktop.py",
|
||||
"--noconfirm",
|
||||
"--clean",
|
||||
"--name", "Argonode",
|
||||
"--name", "argonode",
|
||||
"--windowed",
|
||||
$mode,
|
||||
"--distpath", "desktop-dist",
|
||||
"--workpath", "desktop-build",
|
||||
"--specpath", "desktop-build",
|
||||
"--add-data", "${frontendDist};frontend/dist",
|
||||
"--add-data", "${demoDir};demo",
|
||||
"--collect-all", "matplotlib",
|
||||
"--collect-all", "scipy",
|
||||
"--collect-all", "skimage",
|
||||
@@ -45,4 +47,4 @@ Write-Host "Packaging desktop app..."
|
||||
& $pythonExe @pyInstallerArgs
|
||||
|
||||
Write-Host "Desktop build complete."
|
||||
Write-Host "Output folder: $repoRoot\desktop-dist\Argonode"
|
||||
Write-Host "Output folder: $repoRoot\desktop-dist\argonode"
|
||||
79
scripts/generate_demo_particles.py
Normal file
79
scripts/generate_demo_particles.py
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate a synthetic nanoparticle image for the demo/ folder.
|
||||
|
||||
The image simulates an AFM scan of particles on a flat substrate:
|
||||
- Slightly noisy background
|
||||
- ~20 hemisphere-shaped particles with varying radii and heights
|
||||
- Saved as both .npy (calibrated float64) and .png (visual preview)
|
||||
|
||||
Run from project root:
|
||||
python scripts/generate_demo_particles.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
DEMO_DIR = Path(__file__).resolve().parent.parent / "demo"
|
||||
DEMO_DIR.mkdir(exist_ok=True)
|
||||
|
||||
RNG = np.random.default_rng(2024)
|
||||
|
||||
# --- Image parameters ---
|
||||
N = 256 # pixels
|
||||
SCAN_SIZE = 5e-6 # 5 µm scan
|
||||
PIXEL_SIZE = SCAN_SIZE / N # metres per pixel
|
||||
BG_NOISE_RMS = 0.3e-9 # 0.3 nm background noise
|
||||
|
||||
# --- Generate particles ---
|
||||
particles = []
|
||||
# Hand-placed cluster + random scatter to give a realistic spread
|
||||
fixed = [
|
||||
# (cx_frac, cy_frac, radius_nm, height_nm)
|
||||
(0.25, 0.30, 120, 30),
|
||||
(0.28, 0.34, 80, 20),
|
||||
(0.70, 0.25, 150, 45),
|
||||
(0.50, 0.55, 100, 25),
|
||||
(0.55, 0.60, 60, 15),
|
||||
(0.15, 0.75, 200, 55),
|
||||
(0.80, 0.80, 90, 22),
|
||||
]
|
||||
for cx_f, cy_f, r_nm, h_nm in fixed:
|
||||
particles.append((cx_f * N, cy_f * N, r_nm * 1e-9, h_nm * 1e-9))
|
||||
|
||||
# Random particles
|
||||
for _ in range(15):
|
||||
cx = RNG.uniform(20, N - 20)
|
||||
cy = RNG.uniform(20, N - 20)
|
||||
radius = RNG.uniform(30, 180) * 1e-9 # 30–180 nm
|
||||
height = RNG.uniform(8, 60) * 1e-9 # 8–60 nm
|
||||
particles.append((cx, cy, radius, height))
|
||||
|
||||
# --- Render height map ---
|
||||
image = RNG.normal(0, BG_NOISE_RMS, (N, N))
|
||||
|
||||
yy, xx = np.mgrid[0:N, 0:N]
|
||||
|
||||
for cx, cy, radius_m, height_m in particles:
|
||||
radius_px = radius_m / PIXEL_SIZE
|
||||
dist2 = (xx - cx) ** 2 + (yy - cy) ** 2
|
||||
inside = dist2 < radius_px ** 2
|
||||
# Hemisphere profile: z = h * sqrt(1 - (r/R)^2)
|
||||
z = np.zeros_like(image)
|
||||
z[inside] = height_m * np.sqrt(1.0 - dist2[inside] / radius_px ** 2)
|
||||
image = np.maximum(image, z) # particles don't subtract from each other
|
||||
|
||||
# --- Save .npy (float64 metres) ---
|
||||
npy_path = DEMO_DIR / "nanoparticles.npy"
|
||||
np.save(str(npy_path), image)
|
||||
print(f"Saved {npy_path} shape={image.shape} range=[{image.min():.2e}, {image.max():.2e}] m")
|
||||
|
||||
# --- Save .png (8-bit grayscale for quick visual) ---
|
||||
from PIL import Image
|
||||
|
||||
normed = (image - image.min()) / (image.max() - image.min())
|
||||
uint8 = (normed * 255).astype(np.uint8)
|
||||
png_path = DEMO_DIR / "nanoparticles.png"
|
||||
Image.fromarray(uint8, mode="L").save(str(png_path))
|
||||
print(f"Saved {png_path}")
|
||||
|
||||
print(f"\n{len(particles)} particles generated on a {SCAN_SIZE*1e6:.0f} µm × {SCAN_SIZE*1e6:.0f} µm scan")
|
||||
445
tests/test_grains.py
Normal file
445
tests/test_grains.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Thorough tests for the grain/particle analysis pipeline:
|
||||
ThresholdMask → GrainAnalysis
|
||||
|
||||
Covers synthetic geometry (known answers), the demo nanoparticles image,
|
||||
edge cases, and physical-unit correctness.
|
||||
|
||||
Run from project root:
|
||||
.venv/bin/python -m tests.test_grains
|
||||
"""
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
def make_field(data, xreal=1e-6, yreal=1e-6):
|
||||
return DataField(data=data.astype(np.float64), xreal=xreal, yreal=yreal,
|
||||
si_unit_xy="m", si_unit_z="m")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# ThresholdMask tests
|
||||
# =========================================================================
|
||||
|
||||
def test_threshold_otsu_bimodal():
|
||||
"""Otsu on a clean bimodal image should separate the two populations."""
|
||||
print("=== Test: Otsu on bimodal image ===")
|
||||
from backend.nodes.grains import ThresholdMask
|
||||
node = ThresholdMask()
|
||||
|
||||
data = np.zeros((128, 128))
|
||||
data[30:50, 30:50] = 10.0 # bright square
|
||||
data[70:100, 80:110] = 10.0 # another bright region
|
||||
field = make_field(data)
|
||||
|
||||
mask, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
bright_pixels = (mask == 255)
|
||||
# Should capture both bright regions
|
||||
assert bright_pixels[40, 40], "Otsu missed bright region 1"
|
||||
assert bright_pixels[85, 95], "Otsu missed bright region 2"
|
||||
# Background should be dark
|
||||
assert not bright_pixels[0, 0], "Otsu false positive in background"
|
||||
assert not bright_pixels[60, 60], "Otsu false positive between regions"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_threshold_relative_range():
|
||||
"""Relative threshold at 0.5 should be the midpoint of [min, max]."""
|
||||
print("=== Test: Relative threshold at midpoint ===")
|
||||
from backend.nodes.grains import ThresholdMask
|
||||
node = ThresholdMask()
|
||||
|
||||
data = np.full((64, 64), 2.0)
|
||||
data[10:20, 10:20] = 8.0 # bright patch, range = [2, 8], midpoint = 5
|
||||
field = make_field(data)
|
||||
|
||||
mask, = node.process(field, method="relative", threshold=0.5, direction="above")
|
||||
# Only the bright patch (value 8 >= 5) should be masked
|
||||
assert np.all(mask[10:20, 10:20] == 255)
|
||||
assert np.all(mask[0:10, :] == 0)
|
||||
assert np.all(mask[20:, :] == 0)
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_threshold_empty_mask():
|
||||
"""Very high absolute threshold on low data should produce an empty mask."""
|
||||
print("=== Test: Empty mask from high threshold ===")
|
||||
from backend.nodes.grains import ThresholdMask
|
||||
node = ThresholdMask()
|
||||
|
||||
data = np.ones((64, 64))
|
||||
field = make_field(data)
|
||||
|
||||
mask, = node.process(field, method="absolute", threshold=999.0, direction="above")
|
||||
assert mask.sum() == 0, "Mask should be completely empty"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_threshold_full_mask():
|
||||
"""Very low absolute threshold should produce an all-white mask."""
|
||||
print("=== Test: Full mask from low threshold ===")
|
||||
from backend.nodes.grains import ThresholdMask
|
||||
node = ThresholdMask()
|
||||
|
||||
data = np.ones((64, 64)) * 5.0
|
||||
field = make_field(data)
|
||||
|
||||
mask, = node.process(field, method="absolute", threshold=-1.0, direction="above")
|
||||
assert np.all(mask == 255), "Mask should be all white"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# GrainAnalysis tests
|
||||
# =========================================================================
|
||||
|
||||
def test_single_circle_area():
|
||||
"""A single filled circle — verify pixel count and physical area."""
|
||||
print("=== Test: Single circle area ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
|
||||
N = 200
|
||||
XREAL = 2e-6 # 2 µm
|
||||
data = np.zeros((N, N))
|
||||
mask = np.zeros((N, N), dtype=np.uint8)
|
||||
|
||||
# Draw a filled circle, radius 30 px, centred at (100, 100)
|
||||
yy, xx = np.mgrid[0:N, 0:N]
|
||||
r = 30
|
||||
circle = ((xx - 100) ** 2 + (yy - 100) ** 2) <= r ** 2
|
||||
data[circle] = 5.0
|
||||
mask[circle] = 255
|
||||
|
||||
field = make_field(data, xreal=XREAL, yreal=XREAL)
|
||||
table, = node.process(field, mask=mask, min_size=1)
|
||||
|
||||
assert len(table) == 1, f"Expected 1 grain, got {len(table)}"
|
||||
grain = table[0]
|
||||
|
||||
# Pixel area of a discrete circle: should be close to π r²
|
||||
expected_px = np.pi * r ** 2
|
||||
assert abs(grain["area_px"] - expected_px) / expected_px < 0.02, \
|
||||
f"area_px={grain['area_px']}, expected≈{expected_px:.0f}"
|
||||
|
||||
# Physical area
|
||||
pixel_area = (XREAL / N) ** 2
|
||||
expected_m2 = grain["area_px"] * pixel_area
|
||||
assert abs(grain["area_m2"] - expected_m2) < 1e-20, \
|
||||
f"area_m2 mismatch: {grain['area_m2']} vs {expected_m2}"
|
||||
|
||||
# Equivalent diameter should be close to 2r in physical units
|
||||
expected_diam = 2 * r * (XREAL / N)
|
||||
assert abs(grain["equiv_diam_m"] - expected_diam) / expected_diam < 0.02, \
|
||||
f"equiv_diam={grain['equiv_diam_m']:.3e}, expected≈{expected_diam:.3e}"
|
||||
|
||||
# Heights
|
||||
assert abs(grain["mean_height"] - 5.0) < 1e-10
|
||||
assert abs(grain["max_height"] - 5.0) < 1e-10
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_multiple_grains_separation():
|
||||
"""Three well-separated grains of different sizes — check each is reported."""
|
||||
print("=== Test: Multiple grain separation ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
|
||||
N = 128
|
||||
data = np.zeros((N, N))
|
||||
mask = np.zeros((N, N), dtype=np.uint8)
|
||||
|
||||
# Grain A: 20×20 block, height 10
|
||||
data[10:30, 10:30] = 10.0
|
||||
mask[10:30, 10:30] = 255
|
||||
|
||||
# Grain B: 10×10 block, height 7
|
||||
data[60:70, 60:70] = 7.0
|
||||
mask[60:70, 60:70] = 255
|
||||
|
||||
# Grain C: 5×5 block, height 3
|
||||
data[100:105, 100:105] = 3.0
|
||||
mask[100:105, 100:105] = 255
|
||||
|
||||
field = make_field(data)
|
||||
table, = node.process(field, mask=mask, min_size=1)
|
||||
|
||||
assert len(table) == 3, f"Expected 3 grains, got {len(table)}"
|
||||
|
||||
table.sort(key=lambda r: r["area_px"], reverse=True)
|
||||
assert table[0]["area_px"] == 400 # 20×20
|
||||
assert table[1]["area_px"] == 100 # 10×10
|
||||
assert table[2]["area_px"] == 25 # 5×5
|
||||
|
||||
assert abs(table[0]["mean_height"] - 10.0) < 1e-10
|
||||
assert abs(table[1]["mean_height"] - 7.0) < 1e-10
|
||||
assert abs(table[2]["mean_height"] - 3.0) < 1e-10
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_min_size_filtering():
|
||||
"""min_size should exclude grains smaller than the threshold."""
|
||||
print("=== Test: min_size filtering ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
|
||||
N = 64
|
||||
data = np.zeros((N, N))
|
||||
mask = np.zeros((N, N), dtype=np.uint8)
|
||||
|
||||
# Large grain: 15×15 = 225 px
|
||||
data[5:20, 5:20] = 1.0
|
||||
mask[5:20, 5:20] = 255
|
||||
|
||||
# Medium grain: 8×8 = 64 px
|
||||
data[30:38, 30:38] = 1.0
|
||||
mask[30:38, 30:38] = 255
|
||||
|
||||
# Tiny grain: 3×3 = 9 px
|
||||
data[50:53, 50:53] = 1.0
|
||||
mask[50:53, 50:53] = 255
|
||||
|
||||
field = make_field(data)
|
||||
|
||||
# min_size=1: all three
|
||||
table, = node.process(field, mask=mask, min_size=1)
|
||||
assert len(table) == 3
|
||||
|
||||
# min_size=10: drops the 3×3
|
||||
table, = node.process(field, mask=mask, min_size=10)
|
||||
assert len(table) == 2
|
||||
|
||||
# min_size=100: drops the 3×3 and 8×8
|
||||
table, = node.process(field, mask=mask, min_size=100)
|
||||
assert len(table) == 1
|
||||
assert table[0]["area_px"] == 225
|
||||
|
||||
# min_size=300: drops everything
|
||||
table, = node.process(field, mask=mask, min_size=300)
|
||||
assert len(table) == 0
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_grain_bounding_box():
|
||||
"""Bounding box should match the grain extents."""
|
||||
print("=== Test: Grain bounding box ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
|
||||
N = 64
|
||||
data = np.zeros((N, N))
|
||||
mask = np.zeros((N, N), dtype=np.uint8)
|
||||
# Place a grain at rows 20:35, cols 10:45
|
||||
data[20:35, 10:45] = 2.0
|
||||
mask[20:35, 10:45] = 255
|
||||
|
||||
field = make_field(data)
|
||||
table, = node.process(field, mask=mask, min_size=1)
|
||||
assert len(table) == 1
|
||||
|
||||
bbox = table[0]["bbox"]
|
||||
# Format: "(xmin,ymin)-(xmax,ymax)" = "(10,20)-(44,34)"
|
||||
assert bbox == "(10,20)-(44,34)", f"bbox={bbox}, expected (10,20)-(44,34)"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_empty_mask_produces_no_grains():
|
||||
"""An all-zero mask should yield zero grains."""
|
||||
print("=== Test: Empty mask → no grains ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
|
||||
field = make_field(np.ones((64, 64)))
|
||||
mask = np.zeros((64, 64), dtype=np.uint8)
|
||||
|
||||
table, = node.process(field, mask=mask, min_size=1)
|
||||
assert len(table) == 0
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_grain_at_image_edge():
|
||||
"""A grain touching the image border should still be detected."""
|
||||
print("=== Test: Grain at image edge ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
|
||||
N = 64
|
||||
data = np.zeros((N, N))
|
||||
mask = np.zeros((N, N), dtype=np.uint8)
|
||||
# Grain touching top-left corner
|
||||
data[0:10, 0:10] = 4.0
|
||||
mask[0:10, 0:10] = 255
|
||||
|
||||
field = make_field(data)
|
||||
table, = node.process(field, mask=mask, min_size=1)
|
||||
assert len(table) == 1
|
||||
assert table[0]["area_px"] == 100
|
||||
assert table[0]["bbox"] == "(0,0)-(9,9)"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_adjacent_grains_connectivity():
|
||||
"""Two diagonally-touching blocks should be separate grains
|
||||
(scipy.ndimage.label uses 4-connectivity by default)."""
|
||||
print("=== Test: Diagonal adjacency → separate grains ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
|
||||
N = 32
|
||||
data = np.zeros((N, N))
|
||||
mask = np.zeros((N, N), dtype=np.uint8)
|
||||
|
||||
# Block A
|
||||
data[5:10, 5:10] = 1.0
|
||||
mask[5:10, 5:10] = 255
|
||||
|
||||
# Block B diagonally adjacent (touching only at corner 10,10)
|
||||
data[10:15, 10:15] = 1.0
|
||||
mask[10:15, 10:15] = 255
|
||||
|
||||
field = make_field(data)
|
||||
table, = node.process(field, mask=mask, min_size=1)
|
||||
# Default label() uses structure that connects diagonals? Let's verify.
|
||||
# scipy.ndimage.label default is cross-shaped (no diagonals) for 2D
|
||||
assert len(table) == 2, f"Expected 2 separate grains, got {len(table)}"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# End-to-end pipeline: ThresholdMask → GrainAnalysis
|
||||
# =========================================================================
|
||||
|
||||
def test_pipeline_synthetic():
|
||||
"""Full pipeline on a synthetic image with known geometry."""
|
||||
print("=== Test: Full pipeline on synthetic particles ===")
|
||||
from backend.nodes.grains import ThresholdMask, GrainAnalysis
|
||||
|
||||
N = 200
|
||||
XREAL = 10e-6 # 10 µm
|
||||
rng = np.random.default_rng(99)
|
||||
|
||||
# Background at 0 with small noise, particles as raised bumps
|
||||
bg = rng.normal(0, 0.1, (N, N))
|
||||
particles = np.zeros((N, N))
|
||||
|
||||
yy, xx = np.mgrid[0:N, 0:N]
|
||||
|
||||
specs = [
|
||||
(50, 50, 15, 5.0), # (cx, cy, radius_px, height)
|
||||
(150, 50, 20, 8.0),
|
||||
(100, 100, 10, 3.0),
|
||||
(50, 160, 25, 6.0),
|
||||
(160, 160, 12, 4.0),
|
||||
]
|
||||
for cx, cy, r, h in specs:
|
||||
inside = ((xx - cx) ** 2 + (yy - cy) ** 2) <= r ** 2
|
||||
particles[inside] = h
|
||||
|
||||
data = bg + particles
|
||||
field = make_field(data, xreal=XREAL, yreal=XREAL)
|
||||
|
||||
# Step 1: threshold
|
||||
thresh = ThresholdMask()
|
||||
mask, = thresh.process(field, method="absolute", threshold=1.0, direction="above")
|
||||
|
||||
# Particles are well above noise, so mask should capture all 5
|
||||
assert mask.max() == 255, "No particles detected"
|
||||
|
||||
# Step 2: grain analysis
|
||||
ga = GrainAnalysis()
|
||||
table, = ga.process(field, mask=mask, min_size=5)
|
||||
|
||||
assert len(table) == 5, f"Expected 5 grains, got {len(table)}"
|
||||
|
||||
# Verify that detected areas are in the right ballpark
|
||||
table.sort(key=lambda r: r["area_px"], reverse=True)
|
||||
expected_areas = sorted([np.pi * r ** 2 for _, _, r, _ in specs], reverse=True)
|
||||
|
||||
for grain, expected_px in zip(table, expected_areas):
|
||||
ratio = grain["area_px"] / expected_px
|
||||
assert 0.85 < ratio < 1.15, \
|
||||
f"grain area_px={grain['area_px']}, expected≈{expected_px:.0f}, ratio={ratio:.2f}"
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_pipeline_demo_image():
|
||||
"""Run the full pipeline on the bundled demo nanoparticles image."""
|
||||
print("=== Test: Full pipeline on demo nanoparticles.npy ===")
|
||||
from pathlib import Path
|
||||
from backend.nodes.grains import ThresholdMask, GrainAnalysis
|
||||
from backend.runtime_paths import demo_dir
|
||||
|
||||
npy_path = demo_dir() / "nanoparticles.npy"
|
||||
if not npy_path.exists():
|
||||
print(" SKIP (demo image not found)\n")
|
||||
return
|
||||
|
||||
data = np.load(str(npy_path)).astype(np.float64)
|
||||
# The demo image is a 5 µm × 5 µm scan
|
||||
field = make_field(data, xreal=5e-6, yreal=5e-6)
|
||||
|
||||
# Threshold to find particles (they are raised above background)
|
||||
thresh = ThresholdMask()
|
||||
mask, = thresh.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
|
||||
# Should detect particles
|
||||
assert mask.max() == 255, "No particles found in demo image"
|
||||
particle_fraction = (mask == 255).sum() / mask.size
|
||||
assert 0.01 < particle_fraction < 0.5, \
|
||||
f"Suspicious particle fraction: {particle_fraction:.3f}"
|
||||
print(f" Mask: {particle_fraction*100:.1f}% of pixels are particles")
|
||||
|
||||
# Grain analysis
|
||||
ga = GrainAnalysis()
|
||||
table, = ga.process(field, mask=mask, min_size=20)
|
||||
|
||||
assert len(table) > 0, "No grains detected"
|
||||
print(f" Found {len(table)} grains (min_size=20)")
|
||||
|
||||
# Sanity checks on grain properties
|
||||
for grain in table:
|
||||
assert grain["area_px"] >= 20
|
||||
assert grain["area_m2"] > 0
|
||||
assert grain["equiv_diam_m"] > 0
|
||||
assert grain["max_height"] >= grain["mean_height"]
|
||||
assert grain["mean_height"] > 0
|
||||
|
||||
# Physical size sanity: equivalent diameters should be in the nm–µm range
|
||||
diams_nm = [g["equiv_diam_m"] * 1e9 for g in table]
|
||||
print(f" Diameters: min={min(diams_nm):.0f} nm, max={max(diams_nm):.0f} nm")
|
||||
assert all(1 < d < 2000 for d in diams_nm), \
|
||||
f"Grain diameters out of expected range: {diams_nm}"
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Run all tests
|
||||
# =========================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ThresholdMask
|
||||
test_threshold_otsu_bimodal()
|
||||
test_threshold_relative_range()
|
||||
test_threshold_empty_mask()
|
||||
test_threshold_full_mask()
|
||||
|
||||
# GrainAnalysis
|
||||
test_single_circle_area()
|
||||
test_multiple_grains_separation()
|
||||
test_min_size_filtering()
|
||||
test_grain_bounding_box()
|
||||
test_empty_mask_produces_no_grains()
|
||||
test_grain_at_image_edge()
|
||||
test_adjacent_grains_connectivity()
|
||||
|
||||
# End-to-end pipeline
|
||||
test_pipeline_synthetic()
|
||||
test_pipeline_demo_image()
|
||||
|
||||
print("All grain tests passed!")
|
||||
@@ -85,6 +85,88 @@ def test_edge_detect():
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_fft_filter_1d():
|
||||
print("=== Test: FFTFilter1D ===")
|
||||
from backend.nodes.filters import FFTFilter1D
|
||||
node = FFTFilter1D()
|
||||
|
||||
# Signal: low-frequency sine + high-frequency sine
|
||||
n = 256
|
||||
t = np.arange(n, dtype=np.float64) / n
|
||||
low = np.sin(2 * np.pi * 3 * t) # 3 cycles — low freq
|
||||
high = np.sin(2 * np.pi * 80 * t) # 80 cycles — high freq
|
||||
line = low + high
|
||||
|
||||
# Lowpass should keep low, suppress high
|
||||
filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
||||
assert len(filtered_lp) == n
|
||||
corr_low = np.corrcoef(filtered_lp, low)[0, 1]
|
||||
corr_high = np.corrcoef(filtered_lp, high)[0, 1]
|
||||
assert corr_low > 0.95, f"Lowpass: correlation with low={corr_low}"
|
||||
assert abs(corr_high) < 0.3, f"Lowpass: correlation with high={corr_high}"
|
||||
|
||||
# Highpass should keep high, suppress low
|
||||
filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
||||
corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1]
|
||||
corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1]
|
||||
assert abs(corr_low_hp) < 0.3, f"Highpass: correlation with low={corr_low_hp}"
|
||||
assert corr_high_hp > 0.95, f"Highpass: correlation with high={corr_high_hp}"
|
||||
|
||||
# Bandpass centred on the high frequency
|
||||
filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4)
|
||||
corr_low_bp = np.corrcoef(filtered_bp, low)[0, 1]
|
||||
corr_high_bp = np.corrcoef(filtered_bp, high)[0, 1]
|
||||
assert abs(corr_low_bp) < 0.3, f"Bandpass: correlation with low={corr_low_bp}"
|
||||
assert corr_high_bp > 0.9, f"Bandpass: correlation with high={corr_high_bp}"
|
||||
|
||||
# Notch (band-reject) centred on the high frequency — should remove it
|
||||
filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4)
|
||||
corr_low_notch = np.corrcoef(filtered_notch, low)[0, 1]
|
||||
corr_high_notch = np.corrcoef(filtered_notch, high)[0, 1]
|
||||
assert corr_low_notch > 0.95, f"Notch: correlation with low={corr_low_notch}"
|
||||
assert abs(corr_high_notch) < 0.3, f"Notch: correlation with high={corr_high_notch}"
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_fft_filter_2d():
|
||||
print("=== Test: FFTFilter2D ===")
|
||||
from backend.nodes.filters import FFTFilter2D
|
||||
node = FFTFilter2D()
|
||||
|
||||
N = 128
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
# Low-frequency 2D pattern + high-frequency pattern
|
||||
low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y)
|
||||
high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y)
|
||||
data = low_2d + high_2d
|
||||
field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6)
|
||||
|
||||
# Lowpass — should preserve low, remove high
|
||||
result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
||||
assert result_lp.data.shape == (N, N)
|
||||
assert result_lp.xreal == field.xreal
|
||||
assert result_lp.si_unit_z == field.si_unit_z
|
||||
corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1]
|
||||
corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1]
|
||||
assert corr_low > 0.9, f"2D lowpass: correlation with low={corr_low}"
|
||||
assert abs(corr_high) < 0.3, f"2D lowpass: correlation with high={corr_high}"
|
||||
|
||||
# Highpass — should preserve high, remove low
|
||||
result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
||||
corr_low_hp = np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]
|
||||
corr_high_hp = np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1]
|
||||
assert abs(corr_low_hp) < 0.3, f"2D highpass: correlation with low={corr_low_hp}"
|
||||
assert corr_high_hp > 0.9, f"2D highpass: correlation with high={corr_high_hp}"
|
||||
|
||||
# Constant field should be unchanged by lowpass (DC preservation)
|
||||
const = make_field(data=np.ones((32, 32)) * 7.0)
|
||||
result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2)
|
||||
assert np.allclose(result_const.data, 7.0, atol=1e-10), "Lowpass should preserve constant field"
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Level
|
||||
# =========================================================================
|
||||
@@ -199,7 +281,7 @@ def test_height_histogram():
|
||||
data = np.linspace(0, 1, 1000).reshape(25, 40)
|
||||
field = make_field(data=data)
|
||||
|
||||
counts, bin_centers = node.process(field, n_bins=10)
|
||||
counts, bin_centers = node.process(field, n_bins=10, y_scale="linear")
|
||||
assert len(counts) == 10
|
||||
assert len(bin_centers) == 10
|
||||
assert counts.dtype == np.float64
|
||||
@@ -265,7 +347,7 @@ def test_cross_section():
|
||||
|
||||
def test_threshold_mask():
|
||||
print("=== Test: ThresholdMask ===")
|
||||
from backend.nodes.grains import ThresholdMask
|
||||
from backend.nodes.mask import ThresholdMask
|
||||
node = ThresholdMask()
|
||||
|
||||
# Clear bimodal data: left half = 0, right half = 1
|
||||
@@ -273,6 +355,11 @@ def test_threshold_mask():
|
||||
data[:, 32:] = 1.0
|
||||
field = make_field(data=data)
|
||||
|
||||
# Capture overlay preview
|
||||
previews = []
|
||||
ThresholdMask._broadcast_fn = lambda nid, uri: previews.append(uri)
|
||||
ThresholdMask._current_node_id = "test"
|
||||
|
||||
# Absolute threshold at 0.5
|
||||
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
|
||||
assert mask.dtype == np.uint8
|
||||
@@ -280,6 +367,10 @@ def test_threshold_mask():
|
||||
assert np.all(mask[:, :32] == 0)
|
||||
assert np.all(mask[:, 32:] == 255)
|
||||
|
||||
# Verify overlay preview was broadcast
|
||||
assert len(previews) == 1
|
||||
assert previews[0].startswith("data:image/png;base64,")
|
||||
|
||||
# Direction "below"
|
||||
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
|
||||
assert np.all(mask_below[:, :32] == 255)
|
||||
@@ -292,20 +383,117 @@ def test_threshold_mask():
|
||||
# Otsu should find the bimodal threshold
|
||||
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
|
||||
|
||||
ThresholdMask._broadcast_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_grain_analysis():
|
||||
print("=== Test: GrainAnalysis ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
def test_mask_morphology():
|
||||
print("=== Test: MaskMorphology ===")
|
||||
from backend.nodes.mask import MaskMorphology
|
||||
node = MaskMorphology()
|
||||
|
||||
# Create a field with two distinct "grains"
|
||||
# Small square blob in the centre
|
||||
mask = np.zeros((64, 64), dtype=np.uint8)
|
||||
mask[28:36, 28:36] = 255 # 8x8 block
|
||||
orig_count = np.count_nonzero(mask)
|
||||
|
||||
# Dilate should grow the region
|
||||
dilated, = node.process(mask, operation="dilate", radius=1, shape="square")
|
||||
assert dilated.dtype == np.uint8
|
||||
assert np.count_nonzero(dilated) > orig_count
|
||||
|
||||
# Erode should shrink it
|
||||
eroded, = node.process(mask, operation="erode", radius=1, shape="square")
|
||||
assert np.count_nonzero(eroded) < orig_count
|
||||
|
||||
# Open on a clean block should give back roughly the same block
|
||||
opened, = node.process(mask, operation="open", radius=1, shape="square")
|
||||
assert np.count_nonzero(opened) <= orig_count
|
||||
|
||||
# Close on a mask with a 1-pixel hole should fill the hole
|
||||
mask_hole = mask.copy()
|
||||
mask_hole[32, 32] = 0 # poke a hole
|
||||
assert np.count_nonzero(mask_hole) == orig_count - 1
|
||||
closed, = node.process(mask_hole, operation="close", radius=1, shape="square")
|
||||
assert closed[32, 32] == 255, "Close should fill the 1-pixel hole"
|
||||
|
||||
# Disk structuring element should also work
|
||||
dilated_disk, = node.process(mask, operation="dilate", radius=2, shape="disk")
|
||||
assert np.count_nonzero(dilated_disk) > orig_count
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_mask_invert():
|
||||
print("=== Test: MaskInvert ===")
|
||||
from backend.nodes.mask import MaskInvert
|
||||
node = MaskInvert()
|
||||
|
||||
mask = np.zeros((64, 64), dtype=np.uint8)
|
||||
mask[10:20, 10:20] = 255
|
||||
|
||||
inverted, = node.process(mask)
|
||||
assert inverted.dtype == np.uint8
|
||||
assert np.all(inverted[10:20, 10:20] == 0)
|
||||
assert np.all(inverted[0:10, 0:10] == 255)
|
||||
# Double-invert should return to original
|
||||
double, = node.process(inverted)
|
||||
assert np.array_equal(double, mask)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_mask_combine():
|
||||
print("=== Test: MaskCombine ===")
|
||||
from backend.nodes.mask import MaskCombine
|
||||
node = MaskCombine()
|
||||
|
||||
# Two overlapping squares
|
||||
a = np.zeros((64, 64), dtype=np.uint8)
|
||||
a[10:30, 10:30] = 255 # 20x20
|
||||
b = np.zeros((64, 64), dtype=np.uint8)
|
||||
b[20:40, 20:40] = 255 # 20x20, overlaps 10x10
|
||||
|
||||
# AND — only the overlap
|
||||
result_and, = node.process(a, b, operation="and")
|
||||
assert np.all(result_and[20:30, 20:30] == 255)
|
||||
assert result_and[15, 15] == 0 # a-only region
|
||||
assert result_and[35, 35] == 0 # b-only region
|
||||
|
||||
# OR — union
|
||||
result_or, = node.process(a, b, operation="or")
|
||||
assert result_or[15, 15] == 255
|
||||
assert result_or[35, 35] == 255
|
||||
assert result_or[25, 25] == 255
|
||||
assert result_or[5, 5] == 0
|
||||
|
||||
# XOR — symmetric difference
|
||||
result_xor, = node.process(a, b, operation="xor")
|
||||
assert result_xor[15, 15] == 255 # a-only
|
||||
assert result_xor[35, 35] == 255 # b-only
|
||||
assert result_xor[25, 25] == 0 # overlap excluded
|
||||
|
||||
# Subtract — a minus b
|
||||
result_sub, = node.process(a, b, operation="subtract")
|
||||
assert result_sub[15, 15] == 255 # a-only kept
|
||||
assert result_sub[25, 25] == 0 # overlap removed
|
||||
assert result_sub[35, 35] == 0 # b-only not included
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_particle_analysis():
|
||||
print("=== Test: ParticleAnalysis ===")
|
||||
from backend.nodes.grains import ParticleAnalysis
|
||||
node = ParticleAnalysis()
|
||||
|
||||
# Create a field with two distinct particles
|
||||
N = 64
|
||||
data = np.zeros((N, N))
|
||||
# Grain 1: 10x10 block at top-left with height 5
|
||||
# Particle 1: 10x10 block at top-left with height 5
|
||||
data[5:15, 5:15] = 5.0
|
||||
# Grain 2: 8x8 block at bottom-right with height 3
|
||||
# Particle 2: 8x8 block at bottom-right with height 3
|
||||
data[45:53, 45:53] = 3.0
|
||||
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
||||
|
||||
@@ -315,7 +503,7 @@ def test_grain_analysis():
|
||||
mask[45:53, 45:53] = 255
|
||||
|
||||
table, = node.process(field, mask=mask, min_size=10)
|
||||
assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
|
||||
assert len(table) == 2, f"Expected 2 particles, got {len(table)}"
|
||||
|
||||
# Sort by area descending
|
||||
table.sort(key=lambda r: r["area_px"], reverse=True)
|
||||
@@ -324,7 +512,7 @@ def test_grain_analysis():
|
||||
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
|
||||
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
|
||||
|
||||
# min_size filtering: only keep grains >= 80 px
|
||||
# min_size filtering: only keep particles >= 80 px
|
||||
table_filtered, = node.process(field, mask=mask, min_size=80)
|
||||
assert len(table_filtered) == 1
|
||||
assert table_filtered[0]["area_px"] == 100
|
||||
@@ -462,6 +650,8 @@ if __name__ == "__main__":
|
||||
test_gaussian_filter()
|
||||
test_median_filter()
|
||||
test_edge_detect()
|
||||
test_fft_filter_1d()
|
||||
test_fft_filter_2d()
|
||||
|
||||
# Level
|
||||
test_plane_level()
|
||||
@@ -473,9 +663,14 @@ if __name__ == "__main__":
|
||||
test_height_histogram()
|
||||
test_cross_section()
|
||||
|
||||
# Grains
|
||||
# Mask
|
||||
test_threshold_mask()
|
||||
test_grain_analysis()
|
||||
test_mask_morphology()
|
||||
test_mask_invert()
|
||||
test_mask_combine()
|
||||
|
||||
# Grains
|
||||
test_particle_analysis()
|
||||
|
||||
# I/O
|
||||
test_load_image()
|
||||
|
||||
Reference in New Issue
Block a user