fft multi channel output
This commit is contained in:
@@ -266,21 +266,20 @@ class FFT2D:
|
|||||||
"field": ("DATA_FIELD",),
|
"field": ("DATA_FIELD",),
|
||||||
"windowing": (["hann", "hamming", "blackman", "none"],),
|
"windowing": (["hann", "hamming", "blackman", "none"],),
|
||||||
"level": (["mean", "plane", "none"],),
|
"level": (["mean", "plane", "none"],),
|
||||||
"output": (["log_magnitude", "magnitude", "phase", "psdf"],),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("DATA_FIELD",)
|
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "DATA_FIELD", "DATA_FIELD")
|
||||||
RETURN_NAMES = ("spectrum",)
|
RETURN_NAMES = ("log_magnitude", "magnitude", "phase", "psdf")
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "analysis"
|
CATEGORY = "analysis"
|
||||||
DESCRIPTION = (
|
DESCRIPTION = (
|
||||||
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
|
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
|
||||||
"Output can be log magnitude, magnitude, phase, or PSDF. "
|
"Outputs log magnitude, magnitude, phase, and PSDF as separate channels. "
|
||||||
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
|
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, field: DataField, windowing: str, level: str, output: str) -> tuple:
|
def process(self, field: DataField, windowing: str, level: str) -> tuple:
|
||||||
data = field.data.copy()
|
data = field.data.copy()
|
||||||
yres, xres = data.shape
|
yres, xres = data.shape
|
||||||
|
|
||||||
@@ -320,8 +319,59 @@ class FFT2D:
|
|||||||
F = np.fft.fftshift(np.fft.fft2(data))
|
F = np.fft.fftshift(np.fft.fft2(data))
|
||||||
n = xres * yres
|
n = xres * yres
|
||||||
|
|
||||||
if output == "log_magnitude":
|
magnitude = np.abs(F)
|
||||||
mag = np.abs(F)
|
log_magnitude = np.log1p(magnitude)
|
||||||
|
phase = np.angle(F)
|
||||||
|
|
||||||
|
dx = field.xreal / xres
|
||||||
|
dy = field.yreal / yres
|
||||||
|
psdf = (magnitude ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
|
||||||
|
|
||||||
|
spatial_freq_xreal = xres / field.xreal
|
||||||
|
spatial_freq_yreal = yres / field.yreal
|
||||||
|
angular_freq_xreal = 2.0 * np.pi * xres / field.xreal
|
||||||
|
angular_freq_yreal = 2.0 * np.pi * yres / field.yreal
|
||||||
|
|
||||||
|
return (
|
||||||
|
DataField(
|
||||||
|
data=log_magnitude,
|
||||||
|
xreal=spatial_freq_xreal,
|
||||||
|
yreal=spatial_freq_yreal,
|
||||||
|
si_unit_xy="1/m",
|
||||||
|
si_unit_z=field.si_unit_z,
|
||||||
|
domain="frequency",
|
||||||
|
colormap=field.colormap,
|
||||||
|
),
|
||||||
|
DataField(
|
||||||
|
data=magnitude,
|
||||||
|
xreal=spatial_freq_xreal,
|
||||||
|
yreal=spatial_freq_yreal,
|
||||||
|
si_unit_xy="1/m",
|
||||||
|
si_unit_z=field.si_unit_z,
|
||||||
|
domain="frequency",
|
||||||
|
colormap=field.colormap,
|
||||||
|
),
|
||||||
|
DataField(
|
||||||
|
data=phase,
|
||||||
|
xreal=spatial_freq_xreal,
|
||||||
|
yreal=spatial_freq_yreal,
|
||||||
|
si_unit_xy="1/m",
|
||||||
|
si_unit_z=field.si_unit_z,
|
||||||
|
domain="frequency",
|
||||||
|
colormap=field.colormap,
|
||||||
|
),
|
||||||
|
DataField(
|
||||||
|
data=psdf,
|
||||||
|
xreal=angular_freq_xreal,
|
||||||
|
yreal=angular_freq_yreal,
|
||||||
|
si_unit_xy="1/m",
|
||||||
|
si_unit_z=f"({field.si_unit_z})^2 m^2",
|
||||||
|
domain="frequency",
|
||||||
|
colormap=field.colormap,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if False: # Unreachable legacy block retained below.
|
||||||
# Log scale with floor to avoid log(0)
|
# Log scale with floor to avoid log(0)
|
||||||
result = np.log1p(mag)
|
result = np.log1p(mag)
|
||||||
elif output == "magnitude":
|
elif output == "magnitude":
|
||||||
@@ -359,6 +409,109 @@ class FFT2D:
|
|||||||
return (out_field,)
|
return (out_field,)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# InverseFFT2D
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Inverse 2D FFT")
|
||||||
|
class InverseFFT2D:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"spectrum": ("DATA_FIELD",),
|
||||||
|
"representation": (["magnitude", "log_magnitude", "psdf"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"phase": ("DATA_FIELD",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD",)
|
||||||
|
RETURN_NAMES = ("image",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "analysis"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Reconstruct a spatial-domain image from a 2D frequency spectrum. "
|
||||||
|
"For exact reconstruction, connect magnitude/phase (or log magnitude/phase, "
|
||||||
|
"or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple:
|
||||||
|
if spectrum.domain != "frequency":
|
||||||
|
raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.")
|
||||||
|
|
||||||
|
if phase is not None:
|
||||||
|
if phase.data.shape != spectrum.data.shape:
|
||||||
|
raise ValueError("Phase input must have the same shape as the spectrum.")
|
||||||
|
if phase.domain != "frequency":
|
||||||
|
raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.")
|
||||||
|
|
||||||
|
amplitude = self._resolve_amplitude(spectrum, representation)
|
||||||
|
phase_data = phase.data if phase is not None else np.zeros_like(amplitude)
|
||||||
|
F = amplitude * np.exp(1j * phase_data)
|
||||||
|
|
||||||
|
spatial = np.fft.ifft2(np.fft.ifftshift(F)).real
|
||||||
|
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
|
||||||
|
z_unit = self._recover_z_unit(spectrum, representation, phase)
|
||||||
|
|
||||||
|
out_field = DataField(
|
||||||
|
data=spatial,
|
||||||
|
xreal=xreal,
|
||||||
|
yreal=yreal,
|
||||||
|
si_unit_xy="m",
|
||||||
|
si_unit_z=z_unit,
|
||||||
|
domain="spatial",
|
||||||
|
colormap=spectrum.colormap,
|
||||||
|
)
|
||||||
|
return (out_field,)
|
||||||
|
|
||||||
|
def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray:
|
||||||
|
data = np.asarray(spectrum.data, dtype=np.float64)
|
||||||
|
|
||||||
|
if representation == "magnitude":
|
||||||
|
return np.clip(data, 0.0, None)
|
||||||
|
if representation == "log_magnitude":
|
||||||
|
return np.expm1(data)
|
||||||
|
if representation == "psdf":
|
||||||
|
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
|
||||||
|
n = spectrum.xres * spectrum.yres
|
||||||
|
dx = xreal / spectrum.xres
|
||||||
|
dy = yreal / spectrum.yres
|
||||||
|
scale = n * 4.0 * np.pi ** 2 / (dx * dy)
|
||||||
|
return np.sqrt(np.clip(data, 0.0, None) * scale)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported spectrum representation: {representation}")
|
||||||
|
|
||||||
|
def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]:
|
||||||
|
if representation == "psdf":
|
||||||
|
xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal
|
||||||
|
yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal
|
||||||
|
else:
|
||||||
|
xreal = spectrum.xres / spectrum.xreal
|
||||||
|
yreal = spectrum.yres / spectrum.yreal
|
||||||
|
return float(xreal), float(yreal)
|
||||||
|
|
||||||
|
def _recover_z_unit(
|
||||||
|
self,
|
||||||
|
spectrum: DataField,
|
||||||
|
representation: str,
|
||||||
|
phase: DataField | None,
|
||||||
|
) -> str:
|
||||||
|
if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip():
|
||||||
|
return phase.si_unit_z
|
||||||
|
|
||||||
|
if representation != "psdf":
|
||||||
|
return spectrum.si_unit_z
|
||||||
|
|
||||||
|
unit = str(spectrum.si_unit_z or "").strip()
|
||||||
|
if unit.startswith("(") and ")^2 m^2" in unit:
|
||||||
|
return unit.split(")^2 m^2", 1)[0][1:]
|
||||||
|
if unit.endswith("^2 m^2"):
|
||||||
|
return unit[:-6].removesuffix("^2").strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# CrossSection
|
# CrossSection
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import numpy as np
|
|||||||
|
|
||||||
sys.path.insert(0, ".")
|
sys.path.insert(0, ".")
|
||||||
from backend.data_types import DataField
|
from backend.data_types import DataField
|
||||||
from backend.nodes.analysis import FFT2D
|
from backend.nodes.analysis import FFT2D, InverseFFT2D
|
||||||
|
|
||||||
|
|
||||||
def make_field(data, xreal=1e-6, yreal=1e-6):
|
def make_field(data, xreal=1e-6, yreal=1e-6):
|
||||||
@@ -24,7 +24,7 @@ def test_dc_removal():
|
|||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
node = FFT2D()
|
node = FFT2D()
|
||||||
|
|
||||||
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
_, result, _, _ = node.process(field, windowing="none", level="mean")
|
||||||
peak = result.data.max()
|
peak = result.data.max()
|
||||||
print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}")
|
print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}")
|
||||||
assert peak < 1e-10, f"Expected ~0, got {peak}"
|
assert peak < 1e-10, f"Expected ~0, got {peak}"
|
||||||
@@ -43,7 +43,7 @@ def test_single_frequency():
|
|||||||
field = make_field(data, xreal=xreal, yreal=xreal)
|
field = make_field(data, xreal=xreal, yreal=xreal)
|
||||||
|
|
||||||
node = FFT2D()
|
node = FFT2D()
|
||||||
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
_, result, _, _ = node.process(field, windowing="none", level="mean")
|
||||||
|
|
||||||
# The peak should be at column offset = freq_cycles from center
|
# The peak should be at column offset = freq_cycles from center
|
||||||
mag = result.data
|
mag = result.data
|
||||||
@@ -76,7 +76,7 @@ def test_2d_frequency():
|
|||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
|
|
||||||
node = FFT2D()
|
node = FFT2D()
|
||||||
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
_, result, _, _ = node.process(field, windowing="none", level="mean")
|
||||||
mag = result.data
|
mag = result.data
|
||||||
|
|
||||||
cy, cx = N // 2, N // 2
|
cy, cx = N // 2, N // 2
|
||||||
@@ -110,7 +110,7 @@ def test_psdf_normalization():
|
|||||||
field = make_field(data, xreal=xreal, yreal=xreal)
|
field = make_field(data, xreal=xreal, yreal=xreal)
|
||||||
node = FFT2D()
|
node = FFT2D()
|
||||||
|
|
||||||
result, = node.process(field, windowing="none", level="none", output="psdf")
|
_, _, _, result = node.process(field, windowing="none", level="none")
|
||||||
psdf = result.data
|
psdf = result.data
|
||||||
|
|
||||||
# Integrate: sum of PSDF * dk_x * dk_y
|
# Integrate: sum of PSDF * dk_x * dk_y
|
||||||
@@ -141,11 +141,11 @@ def test_windowing_reduces_leakage():
|
|||||||
node = FFT2D()
|
node = FFT2D()
|
||||||
|
|
||||||
# Without windowing
|
# Without windowing
|
||||||
r_none, = node.process(field, windowing="none", level="mean", output="magnitude")
|
_, r_none, _, _ = node.process(field, windowing="none", level="mean")
|
||||||
mag_none = r_none.data[N // 2, :] # center row
|
mag_none = r_none.data[N // 2, :] # center row
|
||||||
|
|
||||||
# With Hann windowing
|
# With Hann windowing
|
||||||
r_hann, = node.process(field, windowing="hann", level="mean", output="magnitude")
|
_, r_hann, _, _ = node.process(field, windowing="hann", level="mean")
|
||||||
mag_hann = r_hann.data[N // 2, :]
|
mag_hann = r_hann.data[N // 2, :]
|
||||||
|
|
||||||
# Measure leakage: ratio of energy far from peak vs total
|
# Measure leakage: ratio of energy far from peak vs total
|
||||||
@@ -178,15 +178,15 @@ def test_plane_subtraction():
|
|||||||
node = FFT2D()
|
node = FFT2D()
|
||||||
|
|
||||||
# Without leveling — huge DC and low-freq energy
|
# Without leveling — huge DC and low-freq energy
|
||||||
r_none, = node.process(field, windowing="none", level="none", output="magnitude")
|
_, r_none, _, _ = node.process(field, windowing="none", level="none")
|
||||||
dc_none = r_none.data[N // 2, N // 2]
|
dc_none = r_none.data[N // 2, N // 2]
|
||||||
|
|
||||||
# With mean subtraction — DC removed but gradient leaks
|
# With mean subtraction — DC removed but gradient leaks
|
||||||
r_mean, = node.process(field, windowing="none", level="mean", output="magnitude")
|
_, r_mean, _, _ = node.process(field, windowing="none", level="mean")
|
||||||
dc_mean = r_mean.data[N // 2, N // 2]
|
dc_mean = r_mean.data[N // 2, N // 2]
|
||||||
|
|
||||||
# With plane subtraction — gradient removed
|
# With plane subtraction — gradient removed
|
||||||
r_plane, = node.process(field, windowing="none", level="plane", output="magnitude")
|
_, r_plane, _, _ = node.process(field, windowing="none", level="plane")
|
||||||
dc_plane = r_plane.data[N // 2, N // 2]
|
dc_plane = r_plane.data[N // 2, N // 2]
|
||||||
|
|
||||||
# With plane subtraction, check the low-freq energy near DC is reduced
|
# With plane subtraction, check the low-freq energy near DC is reduced
|
||||||
@@ -213,7 +213,7 @@ def test_non_square():
|
|||||||
field = make_field(data, xreal=1.5e-6, yreal=1.0e-6)
|
field = make_field(data, xreal=1.5e-6, yreal=1.0e-6)
|
||||||
node = FFT2D()
|
node = FFT2D()
|
||||||
|
|
||||||
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
result, _, _, _ = node.process(field, windowing="hann", level="mean")
|
||||||
assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}"
|
assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}"
|
||||||
assert np.all(np.isfinite(result.data)), "Non-finite values in output"
|
assert np.all(np.isfinite(result.data)), "Non-finite values in output"
|
||||||
print(f" Shape: {result.data.shape}")
|
print(f" Shape: {result.data.shape}")
|
||||||
@@ -234,7 +234,7 @@ def test_log_magnitude_visual_range():
|
|||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
|
|
||||||
node = FFT2D()
|
node = FFT2D()
|
||||||
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
result, _, _, _ = node.process(field, windowing="hann", level="mean")
|
||||||
|
|
||||||
vmin, vmax = result.data.min(), result.data.max()
|
vmin, vmax = result.data.min(), result.data.max()
|
||||||
dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30)
|
dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30)
|
||||||
@@ -246,6 +246,91 @@ def test_log_magnitude_visual_range():
|
|||||||
print(" PASS\n")
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_inverse_fft_reconstructs_from_magnitude_and_phase():
|
||||||
|
"""Magnitude + phase from FFT2D should reconstruct the original image."""
|
||||||
|
print("=== Test: Inverse FFT from magnitude + phase ===")
|
||||||
|
rng = np.random.default_rng(123)
|
||||||
|
data = rng.standard_normal((64, 96))
|
||||||
|
field = make_field(data, xreal=2.4e-6, yreal=1.6e-6)
|
||||||
|
|
||||||
|
fft_node = FFT2D()
|
||||||
|
ifft_node = InverseFFT2D()
|
||||||
|
|
||||||
|
_, magnitude, phase, _ = fft_node.process(field, windowing="none", level="none")
|
||||||
|
reconstructed, = ifft_node.process(magnitude, representation="magnitude", phase=phase)
|
||||||
|
|
||||||
|
max_err = np.max(np.abs(reconstructed.data - field.data))
|
||||||
|
print(f" Reconstruction max error: {max_err:.3e}")
|
||||||
|
assert reconstructed.domain == "spatial"
|
||||||
|
assert reconstructed.data.shape == field.data.shape
|
||||||
|
assert np.isclose(reconstructed.xreal, field.xreal)
|
||||||
|
assert np.isclose(reconstructed.yreal, field.yreal)
|
||||||
|
assert max_err < 1e-9, f"Expected near-exact reconstruction, got {max_err}"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_inverse_fft_reconstructs_from_log_magnitude_and_phase():
|
||||||
|
"""log(|F|) + phase should also reconstruct after expm1 inversion."""
|
||||||
|
print("=== Test: Inverse FFT from log magnitude + phase ===")
|
||||||
|
y, x = np.mgrid[0:72, 0:80] / 80.0
|
||||||
|
data = (
|
||||||
|
0.8 * np.sin(2 * np.pi * 6 * x)
|
||||||
|
+ 0.35 * np.cos(2 * np.pi * 9 * y)
|
||||||
|
+ 0.15 * np.sin(2 * np.pi * (4 * x + 3 * y))
|
||||||
|
)
|
||||||
|
field = make_field(data, xreal=1.6e-6, yreal=1.44e-6)
|
||||||
|
|
||||||
|
fft_node = FFT2D()
|
||||||
|
ifft_node = InverseFFT2D()
|
||||||
|
|
||||||
|
log_magnitude, _, phase, _ = fft_node.process(field, windowing="none", level="none")
|
||||||
|
reconstructed, = ifft_node.process(log_magnitude, representation="log_magnitude", phase=phase)
|
||||||
|
|
||||||
|
rms_err = np.sqrt(np.mean((reconstructed.data - field.data) ** 2))
|
||||||
|
print(f" Reconstruction RMS error: {rms_err:.3e}")
|
||||||
|
assert rms_err < 1e-9, f"Expected near-exact reconstruction, got {rms_err}"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_inverse_fft_reconstructs_from_psdf_and_phase():
|
||||||
|
"""PSDF + phase should reconstruct after undoing PSDF scaling."""
|
||||||
|
print("=== Test: Inverse FFT from PSDF + phase ===")
|
||||||
|
rng = np.random.default_rng(321)
|
||||||
|
data = rng.standard_normal((48, 64))
|
||||||
|
field = make_field(data, xreal=3.2e-6, yreal=2.4e-6)
|
||||||
|
|
||||||
|
fft_node = FFT2D()
|
||||||
|
ifft_node = InverseFFT2D()
|
||||||
|
|
||||||
|
_, _, phase, psdf = fft_node.process(field, windowing="none", level="none")
|
||||||
|
reconstructed, = ifft_node.process(psdf, representation="psdf", phase=phase)
|
||||||
|
|
||||||
|
max_err = np.max(np.abs(reconstructed.data - field.data))
|
||||||
|
print(f" Reconstruction max error: {max_err:.3e}")
|
||||||
|
assert reconstructed.si_unit_z == field.si_unit_z
|
||||||
|
assert max_err < 1e-8, f"Expected near-exact reconstruction, got {max_err}"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_inverse_fft_zero_phase_mode_returns_valid_image():
|
||||||
|
"""Spectrum-only inversion should return a finite spatial image with the right shape."""
|
||||||
|
print("=== Test: Inverse FFT zero-phase mode ===")
|
||||||
|
data = np.sin(2 * np.pi * 5 * np.mgrid[0:64, 0:64][1] / 64.0)
|
||||||
|
field = make_field(data, xreal=1e-6, yreal=1e-6)
|
||||||
|
|
||||||
|
fft_node = FFT2D()
|
||||||
|
ifft_node = InverseFFT2D()
|
||||||
|
|
||||||
|
_, magnitude, _, _ = fft_node.process(field, windowing="none", level="none")
|
||||||
|
reconstructed, = ifft_node.process(magnitude, representation="magnitude")
|
||||||
|
|
||||||
|
print(f" Output shape: {reconstructed.data.shape}")
|
||||||
|
assert reconstructed.domain == "spatial"
|
||||||
|
assert reconstructed.data.shape == field.data.shape
|
||||||
|
assert np.all(np.isfinite(reconstructed.data))
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_dc_removal()
|
test_dc_removal()
|
||||||
test_single_frequency()
|
test_single_frequency()
|
||||||
@@ -255,4 +340,8 @@ if __name__ == "__main__":
|
|||||||
test_plane_subtraction()
|
test_plane_subtraction()
|
||||||
test_non_square()
|
test_non_square()
|
||||||
test_log_magnitude_visual_range()
|
test_log_magnitude_visual_range()
|
||||||
|
test_inverse_fft_reconstructs_from_magnitude_and_phase()
|
||||||
|
test_inverse_fft_reconstructs_from_log_magnitude_and_phase()
|
||||||
|
test_inverse_fft_reconstructs_from_psdf_and_phase()
|
||||||
|
test_inverse_fft_zero_phase_mode_returns_valid_image()
|
||||||
print("All tests passed!")
|
print("All tests passed!")
|
||||||
|
|||||||
@@ -43,9 +43,10 @@ def main():
|
|||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
save_field(field, "01_sines_input")
|
save_field(field, "01_sines_input")
|
||||||
|
|
||||||
for output_mode in ["log_magnitude", "magnitude", "psdf"]:
|
log_magnitude, magnitude, _, psdf = node.process(field, windowing="hann", level="mean")
|
||||||
result, = node.process(field, windowing="hann", level="mean", output=output_mode)
|
save_field(log_magnitude, "01_sines_log_magnitude")
|
||||||
save_field(result, f"01_sines_{output_mode}")
|
save_field(magnitude, "01_sines_magnitude")
|
||||||
|
save_field(psdf, "01_sines_psdf")
|
||||||
|
|
||||||
# --- Test 2: Real-world-like surface with noise + tilt ---
|
# --- Test 2: Real-world-like surface with noise + tilt ---
|
||||||
print("\nTest 2: Tilted surface with features")
|
print("\nTest 2: Tilted surface with features")
|
||||||
@@ -57,7 +58,7 @@ def main():
|
|||||||
save_field(field, "02_surface_input")
|
save_field(field, "02_surface_input")
|
||||||
|
|
||||||
for level_mode in ["none", "mean", "plane"]:
|
for level_mode in ["none", "mean", "plane"]:
|
||||||
result, = node.process(field, windowing="hann", level=level_mode, output="log_magnitude")
|
result, _, _, _ = node.process(field, windowing="hann", level=level_mode)
|
||||||
save_field(result, f"02_surface_fft_level_{level_mode}")
|
save_field(result, f"02_surface_fft_level_{level_mode}")
|
||||||
|
|
||||||
# --- Test 3: Checkerboard pattern ---
|
# --- Test 3: Checkerboard pattern ---
|
||||||
@@ -67,7 +68,7 @@ def main():
|
|||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
save_field(field, "03_checker_input")
|
save_field(field, "03_checker_input")
|
||||||
|
|
||||||
result, = node.process(field, windowing="none", level="mean", output="log_magnitude")
|
result, _, _, _ = node.process(field, windowing="none", level="mean")
|
||||||
save_field(result, "03_checker_fft")
|
save_field(result, "03_checker_fft")
|
||||||
|
|
||||||
# --- Test 4: Concentric rings (radial frequency) ---
|
# --- Test 4: Concentric rings (radial frequency) ---
|
||||||
@@ -77,7 +78,7 @@ def main():
|
|||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
save_field(field, "04_rings_input")
|
save_field(field, "04_rings_input")
|
||||||
|
|
||||||
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
result, _, _, _ = node.process(field, windowing="hann", level="mean")
|
||||||
save_field(result, "04_rings_fft")
|
save_field(result, "04_rings_fft")
|
||||||
|
|
||||||
# --- Test 5: Compare windowing effects ---
|
# --- Test 5: Compare windowing effects ---
|
||||||
@@ -87,7 +88,7 @@ def main():
|
|||||||
save_field(field, "05_window_input")
|
save_field(field, "05_window_input")
|
||||||
|
|
||||||
for win in ["none", "hann", "hamming", "blackman"]:
|
for win in ["none", "hann", "hamming", "blackman"]:
|
||||||
result, = node.process(field, windowing=win, level="mean", output="log_magnitude")
|
result, _, _, _ = node.process(field, windowing=win, level="mean")
|
||||||
save_field(result, f"05_window_{win}")
|
save_field(result, f"05_window_{win}")
|
||||||
|
|
||||||
print(f"\nAll outputs saved to {OUT_DIR}/")
|
print(f"\nAll outputs saved to {OUT_DIR}/")
|
||||||
|
|||||||
@@ -1229,7 +1229,7 @@ def test_fft2d():
|
|||||||
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
||||||
|
|
||||||
# log_magnitude
|
# log_magnitude
|
||||||
spectrum, = node.process(field, windowing="none", level="none", output="log_magnitude")
|
spectrum, spec_mag, spec_phase, spec_psdf = node.process(field, windowing="none", level="none")
|
||||||
assert spectrum.data.shape == (N, N)
|
assert spectrum.data.shape == (N, N)
|
||||||
assert spectrum.domain == "frequency"
|
assert spectrum.domain == "frequency"
|
||||||
assert spectrum.si_unit_xy == "1/m"
|
assert spectrum.si_unit_xy == "1/m"
|
||||||
@@ -1240,31 +1240,31 @@ def test_fft2d():
|
|||||||
assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}"
|
assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}"
|
||||||
|
|
||||||
# magnitude output
|
# magnitude output
|
||||||
spec_mag, = node.process(field, windowing="hann", level="mean", output="magnitude")
|
_, spec_mag, _, _ = node.process(field, windowing="hann", level="mean")
|
||||||
assert spec_mag.data.shape == (N, N)
|
assert spec_mag.data.shape == (N, N)
|
||||||
assert np.all(spec_mag.data >= 0)
|
assert np.all(spec_mag.data >= 0)
|
||||||
|
|
||||||
# phase output
|
# phase output
|
||||||
spec_phase, = node.process(field, windowing="none", level="none", output="phase")
|
_, _, spec_phase, _ = node.process(field, windowing="none", level="none")
|
||||||
assert spec_phase.data.shape == (N, N)
|
assert spec_phase.data.shape == (N, N)
|
||||||
assert spec_phase.data.min() >= -np.pi - 0.01
|
assert spec_phase.data.min() >= -np.pi - 0.01
|
||||||
assert spec_phase.data.max() <= np.pi + 0.01
|
assert spec_phase.data.max() <= np.pi + 0.01
|
||||||
|
|
||||||
# psdf output — units should reflect PSDF calibration
|
# psdf output — units should reflect PSDF calibration
|
||||||
spec_psdf, = node.process(field, windowing="hamming", level="plane", output="psdf")
|
_, _, _, spec_psdf = node.process(field, windowing="hamming", level="plane")
|
||||||
assert spec_psdf.data.shape == (N, N)
|
assert spec_psdf.data.shape == (N, N)
|
||||||
assert np.all(spec_psdf.data >= 0)
|
assert np.all(spec_psdf.data >= 0)
|
||||||
assert "^2" in spec_psdf.si_unit_z
|
assert "^2" in spec_psdf.si_unit_z
|
||||||
|
|
||||||
# Constant field should have all energy at DC
|
# Constant field should have all energy at DC
|
||||||
const_field = make_field(data=np.ones((32, 32)) * 3.0)
|
const_field = make_field(data=np.ones((32, 32)) * 3.0)
|
||||||
spec_const, = node.process(const_field, windowing="none", level="none", output="magnitude")
|
_, spec_const, _, _ = node.process(const_field, windowing="none", level="none")
|
||||||
centre32 = 16
|
centre32 = 16
|
||||||
dc_val = spec_const.data[centre32, centre32]
|
dc_val = spec_const.data[centre32, centre32]
|
||||||
assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field"
|
assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field"
|
||||||
|
|
||||||
# Blackman windowing should also work without error
|
# Blackman windowing should also work without error
|
||||||
spec_bk, = node.process(field, windowing="blackman", level="none", output="log_magnitude")
|
spec_bk, _, _, _ = node.process(field, windowing="blackman", level="none")
|
||||||
assert spec_bk.data.shape == (N, N)
|
assert spec_bk.data.shape == (N, N)
|
||||||
|
|
||||||
print(" PASS\n")
|
print(" PASS\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user