From bce11590c74b99c95389609419eb867452c64636 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Wed, 25 Mar 2026 14:30:28 -0700 Subject: [PATCH] fft multi channel output --- backend/nodes/analysis.py | 167 ++++++++++++++++++++++++++++++++++++-- tests/test_fft.py | 113 +++++++++++++++++++++++--- tests/test_fft_visual.py | 15 ++-- tests/test_nodes.py | 12 +-- 4 files changed, 275 insertions(+), 32 deletions(-) diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index f0e9721..6229226 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -266,21 +266,20 @@ class FFT2D: "field": ("DATA_FIELD",), "windowing": (["hann", "hamming", "blackman", "none"],), "level": (["mean", "plane", "none"],), - "output": (["log_magnitude", "magnitude", "phase", "psdf"],), } } - RETURN_TYPES = ("DATA_FIELD",) - RETURN_NAMES = ("spectrum",) + RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "DATA_FIELD", "DATA_FIELD") + RETURN_NAMES = ("log_magnitude", "magnitude", "phase", "psdf") FUNCTION = "process" CATEGORY = "analysis" DESCRIPTION = ( "Compute the 2D FFT with optional windowing and mean/plane subtraction. " - "Output can be log magnitude, magnitude, phase, or PSDF. " + "Outputs log magnitude, magnitude, phase, and PSDF as separate channels. " "Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf." ) - def process(self, field: DataField, windowing: str, level: str, output: str) -> tuple: + def process(self, field: DataField, windowing: str, level: str) -> tuple: data = field.data.copy() yres, xres = data.shape @@ -320,8 +319,59 @@ class FFT2D: F = np.fft.fftshift(np.fft.fft2(data)) n = xres * yres - if output == "log_magnitude": - mag = np.abs(F) + magnitude = np.abs(F) + log_magnitude = np.log1p(magnitude) + phase = np.angle(F) + + dx = field.xreal / xres + dy = field.yreal / yres + psdf = (magnitude ** 2) * dx * dy / (n * 4.0 * np.pi ** 2) + + spatial_freq_xreal = xres / field.xreal + spatial_freq_yreal = yres / field.yreal + angular_freq_xreal = 2.0 * np.pi * xres / field.xreal + angular_freq_yreal = 2.0 * np.pi * yres / field.yreal + + return ( + DataField( + data=log_magnitude, + xreal=spatial_freq_xreal, + yreal=spatial_freq_yreal, + si_unit_xy="1/m", + si_unit_z=field.si_unit_z, + domain="frequency", + colormap=field.colormap, + ), + DataField( + data=magnitude, + xreal=spatial_freq_xreal, + yreal=spatial_freq_yreal, + si_unit_xy="1/m", + si_unit_z=field.si_unit_z, + domain="frequency", + colormap=field.colormap, + ), + DataField( + data=phase, + xreal=spatial_freq_xreal, + yreal=spatial_freq_yreal, + si_unit_xy="1/m", + si_unit_z=field.si_unit_z, + domain="frequency", + colormap=field.colormap, + ), + DataField( + data=psdf, + xreal=angular_freq_xreal, + yreal=angular_freq_yreal, + si_unit_xy="1/m", + si_unit_z=f"({field.si_unit_z})^2 m^2", + domain="frequency", + colormap=field.colormap, + ), + ) + + if False: # Unreachable legacy block retained below. # Log scale with floor to avoid log(0) result = np.log1p(mag) elif output == "magnitude": @@ -359,6 +409,109 @@ class FFT2D: return (out_field,) +# --------------------------------------------------------------------------- +# InverseFFT2D +# --------------------------------------------------------------------------- + +@register_node(display_name="Inverse 2D FFT") +class InverseFFT2D: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "spectrum": ("DATA_FIELD",), + "representation": (["magnitude", "log_magnitude", "psdf"],), + }, + "optional": { + "phase": ("DATA_FIELD",), + }, + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("image",) + FUNCTION = "process" + CATEGORY = "analysis" + DESCRIPTION = ( + "Reconstruct a spatial-domain image from a 2D frequency spectrum. " + "For exact reconstruction, connect magnitude/phase (or log magnitude/phase, " + "or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed." + ) + + def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple: + if spectrum.domain != "frequency": + raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.") + + if phase is not None: + if phase.data.shape != spectrum.data.shape: + raise ValueError("Phase input must have the same shape as the spectrum.") + if phase.domain != "frequency": + raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.") + + amplitude = self._resolve_amplitude(spectrum, representation) + phase_data = phase.data if phase is not None else np.zeros_like(amplitude) + F = amplitude * np.exp(1j * phase_data) + + spatial = np.fft.ifft2(np.fft.ifftshift(F)).real + xreal, yreal = self._recover_spatial_extent(spectrum, representation) + z_unit = self._recover_z_unit(spectrum, representation, phase) + + out_field = DataField( + data=spatial, + xreal=xreal, + yreal=yreal, + si_unit_xy="m", + si_unit_z=z_unit, + domain="spatial", + colormap=spectrum.colormap, + ) + return (out_field,) + + def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray: + data = np.asarray(spectrum.data, dtype=np.float64) + + if representation == "magnitude": + return np.clip(data, 0.0, None) + if representation == "log_magnitude": + return np.expm1(data) + if representation == "psdf": + xreal, yreal = self._recover_spatial_extent(spectrum, representation) + n = spectrum.xres * spectrum.yres + dx = xreal / spectrum.xres + dy = yreal / spectrum.yres + scale = n * 4.0 * np.pi ** 2 / (dx * dy) + return np.sqrt(np.clip(data, 0.0, None) * scale) + + raise ValueError(f"Unsupported spectrum representation: {representation}") + + def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]: + if representation == "psdf": + xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal + yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal + else: + xreal = spectrum.xres / spectrum.xreal + yreal = spectrum.yres / spectrum.yreal + return float(xreal), float(yreal) + + def _recover_z_unit( + self, + spectrum: DataField, + representation: str, + phase: DataField | None, + ) -> str: + if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip(): + return phase.si_unit_z + + if representation != "psdf": + return spectrum.si_unit_z + + unit = str(spectrum.si_unit_z or "").strip() + if unit.startswith("(") and ")^2 m^2" in unit: + return unit.split(")^2 m^2", 1)[0][1:] + if unit.endswith("^2 m^2"): + return unit[:-6].removesuffix("^2").strip() + return "" + + # --------------------------------------------------------------------------- # CrossSection # --------------------------------------------------------------------------- diff --git a/tests/test_fft.py b/tests/test_fft.py index 930daf4..cb2e794 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -9,7 +9,7 @@ import numpy as np sys.path.insert(0, ".") from backend.data_types import DataField -from backend.nodes.analysis import FFT2D +from backend.nodes.analysis import FFT2D, InverseFFT2D def make_field(data, xreal=1e-6, yreal=1e-6): @@ -24,7 +24,7 @@ def test_dc_removal(): field = make_field(data) node = FFT2D() - result, = node.process(field, windowing="none", level="mean", output="magnitude") + _, result, _, _ = node.process(field, windowing="none", level="mean") peak = result.data.max() print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}") assert peak < 1e-10, f"Expected ~0, got {peak}" @@ -43,7 +43,7 @@ def test_single_frequency(): field = make_field(data, xreal=xreal, yreal=xreal) node = FFT2D() - result, = node.process(field, windowing="none", level="mean", output="magnitude") + _, result, _, _ = node.process(field, windowing="none", level="mean") # The peak should be at column offset = freq_cycles from center mag = result.data @@ -76,7 +76,7 @@ def test_2d_frequency(): field = make_field(data) node = FFT2D() - result, = node.process(field, windowing="none", level="mean", output="magnitude") + _, result, _, _ = node.process(field, windowing="none", level="mean") mag = result.data cy, cx = N // 2, N // 2 @@ -110,7 +110,7 @@ def test_psdf_normalization(): field = make_field(data, xreal=xreal, yreal=xreal) node = FFT2D() - result, = node.process(field, windowing="none", level="none", output="psdf") + _, _, _, result = node.process(field, windowing="none", level="none") psdf = result.data # Integrate: sum of PSDF * dk_x * dk_y @@ -141,11 +141,11 @@ def test_windowing_reduces_leakage(): node = FFT2D() # Without windowing - r_none, = node.process(field, windowing="none", level="mean", output="magnitude") + _, r_none, _, _ = node.process(field, windowing="none", level="mean") mag_none = r_none.data[N // 2, :] # center row # With Hann windowing - r_hann, = node.process(field, windowing="hann", level="mean", output="magnitude") + _, r_hann, _, _ = node.process(field, windowing="hann", level="mean") mag_hann = r_hann.data[N // 2, :] # Measure leakage: ratio of energy far from peak vs total @@ -178,15 +178,15 @@ def test_plane_subtraction(): node = FFT2D() # Without leveling — huge DC and low-freq energy - r_none, = node.process(field, windowing="none", level="none", output="magnitude") + _, r_none, _, _ = node.process(field, windowing="none", level="none") dc_none = r_none.data[N // 2, N // 2] # With mean subtraction — DC removed but gradient leaks - r_mean, = node.process(field, windowing="none", level="mean", output="magnitude") + _, r_mean, _, _ = node.process(field, windowing="none", level="mean") dc_mean = r_mean.data[N // 2, N // 2] # With plane subtraction — gradient removed - r_plane, = node.process(field, windowing="none", level="plane", output="magnitude") + _, r_plane, _, _ = node.process(field, windowing="none", level="plane") dc_plane = r_plane.data[N // 2, N // 2] # With plane subtraction, check the low-freq energy near DC is reduced @@ -213,7 +213,7 @@ def test_non_square(): field = make_field(data, xreal=1.5e-6, yreal=1.0e-6) node = FFT2D() - result, = node.process(field, windowing="hann", level="mean", output="log_magnitude") + result, _, _, _ = node.process(field, windowing="hann", level="mean") assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}" assert np.all(np.isfinite(result.data)), "Non-finite values in output" print(f" Shape: {result.data.shape}") @@ -234,7 +234,7 @@ def test_log_magnitude_visual_range(): field = make_field(data) node = FFT2D() - result, = node.process(field, windowing="hann", level="mean", output="log_magnitude") + result, _, _, _ = node.process(field, windowing="hann", level="mean") vmin, vmax = result.data.min(), result.data.max() dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30) @@ -246,6 +246,91 @@ def test_log_magnitude_visual_range(): print(" PASS\n") +def test_inverse_fft_reconstructs_from_magnitude_and_phase(): + """Magnitude + phase from FFT2D should reconstruct the original image.""" + print("=== Test: Inverse FFT from magnitude + phase ===") + rng = np.random.default_rng(123) + data = rng.standard_normal((64, 96)) + field = make_field(data, xreal=2.4e-6, yreal=1.6e-6) + + fft_node = FFT2D() + ifft_node = InverseFFT2D() + + _, magnitude, phase, _ = fft_node.process(field, windowing="none", level="none") + reconstructed, = ifft_node.process(magnitude, representation="magnitude", phase=phase) + + max_err = np.max(np.abs(reconstructed.data - field.data)) + print(f" Reconstruction max error: {max_err:.3e}") + assert reconstructed.domain == "spatial" + assert reconstructed.data.shape == field.data.shape + assert np.isclose(reconstructed.xreal, field.xreal) + assert np.isclose(reconstructed.yreal, field.yreal) + assert max_err < 1e-9, f"Expected near-exact reconstruction, got {max_err}" + print(" PASS\n") + + +def test_inverse_fft_reconstructs_from_log_magnitude_and_phase(): + """log(|F|) + phase should also reconstruct after expm1 inversion.""" + print("=== Test: Inverse FFT from log magnitude + phase ===") + y, x = np.mgrid[0:72, 0:80] / 80.0 + data = ( + 0.8 * np.sin(2 * np.pi * 6 * x) + + 0.35 * np.cos(2 * np.pi * 9 * y) + + 0.15 * np.sin(2 * np.pi * (4 * x + 3 * y)) + ) + field = make_field(data, xreal=1.6e-6, yreal=1.44e-6) + + fft_node = FFT2D() + ifft_node = InverseFFT2D() + + log_magnitude, _, phase, _ = fft_node.process(field, windowing="none", level="none") + reconstructed, = ifft_node.process(log_magnitude, representation="log_magnitude", phase=phase) + + rms_err = np.sqrt(np.mean((reconstructed.data - field.data) ** 2)) + print(f" Reconstruction RMS error: {rms_err:.3e}") + assert rms_err < 1e-9, f"Expected near-exact reconstruction, got {rms_err}" + print(" PASS\n") + + +def test_inverse_fft_reconstructs_from_psdf_and_phase(): + """PSDF + phase should reconstruct after undoing PSDF scaling.""" + print("=== Test: Inverse FFT from PSDF + phase ===") + rng = np.random.default_rng(321) + data = rng.standard_normal((48, 64)) + field = make_field(data, xreal=3.2e-6, yreal=2.4e-6) + + fft_node = FFT2D() + ifft_node = InverseFFT2D() + + _, _, phase, psdf = fft_node.process(field, windowing="none", level="none") + reconstructed, = ifft_node.process(psdf, representation="psdf", phase=phase) + + max_err = np.max(np.abs(reconstructed.data - field.data)) + print(f" Reconstruction max error: {max_err:.3e}") + assert reconstructed.si_unit_z == field.si_unit_z + assert max_err < 1e-8, f"Expected near-exact reconstruction, got {max_err}" + print(" PASS\n") + + +def test_inverse_fft_zero_phase_mode_returns_valid_image(): + """Spectrum-only inversion should return a finite spatial image with the right shape.""" + print("=== Test: Inverse FFT zero-phase mode ===") + data = np.sin(2 * np.pi * 5 * np.mgrid[0:64, 0:64][1] / 64.0) + field = make_field(data, xreal=1e-6, yreal=1e-6) + + fft_node = FFT2D() + ifft_node = InverseFFT2D() + + _, magnitude, _, _ = fft_node.process(field, windowing="none", level="none") + reconstructed, = ifft_node.process(magnitude, representation="magnitude") + + print(f" Output shape: {reconstructed.data.shape}") + assert reconstructed.domain == "spatial" + assert reconstructed.data.shape == field.data.shape + assert np.all(np.isfinite(reconstructed.data)) + print(" PASS\n") + + if __name__ == "__main__": test_dc_removal() test_single_frequency() @@ -255,4 +340,8 @@ if __name__ == "__main__": test_plane_subtraction() test_non_square() test_log_magnitude_visual_range() + test_inverse_fft_reconstructs_from_magnitude_and_phase() + test_inverse_fft_reconstructs_from_log_magnitude_and_phase() + test_inverse_fft_reconstructs_from_psdf_and_phase() + test_inverse_fft_zero_phase_mode_returns_valid_image() print("All tests passed!") diff --git a/tests/test_fft_visual.py b/tests/test_fft_visual.py index e8c2cb9..fde9bfd 100644 --- a/tests/test_fft_visual.py +++ b/tests/test_fft_visual.py @@ -43,9 +43,10 @@ def main(): field = make_field(data) save_field(field, "01_sines_input") - for output_mode in ["log_magnitude", "magnitude", "psdf"]: - result, = node.process(field, windowing="hann", level="mean", output=output_mode) - save_field(result, f"01_sines_{output_mode}") + log_magnitude, magnitude, _, psdf = node.process(field, windowing="hann", level="mean") + save_field(log_magnitude, "01_sines_log_magnitude") + save_field(magnitude, "01_sines_magnitude") + save_field(psdf, "01_sines_psdf") # --- Test 2: Real-world-like surface with noise + tilt --- print("\nTest 2: Tilted surface with features") @@ -57,7 +58,7 @@ def main(): save_field(field, "02_surface_input") for level_mode in ["none", "mean", "plane"]: - result, = node.process(field, windowing="hann", level=level_mode, output="log_magnitude") + result, _, _, _ = node.process(field, windowing="hann", level=level_mode) save_field(result, f"02_surface_fft_level_{level_mode}") # --- Test 3: Checkerboard pattern --- @@ -67,7 +68,7 @@ def main(): field = make_field(data) save_field(field, "03_checker_input") - result, = node.process(field, windowing="none", level="mean", output="log_magnitude") + result, _, _, _ = node.process(field, windowing="none", level="mean") save_field(result, "03_checker_fft") # --- Test 4: Concentric rings (radial frequency) --- @@ -77,7 +78,7 @@ def main(): field = make_field(data) save_field(field, "04_rings_input") - result, = node.process(field, windowing="hann", level="mean", output="log_magnitude") + result, _, _, _ = node.process(field, windowing="hann", level="mean") save_field(result, "04_rings_fft") # --- Test 5: Compare windowing effects --- @@ -87,7 +88,7 @@ def main(): save_field(field, "05_window_input") for win in ["none", "hann", "hamming", "blackman"]: - result, = node.process(field, windowing=win, level="mean", output="log_magnitude") + result, _, _, _ = node.process(field, windowing=win, level="mean") save_field(result, f"05_window_{win}") print(f"\nAll outputs saved to {OUT_DIR}/") diff --git a/tests/test_nodes.py b/tests/test_nodes.py index f85ca45..821bb28 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1229,7 +1229,7 @@ def test_fft2d(): field = make_field(data=data, xreal=1e-6, yreal=1e-6) # log_magnitude - spectrum, = node.process(field, windowing="none", level="none", output="log_magnitude") + spectrum, spec_mag, spec_phase, spec_psdf = node.process(field, windowing="none", level="none") assert spectrum.data.shape == (N, N) assert spectrum.domain == "frequency" assert spectrum.si_unit_xy == "1/m" @@ -1240,31 +1240,31 @@ def test_fft2d(): assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}" # magnitude output - spec_mag, = node.process(field, windowing="hann", level="mean", output="magnitude") + _, spec_mag, _, _ = node.process(field, windowing="hann", level="mean") assert spec_mag.data.shape == (N, N) assert np.all(spec_mag.data >= 0) # phase output - spec_phase, = node.process(field, windowing="none", level="none", output="phase") + _, _, spec_phase, _ = node.process(field, windowing="none", level="none") assert spec_phase.data.shape == (N, N) assert spec_phase.data.min() >= -np.pi - 0.01 assert spec_phase.data.max() <= np.pi + 0.01 # psdf output — units should reflect PSDF calibration - spec_psdf, = node.process(field, windowing="hamming", level="plane", output="psdf") + _, _, _, spec_psdf = node.process(field, windowing="hamming", level="plane") assert spec_psdf.data.shape == (N, N) assert np.all(spec_psdf.data >= 0) assert "^2" in spec_psdf.si_unit_z # Constant field should have all energy at DC const_field = make_field(data=np.ones((32, 32)) * 3.0) - spec_const, = node.process(const_field, windowing="none", level="none", output="magnitude") + _, spec_const, _, _ = node.process(const_field, windowing="none", level="none") centre32 = 16 dc_val = spec_const.data[centre32, centre32] assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field" # Blackman windowing should also work without error - spec_bk, = node.process(field, windowing="blackman", level="none", output="log_magnitude") + spec_bk, _, _, _ = node.process(field, windowing="blackman", level="none") assert spec_bk.data.shape == (N, N) print(" PASS\n")