initial commit
0
tests/__init__.py
Normal file
BIN
tests/output/01_sines_input.png
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
tests/output/01_sines_log_magnitude.png
Normal file
|
After Width: | Height: | Size: 981 B |
BIN
tests/output/01_sines_magnitude.png
Normal file
|
After Width: | Height: | Size: 1.0 KiB |
BIN
tests/output/01_sines_psdf.png
Normal file
|
After Width: | Height: | Size: 1.0 KiB |
BIN
tests/output/02_surface_fft_level_mean.png
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
tests/output/02_surface_fft_level_none.png
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
tests/output/02_surface_fft_level_plane.png
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
tests/output/02_surface_input.png
Normal file
|
After Width: | Height: | Size: 70 KiB |
BIN
tests/output/03_checker_fft.png
Normal file
|
After Width: | Height: | Size: 8.1 KiB |
BIN
tests/output/03_checker_input.png
Normal file
|
After Width: | Height: | Size: 1.0 KiB |
BIN
tests/output/04_rings_fft.png
Normal file
|
After Width: | Height: | Size: 23 KiB |
BIN
tests/output/04_rings_input.png
Normal file
|
After Width: | Height: | Size: 144 KiB |
BIN
tests/output/05_window_blackman.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
tests/output/05_window_hamming.png
Normal file
|
After Width: | Height: | Size: 2.4 KiB |
BIN
tests/output/05_window_hann.png
Normal file
|
After Width: | Height: | Size: 2.0 KiB |
BIN
tests/output/05_window_input.png
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
tests/output/05_window_none.png
Normal file
|
After Width: | Height: | Size: 1.7 KiB |
258
tests/test_fft.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Test the FFT2D node against known inputs and Gwyddion-equivalent results.
|
||||
|
||||
Run from project root:
|
||||
python -m tests.test_fft
|
||||
"""
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
from backend.data_types import DataField
|
||||
from backend.nodes.analysis import FFT2D
|
||||
|
||||
|
||||
def make_field(data, xreal=1e-6, yreal=1e-6):
|
||||
"""Create a DataField from a 2D array."""
|
||||
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
|
||||
|
||||
|
||||
def test_dc_removal():
|
||||
"""A constant image should produce near-zero FFT after mean subtraction."""
|
||||
print("=== Test: DC removal ===")
|
||||
data = np.ones((64, 64)) * 42.0
|
||||
field = make_field(data)
|
||||
node = FFT2D()
|
||||
|
||||
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||
peak = result.data.max()
|
||||
print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}")
|
||||
assert peak < 1e-10, f"Expected ~0, got {peak}"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_single_frequency():
|
||||
"""A pure sine wave should produce two peaks at the known frequency."""
|
||||
print("=== Test: Single frequency detection ===")
|
||||
N = 128
|
||||
xreal = 1e-6 # 1 micron
|
||||
freq_cycles = 10 # 10 cycles across the image
|
||||
|
||||
x = np.linspace(0, 1, N, endpoint=False)
|
||||
data = np.sin(2 * np.pi * freq_cycles * x)[np.newaxis, :] * np.ones((N, 1))
|
||||
field = make_field(data, xreal=xreal, yreal=xreal)
|
||||
|
||||
node = FFT2D()
|
||||
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||
|
||||
# The peak should be at column offset = freq_cycles from center
|
||||
mag = result.data
|
||||
cy, cx = N // 2, N // 2 # center (DC)
|
||||
|
||||
# Find the peak location (excluding DC which should be ~0 after mean sub)
|
||||
mag_copy = mag.copy()
|
||||
mag_copy[cy, cx] = 0
|
||||
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
|
||||
peak_col_offset = abs(peak_idx[1] - cx)
|
||||
|
||||
print(f" Image: {N}x{N}, {freq_cycles} horizontal cycles")
|
||||
print(f" Expected peak at column offset {freq_cycles} from center")
|
||||
print(f" Found peak at {peak_idx} (offset {peak_col_offset})")
|
||||
print(f" DC value: {mag[cy, cx]:.2e}")
|
||||
print(f" Peak value: {mag[peak_idx]:.2e}")
|
||||
assert peak_col_offset == freq_cycles, f"Expected offset {freq_cycles}, got {peak_col_offset}"
|
||||
assert peak_idx[0] == cy, f"Expected peak on center row, got row {peak_idx[0]}"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_2d_frequency():
|
||||
"""A 2D sine should produce peaks at the correct (kx, ky) position."""
|
||||
print("=== Test: 2D frequency detection ===")
|
||||
N = 128
|
||||
fx, fy = 8, 5 # cycles in x and y
|
||||
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
data = np.sin(2 * np.pi * (fx * x + fy * y))
|
||||
field = make_field(data)
|
||||
|
||||
node = FFT2D()
|
||||
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||
mag = result.data
|
||||
|
||||
cy, cx = N // 2, N // 2
|
||||
mag_copy = mag.copy()
|
||||
mag_copy[cy, cx] = 0
|
||||
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
|
||||
|
||||
dx = abs(peak_idx[1] - cx)
|
||||
dy = abs(peak_idx[0] - cy)
|
||||
|
||||
print(f" Input: sin(2π({fx}x + {fy}y))")
|
||||
print(f" Expected peak offset: ({fy}, {fx}) from center")
|
||||
print(f" Found peak at {peak_idx} (offset dy={dy}, dx={dx})")
|
||||
assert dx == fx and dy == fy, f"Expected ({fy},{fx}), got ({dy},{dx})"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_psdf_normalization():
|
||||
"""
|
||||
PSDF of white noise should integrate to the variance.
|
||||
|
||||
Parseval's theorem: sum of PSDF * dk_x * dk_y ≈ variance of the signal.
|
||||
"""
|
||||
print("=== Test: PSDF normalization (Parseval) ===")
|
||||
N = 256
|
||||
xreal = 1e-6
|
||||
rng = np.random.default_rng(42)
|
||||
data = rng.standard_normal((N, N))
|
||||
variance = data.var()
|
||||
|
||||
field = make_field(data, xreal=xreal, yreal=xreal)
|
||||
node = FFT2D()
|
||||
|
||||
result, = node.process(field, windowing="none", level="none", output="psdf")
|
||||
psdf = result.data
|
||||
|
||||
# Integrate: sum of PSDF * dk_x * dk_y
|
||||
# Our output field has xreal = 2π*N/xreal (angular freq range)
|
||||
dk_x = result.xreal / N
|
||||
dk_y = result.yreal / N
|
||||
integral = psdf.sum() * dk_x * dk_y
|
||||
|
||||
ratio = integral / variance
|
||||
print(f" Signal variance: {variance:.6f}")
|
||||
print(f" PSDF integral: {integral:.6f}")
|
||||
print(f" Ratio (should be ~1.0): {ratio:.4f}")
|
||||
# Allow 20% tolerance for finite-size effects
|
||||
assert 0.8 < ratio < 1.2, f"Parseval's theorem violated: ratio = {ratio}"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_windowing_reduces_leakage():
|
||||
"""Windowing should reduce spectral leakage from a non-integer frequency."""
|
||||
print("=== Test: Windowing reduces leakage ===")
|
||||
N = 128
|
||||
freq = 10.5 # non-integer → spectral leakage without windowing
|
||||
|
||||
x = np.linspace(0, 1, N, endpoint=False)
|
||||
data = np.sin(2 * np.pi * freq * x)[np.newaxis, :] * np.ones((N, 1))
|
||||
field = make_field(data)
|
||||
|
||||
node = FFT2D()
|
||||
|
||||
# Without windowing
|
||||
r_none, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||
mag_none = r_none.data[N // 2, :] # center row
|
||||
|
||||
# With Hann windowing
|
||||
r_hann, = node.process(field, windowing="hann", level="mean", output="magnitude")
|
||||
mag_hann = r_hann.data[N // 2, :]
|
||||
|
||||
# Measure leakage: ratio of energy far from peak vs total
|
||||
peak_col = np.argmax(mag_none)
|
||||
far_mask = np.ones(N, dtype=bool)
|
||||
far_mask[max(0, peak_col - 3):peak_col + 4] = False
|
||||
# Also mask the symmetric peak
|
||||
sym_col = N - peak_col
|
||||
far_mask[max(0, sym_col - 3):sym_col + 4] = False
|
||||
|
||||
leakage_none = mag_none[far_mask].sum() / mag_none.sum()
|
||||
leakage_hann = mag_hann[far_mask].sum() / mag_hann.sum()
|
||||
|
||||
print(f" Non-integer frequency: {freq}")
|
||||
print(f" Leakage without windowing: {leakage_none:.4f}")
|
||||
print(f" Leakage with Hann window: {leakage_hann:.4f}")
|
||||
assert leakage_hann < leakage_none, "Hann window should reduce leakage"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_plane_subtraction():
|
||||
"""Plane subtraction should remove linear gradients."""
|
||||
print("=== Test: Plane subtraction ===")
|
||||
N = 64
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
# Tilted plane + sine wave
|
||||
data = 100 * x + 50 * y + np.sin(2 * np.pi * 8 * x)
|
||||
field = make_field(data)
|
||||
|
||||
node = FFT2D()
|
||||
|
||||
# Without leveling — huge DC and low-freq energy
|
||||
r_none, = node.process(field, windowing="none", level="none", output="magnitude")
|
||||
dc_none = r_none.data[N // 2, N // 2]
|
||||
|
||||
# With mean subtraction — DC removed but gradient leaks
|
||||
r_mean, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||
dc_mean = r_mean.data[N // 2, N // 2]
|
||||
|
||||
# With plane subtraction — gradient removed
|
||||
r_plane, = node.process(field, windowing="none", level="plane", output="magnitude")
|
||||
dc_plane = r_plane.data[N // 2, N // 2]
|
||||
|
||||
# With plane subtraction, check the low-freq energy near DC is reduced
|
||||
# (plane subtraction removes gradients that leak into low frequencies)
|
||||
r = 3 # radius around DC to check
|
||||
cy, cx = N // 2, N // 2
|
||||
lowfreq_none = r_none.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
|
||||
lowfreq_plane = r_plane.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
|
||||
|
||||
print(f" DC magnitude (no leveling): {dc_none:.2e}")
|
||||
print(f" DC magnitude (mean subtract): {dc_mean:.2e}")
|
||||
print(f" DC magnitude (plane subtract): {dc_plane:.2e}")
|
||||
print(f" Low-freq energy (no level): {lowfreq_none:.2e}")
|
||||
print(f" Low-freq energy (plane sub): {lowfreq_plane:.2e}")
|
||||
assert dc_mean < dc_none, "Mean subtraction should reduce DC"
|
||||
assert lowfreq_plane < lowfreq_none * 0.01, "Plane subtraction should reduce low-freq energy"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_non_square():
|
||||
"""FFT should work on non-square, non-power-of-2 images."""
|
||||
print("=== Test: Non-square image ===")
|
||||
data = np.random.default_rng(99).standard_normal((100, 150))
|
||||
field = make_field(data, xreal=1.5e-6, yreal=1.0e-6)
|
||||
node = FFT2D()
|
||||
|
||||
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
||||
assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}"
|
||||
assert np.all(np.isfinite(result.data)), "Non-finite values in output"
|
||||
print(f" Shape: {result.data.shape}")
|
||||
print(f" Output range: [{result.data.min():.4f}, {result.data.max():.4f}]")
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_log_magnitude_visual_range():
|
||||
"""Log magnitude should produce a reasonable dynamic range for display."""
|
||||
print("=== Test: Log magnitude visual range ===")
|
||||
N = 128
|
||||
x = np.linspace(0, 1, N, endpoint=False)
|
||||
# Multi-frequency test image
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
data = (np.sin(2 * np.pi * 5 * x) +
|
||||
0.5 * np.sin(2 * np.pi * 15 * x + 2 * np.pi * 10 * y) +
|
||||
0.1 * np.random.default_rng(7).standard_normal((N, N)))
|
||||
field = make_field(data)
|
||||
|
||||
node = FFT2D()
|
||||
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
||||
|
||||
vmin, vmax = result.data.min(), result.data.max()
|
||||
dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30)
|
||||
|
||||
print(f" Log magnitude range: [{vmin:.4f}, {vmax:.4f}]")
|
||||
print(f" Dynamic range: {dynamic_range:.2f}")
|
||||
assert vmax > vmin, "Log magnitude should have nonzero range"
|
||||
assert np.all(np.isfinite(result.data)), "Non-finite values in log magnitude"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dc_removal()
|
||||
test_single_frequency()
|
||||
test_2d_frequency()
|
||||
test_psdf_normalization()
|
||||
test_windowing_reduces_leakage()
|
||||
test_plane_subtraction()
|
||||
test_non_square()
|
||||
test_log_magnitude_visual_range()
|
||||
print("All tests passed!")
|
||||
97
tests/test_fft_visual.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Generate test images and their FFT outputs for visual comparison with Gwyddion.
|
||||
Saves PNG files to tests/output/.
|
||||
|
||||
Run: .venv/bin/python -m tests.test_fft_visual
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
from backend.data_types import DataField, datafield_to_uint8, encode_preview
|
||||
from backend.nodes.analysis import FFT2D
|
||||
|
||||
OUT_DIR = os.path.join(os.path.dirname(__file__), "output")
|
||||
os.makedirs(OUT_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def save_field(field, name, colormap="viridis"):
|
||||
"""Save a DataField as a PNG for visual inspection."""
|
||||
from PIL import Image
|
||||
arr = datafield_to_uint8(field, colormap)
|
||||
img = Image.fromarray(arr)
|
||||
path = os.path.join(OUT_DIR, f"{name}.png")
|
||||
img.save(path)
|
||||
print(f" Saved {path} (range: [{field.data.min():.4g}, {field.data.max():.4g}])")
|
||||
|
||||
|
||||
def make_field(data, xreal=1e-6, yreal=1e-6):
|
||||
return DataField(data=data, xreal=xreal, yreal=yreal)
|
||||
|
||||
|
||||
def main():
|
||||
node = FFT2D()
|
||||
N = 256
|
||||
|
||||
# --- Test 1: Multi-frequency sine waves ---
|
||||
print("Test 1: Multi-frequency sine waves")
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
data = (np.sin(2 * np.pi * 10 * x)
|
||||
+ 0.7 * np.sin(2 * np.pi * 25 * y)
|
||||
+ 0.3 * np.sin(2 * np.pi * (15 * x + 8 * y)))
|
||||
field = make_field(data)
|
||||
save_field(field, "01_sines_input")
|
||||
|
||||
for output_mode in ["log_magnitude", "magnitude", "psdf"]:
|
||||
result, = node.process(field, windowing="hann", level="mean", output=output_mode)
|
||||
save_field(result, f"01_sines_{output_mode}")
|
||||
|
||||
# --- Test 2: Real-world-like surface with noise + tilt ---
|
||||
print("\nTest 2: Tilted surface with features")
|
||||
rng = np.random.default_rng(42)
|
||||
data = (50 * x + 30 * y # tilt
|
||||
+ np.sin(2 * np.pi * 20 * x) # periodic feature
|
||||
+ 0.5 * rng.standard_normal((N, N))) # noise
|
||||
field = make_field(data)
|
||||
save_field(field, "02_surface_input")
|
||||
|
||||
for level_mode in ["none", "mean", "plane"]:
|
||||
result, = node.process(field, windowing="hann", level=level_mode, output="log_magnitude")
|
||||
save_field(result, f"02_surface_fft_level_{level_mode}")
|
||||
|
||||
# --- Test 3: Checkerboard pattern ---
|
||||
print("\nTest 3: Checkerboard")
|
||||
freq = 16
|
||||
data = np.sign(np.sin(2 * np.pi * freq * x) * np.sin(2 * np.pi * freq * y))
|
||||
field = make_field(data)
|
||||
save_field(field, "03_checker_input")
|
||||
|
||||
result, = node.process(field, windowing="none", level="mean", output="log_magnitude")
|
||||
save_field(result, "03_checker_fft")
|
||||
|
||||
# --- Test 4: Concentric rings (radial frequency) ---
|
||||
print("\nTest 4: Concentric rings")
|
||||
r = np.sqrt((x - 0.5)**2 + (y - 0.5)**2)
|
||||
data = np.sin(2 * np.pi * 30 * r)
|
||||
field = make_field(data)
|
||||
save_field(field, "04_rings_input")
|
||||
|
||||
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
||||
save_field(result, "04_rings_fft")
|
||||
|
||||
# --- Test 5: Compare windowing effects ---
|
||||
print("\nTest 5: Windowing comparison")
|
||||
data = np.sin(2 * np.pi * 10.5 * x) + 0.5 * np.sin(2 * np.pi * 30.3 * y)
|
||||
field = make_field(data)
|
||||
save_field(field, "05_window_input")
|
||||
|
||||
for win in ["none", "hann", "hamming", "blackman"]:
|
||||
result, = node.process(field, windowing=win, level="mean", output="log_magnitude")
|
||||
save_field(result, f"05_window_{win}")
|
||||
|
||||
print(f"\nAll outputs saved to {OUT_DIR}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
488
tests/test_nodes.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
Tests for all argonode backend nodes (excluding FFT2D which has its own test file).
|
||||
|
||||
Run from project root:
|
||||
.venv/bin/python -m tests.test_nodes
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, ".")
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
|
||||
"""Create a DataField, optionally from given data or a random field."""
|
||||
if data is None:
|
||||
data = np.random.default_rng(42).standard_normal(shape)
|
||||
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Filters
|
||||
# =========================================================================
|
||||
|
||||
def test_gaussian_filter():
|
||||
print("=== Test: GaussianFilter ===")
|
||||
from backend.nodes.filters import GaussianFilter
|
||||
node = GaussianFilter()
|
||||
field = make_field()
|
||||
|
||||
result, = node.process(field, sigma=2.0)
|
||||
assert result.data.shape == field.data.shape
|
||||
assert result.xreal == field.xreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
# Gaussian blur should reduce variance
|
||||
assert result.data.std() < field.data.std()
|
||||
# With very small sigma, output should be nearly unchanged
|
||||
result_tiny, = node.process(field, sigma=0.01)
|
||||
assert np.allclose(result_tiny.data, field.data, atol=1e-6)
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_median_filter():
|
||||
print("=== Test: MedianFilter ===")
|
||||
from backend.nodes.filters import MedianFilter
|
||||
node = MedianFilter()
|
||||
|
||||
# Median filter should remove salt-and-pepper noise
|
||||
data = np.zeros((64, 64))
|
||||
rng = np.random.default_rng(7)
|
||||
noise_idx = rng.choice(64 * 64, size=100, replace=False)
|
||||
data.ravel()[noise_idx] = 1.0
|
||||
field = make_field(data=data)
|
||||
|
||||
result, = node.process(field, size=3)
|
||||
assert result.data.shape == field.data.shape
|
||||
# Should remove most impulse noise
|
||||
assert result.data.sum() < field.data.sum()
|
||||
# Size=1 should be identity
|
||||
result_1, = node.process(field, size=1)
|
||||
assert np.array_equal(result_1.data, field.data)
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_edge_detect():
|
||||
print("=== Test: EdgeDetect ===")
|
||||
from backend.nodes.filters import EdgeDetect
|
||||
node = EdgeDetect()
|
||||
|
||||
# Create an image with a sharp vertical edge
|
||||
data = np.zeros((64, 64))
|
||||
data[:, 32:] = 1.0
|
||||
field = make_field(data=data)
|
||||
|
||||
for method in ["sobel", "prewitt", "laplacian", "log"]:
|
||||
result, = node.process(field, method=method, sigma=1.0)
|
||||
assert result.data.shape == field.data.shape
|
||||
# Edge response should be strongest near column 32
|
||||
col_energy = np.abs(result.data).sum(axis=0)
|
||||
peak_col = np.argmax(col_energy)
|
||||
assert abs(peak_col - 32) <= 2, f"{method}: peak at col {peak_col}, expected ~32"
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Level
|
||||
# =========================================================================
|
||||
|
||||
def test_plane_level():
|
||||
print("=== Test: PlaneLevelField ===")
|
||||
from backend.nodes.level import PlaneLevelField
|
||||
node = PlaneLevelField()
|
||||
|
||||
# Create a tilted plane + small signal
|
||||
N = 64
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
signal = np.sin(2 * np.pi * 5 * x)
|
||||
data = 100 * x + 50 * y + signal
|
||||
field = make_field(data=data)
|
||||
|
||||
result, = node.process(field)
|
||||
assert result.data.shape == field.data.shape
|
||||
# After plane leveling, mean should be near zero
|
||||
assert abs(result.data.mean()) < 1e-10
|
||||
# The signal should remain (correlation with original sine)
|
||||
corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1]
|
||||
assert corr > 0.98, f"Signal correlation after leveling: {corr}"
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_poly_level():
|
||||
print("=== Test: PolyLevelField ===")
|
||||
from backend.nodes.level import PolyLevelField
|
||||
node = PolyLevelField()
|
||||
|
||||
N = 64
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
# Quadratic background + signal
|
||||
background = 50 * x**2 + 30 * y**2 + 10 * x * y
|
||||
signal = np.sin(2 * np.pi * 8 * x)
|
||||
data = background + signal
|
||||
field = make_field(data=data)
|
||||
|
||||
leveled, bg = node.process(field, degree_x=2, degree_y=2)
|
||||
assert leveled.data.shape == field.data.shape
|
||||
assert bg.data.shape == field.data.shape
|
||||
# leveled + bg should reconstruct original
|
||||
assert np.allclose(leveled.data + bg.data, field.data, atol=1e-10)
|
||||
# Signal should be preserved after leveling
|
||||
corr = np.corrcoef(leveled.data.ravel(), signal.ravel())[0, 1]
|
||||
assert corr > 0.95, f"Signal correlation after poly leveling: {corr}"
|
||||
# Degree 0 should just subtract the mean
|
||||
leveled_0, bg_0 = node.process(field, degree_x=0, degree_y=0)
|
||||
assert abs(leveled_0.data.mean()) < 1e-10
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_fix_zero():
|
||||
print("=== Test: FixZero ===")
|
||||
from backend.nodes.level import FixZero
|
||||
node = FixZero()
|
||||
field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64))
|
||||
|
||||
result_min, = node.process(field, method="min")
|
||||
assert result_min.data.min() == 0.0
|
||||
assert result_min.data.max() == 30.0
|
||||
|
||||
result_mean, = node.process(field, method="mean")
|
||||
assert abs(result_mean.data.mean()) < 1e-10
|
||||
|
||||
result_median, = node.process(field, method="median")
|
||||
assert abs(np.median(result_median.data)) < 1e-10
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis (non-FFT)
|
||||
# =========================================================================
|
||||
|
||||
def test_statistics():
|
||||
print("=== Test: StatisticsNode ===")
|
||||
from backend.nodes.analysis import StatisticsNode
|
||||
node = StatisticsNode()
|
||||
|
||||
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
|
||||
field = make_field(data=data)
|
||||
|
||||
table, = node.process(field)
|
||||
stats = {row["quantity"]: row["value"] for row in table}
|
||||
|
||||
assert stats["min"] == 1.0
|
||||
assert stats["max"] == 4.0
|
||||
assert stats["mean"] == 2.5
|
||||
assert stats["median"] == 2.5
|
||||
assert stats["range"] == 3.0
|
||||
# RMS = sqrt(mean((x - mean)^2))
|
||||
expected_rms = np.sqrt(np.mean((data - 2.5) ** 2))
|
||||
assert abs(stats["RMS"] - expected_rms) < 1e-10
|
||||
|
||||
# Constant data should have RMS=0, skewness=0, kurtosis=0
|
||||
const_field = make_field(data=np.ones((4, 4)) * 5.0)
|
||||
table_const, = node.process(const_field)
|
||||
const_stats = {row["quantity"]: row["value"] for row in table_const}
|
||||
assert const_stats["RMS"] == 0.0
|
||||
assert const_stats["skewness"] == 0.0
|
||||
assert const_stats["kurtosis"] == 0.0
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_height_histogram():
|
||||
print("=== Test: HeightHistogram ===")
|
||||
from backend.nodes.analysis import HeightHistogram
|
||||
node = HeightHistogram()
|
||||
|
||||
# Uniform data should give a roughly flat histogram
|
||||
data = np.linspace(0, 1, 1000).reshape(25, 40)
|
||||
field = make_field(data=data)
|
||||
|
||||
counts, bin_centers = node.process(field, n_bins=10)
|
||||
assert len(counts) == 10
|
||||
assert len(bin_centers) == 10
|
||||
assert counts.dtype == np.float64
|
||||
# Total counts should equal number of pixels
|
||||
assert counts.sum() == 1000
|
||||
# For uniform data, each bin should have ~100 counts
|
||||
assert np.std(counts) < 10, f"Histogram not flat enough: std={np.std(counts)}"
|
||||
# Bin centers should span the data range
|
||||
assert bin_centers[0] > 0.0
|
||||
assert bin_centers[-1] < 1.0
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_cross_section():
|
||||
print("=== Test: CrossSection ===")
|
||||
from backend.nodes.analysis import CrossSection
|
||||
node = CrossSection()
|
||||
|
||||
# Create a field with a known horizontal gradient
|
||||
N = 100
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
data = x * 10.0 # value = 10 * x_fraction
|
||||
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
||||
|
||||
# Horizontal cross section at y=0.5
|
||||
(profile,) = node.process(
|
||||
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
|
||||
extend="none", n_samples=100,
|
||||
)
|
||||
assert len(profile) == 100
|
||||
# Profile should be a linear ramp from ~0 to ~10
|
||||
assert profile[0] < 0.5, f"Start of profile: {profile[0]}"
|
||||
assert profile[-1] > 9.5, f"End of profile: {profile[-1]}"
|
||||
|
||||
# n_samples=0 should auto-calculate
|
||||
(profile_auto,) = node.process(
|
||||
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
|
||||
extend="none", n_samples=0,
|
||||
)
|
||||
assert len(profile_auto) >= 2
|
||||
|
||||
# Test extend to edges — a short segment should be extended
|
||||
(profile_ext,) = node.process(
|
||||
field, x1=0.3, y1=0.5, x2=0.7, y2=0.5,
|
||||
extend="to_edges", n_samples=100,
|
||||
)
|
||||
# Extended profile should start near 0 and end near 10
|
||||
assert profile_ext[0] < 0.5
|
||||
assert profile_ext[-1] > 9.5
|
||||
|
||||
# Diagonal cross section
|
||||
(profile_diag,) = node.process(
|
||||
field, x1=0.0, y1=0.0, x2=1.0, y2=1.0,
|
||||
extend="none", n_samples=50,
|
||||
)
|
||||
assert len(profile_diag) == 50
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Grains
|
||||
# =========================================================================
|
||||
|
||||
def test_threshold_mask():
|
||||
print("=== Test: ThresholdMask ===")
|
||||
from backend.nodes.grains import ThresholdMask
|
||||
node = ThresholdMask()
|
||||
|
||||
# Clear bimodal data: left half = 0, right half = 1
|
||||
data = np.zeros((64, 64))
|
||||
data[:, 32:] = 1.0
|
||||
field = make_field(data=data)
|
||||
|
||||
# Absolute threshold at 0.5
|
||||
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
|
||||
assert mask.dtype == np.uint8
|
||||
assert mask.shape == (64, 64)
|
||||
assert np.all(mask[:, :32] == 0)
|
||||
assert np.all(mask[:, 32:] == 255)
|
||||
|
||||
# Direction "below"
|
||||
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
|
||||
assert np.all(mask_below[:, :32] == 255)
|
||||
assert np.all(mask_below[:, 32:] == 0)
|
||||
|
||||
# Relative threshold at 0.5 (midpoint of range)
|
||||
mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above")
|
||||
assert np.all(mask_rel[:, 32:] == 255)
|
||||
|
||||
# Otsu should find the bimodal threshold
|
||||
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_grain_analysis():
|
||||
print("=== Test: GrainAnalysis ===")
|
||||
from backend.nodes.grains import GrainAnalysis
|
||||
node = GrainAnalysis()
|
||||
|
||||
# Create a field with two distinct "grains"
|
||||
N = 64
|
||||
data = np.zeros((N, N))
|
||||
# Grain 1: 10x10 block at top-left with height 5
|
||||
data[5:15, 5:15] = 5.0
|
||||
# Grain 2: 8x8 block at bottom-right with height 3
|
||||
data[45:53, 45:53] = 3.0
|
||||
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
||||
|
||||
# Create matching mask
|
||||
mask = np.zeros((N, N), dtype=np.uint8)
|
||||
mask[5:15, 5:15] = 255
|
||||
mask[45:53, 45:53] = 255
|
||||
|
||||
table, = node.process(field, mask=mask, min_size=10)
|
||||
assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
|
||||
|
||||
# Sort by area descending
|
||||
table.sort(key=lambda r: r["area_px"], reverse=True)
|
||||
assert table[0]["area_px"] == 100 # 10x10
|
||||
assert table[1]["area_px"] == 64 # 8x8
|
||||
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
|
||||
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
|
||||
|
||||
# min_size filtering: only keep grains >= 80 px
|
||||
table_filtered, = node.process(field, mask=mask, min_size=80)
|
||||
assert len(table_filtered) == 1
|
||||
assert table_filtered[0]["area_px"] == 100
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O
|
||||
# =========================================================================
|
||||
|
||||
def test_load_image():
|
||||
print("=== Test: LoadImage ===")
|
||||
from backend.nodes.io import LoadImage
|
||||
from PIL import Image
|
||||
node = LoadImage()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Test loading a grayscale PNG
|
||||
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
|
||||
img = Image.fromarray(arr, mode="L")
|
||||
path = os.path.join(tmpdir, "test_gray.png")
|
||||
img.save(path)
|
||||
|
||||
image, field = node.load(filename=path)
|
||||
assert image.shape == (48, 64)
|
||||
assert field.data.shape == (48, 64)
|
||||
assert field.data.dtype == np.float64
|
||||
|
||||
# Test loading an RGB PNG (should average to grayscale for field)
|
||||
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
|
||||
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
|
||||
path_rgb = os.path.join(tmpdir, "test_rgb.png")
|
||||
img_rgb.save(path_rgb)
|
||||
|
||||
image_rgb, field_rgb = node.load(filename=path_rgb)
|
||||
assert image_rgb.shape == (32, 32, 3)
|
||||
assert field_rgb.data.shape == (32, 32)
|
||||
|
||||
# Test loading a .npy file
|
||||
data_npy = np.random.default_rng(3).standard_normal((50, 60))
|
||||
path_npy = os.path.join(tmpdir, "test.npy")
|
||||
np.save(path_npy, data_npy)
|
||||
|
||||
image_npy, field_npy = node.load(filename=path_npy)
|
||||
assert np.allclose(field_npy.data, data_npy)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_save_image():
|
||||
print("=== Test: SaveImage ===")
|
||||
from backend.nodes.io import SaveImage
|
||||
node = SaveImage()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Monkey-patch OUTPUT_DIR for testing
|
||||
from pathlib import Path
|
||||
import backend.nodes.io as io_mod
|
||||
orig_dir = io_mod.OUTPUT_DIR
|
||||
io_mod.OUTPUT_DIR = Path(tmpdir)
|
||||
|
||||
try:
|
||||
arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8)
|
||||
|
||||
# Save as PNG
|
||||
node.save(image=arr, filename_prefix="test", format="PNG")
|
||||
saved = os.listdir(tmpdir)
|
||||
assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}"
|
||||
|
||||
# Save as NPY
|
||||
node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY")
|
||||
saved = os.listdir(tmpdir)
|
||||
assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}"
|
||||
finally:
|
||||
io_mod.OUTPUT_DIR = orig_dir
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Display (limited testing — these are output nodes with WS callbacks)
|
||||
# =========================================================================
|
||||
|
||||
def test_preview_image():
|
||||
print("=== Test: PreviewImage ===")
|
||||
from backend.nodes.display import PreviewImage
|
||||
node = PreviewImage()
|
||||
|
||||
# Set up a capture for the broadcast
|
||||
captured = []
|
||||
PreviewImage._broadcast_fn = lambda node_id, data_uri: captured.append(data_uri)
|
||||
PreviewImage._current_node_id = "test"
|
||||
|
||||
# Preview with a DataField
|
||||
field = make_field()
|
||||
node.preview(colormap="viridis", field=field)
|
||||
assert len(captured) == 1
|
||||
assert captured[0].startswith("data:image/png;base64,")
|
||||
|
||||
# Preview with an IMAGE array
|
||||
captured.clear()
|
||||
arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8)
|
||||
node.preview(colormap="gray", image=arr)
|
||||
assert len(captured) == 1
|
||||
|
||||
# Clean up
|
||||
PreviewImage._broadcast_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_print_table():
|
||||
print("=== Test: PrintTable ===")
|
||||
from backend.nodes.display import PrintTable
|
||||
node = PrintTable()
|
||||
|
||||
captured = []
|
||||
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
|
||||
PrintTable._current_node_id = "test"
|
||||
|
||||
table = [{"quantity": "test", "value": 42.0, "unit": "m"}]
|
||||
node.print_table(table=table)
|
||||
assert len(captured) == 1
|
||||
assert captured[0] == table
|
||||
|
||||
PrintTable._broadcast_table_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Run all tests
|
||||
# =========================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Filters
|
||||
test_gaussian_filter()
|
||||
test_median_filter()
|
||||
test_edge_detect()
|
||||
|
||||
# Level
|
||||
test_plane_level()
|
||||
test_poly_level()
|
||||
test_fix_zero()
|
||||
|
||||
# Analysis
|
||||
test_statistics()
|
||||
test_height_histogram()
|
||||
test_cross_section()
|
||||
|
||||
# Grains
|
||||
test_threshold_mask()
|
||||
test_grain_analysis()
|
||||
|
||||
# I/O
|
||||
test_load_image()
|
||||
test_save_image()
|
||||
|
||||
# Display
|
||||
test_preview_image()
|
||||
test_print_table()
|
||||
|
||||
print("All tests passed!")
|
||||