diff --git a/GWYDDION_FEATURE_GAP.md b/GWYDDION_FEATURE_GAP.md index b393f82..762e41a 100644 --- a/GWYDDION_FEATURE_GAP.md +++ b/GWYDDION_FEATURE_GAP.md @@ -1,4 +1,4 @@ -# Gwyddion Features Not Yet in Argonode +# Gwyddion Features Not Yet in argonode Reference for future implementation. Grouped by value to typical SPM workflows. @@ -11,9 +11,9 @@ Reference for future implementation. Grouped by value to typical SPM workflows. | 1 | Line Correction | linecorrect.c, linematch.c | Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts. | | 2 | Scar Removal | scars.c | Detect and interpolate scan-line defects (horizontal streaks). | | 3 | Facet Leveling | facet-level.c | Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces. | -| 4 | Morphological Mask Ops | mask_morph.c | Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks. | -| 5 | 1D FFT Filter | fft_filter_1d.c | Bandpass/lowpass/highpass filtering of LINE profiles. | -| 6 | 2D FFT Filter | fft_filter_2d.c | Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.). | +| ~~4~~ | ~~Morphological Mask Ops~~ | ~~mask_morph.c~~ | ~~Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks.~~ **DONE** | +| ~~5~~ | ~~1D FFT Filter~~ | ~~fft_filter_1d.c~~ | ~~Bandpass/lowpass/highpass filtering of LINE profiles.~~ **DONE** | +| ~~6~~ | ~~2D FFT Filter~~ | ~~fft_filter_2d.c~~ | ~~Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.).~~ **DONE** | | 7 | Autocorrelation (ACF) | acf2d.c | 2D autocorrelation function. Reveals periodic structures and correlation lengths. | | 8 | PSDF | psdf2d.c | Radial/2D power spectral density function. Complementary to ACF for roughness characterization. | | 9 | Fractal Dimension | fractal.c | Multiple methods: partitioning, cube counting, triangulation, PSDF, HHCF. Quantifies surface complexity. | @@ -61,11 +61,11 @@ Reference for future implementation. Grouped by value to typical SPM workflows. --- -## Already Implemented in Argonode +## Already Implemented in argonode For reference, these Gwyddion equivalents are already covered: -| Argonode Node | Category | Gwyddion Equivalent | +| argonode Node | Category | Gwyddion Equivalent | |--------------|----------|-------------------| | Load Image / Load SPM File | io | File import (gwy, sxm, ibw) | | Save Image | io | File export | @@ -76,12 +76,17 @@ For reference, these Gwyddion equivalents are already covered: | Gaussian Filter | filters | filters.c (gaussian) | | Median Filter | filters | filters.c (median) | | Edge Detect | filters | edge.c (sobel, prewitt, laplacian, LoG) | +| 1D FFT Filter | filters | fft_filter_1d.c (lowpass, highpass, bandpass, notch) | +| 2D FFT Filter | filters | fft_filter_2d.c (lowpass, highpass, bandpass, notch) | | Statistics | analysis | stats.c | | Height Histogram | analysis | linestats.c (dh) | | 2D FFT | analysis | fft.c | | Cross Section | analysis | profile tool | | Profile Roughness | analysis | roughness.c (Ra, Rq, Rsk, Rku, Rp, Rv, Rt) | | Line Math | analysis | linestats.c | -| Threshold Mask | grains | threshold.c, otsu_threshold.c | -| Grain Analysis | grains | grain_stat.c | +| Threshold Mask | mask | threshold.c, otsu_threshold.c | +| Mask Morphology | mask | mask_morph.c (erode, dilate, open, close) | +| Mask Invert | mask | — | +| Mask Combine | mask | — (boolean AND, OR, XOR, subtract) | +| Particle Analysis | grains | grain_stat.c | | Preview / 3D View / Print Table | display | Presentation, 3D view | diff --git a/README.md b/README.md index f3902a5..999d091 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# Argonode +# argonode -Argonode is a node-based image analysis application with: +argonode is a node-based image analysis application with: - a Python backend built on `aiohttp` - a React + Vite frontend @@ -135,13 +135,13 @@ powershell -ExecutionPolicy Bypass -File scripts\build-desktop.ps1 The packaged app is written to: ```text -desktop-dist/Argonode/ +desktop-dist/argonode/ ``` Main executable: ```text -desktop-dist/Argonode/Argonode.exe +desktop-dist/argonode/argonode.exe ``` ### One-File Build @@ -161,14 +161,14 @@ During normal source-based development, input/output folders live under the repo In the packaged desktop app, writable data is stored under: ```text -%LOCALAPPDATA%\Argonode\ +%LOCALAPPDATA%\argonode\ ``` Specifically: ```text -%LOCALAPPDATA%\Argonode\input -%LOCALAPPDATA%\Argonode\output +%LOCALAPPDATA%\argonode\input +%LOCALAPPDATA%\argonode\output ``` You can override the packaged app data directory with: diff --git a/backend/execution.py b/backend/execution.py index cb11b62..b9a52af 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -177,13 +177,19 @@ class ExecutionEngine: ) -> None: """Wire up broadcast callbacks on display node classes.""" from backend.nodes.display import PreviewImage, PrintTable, View3D - from backend.nodes.analysis import CrossSection + from backend.nodes.analysis import CrossSection, LineCursors + from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine from backend.nodes.io import SaveImage PreviewImage._broadcast_fn = on_preview + ThresholdMask._broadcast_fn = on_preview + MaskMorphology._broadcast_fn = on_preview + MaskInvert._broadcast_fn = on_preview + MaskCombine._broadcast_fn = on_preview View3D._broadcast_mesh_fn = on_mesh PrintTable._broadcast_table_fn = on_table CrossSection._broadcast_overlay_fn = on_overlay + LineCursors._broadcast_overlay_fn = on_overlay SaveImage._broadcast_preview = ( (lambda data_uri: on_preview("save", data_uri)) if on_preview else None ) @@ -191,8 +197,10 @@ class ExecutionEngine: def _set_node_id_on_display(self, cls: type, node_id: str) -> None: """Inform display nodes of their current node_id for WS tagging.""" from backend.nodes.display import PreviewImage, PrintTable, View3D - from backend.nodes.analysis import CrossSection - if cls in (PreviewImage, PrintTable, View3D, CrossSection): + from backend.nodes.analysis import CrossSection, LineCursors + from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine + if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors, + ThresholdMask, MaskMorphology, MaskInvert, MaskCombine): cls._current_node_id = node_id def _auto_preview( @@ -206,12 +214,16 @@ class ExecutionEngine: """ After every node executes, inspect its outputs and broadcast a preview for the first DATA_FIELD, IMAGE, or TABLE found. + Skip nodes that broadcast their own custom preview. """ import numpy as np from backend.data_types import ( DataField, datafield_to_uint8, image_to_uint8, encode_preview, ) + if getattr(cls, "_CUSTOM_PREVIEW", False): + return + return_types = getattr(cls, "RETURN_TYPES", ()) for slot, type_name in enumerate(return_types): diff --git a/backend/main.py b/backend/main.py index 657bd14..ad8bc28 100644 --- a/backend/main.py +++ b/backend/main.py @@ -36,7 +36,7 @@ def main() -> None: app = create_app(loop) log.info("=" * 60) - log.info(" Argonode — Node-based image analysis") + log.info(" argonode — Node-based image analysis") log.info(" Open your browser at http://%s:%d", HOST, PORT) log.info("=" * 60) diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index 588daa7..cae9bb6 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -1,2 +1,2 @@ # Import all node modules to trigger @register_node decorators. -from . import io, filters, level, analysis, grains, display +from . import io, filters, level, analysis, grains, mask, display diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index d112a95..60c6b30 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -69,6 +69,7 @@ class HeightHistogram: "required": { "field": ("DATA_FIELD",), "n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}), + "y_scale": (["linear", "log"],), } } @@ -78,13 +79,150 @@ class HeightHistogram: CATEGORY = "analysis" DESCRIPTION = ( "Compute the height distribution histogram (DH). " + "Use log scale to reveal small peaks next to a dominant background. " "Equivalent to gwy_data_field_dh." ) - def process(self, field: DataField, n_bins: int) -> tuple: + def process(self, field: DataField, n_bins: int, y_scale: str = "linear") -> tuple: counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins)) bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) - return (counts.astype(np.float64), bin_centers) + counts = counts.astype(np.float64) + if y_scale == "log": + counts = np.log10(1.0 + counts) + return (counts, bin_centers) + + +# --------------------------------------------------------------------------- +# LineCursors — interactive measurement cursors on any LINE plot +# --------------------------------------------------------------------------- + +@register_node(display_name="Line Cursors") +class LineCursors: + """Place two draggable cursors on any LINE plot to measure values and deltas.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "line": ("LINE",), + "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + }, + "optional": { + "x_axis": ("LINE",), + }, + } + + RETURN_TYPES = ("TABLE",) + RETURN_NAMES = ("measurement",) + FUNCTION = "process" + CATEGORY = "analysis" + DESCRIPTION = ( + "Place two cursors on any line plot (histogram, cross section, profile) " + "to measure positions, values, and deltas. Drag the markers to reposition." + ) + + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process( + self, line, x1: float, y1: float, x2: float, y2: float, + x_axis=None, + ) -> tuple: + import io as _io + import base64 + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + y = np.asarray(line, dtype=np.float64).ravel() + n = len(y) + if x_axis is not None: + x = np.asarray(x_axis, dtype=np.float64).ravel()[:n] + else: + x = np.arange(n, dtype=np.float64) + + # --- Render the base plot first to determine axes bounds --- + fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100) + fig.patch.set_facecolor("#1e293b") + ax.set_facecolor("#0f172a") + ax.plot(x, y, color="#ff9800", linewidth=1.2) + ax.tick_params(colors="#94a3b8", labelsize=7) + for spine in ax.spines.values(): + spine.set_color("#334155") + ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5) + fig.tight_layout(pad=0.4) + + # Force a draw so transforms are valid + fig.canvas.draw() + + # Get axes position in figure-fraction coordinates + ax_pos = ax.get_position() + ax_l, ax_b = ax_pos.x0, ax_pos.y0 + ax_w, ax_h = ax_pos.width, ax_pos.height + + # x1/y1 arrive as image-fraction from the frontend drag. + # Convert image-fraction x → axes-fraction → nearest data index. + def img_x_to_idx(ix): + axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1) + return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1)) + + idx_a = img_x_to_idx(x1) + idx_b = img_x_to_idx(x2) + + xa, ya = float(x[idx_a]), float(y[idx_a]) + xb, yb = float(x[idx_b]), float(y[idx_b]) + + # --- Draw cursor lines and markers on the plot --- + ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9) + ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9) + ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5) + ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5) + ax.annotate( + "", xy=(xb, yb), xytext=(xa, ya), + arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5), + ) + + # --- Broadcast overlay --- + if LineCursors._broadcast_overlay_fn is not None: + # Convert data-space positions back to image-fraction for markers + fig.canvas.draw() + inv = fig.transFigure.inverted() + fig_a = inv.transform(ax.transData.transform([xa, ya])) + fig_b = inv.transform(ax.transData.transform([xb, yb])) + + buf = _io.BytesIO() + fig.savefig(buf, format="png", facecolor=fig.get_facecolor()) + buf.seek(0) + image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode() + + LineCursors._broadcast_overlay_fn( + LineCursors._current_node_id, + { + "image": image_uri, + "x1": float(fig_a[0]), + "y1": float(1.0 - fig_a[1]), # flip: image y=0 is top + "x2": float(fig_b[0]), + "y2": float(1.0 - fig_b[1]), + "a_locked": False, + "b_locked": False, + }, + ) + + plt.close(fig) + + # --- Output table --- + table = [ + {"quantity": "A position", "value": xa, "unit": ""}, + {"quantity": "A value", "value": ya, "unit": ""}, + {"quantity": "B position", "value": xb, "unit": ""}, + {"quantity": "B value", "value": yb, "unit": ""}, + {"quantity": "delta X", "value": xb - xa, "unit": ""}, + {"quantity": "delta Y", "value": yb - ya, "unit": ""}, + ] + return (table,) # --------------------------------------------------------------------------- @@ -242,9 +380,9 @@ class CrossSection: return { "required": { "field": ("DATA_FIELD",), - "x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "x1": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), - "x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "x2": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), "extend": (["none", "to_edges"],), "n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), diff --git a/backend/nodes/filters.py b/backend/nodes/filters.py index 783a04b..8a5f872 100644 --- a/backend/nodes/filters.py +++ b/backend/nodes/filters.py @@ -5,6 +5,8 @@ Gwyddion equivalents: GaussianFilter → gwy_data_field_filter_gaussian MedianFilter → gwy_data_field_filter_median EdgeDetect → gwy_data_field_filter_sobel / laplacian / log + FFTFilter1D → fft_filter_1d.c (bandpass/lowpass/highpass on LINE profiles) + FFTFilter2D → fft_filter_2d.c (frequency-domain filtering of DATA_FIELDs) """ from __future__ import annotations @@ -113,3 +115,190 @@ class EdgeDetect: raise ValueError(f"Unknown edge detection method: {method}") return (field.replace(data=result),) + + +# --------------------------------------------------------------------------- +# Butterworth transfer function helpers +# --------------------------------------------------------------------------- + +def _butterworth_lp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray: + """Butterworth lowpass: H = 1 / (1 + (f/fc)^(2n)).""" + with np.errstate(divide="ignore", over="ignore"): + return 1.0 / (1.0 + (freq / cutoff) ** (2 * order)) + + +def _butterworth_hp(freq: np.ndarray, cutoff: float, order: int) -> np.ndarray: + """Butterworth highpass: H = 1 / (1 + (fc/f)^(2n)).""" + with np.errstate(divide="ignore", invalid="ignore"): + h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order)) + h = np.where(np.isfinite(h), h, 0.0) + return h + + +def _build_1d_transfer(n: int, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> np.ndarray: + """Build a 1-D transfer function for an FFT of length *n*. + + Frequencies are normalised so that 1.0 = Nyquist (fs/2). + The returned array has the same layout as np.fft.rfft output (length n//2+1). + """ + freq = np.linspace(0, 1, n // 2 + 1) + + if filter_type == "lowpass": + H = _butterworth_lp(freq, cutoff, order) + elif filter_type == "highpass": + H = _butterworth_hp(freq, cutoff, order) + elif filter_type == "bandpass": + H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order) + elif filter_type == "notch": + bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order) + H = 1.0 - bp + else: + H = np.ones_like(freq) + return H + + +# --------------------------------------------------------------------------- +# FFTFilter1D — frequency-domain filtering of LINE profiles +# --------------------------------------------------------------------------- + +@register_node(display_name="1D FFT Filter") +class FFTFilter1D: + """Bandpass / lowpass / highpass / notch filtering of 1-D line profiles. + + Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth + transfer function with configurable order for a smooth roll-off. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "line": ("LINE",), + "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), + "cutoff": ("FLOAT", { + "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "cutoff_high": ("FLOAT", { + "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), + } + } + + RETURN_TYPES = ("LINE",) + RETURN_NAMES = ("filtered",) + FUNCTION = "process" + CATEGORY = "filters" + DESCRIPTION = ( + "Frequency-domain filtering of a 1-D line profile. " + "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " + "with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. " + "Equivalent to Gwyddion fft_filter_1d." + ) + + def process(self, line, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> tuple: + z = np.asarray(line, dtype=np.float64).ravel() + n = len(z) + + # Forward FFT (real-valued) + Z = np.fft.rfft(z) + + # Build and apply transfer function + H = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order) + Z *= H + + # Inverse FFT + filtered = np.fft.irfft(Z, n=n) + return (filtered,) + + +# --------------------------------------------------------------------------- +# FFTFilter2D — frequency-domain filtering of DATA_FIELDs +# --------------------------------------------------------------------------- + +@register_node(display_name="2D FFT Filter") +class FFTFilter2D: + """Frequency-domain filtering of 2-D data fields (images). + + Equivalent to Gwyddion's fft_filter_2d module. Applies a radial + Butterworth transfer function in the frequency domain to remove or + isolate periodic features. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), + "cutoff": ("FLOAT", { + "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "cutoff_high": ("FLOAT", { + "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("filtered",) + FUNCTION = "process" + CATEGORY = "filters" + DESCRIPTION = ( + "Frequency-domain filtering of a 2-D data field. " + "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " + "with a radial Butterworth roll-off. Cutoffs are fractions of the " + "Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or " + "bandpass/notch to isolate or remove periodic noise. " + "Equivalent to Gwyddion fft_filter_2d." + ) + + def process(self, field: DataField, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> tuple: + data = field.data.copy() + yres, xres = data.shape + + # Subtract mean to avoid DC leakage artefacts + mean_val = data.mean() + data -= mean_val + + # Forward 2D FFT + F = np.fft.fft2(data) + F = np.fft.fftshift(F) + + # Build radial frequency grid normalised to [0, 1] (1 = Nyquist) + fy = np.fft.fftshift(np.fft.fftfreq(yres)) # range [-0.5, 0.5) + fx = np.fft.fftshift(np.fft.fftfreq(xres)) + FX, FY = np.meshgrid(fx, fy) + # Normalise so that corner = 1 in each axis independently, + # then take Euclidean norm; max radial value = 1.0 at Nyquist. + R = np.sqrt((FX / 0.5) ** 2 + (FY / 0.5) ** 2) + R = np.clip(R / R.max(), 0, 1) if R.max() > 0 else R + + # Build transfer function + if filter_type == "lowpass": + H = _butterworth_lp(R, cutoff, order) + elif filter_type == "highpass": + H = _butterworth_hp(R, cutoff, order) + elif filter_type == "bandpass": + H = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order) + elif filter_type == "notch": + bp = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order) + H = 1.0 - bp + else: + H = np.ones_like(R) + + # Apply filter + F *= H + + # Inverse FFT + F = np.fft.ifftshift(F) + result = np.fft.ifft2(F).real + + # Restore DC + result += mean_val + + return (field.replace(data=result),) diff --git a/backend/nodes/grains.py b/backend/nodes/grains.py index 6228955..7ef5219 100644 --- a/backend/nodes/grains.py +++ b/backend/nodes/grains.py @@ -1,9 +1,8 @@ """ -Grain/feature detection nodes. +Particle detection nodes. Gwyddion equivalents: - ThresholdMask → threshold.c / otsu_threshold.c - GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c) + ParticleAnalysis → gwy_data_field_grains_get_values (grains-values.c) """ from __future__ import annotations @@ -13,61 +12,11 @@ from backend.data_types import DataField # --------------------------------------------------------------------------- -# ThresholdMask +# ParticleAnalysis # --------------------------------------------------------------------------- -@register_node(display_name="Threshold Mask") -class ThresholdMask: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "method": (["otsu", "absolute", "relative"],), - "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}), - "direction": (["above", "below"],), - } - } - - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("mask",) - FUNCTION = "process" - CATEGORY = "grains" - DESCRIPTION = ( - "Create a binary mask by thresholding data. " - "Otsu automatically finds the optimal threshold. " - "Equivalent to Gwyddion's threshold and otsu_threshold modules." - ) - - def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple: - data = field.data - - if method == "otsu": - from skimage.filters import threshold_otsu - t = threshold_otsu(data) - elif method == "absolute": - t = float(threshold) - elif method == "relative": - # threshold is a fraction [0, 1] of the data range - dmin, dmax = data.min(), data.max() - t = dmin + float(threshold) * (dmax - dmin) - else: - raise ValueError(f"Unknown threshold method: {method}") - - if direction == "above": - mask = (data >= t).astype(np.uint8) * 255 - else: - mask = (data < t).astype(np.uint8) * 255 - - return (mask,) - - -# --------------------------------------------------------------------------- -# GrainAnalysis -# --------------------------------------------------------------------------- - -@register_node(display_name="Grain Analysis") -class GrainAnalysis: +@register_node(display_name="Particle Analysis") +class ParticleAnalysis: @classmethod def INPUT_TYPES(cls): return { @@ -79,43 +28,43 @@ class GrainAnalysis: } RETURN_TYPES = ("TABLE",) - RETURN_NAMES = ("grain_stats",) + RETURN_NAMES = ("particle_stats",) FUNCTION = "process" CATEGORY = "grains" DESCRIPTION = ( - "Label connected grain regions in a binary mask and compute per-grain statistics: " - "area, equivalent diameter, mean/max height, bounding box. " + "Label connected particle regions in a binary mask and compute per-particle " + "statistics: area, equivalent diameter, mean/max height, bounding box. " "Equivalent to gwy_data_field_grains_get_values." ) def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple: - from scipy.ndimage import label, find_objects + from scipy.ndimage import label binary = (mask > 127).astype(np.int32) - labeled, n_grains = label(binary) + labeled, n_particles = label(binary) pixel_area = field.dx * field.dy # m^2 per pixel rows = [] - for grain_id in range(1, n_grains + 1): - grain_pixels = labeled == grain_id - area_px = int(grain_pixels.sum()) + for pid in range(1, n_particles + 1): + particle_pixels = labeled == pid + area_px = int(particle_pixels.sum()) if area_px < min_size: continue area_m2 = area_px * pixel_area equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi)) - heights = field.data[grain_pixels] + heights = field.data[particle_pixels] mean_h = float(heights.mean()) max_h = float(heights.max()) # Bounding box - ys, xs = np.where(grain_pixels) + ys, xs = np.where(particle_pixels) bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})" rows.append({ - "grain_id": grain_id, + "particle_id": pid, "area_px": area_px, "area_m2": area_m2, "equiv_diam_m": equiv_diam, diff --git a/backend/nodes/io.py b/backend/nodes/io.py index 40cfae9..7a44f46 100644 --- a/backend/nodes/io.py +++ b/backend/nodes/io.py @@ -9,12 +9,16 @@ from pathlib import Path from backend.node_registry import register_node from backend.data_types import DataField, encode_preview, image_to_uint8 -from backend.runtime_paths import input_dir, output_dir +from backend.runtime_paths import demo_dir, input_dir, output_dir # Resolved at server startup so nodes know where to look +DEMO_DIR = demo_dir() INPUT_DIR = input_dir() OUTPUT_DIR = output_dir() +_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz", + ".gwy", ".sxm", ".ibw"} + # --------------------------------------------------------------------------- # LoadImage @@ -68,6 +72,81 @@ class LoadImage: return (arr, field) +# --------------------------------------------------------------------------- +# LoadDemo +# --------------------------------------------------------------------------- + +def _list_demo_files() -> list[str]: + """Return sorted list of demo filenames available in the demo/ directory.""" + if not DEMO_DIR.exists(): + return [] + return sorted( + f.name for f in DEMO_DIR.iterdir() + if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS + ) + + +@register_node(display_name="Load Demo Image") +class LoadDemo: + @classmethod + def INPUT_TYPES(cls): + choices = _list_demo_files() or ["(no demo images found)"] + return { + "required": { + "name": (choices,), + } + } + + RETURN_TYPES = ("IMAGE", "DATA_FIELD") + RETURN_NAMES = ("image", "field") + FUNCTION = "load" + CATEGORY = "io" + DESCRIPTION = "Load a bundled demo image so you can try the app without providing your own data." + + def load(self, name: str): + path = DEMO_DIR / name + if not path.exists(): + raise FileNotFoundError(f"Demo image not found: {name}") + + ext = path.suffix.lower() + + # SPM formats → delegate to LoadSPM-style loading, return as IMAGE + DATA_FIELD + if ext == ".gwy": + field = LoadSPM()._load_gwy(path, "Z") + arr = field.data + return (arr, field) + elif ext == ".sxm": + field = LoadSPM()._load_sxm(path, "Z") + arr = field.data + return (arr, field) + elif ext == ".ibw": + field = LoadSPM()._load_ibw(path) + arr = field.data + return (arr, field) + + # npy / npz + if ext == ".npy": + arr = np.load(str(path)).astype(np.float64) + elif ext == ".npz": + npz = np.load(str(path)) + key = list(npz.files)[0] + arr = npz[key].astype(np.float64) + else: + from PIL import Image + img = Image.open(str(path)) + arr = np.array(img) + if arr.dtype != np.uint8: + arr = arr.astype(np.float64) + + if arr.ndim == 3: + gray = np.mean(arr.astype(np.float64), axis=2) + else: + gray = arr.astype(np.float64) + + field = DataField(data=gray) + return (arr, field) + + # --------------------------------------------------------------------------- # LoadSPM # --------------------------------------------------------------------------- diff --git a/backend/nodes/mask.py b/backend/nodes/mask.py new file mode 100644 index 0000000..9203488 --- /dev/null +++ b/backend/nodes/mask.py @@ -0,0 +1,273 @@ +""" +Mask operation nodes — creation, morphology, and boolean combination. + +Gwyddion equivalents: + ThresholdMask → threshold.c / otsu_threshold.c + MaskMorphology → mask_morph.c (erode, dilate, open, close) + MaskInvert → (bitwise NOT on mask) + MaskCombine → (boolean ops between two masks) +""" + +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, datafield_to_uint8, encode_preview + + +def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray: + """Render greyscale base image with red shadow on masked (255) pixels. + + Returns (H, W, 3) uint8 array. + """ + grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8 + overlay = grey.astype(np.float64) + mask_bool = mask == 255 + alpha = 0.45 + overlay[mask_bool, 0] = overlay[mask_bool, 0] * (1 - alpha) + 255 * alpha + overlay[mask_bool, 1] = overlay[mask_bool, 1] * (1 - alpha) + overlay[mask_bool, 2] = overlay[mask_bool, 2] * (1 - alpha) + return np.clip(overlay, 0, 255).astype(np.uint8) + + +# --------------------------------------------------------------------------- +# ThresholdMask +# --------------------------------------------------------------------------- + +@register_node(display_name="Threshold Mask") +class ThresholdMask: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "method": (["otsu", "absolute", "relative"],), + "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}), + "direction": (["above", "below"],), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + CATEGORY = "mask" + DESCRIPTION = ( + "Create a binary mask by thresholding data. " + "Otsu automatically finds the optimal threshold. " + "Equivalent to Gwyddion's threshold and otsu_threshold modules." + ) + + _broadcast_fn = None + _current_node_id: str = "" + + def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple: + data = field.data + + if method == "otsu": + from skimage.filters import threshold_otsu + t = threshold_otsu(data) + elif method == "absolute": + t = float(threshold) + elif method == "relative": + # threshold is a fraction [0, 1] of the data range + dmin, dmax = data.min(), data.max() + t = dmin + float(threshold) * (dmax - dmin) + else: + raise ValueError(f"Unknown threshold method: {method}") + + if direction == "above": + mask = (data >= t).astype(np.uint8) * 255 + else: + mask = (data < t).astype(np.uint8) * 255 + + if ThresholdMask._broadcast_fn is not None: + overlay = _mask_overlay(field, mask) + ThresholdMask._broadcast_fn( + ThresholdMask._current_node_id, encode_preview(overlay), + ) + + return (mask,) + + +# --------------------------------------------------------------------------- +# MaskMorphology +# --------------------------------------------------------------------------- + +@register_node(display_name="Mask Morphology") +class MaskMorphology: + """Morphological operations on binary masks. + + Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close). + """ + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("IMAGE",), + "operation": (["dilate", "erode", "open", "close"],), + "radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}), + "shape": (["disk", "square"],), + }, + "optional": { + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + CATEGORY = "mask" + DESCRIPTION = ( + "Apply morphological operations to a binary mask. " + "Dilate expands regions, erode shrinks them, " + "open (erode then dilate) removes small spots, " + "close (dilate then erode) fills small holes. " + "Equivalent to Gwyddion mask_morph." + ) + + _broadcast_fn = None + _current_node_id: str = "" + + def process(self, mask: np.ndarray, operation: str, radius: int, shape: str, + field: DataField | None = None) -> tuple: + from scipy.ndimage import binary_dilation, binary_erosion + + binary = mask > 127 + + if shape == "disk": + y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1] + struct = (x * x + y * y) <= radius * radius + else: + size = 2 * radius + 1 + struct = np.ones((size, size), dtype=bool) + + if operation == "dilate": + result = binary_dilation(binary, structure=struct) + elif operation == "erode": + result = binary_erosion(binary, structure=struct) + elif operation == "open": + result = binary_dilation( + binary_erosion(binary, structure=struct), + structure=struct, + ) + elif operation == "close": + result = binary_erosion( + binary_dilation(binary, structure=struct), + structure=struct, + ) + else: + raise ValueError(f"Unknown morphological operation: {operation}") + + out = result.astype(np.uint8) * 255 + + if field is not None and MaskMorphology._broadcast_fn is not None: + overlay = _mask_overlay(field, out) + MaskMorphology._broadcast_fn( + MaskMorphology._current_node_id, encode_preview(overlay), + ) + + return (out,) + + +# --------------------------------------------------------------------------- +# MaskInvert +# --------------------------------------------------------------------------- + +@register_node(display_name="Mask Invert") +class MaskInvert: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("IMAGE",), + }, + "optional": { + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + CATEGORY = "mask" + DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions." + + _broadcast_fn = None + _current_node_id: str = "" + + def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple: + out = np.where(mask > 127, np.uint8(0), np.uint8(255)) + + if field is not None and MaskInvert._broadcast_fn is not None: + overlay = _mask_overlay(field, out) + MaskInvert._broadcast_fn( + MaskInvert._current_node_id, encode_preview(overlay), + ) + + return (out,) + + +# --------------------------------------------------------------------------- +# MaskCombine +# --------------------------------------------------------------------------- + +@register_node(display_name="Mask Combine") +class MaskCombine: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask_a": ("IMAGE",), + "mask_b": ("IMAGE",), + "operation": (["and", "or", "xor", "subtract"],), + }, + "optional": { + "field": ("DATA_FIELD",), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + CATEGORY = "mask" + DESCRIPTION = ( + "Combine two binary masks with a boolean operation. " + "AND keeps overlap, OR merges, XOR keeps non-overlapping regions, " + "subtract removes mask_b from mask_a." + ) + + _broadcast_fn = None + _current_node_id: str = "" + + def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str, + field: DataField | None = None) -> tuple: + a = mask_a > 127 + b = mask_b > 127 + + if operation == "and": + result = a & b + elif operation == "or": + result = a | b + elif operation == "xor": + result = a ^ b + elif operation == "subtract": + result = a & ~b + else: + raise ValueError(f"Unknown mask operation: {operation}") + + out = result.astype(np.uint8) * 255 + + if field is not None and MaskCombine._broadcast_fn is not None: + overlay = _mask_overlay(field, out) + MaskCombine._broadcast_fn( + MaskCombine._current_node_id, encode_preview(overlay), + ) + + return (out,) diff --git a/backend/runtime_paths.py b/backend/runtime_paths.py index bab42db..7ddbe72 100644 --- a/backend/runtime_paths.py +++ b/backend/runtime_paths.py @@ -4,7 +4,7 @@ import os import sys from pathlib import Path -APP_NAME = "Argonode" +APP_NAME = "argonode" def project_root() -> Path: @@ -34,13 +34,26 @@ def app_data_dir() -> Path: return Path(override).expanduser().resolve() if getattr(sys, "frozen", False): - local_appdata = os.getenv("LOCALAPPDATA") - base_dir = Path(local_appdata) if local_appdata else Path.home() / "AppData" / "Local" + if sys.platform == "darwin": + base_dir = Path.home() / "Library" / "Application Support" + elif sys.platform == "linux": + xdg = os.getenv("XDG_DATA_HOME") + base_dir = Path(xdg) if xdg else Path.home() / ".local" / "share" + else: + local_appdata = os.getenv("LOCALAPPDATA") + base_dir = Path(local_appdata) if local_appdata else Path.home() / "AppData" / "Local" return (base_dir / APP_NAME).resolve() return project_root() +def demo_dir() -> Path: + bundled = resource_root() / "demo" + if bundled.exists(): + return bundled + return project_root() / "demo" + + def input_dir() -> Path: return app_data_dir() / "input" diff --git a/demo/.gitkeep b/demo/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/demo/nanoparticles.npy b/demo/nanoparticles.npy new file mode 100644 index 0000000..d679672 Binary files /dev/null and b/demo/nanoparticles.npy differ diff --git a/demo/nanoparticles.png b/demo/nanoparticles.png new file mode 100644 index 0000000..dda6204 Binary files /dev/null and b/demo/nanoparticles.png differ diff --git a/desktop.py b/desktop.py index 4c70508..dd9145b 100644 --- a/desktop.py +++ b/desktop.py @@ -14,7 +14,34 @@ from backend.runtime_paths import app_data_dir, ensure_runtime_dirs from backend.server import create_app HOST = "127.0.0.1" -WINDOW_TITLE = "Argonode" +WINDOW_TITLE = "argonode" + + +class _Api: + """Exposed to JavaScript as window.pywebview.api.""" + + def __init__(self, window_ref: list): + self._window_ref = window_ref + + def open_file_dialog(self) -> str | None: + """Open a native file picker and return the selected path (or None).""" + win = self._window_ref[0] + if win is None: + return None + result = win.create_file_dialog( + webview.OPEN_DIALOG, + allow_multiple=False, + file_types=( + "All supported (*.png;*.jpg;*.jpeg;*.tiff;*.tif;*.npy;*.npz;*.gwy;*.sxm;*.ibw)", + "Images (*.png;*.jpg;*.jpeg;*.tiff;*.tif)", + "NumPy (*.npy;*.npz)", + "SPM (*.gwy;*.sxm;*.ibw)", + "All files (*.*)", + ), + ) + if result and len(result) > 0: + return result[0] + return None def _pick_free_port() -> int: @@ -85,17 +112,22 @@ def main() -> None: ready.wait(timeout=15.0) if "error" in state: - raise RuntimeError("Argonode server failed to start") from state["error"] + raise RuntimeError("argonode server failed to start") from state["error"] _wait_for_server(f"{base_url}/nodes") + window_ref: list[webview.Window | None] = [None] + js_api = _Api(window_ref) + window = webview.create_window( WINDOW_TITLE, base_url, width=1600, height=1000, min_size=(1100, 720), + js_api=js_api, ) + window_ref[0] = window def _shutdown() -> None: loop = state.get("loop") diff --git a/frontend/index.html b/frontend/index.html index 7418eec..146d241 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -3,7 +3,7 @@ - Argonode — Image Analysis + argonode — Image Analysis
diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 471f968..98af0fe 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -7,6 +7,7 @@ "name": "argonode-frontend", "dependencies": { "@xyflow/react": "^12.0.0", + "html-to-image": "^1.11.13", "react": "^18.3.0", "react-dom": "^18.3.0", "three": "^0.183.2" @@ -1539,6 +1540,12 @@ "node": ">=6.9.0" } }, + "node_modules/html-to-image": { + "version": "1.11.13", + "resolved": "https://registry.npmjs.org/html-to-image/-/html-to-image-1.11.13.tgz", + "integrity": "sha512-cuOPoI7WApyhBElTTb9oqsawRvZ0rHhaHwghRLlTuffoD1B2aDemlCruLeZrUIIdvG7gs9xeELEPm6PhuASqrg==", + "license": "MIT" + }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index b489587..2082038 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -13,6 +13,7 @@ }, "dependencies": { "@xyflow/react": "^12.0.0", + "html-to-image": "^1.11.13", "react": "^18.3.0", "react-dom": "^18.3.0", "three": "^0.183.2" diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index a75acc7..5a818a4 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -4,24 +4,26 @@ import React, { import { ReactFlow, Background, Controls, MiniMap, useNodesState, useEdgesState, addEdge, useReactFlow, - ReactFlowProvider, + ReactFlowProvider, getNodesBounds, getViewportForBounds, } from '@xyflow/react'; import '@xyflow/react/dist/style.css'; import CustomNode, { NodeContext } from './CustomNode'; import FileBrowser from './FileBrowser'; import * as api from './api'; +import { toBlob } from 'html-to-image'; +import { embedWorkflow, extractWorkflow } from './pngMetadata'; // ── Constants ───────────────────────────────────────────────────────── const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']); const TYPE_COLORS = { - DATA_FIELD: '#3a7abf', - IMAGE: '#4caf50', - LINE: '#ff9800', - TABLE: '#fdd835', - COORD: '#e91e63', + DATA_FIELD: '#ff002f', + IMAGE: '#00ff08a0', + LINE: '#ffbe5c', + TABLE: '#35e2fd', + COORD: '#e91ed1', }; const NODE_TYPES = { custom: CustomNode }; @@ -272,6 +274,13 @@ function Flow() { // ── File browser ──────────────────────────────────────────────────── const openFileBrowser = useCallback((callback) => { + // Use native file picker when running inside pywebview (desktop app) + if (window.pywebview?.api?.open_file_dialog) { + window.pywebview.api.open_file_dialog().then((path) => { + if (path) callback(path); + }); + return; + } setFileBrowserCb(() => callback); }, []); @@ -427,60 +436,162 @@ function Flow() { setStatus({ text: 'Graph cleared.', level: 'info' }); }, [setNodes, setEdges]); - const saveWorkflow = useCallback(() => { - const currentNodes = reactFlow.getNodes().map((n) => ({ + const applyWorkflowData = useCallback((data) => { + const loadedNodes = data.nodes || []; + const loadedEdges = data.edges || []; + const defs = nodeDefsRef.current; + const hydrated = loadedNodes.map((n) => ({ + ...n, + data: { + ...n.data, + definition: defs[n.data.className] || n.data.definition, + previewImage: null, tableRows: null, meshData: null, overlay: null, + }, + })); + setNodes(hydrated); + setEdges(loadedEdges); + const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0)); + nextIdRef.current = maxId + 1; + }, [setNodes, setEdges]); + + const getWorkflowBlob = useCallback(async () => { + const viewportEl = document.querySelector('.react-flow__viewport'); + if (!viewportEl) throw new Error('Flow element not found'); + + const allNodes = reactFlow.getNodes(); + if (allNodes.length === 0) throw new Error('No nodes to capture'); + + const bounds = getNodesBounds(allNodes); + const pad = 0.1; // 10% margin on each side + const imageWidth = Math.ceil(bounds.width * (1 + pad * 2)); + const imageHeight = Math.ceil(bounds.height * (1 + pad * 2)); + const vp = getViewportForBounds(bounds, imageWidth, imageHeight, 0.5, 1, pad); + + const blob = await toBlob(viewportEl, { + backgroundColor: '#1a1a1a', + width: imageWidth, + height: imageHeight, + style: { + width: `${imageWidth}px`, + height: `${imageHeight}px`, + transform: `translate(${vp.x}px, ${vp.y}px) scale(${vp.zoom})`, + }, + }); + if (!blob) throw new Error('Capture returned empty'); + + const currentNodes = allNodes.map((n) => ({ ...n, data: { ...n.data, previewImage: null, tableRows: null, meshData: null, overlay: null }, })); - const data = { version: 1, nodes: currentNodes, edges: reactFlow.getEdges() }; - const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' }); - const a = document.createElement('a'); - a.href = URL.createObjectURL(blob); - a.download = 'workflow.json'; - a.click(); + const workflow = { version: 1, nodes: currentNodes, edges: reactFlow.getEdges() }; + return embedWorkflow(blob, workflow); }, [reactFlow]); + const saveWorkflow = useCallback(async () => { + setStatus({ text: 'Saving…', level: 'info' }); + try { + const finalBlob = await getWorkflowBlob(); + + if (window.showSaveFilePicker) { + const handle = await window.showSaveFilePicker({ + suggestedName: 'workflow.png', + types: [{ description: 'PNG Image', accept: { 'image/png': ['.png'] } }], + }); + const writable = await handle.createWritable(); + await writable.write(finalBlob); + await writable.close(); + } else { + // Fallback: programmatic download + const a = document.createElement('a'); + a.href = URL.createObjectURL(finalBlob); + a.download = 'workflow.png'; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(a.href); + } + + setStatus({ text: 'Workflow saved.', level: 'info' }); + } catch (err) { + if (err.name === 'AbortError') { + setStatus({ text: 'Save cancelled.', level: 'info' }); + } else { + setStatus({ text: 'Save failed: ' + err.message, level: 'error' }); + } + } + }, [getWorkflowBlob]); + + const copySnapshot = useCallback(() => { + setStatus({ text: 'Copying snapshot…', level: 'info' }); + // Pass a Promise to ClipboardItem so the clipboard.write() call + // happens synchronously within the user gesture, avoiding permission errors. + const blobPromise = getWorkflowBlob().catch((err) => { + setStatus({ text: 'Snapshot failed: ' + err.message, level: 'error' }); + throw err; + }); + navigator.clipboard.write([new ClipboardItem({ 'image/png': blobPromise })]).then(() => { + setStatus({ text: 'Snapshot copied to clipboard.', level: 'info' }); + }).catch((err) => { + setStatus({ text: 'Copy failed: ' + err.message, level: 'error' }); + }); + }, [getWorkflowBlob]); + const loadWorkflow = useCallback(() => { const input = document.createElement('input'); input.type = 'file'; - input.accept = '.json'; + input.accept = '.json,.png'; input.onchange = async (e) => { const file = e.target.files[0]; if (!file) return; - const text = await file.text(); try { - const data = JSON.parse(text); - const loadedNodes = data.nodes || []; - const loadedEdges = data.edges || []; - - // Re-populate definitions from current nodeDefs - const defs = nodeDefsRef.current; - const hydrated = loadedNodes.map((n) => ({ - ...n, - data: { - ...n.data, - definition: defs[n.data.className] || n.data.definition, - previewImage: null, - tableRows: null, - meshData: null, - overlay: null, - }, - })); - - setNodes(hydrated); - setEdges(loadedEdges); - - // Update ID counter to avoid collisions - const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0)); - nextIdRef.current = maxId + 1; - + let data; + if (file.name.endsWith('.png') || file.type === 'image/png') { + data = await extractWorkflow(file); + if (!data) { + setStatus({ text: 'No workflow data found in image.', level: 'error' }); + return; + } + } else { + data = JSON.parse(await file.text()); + } + applyWorkflowData(data); setStatus({ text: 'Workflow loaded.', level: 'info' }); } catch { - setStatus({ text: 'Invalid workflow JSON.', level: 'error' }); + setStatus({ text: 'Invalid workflow file.', level: 'error' }); } }; input.click(); - }, [setNodes, setEdges]); + }, [applyWorkflowData]); + + // ── Drag-and-drop workflow image loading ─────────────────────────── + + const onDropFile = useCallback(async (event) => { + const files = event.dataTransfer?.files; + if (!files || files.length === 0) return; + event.preventDefault(); + + const file = files[0]; + if (file.type !== 'image/png') return; + + try { + const data = await extractWorkflow(file); + if (!data) { + setStatus({ text: 'No workflow data in this image.', level: 'error' }); + return; + } + applyWorkflowData(data); + setStatus({ text: 'Workflow loaded from image.', level: 'info' }); + } catch (err) { + setStatus({ text: 'Failed to load: ' + err.message, level: 'error' }); + } + }, [applyWorkflowData]); + + const onDragOver = useCallback((event) => { + if (event.dataTransfer?.types?.includes('Files')) { + event.preventDefault(); + event.dataTransfer.dropEffect = 'copy'; + } + }, []); // ── Keyboard shortcut ─────────────────────────────────────────────── @@ -509,7 +620,7 @@ function Flow() {
{/* Toolbar */}
- Argonode + argonode
- - +
{status.text}
@@ -535,7 +649,7 @@ function Flow() { {/* React Flow canvas */}
{ if (!e.target.closest('.context-menu')) setContextMenu(null); - }}> + }} onDrop={onDropFile} onDragOver={onDragOver}> >> 1)) : (c >>> 1); + } + crcTable[i] = c; +} + +function crc32(bytes) { + let crc = 0xFFFFFFFF; + for (let i = 0; i < bytes.length; i++) { + crc = crcTable[(crc ^ bytes[i]) & 0xFF] ^ (crc >>> 8); + } + return (crc ^ 0xFFFFFFFF) >>> 0; +} + +// ── Helpers ────────────────────────────────────────────────────────── + +const PNG_SIG = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]; + +function isPng(data) { + if (data.length < 8) return false; + for (let i = 0; i < 8; i++) { + if (data[i] !== PNG_SIG[i]) return false; + } + return true; +} + +function chunkType(data, offset) { + return String.fromCharCode( + data[offset + 4], data[offset + 5], data[offset + 6], data[offset + 7], + ); +} + +// ── Public API ─────────────────────────────────────────────────────── + +/** + * Embed a workflow object into a PNG blob as a tEXt chunk. + * Returns a new Blob with the metadata inserted before IEND. + */ +export async function embedWorkflow(pngBlob, workflow) { + const data = new Uint8Array(await pngBlob.arrayBuffer()); + if (!isPng(data)) throw new Error('Not a valid PNG file'); + + const encoder = new TextEncoder(); + + // Build tEXt payload: keyword \0 text + const key = encoder.encode('workflow'); + const val = encoder.encode(JSON.stringify(workflow)); + const payload = new Uint8Array(key.length + 1 + val.length); + payload.set(key, 0); + // payload[key.length] is already 0 (null separator) + payload.set(val, key.length + 1); + + // CRC covers type + payload + const typeBytes = encoder.encode('tEXt'); + const forCrc = new Uint8Array(4 + payload.length); + forCrc.set(typeBytes, 0); + forCrc.set(payload, 4); + + // Assemble chunk: length(4) + type(4) + payload + crc(4) + const chunk = new Uint8Array(12 + payload.length); + const view = new DataView(chunk.buffer); + view.setUint32(0, payload.length); + chunk.set(typeBytes, 4); + chunk.set(payload, 8); + view.setUint32(8 + payload.length, crc32(forCrc)); + + // Locate IEND + let pos = 8; + let iendPos = data.length; + while (pos < data.length) { + const len = new DataView(data.buffer, pos, 4).getUint32(0); + if (chunkType(data, pos) === 'IEND') { iendPos = pos; break; } + pos += 12 + len; + } + + // Splice: [before IEND] + [tEXt chunk] + [IEND] + const result = new Uint8Array(data.length + chunk.length); + result.set(data.subarray(0, iendPos), 0); + result.set(chunk, iendPos); + result.set(data.subarray(iendPos), iendPos + chunk.length); + + return new Blob([result], { type: 'image/png' }); +} + +/** + * Extract the workflow object from a PNG blob's tEXt chunks. + * Returns the parsed object, or null if no "workflow" key is found. + */ +export async function extractWorkflow(pngBlob) { + const data = new Uint8Array(await pngBlob.arrayBuffer()); + if (!isPng(data)) return null; + + const decoder = new TextDecoder(); + let pos = 8; + + while (pos + 8 <= data.length) { + const len = new DataView(data.buffer, pos, 4).getUint32(0); + const type = chunkType(data, pos); + + if (type === 'tEXt' && pos + 8 + len <= data.length) { + const chunkData = data.subarray(pos + 8, pos + 8 + len); + const nullIdx = chunkData.indexOf(0); + if (nullIdx !== -1) { + const k = decoder.decode(chunkData.subarray(0, nullIdx)); + if (k === 'workflow') { + return JSON.parse(decoder.decode(chunkData.subarray(nullIdx + 1))); + } + } + } + + if (type === 'IEND') break; + pos += 12 + len; + } + + return null; +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index 27a0fab..5b84a67 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -21,8 +21,8 @@ html, body, #root { /* ── Toolbar ───────────────────────────────────────────────────────── */ #toolbar { height: 44px; - background: #16213e; - border-bottom: 1px solid #0f3460; + background: #242424; + border-bottom: 1px solid #000000; display: flex; align-items: center; padding: 0 12px; @@ -36,7 +36,7 @@ html, body, #root { font-size: 15px; font-weight: 700; letter-spacing: 0.5px; - color: #e94560; + color: #ffffff; margin-right: 8px; flex-shrink: 0; } @@ -129,8 +129,17 @@ html, body, #root { cursor: grabbing; } -.custom-node.selected { +/* Selected node — target via React Flow's wrapper class */ +.react-flow__node.selected .custom-node { border-color: #90caf9; + box-shadow: 0 0 0 1px #90caf9, 0 0 12px rgba(144, 202, 249, 0.4); +} + +/* Selected edge */ +.react-flow__edge.selected .react-flow__edge-path { + stroke: #90caf9 !important; + stroke-width: 3px !important; + filter: drop-shadow(0 0 4px rgba(144, 202, 249, 0.6)); } .node-title { diff --git a/package.json b/package.json index b27a7b7..5fe014e 100644 --- a/package.json +++ b/package.json @@ -12,6 +12,8 @@ "preview": "npm --prefix frontend run preview", "backend": "python -m backend.main", "desktop": "python desktop.py", - "build:desktop": "powershell -ExecutionPolicy Bypass -File scripts\\build-desktop.ps1" + "build:windows": "powershell -ExecutionPolicy Bypass -File scripts\\build-windows.ps1", + "build:mac": "bash scripts/build-mac.sh", + "build:linux": "bash scripts/build-linux.sh" } } diff --git a/pyproject.toml b/pyproject.toml index 538b571..e83b80c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,10 @@ readme = "GWYDDION_FEATURE_GAP.md" requires-python = ">=3.10" dependencies = [ "aiohttp>=3.9,<4", + "gwyfile>=0.2", + "igor>=0.3", "matplotlib>=3.8,<4", + "nanonispy>=1.1", "numpy>=1.26,<3", "pillow>=10,<12", "scikit-image>=0.22,<1", @@ -18,11 +21,6 @@ dependencies = [ ] [project.optional-dependencies] -spm = [ - "gwyfile>=0.2", - "igor>=0.3", - "nanonispy>=1.1", -] dev = [ "pytest>=8,<9", ] diff --git a/scripts/build-linux.sh b/scripts/build-linux.sh new file mode 100755 index 0000000..df926c5 --- /dev/null +++ b/scripts/build-linux.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +set -euo pipefail + +ONE_FILE=false +CREATE_TAR=true + +while [[ $# -gt 0 ]]; do + case "$1" in + --onefile) ONE_FILE=true; shift ;; + --no-tar) CREATE_TAR=false; shift ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$REPO_ROOT" + +if [ -d ".venv/bin" ]; then + PYTHON=".venv/bin/python" +else + PYTHON="python3" +fi + +FRONTEND_DIST="$REPO_ROOT/frontend/dist" +DEMO_DIR="$REPO_ROOT/demo" + +echo "Building frontend bundle..." +npm run build + +echo "Installing desktop build dependencies..." +uv pip install -e ".[desktop]" + +if $ONE_FILE; then + MODE="--onefile" +else + MODE="--onedir" +fi + +echo "Packaging desktop app with PyInstaller..." +$PYTHON -m PyInstaller \ + desktop.py \ + --noconfirm \ + --clean \ + --name argonode \ + --windowed \ + $MODE \ + --distpath desktop-dist \ + --workpath desktop-build \ + --specpath desktop-build \ + --add-data "${FRONTEND_DIST}:frontend/dist" \ + --add-data "${DEMO_DIR}:demo" \ + --collect-all matplotlib \ + --collect-all scipy \ + --collect-all skimage \ + --collect-all webview + +if $CREATE_TAR; then + TAR_PATH="desktop-dist/argonode-linux.tar.gz" + echo "Creating tarball..." + tar -czf "$TAR_PATH" -C desktop-dist argonode + echo "Tarball created: $TAR_PATH" +fi + +echo "Desktop build complete." +echo "Output: $REPO_ROOT/desktop-dist/" diff --git a/scripts/build-mac.sh b/scripts/build-mac.sh new file mode 100755 index 0000000..c6f92ff --- /dev/null +++ b/scripts/build-mac.sh @@ -0,0 +1,110 @@ +#!/usr/bin/env bash +set -euo pipefail + +ONE_FILE=false +CREATE_DMG=true + +while [[ $# -gt 0 ]]; do + case "$1" in + --onefile) ONE_FILE=true; shift ;; + --no-dmg) CREATE_DMG=false; shift ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$REPO_ROOT" + +if [ -d ".venv/bin" ]; then + PYTHON=".venv/bin/python" +else + PYTHON="python3" +fi + +FRONTEND_DIST="$REPO_ROOT/frontend/dist" +DEMO_DIR="$REPO_ROOT/demo" + +echo "Building frontend bundle..." +npm run build + +echo "Installing desktop build dependencies..." +uv pip install -e ".[desktop]" + +if $ONE_FILE; then + MODE="--onefile" +else + MODE="--onedir" +fi + +echo "Packaging desktop app with PyInstaller..." +$PYTHON -m PyInstaller \ + desktop.py \ + --noconfirm \ + --clean \ + --name argonode \ + --windowed \ + $MODE \ + --distpath desktop-dist \ + --workpath desktop-build \ + --specpath desktop-build \ + --add-data "${FRONTEND_DIST}:frontend/dist" \ + --add-data "${DEMO_DIR}:demo" \ + --collect-all matplotlib \ + --collect-all scipy \ + --collect-all skimage \ + --collect-all webview \ + --icon resources/icon.icns 2>/dev/null || \ +$PYTHON -m PyInstaller \ + desktop.py \ + --noconfirm \ + --clean \ + --name argonode \ + --windowed \ + $MODE \ + --distpath desktop-dist \ + --workpath desktop-build \ + --specpath desktop-build \ + --add-data "${FRONTEND_DIST}:frontend/dist" \ + --add-data "${DEMO_DIR}:demo" \ + --collect-all matplotlib \ + --collect-all scipy \ + --collect-all skimage \ + --collect-all webview + +APP_BUNDLE="desktop-dist/argonode.app" + +if [ ! -d "$APP_BUNDLE" ]; then + # --onedir puts it inside a folder + if [ -d "desktop-dist/argonode/argonode.app" ]; then + APP_BUNDLE="desktop-dist/argonode/argonode.app" + else + echo "Warning: .app bundle not found; skipping DMG creation." + CREATE_DMG=false + fi +fi + +if $CREATE_DMG; then + DMG_PATH="desktop-dist/argonode.dmg" + echo "Creating DMG installer..." + rm -f "$DMG_PATH" + + # Create a temporary directory for DMG contents + DMG_STAGING="desktop-build/dmg-staging" + rm -rf "$DMG_STAGING" + mkdir -p "$DMG_STAGING" + cp -R "$APP_BUNDLE" "$DMG_STAGING/" + ln -s /Applications "$DMG_STAGING/Applications" + + hdiutil create \ + -volname "argonode" \ + -srcfolder "$DMG_STAGING" \ + -ov \ + -format UDZO \ + "$DMG_PATH" + + rm -rf "$DMG_STAGING" + echo "DMG created: $DMG_PATH" +fi + +echo "Desktop build complete." +echo "Output: $REPO_ROOT/desktop-dist/" diff --git a/scripts/build-desktop.ps1 b/scripts/build-windows.ps1 similarity index 86% rename from scripts/build-desktop.ps1 rename to scripts/build-windows.ps1 index ac4c302..b129820 100644 --- a/scripts/build-desktop.ps1 +++ b/scripts/build-windows.ps1 @@ -14,6 +14,7 @@ $pythonExe = if (Test-Path ".\.venv\Scripts\python.exe") { "python" } $frontendDist = Join-Path $repoRoot "frontend\dist" +$demoDir = Join-Path $repoRoot "demo" Write-Host "Building frontend bundle..." npm run build @@ -28,13 +29,14 @@ $pyInstallerArgs = @( "desktop.py", "--noconfirm", "--clean", - "--name", "Argonode", + "--name", "argonode", "--windowed", $mode, "--distpath", "desktop-dist", "--workpath", "desktop-build", "--specpath", "desktop-build", "--add-data", "${frontendDist};frontend/dist", + "--add-data", "${demoDir};demo", "--collect-all", "matplotlib", "--collect-all", "scipy", "--collect-all", "skimage", @@ -45,4 +47,4 @@ Write-Host "Packaging desktop app..." & $pythonExe @pyInstallerArgs Write-Host "Desktop build complete." -Write-Host "Output folder: $repoRoot\desktop-dist\Argonode" +Write-Host "Output folder: $repoRoot\desktop-dist\argonode" diff --git a/scripts/generate_demo_particles.py b/scripts/generate_demo_particles.py new file mode 100644 index 0000000..7d8899e --- /dev/null +++ b/scripts/generate_demo_particles.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +"""Generate a synthetic nanoparticle image for the demo/ folder. + +The image simulates an AFM scan of particles on a flat substrate: + - Slightly noisy background + - ~20 hemisphere-shaped particles with varying radii and heights + - Saved as both .npy (calibrated float64) and .png (visual preview) + +Run from project root: + python scripts/generate_demo_particles.py +""" + +import numpy as np +from pathlib import Path + +DEMO_DIR = Path(__file__).resolve().parent.parent / "demo" +DEMO_DIR.mkdir(exist_ok=True) + +RNG = np.random.default_rng(2024) + +# --- Image parameters --- +N = 256 # pixels +SCAN_SIZE = 5e-6 # 5 µm scan +PIXEL_SIZE = SCAN_SIZE / N # metres per pixel +BG_NOISE_RMS = 0.3e-9 # 0.3 nm background noise + +# --- Generate particles --- +particles = [] +# Hand-placed cluster + random scatter to give a realistic spread +fixed = [ + # (cx_frac, cy_frac, radius_nm, height_nm) + (0.25, 0.30, 120, 30), + (0.28, 0.34, 80, 20), + (0.70, 0.25, 150, 45), + (0.50, 0.55, 100, 25), + (0.55, 0.60, 60, 15), + (0.15, 0.75, 200, 55), + (0.80, 0.80, 90, 22), +] +for cx_f, cy_f, r_nm, h_nm in fixed: + particles.append((cx_f * N, cy_f * N, r_nm * 1e-9, h_nm * 1e-9)) + +# Random particles +for _ in range(15): + cx = RNG.uniform(20, N - 20) + cy = RNG.uniform(20, N - 20) + radius = RNG.uniform(30, 180) * 1e-9 # 30–180 nm + height = RNG.uniform(8, 60) * 1e-9 # 8–60 nm + particles.append((cx, cy, radius, height)) + +# --- Render height map --- +image = RNG.normal(0, BG_NOISE_RMS, (N, N)) + +yy, xx = np.mgrid[0:N, 0:N] + +for cx, cy, radius_m, height_m in particles: + radius_px = radius_m / PIXEL_SIZE + dist2 = (xx - cx) ** 2 + (yy - cy) ** 2 + inside = dist2 < radius_px ** 2 + # Hemisphere profile: z = h * sqrt(1 - (r/R)^2) + z = np.zeros_like(image) + z[inside] = height_m * np.sqrt(1.0 - dist2[inside] / radius_px ** 2) + image = np.maximum(image, z) # particles don't subtract from each other + +# --- Save .npy (float64 metres) --- +npy_path = DEMO_DIR / "nanoparticles.npy" +np.save(str(npy_path), image) +print(f"Saved {npy_path} shape={image.shape} range=[{image.min():.2e}, {image.max():.2e}] m") + +# --- Save .png (8-bit grayscale for quick visual) --- +from PIL import Image + +normed = (image - image.min()) / (image.max() - image.min()) +uint8 = (normed * 255).astype(np.uint8) +png_path = DEMO_DIR / "nanoparticles.png" +Image.fromarray(uint8, mode="L").save(str(png_path)) +print(f"Saved {png_path}") + +print(f"\n{len(particles)} particles generated on a {SCAN_SIZE*1e6:.0f} µm × {SCAN_SIZE*1e6:.0f} µm scan") diff --git a/tests/test_grains.py b/tests/test_grains.py new file mode 100644 index 0000000..7b94f85 --- /dev/null +++ b/tests/test_grains.py @@ -0,0 +1,445 @@ +""" +Thorough tests for the grain/particle analysis pipeline: + ThresholdMask → GrainAnalysis + +Covers synthetic geometry (known answers), the demo nanoparticles image, +edge cases, and physical-unit correctness. + +Run from project root: + .venv/bin/python -m tests.test_grains +""" + +import sys +import numpy as np + +sys.path.insert(0, ".") +from backend.data_types import DataField + + +def make_field(data, xreal=1e-6, yreal=1e-6): + return DataField(data=data.astype(np.float64), xreal=xreal, yreal=yreal, + si_unit_xy="m", si_unit_z="m") + + +# ========================================================================= +# ThresholdMask tests +# ========================================================================= + +def test_threshold_otsu_bimodal(): + """Otsu on a clean bimodal image should separate the two populations.""" + print("=== Test: Otsu on bimodal image ===") + from backend.nodes.grains import ThresholdMask + node = ThresholdMask() + + data = np.zeros((128, 128)) + data[30:50, 30:50] = 10.0 # bright square + data[70:100, 80:110] = 10.0 # another bright region + field = make_field(data) + + mask, = node.process(field, method="otsu", threshold=0.0, direction="above") + bright_pixels = (mask == 255) + # Should capture both bright regions + assert bright_pixels[40, 40], "Otsu missed bright region 1" + assert bright_pixels[85, 95], "Otsu missed bright region 2" + # Background should be dark + assert not bright_pixels[0, 0], "Otsu false positive in background" + assert not bright_pixels[60, 60], "Otsu false positive between regions" + print(" PASS\n") + + +def test_threshold_relative_range(): + """Relative threshold at 0.5 should be the midpoint of [min, max].""" + print("=== Test: Relative threshold at midpoint ===") + from backend.nodes.grains import ThresholdMask + node = ThresholdMask() + + data = np.full((64, 64), 2.0) + data[10:20, 10:20] = 8.0 # bright patch, range = [2, 8], midpoint = 5 + field = make_field(data) + + mask, = node.process(field, method="relative", threshold=0.5, direction="above") + # Only the bright patch (value 8 >= 5) should be masked + assert np.all(mask[10:20, 10:20] == 255) + assert np.all(mask[0:10, :] == 0) + assert np.all(mask[20:, :] == 0) + print(" PASS\n") + + +def test_threshold_empty_mask(): + """Very high absolute threshold on low data should produce an empty mask.""" + print("=== Test: Empty mask from high threshold ===") + from backend.nodes.grains import ThresholdMask + node = ThresholdMask() + + data = np.ones((64, 64)) + field = make_field(data) + + mask, = node.process(field, method="absolute", threshold=999.0, direction="above") + assert mask.sum() == 0, "Mask should be completely empty" + print(" PASS\n") + + +def test_threshold_full_mask(): + """Very low absolute threshold should produce an all-white mask.""" + print("=== Test: Full mask from low threshold ===") + from backend.nodes.grains import ThresholdMask + node = ThresholdMask() + + data = np.ones((64, 64)) * 5.0 + field = make_field(data) + + mask, = node.process(field, method="absolute", threshold=-1.0, direction="above") + assert np.all(mask == 255), "Mask should be all white" + print(" PASS\n") + + +# ========================================================================= +# GrainAnalysis tests +# ========================================================================= + +def test_single_circle_area(): + """A single filled circle — verify pixel count and physical area.""" + print("=== Test: Single circle area ===") + from backend.nodes.grains import GrainAnalysis + node = GrainAnalysis() + + N = 200 + XREAL = 2e-6 # 2 µm + data = np.zeros((N, N)) + mask = np.zeros((N, N), dtype=np.uint8) + + # Draw a filled circle, radius 30 px, centred at (100, 100) + yy, xx = np.mgrid[0:N, 0:N] + r = 30 + circle = ((xx - 100) ** 2 + (yy - 100) ** 2) <= r ** 2 + data[circle] = 5.0 + mask[circle] = 255 + + field = make_field(data, xreal=XREAL, yreal=XREAL) + table, = node.process(field, mask=mask, min_size=1) + + assert len(table) == 1, f"Expected 1 grain, got {len(table)}" + grain = table[0] + + # Pixel area of a discrete circle: should be close to π r² + expected_px = np.pi * r ** 2 + assert abs(grain["area_px"] - expected_px) / expected_px < 0.02, \ + f"area_px={grain['area_px']}, expected≈{expected_px:.0f}" + + # Physical area + pixel_area = (XREAL / N) ** 2 + expected_m2 = grain["area_px"] * pixel_area + assert abs(grain["area_m2"] - expected_m2) < 1e-20, \ + f"area_m2 mismatch: {grain['area_m2']} vs {expected_m2}" + + # Equivalent diameter should be close to 2r in physical units + expected_diam = 2 * r * (XREAL / N) + assert abs(grain["equiv_diam_m"] - expected_diam) / expected_diam < 0.02, \ + f"equiv_diam={grain['equiv_diam_m']:.3e}, expected≈{expected_diam:.3e}" + + # Heights + assert abs(grain["mean_height"] - 5.0) < 1e-10 + assert abs(grain["max_height"] - 5.0) < 1e-10 + print(" PASS\n") + + +def test_multiple_grains_separation(): + """Three well-separated grains of different sizes — check each is reported.""" + print("=== Test: Multiple grain separation ===") + from backend.nodes.grains import GrainAnalysis + node = GrainAnalysis() + + N = 128 + data = np.zeros((N, N)) + mask = np.zeros((N, N), dtype=np.uint8) + + # Grain A: 20×20 block, height 10 + data[10:30, 10:30] = 10.0 + mask[10:30, 10:30] = 255 + + # Grain B: 10×10 block, height 7 + data[60:70, 60:70] = 7.0 + mask[60:70, 60:70] = 255 + + # Grain C: 5×5 block, height 3 + data[100:105, 100:105] = 3.0 + mask[100:105, 100:105] = 255 + + field = make_field(data) + table, = node.process(field, mask=mask, min_size=1) + + assert len(table) == 3, f"Expected 3 grains, got {len(table)}" + + table.sort(key=lambda r: r["area_px"], reverse=True) + assert table[0]["area_px"] == 400 # 20×20 + assert table[1]["area_px"] == 100 # 10×10 + assert table[2]["area_px"] == 25 # 5×5 + + assert abs(table[0]["mean_height"] - 10.0) < 1e-10 + assert abs(table[1]["mean_height"] - 7.0) < 1e-10 + assert abs(table[2]["mean_height"] - 3.0) < 1e-10 + print(" PASS\n") + + +def test_min_size_filtering(): + """min_size should exclude grains smaller than the threshold.""" + print("=== Test: min_size filtering ===") + from backend.nodes.grains import GrainAnalysis + node = GrainAnalysis() + + N = 64 + data = np.zeros((N, N)) + mask = np.zeros((N, N), dtype=np.uint8) + + # Large grain: 15×15 = 225 px + data[5:20, 5:20] = 1.0 + mask[5:20, 5:20] = 255 + + # Medium grain: 8×8 = 64 px + data[30:38, 30:38] = 1.0 + mask[30:38, 30:38] = 255 + + # Tiny grain: 3×3 = 9 px + data[50:53, 50:53] = 1.0 + mask[50:53, 50:53] = 255 + + field = make_field(data) + + # min_size=1: all three + table, = node.process(field, mask=mask, min_size=1) + assert len(table) == 3 + + # min_size=10: drops the 3×3 + table, = node.process(field, mask=mask, min_size=10) + assert len(table) == 2 + + # min_size=100: drops the 3×3 and 8×8 + table, = node.process(field, mask=mask, min_size=100) + assert len(table) == 1 + assert table[0]["area_px"] == 225 + + # min_size=300: drops everything + table, = node.process(field, mask=mask, min_size=300) + assert len(table) == 0 + print(" PASS\n") + + +def test_grain_bounding_box(): + """Bounding box should match the grain extents.""" + print("=== Test: Grain bounding box ===") + from backend.nodes.grains import GrainAnalysis + node = GrainAnalysis() + + N = 64 + data = np.zeros((N, N)) + mask = np.zeros((N, N), dtype=np.uint8) + # Place a grain at rows 20:35, cols 10:45 + data[20:35, 10:45] = 2.0 + mask[20:35, 10:45] = 255 + + field = make_field(data) + table, = node.process(field, mask=mask, min_size=1) + assert len(table) == 1 + + bbox = table[0]["bbox"] + # Format: "(xmin,ymin)-(xmax,ymax)" = "(10,20)-(44,34)" + assert bbox == "(10,20)-(44,34)", f"bbox={bbox}, expected (10,20)-(44,34)" + print(" PASS\n") + + +def test_empty_mask_produces_no_grains(): + """An all-zero mask should yield zero grains.""" + print("=== Test: Empty mask → no grains ===") + from backend.nodes.grains import GrainAnalysis + node = GrainAnalysis() + + field = make_field(np.ones((64, 64))) + mask = np.zeros((64, 64), dtype=np.uint8) + + table, = node.process(field, mask=mask, min_size=1) + assert len(table) == 0 + print(" PASS\n") + + +def test_grain_at_image_edge(): + """A grain touching the image border should still be detected.""" + print("=== Test: Grain at image edge ===") + from backend.nodes.grains import GrainAnalysis + node = GrainAnalysis() + + N = 64 + data = np.zeros((N, N)) + mask = np.zeros((N, N), dtype=np.uint8) + # Grain touching top-left corner + data[0:10, 0:10] = 4.0 + mask[0:10, 0:10] = 255 + + field = make_field(data) + table, = node.process(field, mask=mask, min_size=1) + assert len(table) == 1 + assert table[0]["area_px"] == 100 + assert table[0]["bbox"] == "(0,0)-(9,9)" + print(" PASS\n") + + +def test_adjacent_grains_connectivity(): + """Two diagonally-touching blocks should be separate grains + (scipy.ndimage.label uses 4-connectivity by default).""" + print("=== Test: Diagonal adjacency → separate grains ===") + from backend.nodes.grains import GrainAnalysis + node = GrainAnalysis() + + N = 32 + data = np.zeros((N, N)) + mask = np.zeros((N, N), dtype=np.uint8) + + # Block A + data[5:10, 5:10] = 1.0 + mask[5:10, 5:10] = 255 + + # Block B diagonally adjacent (touching only at corner 10,10) + data[10:15, 10:15] = 1.0 + mask[10:15, 10:15] = 255 + + field = make_field(data) + table, = node.process(field, mask=mask, min_size=1) + # Default label() uses structure that connects diagonals? Let's verify. + # scipy.ndimage.label default is cross-shaped (no diagonals) for 2D + assert len(table) == 2, f"Expected 2 separate grains, got {len(table)}" + print(" PASS\n") + + +# ========================================================================= +# End-to-end pipeline: ThresholdMask → GrainAnalysis +# ========================================================================= + +def test_pipeline_synthetic(): + """Full pipeline on a synthetic image with known geometry.""" + print("=== Test: Full pipeline on synthetic particles ===") + from backend.nodes.grains import ThresholdMask, GrainAnalysis + + N = 200 + XREAL = 10e-6 # 10 µm + rng = np.random.default_rng(99) + + # Background at 0 with small noise, particles as raised bumps + bg = rng.normal(0, 0.1, (N, N)) + particles = np.zeros((N, N)) + + yy, xx = np.mgrid[0:N, 0:N] + + specs = [ + (50, 50, 15, 5.0), # (cx, cy, radius_px, height) + (150, 50, 20, 8.0), + (100, 100, 10, 3.0), + (50, 160, 25, 6.0), + (160, 160, 12, 4.0), + ] + for cx, cy, r, h in specs: + inside = ((xx - cx) ** 2 + (yy - cy) ** 2) <= r ** 2 + particles[inside] = h + + data = bg + particles + field = make_field(data, xreal=XREAL, yreal=XREAL) + + # Step 1: threshold + thresh = ThresholdMask() + mask, = thresh.process(field, method="absolute", threshold=1.0, direction="above") + + # Particles are well above noise, so mask should capture all 5 + assert mask.max() == 255, "No particles detected" + + # Step 2: grain analysis + ga = GrainAnalysis() + table, = ga.process(field, mask=mask, min_size=5) + + assert len(table) == 5, f"Expected 5 grains, got {len(table)}" + + # Verify that detected areas are in the right ballpark + table.sort(key=lambda r: r["area_px"], reverse=True) + expected_areas = sorted([np.pi * r ** 2 for _, _, r, _ in specs], reverse=True) + + for grain, expected_px in zip(table, expected_areas): + ratio = grain["area_px"] / expected_px + assert 0.85 < ratio < 1.15, \ + f"grain area_px={grain['area_px']}, expected≈{expected_px:.0f}, ratio={ratio:.2f}" + + print(" PASS\n") + + +def test_pipeline_demo_image(): + """Run the full pipeline on the bundled demo nanoparticles image.""" + print("=== Test: Full pipeline on demo nanoparticles.npy ===") + from pathlib import Path + from backend.nodes.grains import ThresholdMask, GrainAnalysis + from backend.runtime_paths import demo_dir + + npy_path = demo_dir() / "nanoparticles.npy" + if not npy_path.exists(): + print(" SKIP (demo image not found)\n") + return + + data = np.load(str(npy_path)).astype(np.float64) + # The demo image is a 5 µm × 5 µm scan + field = make_field(data, xreal=5e-6, yreal=5e-6) + + # Threshold to find particles (they are raised above background) + thresh = ThresholdMask() + mask, = thresh.process(field, method="otsu", threshold=0.0, direction="above") + + # Should detect particles + assert mask.max() == 255, "No particles found in demo image" + particle_fraction = (mask == 255).sum() / mask.size + assert 0.01 < particle_fraction < 0.5, \ + f"Suspicious particle fraction: {particle_fraction:.3f}" + print(f" Mask: {particle_fraction*100:.1f}% of pixels are particles") + + # Grain analysis + ga = GrainAnalysis() + table, = ga.process(field, mask=mask, min_size=20) + + assert len(table) > 0, "No grains detected" + print(f" Found {len(table)} grains (min_size=20)") + + # Sanity checks on grain properties + for grain in table: + assert grain["area_px"] >= 20 + assert grain["area_m2"] > 0 + assert grain["equiv_diam_m"] > 0 + assert grain["max_height"] >= grain["mean_height"] + assert grain["mean_height"] > 0 + + # Physical size sanity: equivalent diameters should be in the nm–µm range + diams_nm = [g["equiv_diam_m"] * 1e9 for g in table] + print(f" Diameters: min={min(diams_nm):.0f} nm, max={max(diams_nm):.0f} nm") + assert all(1 < d < 2000 for d in diams_nm), \ + f"Grain diameters out of expected range: {diams_nm}" + + print(" PASS\n") + + +# ========================================================================= +# Run all tests +# ========================================================================= + +if __name__ == "__main__": + # ThresholdMask + test_threshold_otsu_bimodal() + test_threshold_relative_range() + test_threshold_empty_mask() + test_threshold_full_mask() + + # GrainAnalysis + test_single_circle_area() + test_multiple_grains_separation() + test_min_size_filtering() + test_grain_bounding_box() + test_empty_mask_produces_no_grains() + test_grain_at_image_edge() + test_adjacent_grains_connectivity() + + # End-to-end pipeline + test_pipeline_synthetic() + test_pipeline_demo_image() + + print("All grain tests passed!") diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 5f90f4d..541ca60 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -85,6 +85,88 @@ def test_edge_detect(): print(" PASS\n") +def test_fft_filter_1d(): + print("=== Test: FFTFilter1D ===") + from backend.nodes.filters import FFTFilter1D + node = FFTFilter1D() + + # Signal: low-frequency sine + high-frequency sine + n = 256 + t = np.arange(n, dtype=np.float64) / n + low = np.sin(2 * np.pi * 3 * t) # 3 cycles — low freq + high = np.sin(2 * np.pi * 80 * t) # 80 cycles — high freq + line = low + high + + # Lowpass should keep low, suppress high + filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) + assert len(filtered_lp) == n + corr_low = np.corrcoef(filtered_lp, low)[0, 1] + corr_high = np.corrcoef(filtered_lp, high)[0, 1] + assert corr_low > 0.95, f"Lowpass: correlation with low={corr_low}" + assert abs(corr_high) < 0.3, f"Lowpass: correlation with high={corr_high}" + + # Highpass should keep high, suppress low + filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) + corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1] + corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1] + assert abs(corr_low_hp) < 0.3, f"Highpass: correlation with low={corr_low_hp}" + assert corr_high_hp > 0.95, f"Highpass: correlation with high={corr_high_hp}" + + # Bandpass centred on the high frequency + filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4) + corr_low_bp = np.corrcoef(filtered_bp, low)[0, 1] + corr_high_bp = np.corrcoef(filtered_bp, high)[0, 1] + assert abs(corr_low_bp) < 0.3, f"Bandpass: correlation with low={corr_low_bp}" + assert corr_high_bp > 0.9, f"Bandpass: correlation with high={corr_high_bp}" + + # Notch (band-reject) centred on the high frequency — should remove it + filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4) + corr_low_notch = np.corrcoef(filtered_notch, low)[0, 1] + corr_high_notch = np.corrcoef(filtered_notch, high)[0, 1] + assert corr_low_notch > 0.95, f"Notch: correlation with low={corr_low_notch}" + assert abs(corr_high_notch) < 0.3, f"Notch: correlation with high={corr_high_notch}" + + print(" PASS\n") + + +def test_fft_filter_2d(): + print("=== Test: FFTFilter2D ===") + from backend.nodes.filters import FFTFilter2D + node = FFTFilter2D() + + N = 128 + y, x = np.mgrid[0:N, 0:N] / N + # Low-frequency 2D pattern + high-frequency pattern + low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y) + high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y) + data = low_2d + high_2d + field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6) + + # Lowpass — should preserve low, remove high + result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) + assert result_lp.data.shape == (N, N) + assert result_lp.xreal == field.xreal + assert result_lp.si_unit_z == field.si_unit_z + corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1] + corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1] + assert corr_low > 0.9, f"2D lowpass: correlation with low={corr_low}" + assert abs(corr_high) < 0.3, f"2D lowpass: correlation with high={corr_high}" + + # Highpass — should preserve high, remove low + result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) + corr_low_hp = np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1] + corr_high_hp = np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] + assert abs(corr_low_hp) < 0.3, f"2D highpass: correlation with low={corr_low_hp}" + assert corr_high_hp > 0.9, f"2D highpass: correlation with high={corr_high_hp}" + + # Constant field should be unchanged by lowpass (DC preservation) + const = make_field(data=np.ones((32, 32)) * 7.0) + result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2) + assert np.allclose(result_const.data, 7.0, atol=1e-10), "Lowpass should preserve constant field" + + print(" PASS\n") + + # ========================================================================= # Level # ========================================================================= @@ -199,7 +281,7 @@ def test_height_histogram(): data = np.linspace(0, 1, 1000).reshape(25, 40) field = make_field(data=data) - counts, bin_centers = node.process(field, n_bins=10) + counts, bin_centers = node.process(field, n_bins=10, y_scale="linear") assert len(counts) == 10 assert len(bin_centers) == 10 assert counts.dtype == np.float64 @@ -265,7 +347,7 @@ def test_cross_section(): def test_threshold_mask(): print("=== Test: ThresholdMask ===") - from backend.nodes.grains import ThresholdMask + from backend.nodes.mask import ThresholdMask node = ThresholdMask() # Clear bimodal data: left half = 0, right half = 1 @@ -273,6 +355,11 @@ def test_threshold_mask(): data[:, 32:] = 1.0 field = make_field(data=data) + # Capture overlay preview + previews = [] + ThresholdMask._broadcast_fn = lambda nid, uri: previews.append(uri) + ThresholdMask._current_node_id = "test" + # Absolute threshold at 0.5 mask, = node.process(field, method="absolute", threshold=0.5, direction="above") assert mask.dtype == np.uint8 @@ -280,6 +367,10 @@ def test_threshold_mask(): assert np.all(mask[:, :32] == 0) assert np.all(mask[:, 32:] == 255) + # Verify overlay preview was broadcast + assert len(previews) == 1 + assert previews[0].startswith("data:image/png;base64,") + # Direction "below" mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below") assert np.all(mask_below[:, :32] == 255) @@ -292,20 +383,117 @@ def test_threshold_mask(): # Otsu should find the bimodal threshold mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above") assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum() + + ThresholdMask._broadcast_fn = None print(" PASS\n") -def test_grain_analysis(): - print("=== Test: GrainAnalysis ===") - from backend.nodes.grains import GrainAnalysis - node = GrainAnalysis() +def test_mask_morphology(): + print("=== Test: MaskMorphology ===") + from backend.nodes.mask import MaskMorphology + node = MaskMorphology() - # Create a field with two distinct "grains" + # Small square blob in the centre + mask = np.zeros((64, 64), dtype=np.uint8) + mask[28:36, 28:36] = 255 # 8x8 block + orig_count = np.count_nonzero(mask) + + # Dilate should grow the region + dilated, = node.process(mask, operation="dilate", radius=1, shape="square") + assert dilated.dtype == np.uint8 + assert np.count_nonzero(dilated) > orig_count + + # Erode should shrink it + eroded, = node.process(mask, operation="erode", radius=1, shape="square") + assert np.count_nonzero(eroded) < orig_count + + # Open on a clean block should give back roughly the same block + opened, = node.process(mask, operation="open", radius=1, shape="square") + assert np.count_nonzero(opened) <= orig_count + + # Close on a mask with a 1-pixel hole should fill the hole + mask_hole = mask.copy() + mask_hole[32, 32] = 0 # poke a hole + assert np.count_nonzero(mask_hole) == orig_count - 1 + closed, = node.process(mask_hole, operation="close", radius=1, shape="square") + assert closed[32, 32] == 255, "Close should fill the 1-pixel hole" + + # Disk structuring element should also work + dilated_disk, = node.process(mask, operation="dilate", radius=2, shape="disk") + assert np.count_nonzero(dilated_disk) > orig_count + + print(" PASS\n") + + +def test_mask_invert(): + print("=== Test: MaskInvert ===") + from backend.nodes.mask import MaskInvert + node = MaskInvert() + + mask = np.zeros((64, 64), dtype=np.uint8) + mask[10:20, 10:20] = 255 + + inverted, = node.process(mask) + assert inverted.dtype == np.uint8 + assert np.all(inverted[10:20, 10:20] == 0) + assert np.all(inverted[0:10, 0:10] == 255) + # Double-invert should return to original + double, = node.process(inverted) + assert np.array_equal(double, mask) + + print(" PASS\n") + + +def test_mask_combine(): + print("=== Test: MaskCombine ===") + from backend.nodes.mask import MaskCombine + node = MaskCombine() + + # Two overlapping squares + a = np.zeros((64, 64), dtype=np.uint8) + a[10:30, 10:30] = 255 # 20x20 + b = np.zeros((64, 64), dtype=np.uint8) + b[20:40, 20:40] = 255 # 20x20, overlaps 10x10 + + # AND — only the overlap + result_and, = node.process(a, b, operation="and") + assert np.all(result_and[20:30, 20:30] == 255) + assert result_and[15, 15] == 0 # a-only region + assert result_and[35, 35] == 0 # b-only region + + # OR — union + result_or, = node.process(a, b, operation="or") + assert result_or[15, 15] == 255 + assert result_or[35, 35] == 255 + assert result_or[25, 25] == 255 + assert result_or[5, 5] == 0 + + # XOR — symmetric difference + result_xor, = node.process(a, b, operation="xor") + assert result_xor[15, 15] == 255 # a-only + assert result_xor[35, 35] == 255 # b-only + assert result_xor[25, 25] == 0 # overlap excluded + + # Subtract — a minus b + result_sub, = node.process(a, b, operation="subtract") + assert result_sub[15, 15] == 255 # a-only kept + assert result_sub[25, 25] == 0 # overlap removed + assert result_sub[35, 35] == 0 # b-only not included + + print(" PASS\n") + + +def test_particle_analysis(): + print("=== Test: ParticleAnalysis ===") + from backend.nodes.grains import ParticleAnalysis + node = ParticleAnalysis() + + # Create a field with two distinct particles N = 64 data = np.zeros((N, N)) - # Grain 1: 10x10 block at top-left with height 5 + # Particle 1: 10x10 block at top-left with height 5 data[5:15, 5:15] = 5.0 - # Grain 2: 8x8 block at bottom-right with height 3 + # Particle 2: 8x8 block at bottom-right with height 3 data[45:53, 45:53] = 3.0 field = make_field(data=data, xreal=1e-6, yreal=1e-6) @@ -315,7 +503,7 @@ def test_grain_analysis(): mask[45:53, 45:53] = 255 table, = node.process(field, mask=mask, min_size=10) - assert len(table) == 2, f"Expected 2 grains, got {len(table)}" + assert len(table) == 2, f"Expected 2 particles, got {len(table)}" # Sort by area descending table.sort(key=lambda r: r["area_px"], reverse=True) @@ -324,7 +512,7 @@ def test_grain_analysis(): assert abs(table[0]["mean_height"] - 5.0) < 1e-10 assert abs(table[1]["mean_height"] - 3.0) < 1e-10 - # min_size filtering: only keep grains >= 80 px + # min_size filtering: only keep particles >= 80 px table_filtered, = node.process(field, mask=mask, min_size=80) assert len(table_filtered) == 1 assert table_filtered[0]["area_px"] == 100 @@ -462,6 +650,8 @@ if __name__ == "__main__": test_gaussian_filter() test_median_filter() test_edge_detect() + test_fft_filter_1d() + test_fft_filter_2d() # Level test_plane_level() @@ -473,9 +663,14 @@ if __name__ == "__main__": test_height_histogram() test_cross_section() - # Grains + # Mask test_threshold_mask() - test_grain_analysis() + test_mask_morphology() + test_mask_invert() + test_mask_combine() + + # Grains + test_particle_analysis() # I/O test_load_image()