259 lines
9.2 KiB
Python
259 lines
9.2 KiB
Python
"""
|
|
Test the FFT2D node against known inputs and Gwyddion-equivalent results.
|
|
|
|
Run from project root:
|
|
python -m tests.test_fft
|
|
"""
|
|
import sys
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, ".")
|
|
from backend.data_types import DataField
|
|
from backend.nodes.fft_2d import FFT2D
|
|
|
|
|
|
def make_field(data, xreal=1e-6, yreal=1e-6):
|
|
"""Create a DataField from a 2D array."""
|
|
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
|
|
|
|
|
|
def test_dc_removal():
|
|
"""A constant image should produce near-zero FFT after mean subtraction."""
|
|
print("=== Test: DC removal ===")
|
|
data = np.ones((64, 64)) * 42.0
|
|
field = make_field(data)
|
|
node = FFT2D()
|
|
|
|
_, result, _, _ = node.process(field, windowing="none", level="mean")
|
|
peak = result.data.max()
|
|
print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}")
|
|
assert peak < 1e-10, f"Expected ~0, got {peak}"
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_single_frequency():
|
|
"""A pure sine wave should produce two peaks at the known frequency."""
|
|
print("=== Test: Single frequency detection ===")
|
|
N = 128
|
|
xreal = 1e-6 # 1 micron
|
|
freq_cycles = 10 # 10 cycles across the image
|
|
|
|
x = np.linspace(0, 1, N, endpoint=False)
|
|
data = np.sin(2 * np.pi * freq_cycles * x)[np.newaxis, :] * np.ones((N, 1))
|
|
field = make_field(data, xreal=xreal, yreal=xreal)
|
|
|
|
node = FFT2D()
|
|
_, result, _, _ = node.process(field, windowing="none", level="mean")
|
|
|
|
# The peak should be at column offset = freq_cycles from center
|
|
mag = result.data
|
|
cy, cx = N // 2, N // 2 # center (DC)
|
|
|
|
# Find the peak location (excluding DC which should be ~0 after mean sub)
|
|
mag_copy = mag.copy()
|
|
mag_copy[cy, cx] = 0
|
|
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
|
|
peak_col_offset = abs(peak_idx[1] - cx)
|
|
|
|
print(f" Image: {N}x{N}, {freq_cycles} horizontal cycles")
|
|
print(f" Expected peak at column offset {freq_cycles} from center")
|
|
print(f" Found peak at {peak_idx} (offset {peak_col_offset})")
|
|
print(f" DC value: {mag[cy, cx]:.2e}")
|
|
print(f" Peak value: {mag[peak_idx]:.2e}")
|
|
assert peak_col_offset == freq_cycles, f"Expected offset {freq_cycles}, got {peak_col_offset}"
|
|
assert peak_idx[0] == cy, f"Expected peak on center row, got row {peak_idx[0]}"
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_2d_frequency():
|
|
"""A 2D sine should produce peaks at the correct (kx, ky) position."""
|
|
print("=== Test: 2D frequency detection ===")
|
|
N = 128
|
|
fx, fy = 8, 5 # cycles in x and y
|
|
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
data = np.sin(2 * np.pi * (fx * x + fy * y))
|
|
field = make_field(data)
|
|
|
|
node = FFT2D()
|
|
_, result, _, _ = node.process(field, windowing="none", level="mean")
|
|
mag = result.data
|
|
|
|
cy, cx = N // 2, N // 2
|
|
mag_copy = mag.copy()
|
|
mag_copy[cy, cx] = 0
|
|
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
|
|
|
|
dx = abs(peak_idx[1] - cx)
|
|
dy = abs(peak_idx[0] - cy)
|
|
|
|
print(f" Input: sin(2π({fx}x + {fy}y))")
|
|
print(f" Expected peak offset: ({fy}, {fx}) from center")
|
|
print(f" Found peak at {peak_idx} (offset dy={dy}, dx={dx})")
|
|
assert dx == fx and dy == fy, f"Expected ({fy},{fx}), got ({dy},{dx})"
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_psdf_normalization():
|
|
"""
|
|
PSDF of white noise should integrate to the variance.
|
|
|
|
Parseval's theorem: sum of PSDF * dk_x * dk_y ≈ variance of the signal.
|
|
"""
|
|
print("=== Test: PSDF normalization (Parseval) ===")
|
|
N = 256
|
|
xreal = 1e-6
|
|
rng = np.random.default_rng(42)
|
|
data = rng.standard_normal((N, N))
|
|
variance = data.var()
|
|
|
|
field = make_field(data, xreal=xreal, yreal=xreal)
|
|
node = FFT2D()
|
|
|
|
_, _, _, result = node.process(field, windowing="none", level="none")
|
|
psdf = result.data
|
|
|
|
# Integrate: sum of PSDF * dk_x * dk_y
|
|
# Our output field has xreal = 2π*N/xreal (angular freq range)
|
|
dk_x = result.xreal / N
|
|
dk_y = result.yreal / N
|
|
integral = psdf.sum() * dk_x * dk_y
|
|
|
|
ratio = integral / variance
|
|
print(f" Signal variance: {variance:.6f}")
|
|
print(f" PSDF integral: {integral:.6f}")
|
|
print(f" Ratio (should be ~1.0): {ratio:.4f}")
|
|
# Allow 20% tolerance for finite-size effects
|
|
assert 0.8 < ratio < 1.2, f"Parseval's theorem violated: ratio = {ratio}"
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_windowing_reduces_leakage():
|
|
"""Windowing should reduce spectral leakage from a non-integer frequency."""
|
|
print("=== Test: Windowing reduces leakage ===")
|
|
N = 128
|
|
freq = 10.5 # non-integer → spectral leakage without windowing
|
|
|
|
x = np.linspace(0, 1, N, endpoint=False)
|
|
data = np.sin(2 * np.pi * freq * x)[np.newaxis, :] * np.ones((N, 1))
|
|
field = make_field(data)
|
|
|
|
node = FFT2D()
|
|
|
|
# Without windowing
|
|
_, r_none, _, _ = node.process(field, windowing="none", level="mean")
|
|
mag_none = r_none.data[N // 2, :] # center row
|
|
|
|
# With Hann windowing
|
|
_, r_hann, _, _ = node.process(field, windowing="hann", level="mean")
|
|
mag_hann = r_hann.data[N // 2, :]
|
|
|
|
# Measure leakage: ratio of energy far from peak vs total
|
|
peak_col = np.argmax(mag_none)
|
|
far_mask = np.ones(N, dtype=bool)
|
|
far_mask[max(0, peak_col - 3):peak_col + 4] = False
|
|
# Also mask the symmetric peak
|
|
sym_col = N - peak_col
|
|
far_mask[max(0, sym_col - 3):sym_col + 4] = False
|
|
|
|
leakage_none = mag_none[far_mask].sum() / mag_none.sum()
|
|
leakage_hann = mag_hann[far_mask].sum() / mag_hann.sum()
|
|
|
|
print(f" Non-integer frequency: {freq}")
|
|
print(f" Leakage without windowing: {leakage_none:.4f}")
|
|
print(f" Leakage with Hann window: {leakage_hann:.4f}")
|
|
assert leakage_hann < leakage_none, "Hann window should reduce leakage"
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_plane_subtraction():
|
|
"""Plane subtraction should remove linear gradients."""
|
|
print("=== Test: Plane subtraction ===")
|
|
N = 64
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
# Tilted plane + sine wave
|
|
data = 100 * x + 50 * y + np.sin(2 * np.pi * 8 * x)
|
|
field = make_field(data)
|
|
|
|
node = FFT2D()
|
|
|
|
# Without leveling — huge DC and low-freq energy
|
|
_, r_none, _, _ = node.process(field, windowing="none", level="none")
|
|
dc_none = r_none.data[N // 2, N // 2]
|
|
|
|
# With mean subtraction — DC removed but gradient leaks
|
|
_, r_mean, _, _ = node.process(field, windowing="none", level="mean")
|
|
dc_mean = r_mean.data[N // 2, N // 2]
|
|
|
|
# With plane subtraction — gradient removed
|
|
_, r_plane, _, _ = node.process(field, windowing="none", level="plane")
|
|
dc_plane = r_plane.data[N // 2, N // 2]
|
|
|
|
# With plane subtraction, check the low-freq energy near DC is reduced
|
|
# (plane subtraction removes gradients that leak into low frequencies)
|
|
r = 3 # radius around DC to check
|
|
cy, cx = N // 2, N // 2
|
|
lowfreq_none = r_none.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
|
|
lowfreq_plane = r_plane.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
|
|
|
|
print(f" DC magnitude (no leveling): {dc_none:.2e}")
|
|
print(f" DC magnitude (mean subtract): {dc_mean:.2e}")
|
|
print(f" DC magnitude (plane subtract): {dc_plane:.2e}")
|
|
print(f" Low-freq energy (no level): {lowfreq_none:.2e}")
|
|
print(f" Low-freq energy (plane sub): {lowfreq_plane:.2e}")
|
|
assert dc_mean < dc_none, "Mean subtraction should reduce DC"
|
|
assert lowfreq_plane < lowfreq_none * 0.01, "Plane subtraction should reduce low-freq energy"
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_non_square():
|
|
"""FFT should work on non-square, non-power-of-2 images."""
|
|
print("=== Test: Non-square image ===")
|
|
data = np.random.default_rng(99).standard_normal((100, 150))
|
|
field = make_field(data, xreal=1.5e-6, yreal=1.0e-6)
|
|
node = FFT2D()
|
|
|
|
result, _, _, _ = node.process(field, windowing="hann", level="mean")
|
|
assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}"
|
|
assert np.all(np.isfinite(result.data)), "Non-finite values in output"
|
|
print(f" Shape: {result.data.shape}")
|
|
print(f" Output range: [{result.data.min():.4f}, {result.data.max():.4f}]")
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_log_magnitude_visual_range():
|
|
"""Log magnitude should produce a reasonable dynamic range for display."""
|
|
print("=== Test: Log magnitude visual range ===")
|
|
N = 128
|
|
x = np.linspace(0, 1, N, endpoint=False)
|
|
# Multi-frequency test image
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
data = (np.sin(2 * np.pi * 5 * x) +
|
|
0.5 * np.sin(2 * np.pi * 15 * x + 2 * np.pi * 10 * y) +
|
|
0.1 * np.random.default_rng(7).standard_normal((N, N)))
|
|
field = make_field(data)
|
|
|
|
node = FFT2D()
|
|
result, _, _, _ = node.process(field, windowing="hann", level="mean")
|
|
|
|
vmin, vmax = result.data.min(), result.data.max()
|
|
dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30)
|
|
|
|
print(f" Log magnitude range: [{vmin:.4f}, {vmax:.4f}]")
|
|
print(f" Dynamic range: {dynamic_range:.2f}")
|
|
assert vmax > vmin, "Log magnitude should have nonzero range"
|
|
assert np.all(np.isfinite(result.data)), "Non-finite values in log magnitude"
|
|
print(" PASS\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_dc_removal()
|
|
test_single_frequency()
|
|
test_2d_frequency()
|
|
test_psdf_normalization()
|
|
test_windowing_reduces_leakage()
|
|
test_plane_subtraction()
|
|
test_non_square()
|
|
test_log_magnitude_visual_range()
|
|
print("All tests passed!")
|