From 61d7b0fdccae8408267cf4dc93bda9a228c14784 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Fri, 27 Mar 2026 23:11:04 -0700 Subject: [PATCH] add acf and psdf nodes --- GWYDDION_FEATURE_GAP.md | 4 +- backend/node_menu.py | 2 + backend/nodes/__init__.py | 2 + backend/nodes/acf.py | 31 ++++++ backend/nodes/fft_2d.py | 88 ++--------------- backend/nodes/psdf.py | 32 ++++++ backend/nodes/spectral_common.py | 165 +++++++++++++++++++++++++++++++ tests/test_nodes.py | 82 ++++++++++++++- 8 files changed, 325 insertions(+), 81 deletions(-) create mode 100644 backend/nodes/acf.py create mode 100644 backend/nodes/psdf.py create mode 100644 backend/nodes/spectral_common.py diff --git a/GWYDDION_FEATURE_GAP.md b/GWYDDION_FEATURE_GAP.md index 4ce3cb1..e56e29b 100644 --- a/GWYDDION_FEATURE_GAP.md +++ b/GWYDDION_FEATURE_GAP.md @@ -14,8 +14,8 @@ Reference for future implementation. Grouped by value to typical SPM workflows. | ~~4~~ | ~~Morphological Mask Ops~~ | ~~mask_morph.c~~ | ~~Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks.~~ **DONE** | | ~~5~~ | ~~1D FFT Filter~~ | ~~fft_filter_1d.c~~ | ~~Bandpass/lowpass/highpass filtering of LINE profiles.~~ **DONE** | | ~~6~~ | ~~2D FFT Filter~~ | ~~fft_filter_2d.c~~ | ~~Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.).~~ **DONE** | -| 7 | Autocorrelation (ACF) | acf2d.c | 2D autocorrelation function. Reveals periodic structures and correlation lengths. | -| 8 | PSDF | psdf2d.c | Radial/2D power spectral density function. Complementary to ACF for roughness characterization. | +| ~~7~~ | ~~Autocorrelation (ACF)~~ | ~~acf2d.c~~ | ~~2D autocorrelation function. Reveals periodic structures and correlation lengths.~~ **DONE** | +| ~~8~~ | ~~PSDF~~ | ~~psdf2d.c~~ | ~~Radial/2D power spectral density function. Complementary to ACF for roughness characterization.~~ **DONE** | | 9 | Fractal Dimension | fractal.c | Multiple methods: partitioning, cube counting, triangulation, PSDF, HHCF. Quantifies surface complexity. | | 10 | Curvature | curvature.c | Local mean/Gaussian curvature maps. Useful for feature identification. | | 11 | Grain Distance Transform | mask_edt.c | Euclidean distance from grain boundaries. Useful for spatial distribution analysis. | diff --git a/backend/node_menu.py b/backend/node_menu.py index 85410a0..631a112 100644 --- a/backend/node_menu.py +++ b/backend/node_menu.py @@ -52,6 +52,7 @@ MENU_LAYOUT: dict[str, list[str]] = { ], "Frequency": [ "FFT2D", + "PSDF", "InverseFFT2D", ], "Flatten": [ @@ -63,6 +64,7 @@ MENU_LAYOUT: dict[str, list[str]] = { "Measure": [ "CrossSection", "Histogram", + "ACF", "Cursors", "Statistics", "Stats", diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index 1e91d83..de5f40d 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -47,8 +47,10 @@ from backend.nodes import ( # Analysis statistics_node, histogram, + acf, cursors, fft_2d, + psdf, inverse_fft_2d, cross_section, stats, diff --git a/backend/nodes/acf.py b/backend/nodes/acf.py new file mode 100644 index 0000000..d447dcb --- /dev/null +++ b/backend/nodes/acf.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from backend.node_registry import register_node +from backend.data_types import DataField +from backend.nodes.spectral_common import acf_field_from_data, preprocess_spectral_data + + +@register_node(display_name="ACF") +class ACF: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "level": (["mean", "plane", "none"], {"default": "mean"}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("acf",) + FUNCTION = "process" + + DESCRIPTION = ( + "Compute the two-dimensional autocorrelation function with Gwyddion-style " + "mean or plane levelling before correlation. The output is centered on zero shift " + "and uses the default half-range extents from acf2d." + ) + + def process(self, field: DataField, level: str) -> tuple: + data = preprocess_spectral_data(field, level=level, windowing="none") + return (acf_field_from_data(field, data),) diff --git a/backend/nodes/fft_2d.py b/backend/nodes/fft_2d.py index f604b0e..e87668f 100644 --- a/backend/nodes/fft_2d.py +++ b/backend/nodes/fft_2d.py @@ -2,6 +2,11 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node from backend.data_types import DataField +from backend.nodes.spectral_common import ( + preprocess_spectral_data, + psdf_field_from_data, + spatial_frequency_field, +) @register_node(display_name="2D FFT") @@ -27,89 +32,16 @@ class FFT2D: ) def process(self, field: DataField, windowing: str, level: str) -> tuple: - data = field.data.copy() - yres, xres = data.shape - - if level == "mean": - data -= data.mean() - elif level == "plane": - yy, xx = np.mgrid[0:yres, 0:xres] - xx_f = xx.ravel().astype(np.float64) - yy_f = yy.ravel().astype(np.float64) - zz_f = data.ravel() - A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f]) - coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None) - plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy) - data -= plane - - if windowing != "none": - t_y = (np.arange(yres) + 0.5) / yres - t_x = (np.arange(xres) + 0.5) / xres - if windowing == "hann": - wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y) - wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x) - elif windowing == "hamming": - wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y) - wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x) - elif windowing == "blackman": - wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y) - wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x) - else: - wy = np.ones(yres) - wx = np.ones(xres) - data *= np.outer(wy, wx) - + data = preprocess_spectral_data(field, level=level, windowing=windowing) F = np.fft.fftshift(np.fft.fft2(data)) - n = xres * yres magnitude = np.abs(F) log_magnitude = np.log1p(magnitude) phase = np.angle(F) - dx = field.xreal / xres - dy = field.yreal / yres - psdf = (magnitude ** 2) * dx * dy / (n * 4.0 * np.pi ** 2) - - spatial_freq_xreal = xres / field.xreal - spatial_freq_yreal = yres / field.yreal - angular_freq_xreal = 2.0 * np.pi * xres / field.xreal - angular_freq_yreal = 2.0 * np.pi * yres / field.yreal - return ( - DataField( - data=log_magnitude, - xreal=spatial_freq_xreal, - yreal=spatial_freq_yreal, - si_unit_xy="1/m", - si_unit_z=field.si_unit_z, - domain="frequency", - colormap=field.colormap, - ), - DataField( - data=magnitude, - xreal=spatial_freq_xreal, - yreal=spatial_freq_yreal, - si_unit_xy="1/m", - si_unit_z=field.si_unit_z, - domain="frequency", - colormap=field.colormap, - ), - DataField( - data=phase, - xreal=spatial_freq_xreal, - yreal=spatial_freq_yreal, - si_unit_xy="1/m", - si_unit_z=field.si_unit_z, - domain="frequency", - colormap=field.colormap, - ), - DataField( - data=psdf, - xreal=angular_freq_xreal, - yreal=angular_freq_yreal, - si_unit_xy="1/m", - si_unit_z=f"({field.si_unit_z})^2 m^2", - domain="frequency", - colormap=field.colormap, - ), + spatial_frequency_field(field, log_magnitude), + spatial_frequency_field(field, magnitude), + spatial_frequency_field(field, phase), + psdf_field_from_data(field, data), ) diff --git a/backend/nodes/psdf.py b/backend/nodes/psdf.py new file mode 100644 index 0000000..0fb7f4c --- /dev/null +++ b/backend/nodes/psdf.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from backend.node_registry import register_node +from backend.data_types import DataField +from backend.nodes.spectral_common import preprocess_spectral_data, psdf_field_from_data + + +@register_node(display_name="PSDF") +class PSDF: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "windowing": (["hann", "hamming", "blackman", "none"], {"default": "hann"}), + "level": (["mean", "plane", "none"], {"default": "mean"}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("psdf",) + FUNCTION = "process" + + DESCRIPTION = ( + "Compute the two-dimensional power spectral density function with Gwyddion-style " + "window RMS compensation and centered zero frequency. Equivalent to psdf2d / " + "gwy_data_field_2dpsdf." + ) + + def process(self, field: DataField, windowing: str, level: str) -> tuple: + data = preprocess_spectral_data(field, level=level, windowing=windowing) + return (psdf_field_from_data(field, data),) diff --git a/backend/nodes/spectral_common.py b/backend/nodes/spectral_common.py new file mode 100644 index 0000000..cb84e19 --- /dev/null +++ b/backend/nodes/spectral_common.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import numpy as np + +from backend.data_types import DataField + + +def _level_data(data: np.ndarray, level: str) -> np.ndarray: + leveled = np.asarray(data, dtype=np.float64).copy() + yres, xres = leveled.shape + + if level == "none": + return leveled + + if level == "mean": + leveled -= float(np.mean(leveled)) + return leveled + + if level == "plane": + yy, xx = np.mgrid[0:yres, 0:xres] + design = np.column_stack([ + np.ones(xres * yres, dtype=np.float64), + xx.ravel().astype(np.float64), + yy.ravel().astype(np.float64), + ]) + coeffs, _, _, _ = np.linalg.lstsq(design, leveled.ravel(), rcond=None) + plane = coeffs[0] + coeffs[1] * xx + coeffs[2] * yy + leveled -= plane + return leveled + + raise ValueError(f"Unsupported levelling mode: {level}") + + +def _window_vector(size: int, windowing: str) -> np.ndarray: + if size <= 0: + return np.ones(0, dtype=np.float64) + + t = (np.arange(size, dtype=np.float64) + 0.5) / float(size) + + if windowing == "none": + return np.ones(size, dtype=np.float64) + if windowing == "hann": + return 0.5 - 0.5 * np.cos(2.0 * np.pi * t) + if windowing == "hamming": + return 0.54 - 0.46 * np.cos(2.0 * np.pi * t) + if windowing == "blackman": + return 0.42 - 0.5 * np.cos(2.0 * np.pi * t) + 0.08 * np.cos(4.0 * np.pi * t) + + raise ValueError(f"Unsupported windowing mode: {windowing}") + + +def _apply_window_with_rms_compensation(data: np.ndarray, windowing: str) -> np.ndarray: + windowed = np.asarray(data, dtype=np.float64).copy() + if windowing == "none": + return windowed + + rms = float(np.sqrt(np.mean(windowed**2))) + wy = _window_vector(windowed.shape[0], windowing) + wx = _window_vector(windowed.shape[1], windowing) + windowed *= np.outer(wy, wx) + + new_rms = float(np.sqrt(np.mean(windowed**2))) + if rms > 0.0 and new_rms > 0.0: + windowed *= rms / new_rms + + return windowed + + +def preprocess_spectral_data(field: DataField, *, level: str, windowing: str = "none") -> np.ndarray: + leveled = _level_data(field.data, level) + return _apply_window_with_rms_compensation(leveled, windowing) + + +def _inverse_unit(unit: str) -> str: + text = str(unit or "").strip() + if not text: + return "" + return f"1/{text}" + + +def _square_unit(unit: str) -> str: + text = str(unit or "").strip() + if not text: + return "" + if text.isalnum() or text in {"m", "nm", "um", "pm", "V", "A", "Hz", "px"}: + return f"{text}^2" + return f"({text})^2" + + +def _product_unit(*units: str) -> str: + parts = [str(unit).strip() for unit in units if str(unit or "").strip()] + return " ".join(parts) + + +def spatial_frequency_field(field: DataField, data: np.ndarray) -> DataField: + return DataField( + data=np.asarray(data, dtype=np.float64), + xreal=float(field.xres / field.xreal), + yreal=float(field.yres / field.yreal), + xoff=float(-0.5 * field.xres / field.xreal), + yoff=float(-0.5 * field.yres / field.yreal), + si_unit_xy=_inverse_unit(field.si_unit_xy), + si_unit_z=field.si_unit_z, + domain="frequency", + colormap=field.colormap, + ) + + +def psdf_field_from_data(field: DataField, data: np.ndarray) -> DataField: + transformed = np.fft.fftshift(np.fft.fft2(np.asarray(data, dtype=np.float64))) + magnitude = np.abs(transformed) + n = field.xres * field.yres + psdf = (magnitude**2) * field.dx * field.dy / (float(n) * 4.0 * np.pi**2) + xreal = float(2.0 * np.pi / field.dx) + yreal = float(2.0 * np.pi / field.dy) + + return DataField( + data=psdf, + xreal=xreal, + yreal=yreal, + xoff=float(-0.5 * xreal), + yoff=float(-0.5 * yreal), + si_unit_xy=_inverse_unit(field.si_unit_xy), + si_unit_z=_product_unit(_square_unit(field.si_unit_z), _square_unit(field.si_unit_xy)), + domain="frequency", + colormap=field.colormap, + ) + + +def acf_field_from_data(field: DataField, data: np.ndarray, *, xrange: int = 0, yrange: int = 0) -> DataField: + from scipy.signal import fftconvolve + + source = np.asarray(data, dtype=np.float64) + yres, xres = source.shape + xrange = int(xrange) if xrange else max(1, xres // 2) + yrange = int(yrange) if yrange else max(1, yres // 2) + xrange = max(1, min(xrange, xres)) + yrange = max(1, min(yrange, yres)) + + corr_full = fftconvolve(source, source[::-1, ::-1], mode="full") + cy = yres - 1 + cx = xres - 1 + corr = corr_full[cy - (yrange - 1):cy + yrange, cx - (xrange - 1):cx + xrange] + + count_x = np.array([xres - abs(dx) for dx in range(-(xrange - 1), xrange)], dtype=np.float64) + count_y = np.array([yres - abs(dy) for dy in range(-(yrange - 1), yrange)], dtype=np.float64) + counts = np.outer(count_y, count_x) + acf = corr / counts + + txres = 2 * xrange - 1 + tyres = 2 * yrange - 1 + xreal = float(field.xreal * txres / field.xres) + yreal = float(field.yreal * tyres / field.yres) + + return DataField( + data=np.asarray(acf, dtype=np.float64), + xreal=xreal, + yreal=yreal, + xoff=float(-0.5 * xreal), + yoff=float(-0.5 * yreal), + si_unit_xy=field.si_unit_xy, + si_unit_z=_square_unit(field.si_unit_z), + domain="spatial", + colormap=field.colormap, + ) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index a3c17c5..e86e527 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -2316,7 +2316,7 @@ def test_line_cursors(): # ========================================================================= -# Analysis — FFT2D +# Analysis — FFT2D / ACF / PSDF # ========================================================================= def test_fft2d(): @@ -2374,6 +2374,86 @@ def test_fft2d(): print(" PASS\n") +def test_acf(): + print("=== Test: ACF ===") + from backend.nodes.acf import ACF + + node = ACF() + data = np.array([ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [2.0, 1.0, 0.0, -1.0], + [0.0, 1.0, 2.0, 3.0], + ], dtype=np.float64) + field = DataField(data=data, xreal=8.0, yreal=4.0, si_unit_xy="nm", si_unit_z="V") + + acf, = node.process(field, level="none") + assert acf.data.shape == (3, 3) + assert acf.domain == "spatial" + assert acf.si_unit_xy == "nm" + assert acf.si_unit_z == "V^2" + assert np.isclose(acf.xreal, 6.0) + assert np.isclose(acf.yreal, 3.0) + assert np.isclose(acf.xoff, -3.0) + assert np.isclose(acf.yoff, -1.5) + + expected = np.zeros((3, 3), dtype=np.float64) + for iy, dy in enumerate(range(-1, 2)): + for ix, dx in enumerate(range(-1, 2)): + y0a = max(0, dy) + y1a = min(data.shape[0], data.shape[0] + dy) + x0a = max(0, dx) + x1a = min(data.shape[1], data.shape[1] + dx) + lhs = data[y0a:y1a, x0a:x1a] + rhs = data[y0a - dy:y1a - dy, x0a - dx:x1a - dx] + expected[iy, ix] = float(np.mean(lhs * rhs)) + + assert np.allclose(acf.data, expected) + assert np.allclose(acf.data, acf.data[::-1, ::-1]) + print(" PASS\n") + + +def test_psdf_node(): + print("=== Test: PSDF ===") + from backend.nodes.fft_2d import FFT2D + from backend.nodes.psdf import PSDF + + field = DataField( + data=np.random.default_rng(17).standard_normal((64, 64)), + xreal=2.0e-6, + yreal=1.0e-6, + si_unit_xy="m", + si_unit_z="nm", + ) + + fft_node = FFT2D() + psdf_node = PSDF() + + fft_psdf = fft_node.process(field, windowing="hann", level="plane")[3] + psdf, = psdf_node.process(field, windowing="hann", level="plane") + assert np.allclose(psdf.data, fft_psdf.data) + assert psdf.data.shape == field.data.shape + assert psdf.domain == "frequency" + assert psdf.si_unit_xy == "1/m" + assert psdf.si_unit_z == "nm^2 m^2" + assert np.all(psdf.data >= 0.0) + + white = DataField( + data=np.random.default_rng(123).standard_normal((128, 128)), + xreal=1.0e-6, + yreal=1.0e-6, + si_unit_xy="m", + si_unit_z="m", + ) + psdf_white, = psdf_node.process(white, windowing="none", level="none") + variance = float(np.var(white.data)) + dk_x = psdf_white.xreal / psdf_white.xres + dk_y = psdf_white.yreal / psdf_white.yres + integral = float(np.sum(psdf_white.data) * dk_x * dk_y) + assert 0.8 < integral / variance < 1.2 + print(" PASS\n") + + # ========================================================================= # Analysis — Stats # =========================================================================