fft multi channel output

This commit is contained in:
matei jordache
2026-03-25 14:30:28 -07:00
parent 7b896777fc
commit bce11590c7
4 changed files with 275 additions and 32 deletions

View File

@@ -9,7 +9,7 @@ import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField
from backend.nodes.analysis import FFT2D
from backend.nodes.analysis import FFT2D, InverseFFT2D
def make_field(data, xreal=1e-6, yreal=1e-6):
@@ -24,7 +24,7 @@ def test_dc_removal():
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
_, result, _, _ = node.process(field, windowing="none", level="mean")
peak = result.data.max()
print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}")
assert peak < 1e-10, f"Expected ~0, got {peak}"
@@ -43,7 +43,7 @@ def test_single_frequency():
field = make_field(data, xreal=xreal, yreal=xreal)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
_, result, _, _ = node.process(field, windowing="none", level="mean")
# The peak should be at column offset = freq_cycles from center
mag = result.data
@@ -76,7 +76,7 @@ def test_2d_frequency():
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
_, result, _, _ = node.process(field, windowing="none", level="mean")
mag = result.data
cy, cx = N // 2, N // 2
@@ -110,7 +110,7 @@ def test_psdf_normalization():
field = make_field(data, xreal=xreal, yreal=xreal)
node = FFT2D()
result, = node.process(field, windowing="none", level="none", output="psdf")
_, _, _, result = node.process(field, windowing="none", level="none")
psdf = result.data
# Integrate: sum of PSDF * dk_x * dk_y
@@ -141,11 +141,11 @@ def test_windowing_reduces_leakage():
node = FFT2D()
# Without windowing
r_none, = node.process(field, windowing="none", level="mean", output="magnitude")
_, r_none, _, _ = node.process(field, windowing="none", level="mean")
mag_none = r_none.data[N // 2, :] # center row
# With Hann windowing
r_hann, = node.process(field, windowing="hann", level="mean", output="magnitude")
_, r_hann, _, _ = node.process(field, windowing="hann", level="mean")
mag_hann = r_hann.data[N // 2, :]
# Measure leakage: ratio of energy far from peak vs total
@@ -178,15 +178,15 @@ def test_plane_subtraction():
node = FFT2D()
# Without leveling — huge DC and low-freq energy
r_none, = node.process(field, windowing="none", level="none", output="magnitude")
_, r_none, _, _ = node.process(field, windowing="none", level="none")
dc_none = r_none.data[N // 2, N // 2]
# With mean subtraction — DC removed but gradient leaks
r_mean, = node.process(field, windowing="none", level="mean", output="magnitude")
_, r_mean, _, _ = node.process(field, windowing="none", level="mean")
dc_mean = r_mean.data[N // 2, N // 2]
# With plane subtraction — gradient removed
r_plane, = node.process(field, windowing="none", level="plane", output="magnitude")
_, r_plane, _, _ = node.process(field, windowing="none", level="plane")
dc_plane = r_plane.data[N // 2, N // 2]
# With plane subtraction, check the low-freq energy near DC is reduced
@@ -213,7 +213,7 @@ def test_non_square():
field = make_field(data, xreal=1.5e-6, yreal=1.0e-6)
node = FFT2D()
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
result, _, _, _ = node.process(field, windowing="hann", level="mean")
assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}"
assert np.all(np.isfinite(result.data)), "Non-finite values in output"
print(f" Shape: {result.data.shape}")
@@ -234,7 +234,7 @@ def test_log_magnitude_visual_range():
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
result, _, _, _ = node.process(field, windowing="hann", level="mean")
vmin, vmax = result.data.min(), result.data.max()
dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30)
@@ -246,6 +246,91 @@ def test_log_magnitude_visual_range():
print(" PASS\n")
def test_inverse_fft_reconstructs_from_magnitude_and_phase():
"""Magnitude + phase from FFT2D should reconstruct the original image."""
print("=== Test: Inverse FFT from magnitude + phase ===")
rng = np.random.default_rng(123)
data = rng.standard_normal((64, 96))
field = make_field(data, xreal=2.4e-6, yreal=1.6e-6)
fft_node = FFT2D()
ifft_node = InverseFFT2D()
_, magnitude, phase, _ = fft_node.process(field, windowing="none", level="none")
reconstructed, = ifft_node.process(magnitude, representation="magnitude", phase=phase)
max_err = np.max(np.abs(reconstructed.data - field.data))
print(f" Reconstruction max error: {max_err:.3e}")
assert reconstructed.domain == "spatial"
assert reconstructed.data.shape == field.data.shape
assert np.isclose(reconstructed.xreal, field.xreal)
assert np.isclose(reconstructed.yreal, field.yreal)
assert max_err < 1e-9, f"Expected near-exact reconstruction, got {max_err}"
print(" PASS\n")
def test_inverse_fft_reconstructs_from_log_magnitude_and_phase():
"""log(|F|) + phase should also reconstruct after expm1 inversion."""
print("=== Test: Inverse FFT from log magnitude + phase ===")
y, x = np.mgrid[0:72, 0:80] / 80.0
data = (
0.8 * np.sin(2 * np.pi * 6 * x)
+ 0.35 * np.cos(2 * np.pi * 9 * y)
+ 0.15 * np.sin(2 * np.pi * (4 * x + 3 * y))
)
field = make_field(data, xreal=1.6e-6, yreal=1.44e-6)
fft_node = FFT2D()
ifft_node = InverseFFT2D()
log_magnitude, _, phase, _ = fft_node.process(field, windowing="none", level="none")
reconstructed, = ifft_node.process(log_magnitude, representation="log_magnitude", phase=phase)
rms_err = np.sqrt(np.mean((reconstructed.data - field.data) ** 2))
print(f" Reconstruction RMS error: {rms_err:.3e}")
assert rms_err < 1e-9, f"Expected near-exact reconstruction, got {rms_err}"
print(" PASS\n")
def test_inverse_fft_reconstructs_from_psdf_and_phase():
"""PSDF + phase should reconstruct after undoing PSDF scaling."""
print("=== Test: Inverse FFT from PSDF + phase ===")
rng = np.random.default_rng(321)
data = rng.standard_normal((48, 64))
field = make_field(data, xreal=3.2e-6, yreal=2.4e-6)
fft_node = FFT2D()
ifft_node = InverseFFT2D()
_, _, phase, psdf = fft_node.process(field, windowing="none", level="none")
reconstructed, = ifft_node.process(psdf, representation="psdf", phase=phase)
max_err = np.max(np.abs(reconstructed.data - field.data))
print(f" Reconstruction max error: {max_err:.3e}")
assert reconstructed.si_unit_z == field.si_unit_z
assert max_err < 1e-8, f"Expected near-exact reconstruction, got {max_err}"
print(" PASS\n")
def test_inverse_fft_zero_phase_mode_returns_valid_image():
"""Spectrum-only inversion should return a finite spatial image with the right shape."""
print("=== Test: Inverse FFT zero-phase mode ===")
data = np.sin(2 * np.pi * 5 * np.mgrid[0:64, 0:64][1] / 64.0)
field = make_field(data, xreal=1e-6, yreal=1e-6)
fft_node = FFT2D()
ifft_node = InverseFFT2D()
_, magnitude, _, _ = fft_node.process(field, windowing="none", level="none")
reconstructed, = ifft_node.process(magnitude, representation="magnitude")
print(f" Output shape: {reconstructed.data.shape}")
assert reconstructed.domain == "spatial"
assert reconstructed.data.shape == field.data.shape
assert np.all(np.isfinite(reconstructed.data))
print(" PASS\n")
if __name__ == "__main__":
test_dc_removal()
test_single_frequency()
@@ -255,4 +340,8 @@ if __name__ == "__main__":
test_plane_subtraction()
test_non_square()
test_log_magnitude_visual_range()
test_inverse_fft_reconstructs_from_magnitude_and_phase()
test_inverse_fft_reconstructs_from_log_magnitude_and_phase()
test_inverse_fft_reconstructs_from_psdf_and_phase()
test_inverse_fft_zero_phase_mode_returns_valid_image()
print("All tests passed!")

View File

@@ -43,9 +43,10 @@ def main():
field = make_field(data)
save_field(field, "01_sines_input")
for output_mode in ["log_magnitude", "magnitude", "psdf"]:
result, = node.process(field, windowing="hann", level="mean", output=output_mode)
save_field(result, f"01_sines_{output_mode}")
log_magnitude, magnitude, _, psdf = node.process(field, windowing="hann", level="mean")
save_field(log_magnitude, "01_sines_log_magnitude")
save_field(magnitude, "01_sines_magnitude")
save_field(psdf, "01_sines_psdf")
# --- Test 2: Real-world-like surface with noise + tilt ---
print("\nTest 2: Tilted surface with features")
@@ -57,7 +58,7 @@ def main():
save_field(field, "02_surface_input")
for level_mode in ["none", "mean", "plane"]:
result, = node.process(field, windowing="hann", level=level_mode, output="log_magnitude")
result, _, _, _ = node.process(field, windowing="hann", level=level_mode)
save_field(result, f"02_surface_fft_level_{level_mode}")
# --- Test 3: Checkerboard pattern ---
@@ -67,7 +68,7 @@ def main():
field = make_field(data)
save_field(field, "03_checker_input")
result, = node.process(field, windowing="none", level="mean", output="log_magnitude")
result, _, _, _ = node.process(field, windowing="none", level="mean")
save_field(result, "03_checker_fft")
# --- Test 4: Concentric rings (radial frequency) ---
@@ -77,7 +78,7 @@ def main():
field = make_field(data)
save_field(field, "04_rings_input")
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
result, _, _, _ = node.process(field, windowing="hann", level="mean")
save_field(result, "04_rings_fft")
# --- Test 5: Compare windowing effects ---
@@ -87,7 +88,7 @@ def main():
save_field(field, "05_window_input")
for win in ["none", "hann", "hamming", "blackman"]:
result, = node.process(field, windowing=win, level="mean", output="log_magnitude")
result, _, _, _ = node.process(field, windowing=win, level="mean")
save_field(result, f"05_window_{win}")
print(f"\nAll outputs saved to {OUT_DIR}/")

View File

@@ -1229,7 +1229,7 @@ def test_fft2d():
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
# log_magnitude
spectrum, = node.process(field, windowing="none", level="none", output="log_magnitude")
spectrum, spec_mag, spec_phase, spec_psdf = node.process(field, windowing="none", level="none")
assert spectrum.data.shape == (N, N)
assert spectrum.domain == "frequency"
assert spectrum.si_unit_xy == "1/m"
@@ -1240,31 +1240,31 @@ def test_fft2d():
assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}"
# magnitude output
spec_mag, = node.process(field, windowing="hann", level="mean", output="magnitude")
_, spec_mag, _, _ = node.process(field, windowing="hann", level="mean")
assert spec_mag.data.shape == (N, N)
assert np.all(spec_mag.data >= 0)
# phase output
spec_phase, = node.process(field, windowing="none", level="none", output="phase")
_, _, spec_phase, _ = node.process(field, windowing="none", level="none")
assert spec_phase.data.shape == (N, N)
assert spec_phase.data.min() >= -np.pi - 0.01
assert spec_phase.data.max() <= np.pi + 0.01
# psdf output — units should reflect PSDF calibration
spec_psdf, = node.process(field, windowing="hamming", level="plane", output="psdf")
_, _, _, spec_psdf = node.process(field, windowing="hamming", level="plane")
assert spec_psdf.data.shape == (N, N)
assert np.all(spec_psdf.data >= 0)
assert "^2" in spec_psdf.si_unit_z
# Constant field should have all energy at DC
const_field = make_field(data=np.ones((32, 32)) * 3.0)
spec_const, = node.process(const_field, windowing="none", level="none", output="magnitude")
_, spec_const, _, _ = node.process(const_field, windowing="none", level="none")
centre32 = 16
dc_val = spec_const.data[centre32, centre32]
assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field"
# Blackman windowing should also work without error
spec_bk, = node.process(field, windowing="blackman", level="none", output="log_magnitude")
spec_bk, _, _, _ = node.process(field, windowing="blackman", level="none")
assert spec_bk.data.shape == (N, N)
print(" PASS\n")