initial commit

This commit is contained in:
2026-03-23 00:35:30 -07:00
parent 5ecc913e28
commit 87b6905fba
48 changed files with 7012 additions and 1 deletions

0
tests/__init__.py Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 981 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

258
tests/test_fft.py Normal file
View File

@@ -0,0 +1,258 @@
"""
Test the FFT2D node against known inputs and Gwyddion-equivalent results.
Run from project root:
python -m tests.test_fft
"""
import sys
import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField
from backend.nodes.analysis import FFT2D
def make_field(data, xreal=1e-6, yreal=1e-6):
"""Create a DataField from a 2D array."""
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
def test_dc_removal():
"""A constant image should produce near-zero FFT after mean subtraction."""
print("=== Test: DC removal ===")
data = np.ones((64, 64)) * 42.0
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
peak = result.data.max()
print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}")
assert peak < 1e-10, f"Expected ~0, got {peak}"
print(" PASS\n")
def test_single_frequency():
"""A pure sine wave should produce two peaks at the known frequency."""
print("=== Test: Single frequency detection ===")
N = 128
xreal = 1e-6 # 1 micron
freq_cycles = 10 # 10 cycles across the image
x = np.linspace(0, 1, N, endpoint=False)
data = np.sin(2 * np.pi * freq_cycles * x)[np.newaxis, :] * np.ones((N, 1))
field = make_field(data, xreal=xreal, yreal=xreal)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
# The peak should be at column offset = freq_cycles from center
mag = result.data
cy, cx = N // 2, N // 2 # center (DC)
# Find the peak location (excluding DC which should be ~0 after mean sub)
mag_copy = mag.copy()
mag_copy[cy, cx] = 0
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
peak_col_offset = abs(peak_idx[1] - cx)
print(f" Image: {N}x{N}, {freq_cycles} horizontal cycles")
print(f" Expected peak at column offset {freq_cycles} from center")
print(f" Found peak at {peak_idx} (offset {peak_col_offset})")
print(f" DC value: {mag[cy, cx]:.2e}")
print(f" Peak value: {mag[peak_idx]:.2e}")
assert peak_col_offset == freq_cycles, f"Expected offset {freq_cycles}, got {peak_col_offset}"
assert peak_idx[0] == cy, f"Expected peak on center row, got row {peak_idx[0]}"
print(" PASS\n")
def test_2d_frequency():
"""A 2D sine should produce peaks at the correct (kx, ky) position."""
print("=== Test: 2D frequency detection ===")
N = 128
fx, fy = 8, 5 # cycles in x and y
y, x = np.mgrid[0:N, 0:N] / N
data = np.sin(2 * np.pi * (fx * x + fy * y))
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
mag = result.data
cy, cx = N // 2, N // 2
mag_copy = mag.copy()
mag_copy[cy, cx] = 0
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
dx = abs(peak_idx[1] - cx)
dy = abs(peak_idx[0] - cy)
print(f" Input: sin(2π({fx}x + {fy}y))")
print(f" Expected peak offset: ({fy}, {fx}) from center")
print(f" Found peak at {peak_idx} (offset dy={dy}, dx={dx})")
assert dx == fx and dy == fy, f"Expected ({fy},{fx}), got ({dy},{dx})"
print(" PASS\n")
def test_psdf_normalization():
"""
PSDF of white noise should integrate to the variance.
Parseval's theorem: sum of PSDF * dk_x * dk_y ≈ variance of the signal.
"""
print("=== Test: PSDF normalization (Parseval) ===")
N = 256
xreal = 1e-6
rng = np.random.default_rng(42)
data = rng.standard_normal((N, N))
variance = data.var()
field = make_field(data, xreal=xreal, yreal=xreal)
node = FFT2D()
result, = node.process(field, windowing="none", level="none", output="psdf")
psdf = result.data
# Integrate: sum of PSDF * dk_x * dk_y
# Our output field has xreal = 2π*N/xreal (angular freq range)
dk_x = result.xreal / N
dk_y = result.yreal / N
integral = psdf.sum() * dk_x * dk_y
ratio = integral / variance
print(f" Signal variance: {variance:.6f}")
print(f" PSDF integral: {integral:.6f}")
print(f" Ratio (should be ~1.0): {ratio:.4f}")
# Allow 20% tolerance for finite-size effects
assert 0.8 < ratio < 1.2, f"Parseval's theorem violated: ratio = {ratio}"
print(" PASS\n")
def test_windowing_reduces_leakage():
"""Windowing should reduce spectral leakage from a non-integer frequency."""
print("=== Test: Windowing reduces leakage ===")
N = 128
freq = 10.5 # non-integer → spectral leakage without windowing
x = np.linspace(0, 1, N, endpoint=False)
data = np.sin(2 * np.pi * freq * x)[np.newaxis, :] * np.ones((N, 1))
field = make_field(data)
node = FFT2D()
# Without windowing
r_none, = node.process(field, windowing="none", level="mean", output="magnitude")
mag_none = r_none.data[N // 2, :] # center row
# With Hann windowing
r_hann, = node.process(field, windowing="hann", level="mean", output="magnitude")
mag_hann = r_hann.data[N // 2, :]
# Measure leakage: ratio of energy far from peak vs total
peak_col = np.argmax(mag_none)
far_mask = np.ones(N, dtype=bool)
far_mask[max(0, peak_col - 3):peak_col + 4] = False
# Also mask the symmetric peak
sym_col = N - peak_col
far_mask[max(0, sym_col - 3):sym_col + 4] = False
leakage_none = mag_none[far_mask].sum() / mag_none.sum()
leakage_hann = mag_hann[far_mask].sum() / mag_hann.sum()
print(f" Non-integer frequency: {freq}")
print(f" Leakage without windowing: {leakage_none:.4f}")
print(f" Leakage with Hann window: {leakage_hann:.4f}")
assert leakage_hann < leakage_none, "Hann window should reduce leakage"
print(" PASS\n")
def test_plane_subtraction():
"""Plane subtraction should remove linear gradients."""
print("=== Test: Plane subtraction ===")
N = 64
y, x = np.mgrid[0:N, 0:N] / N
# Tilted plane + sine wave
data = 100 * x + 50 * y + np.sin(2 * np.pi * 8 * x)
field = make_field(data)
node = FFT2D()
# Without leveling — huge DC and low-freq energy
r_none, = node.process(field, windowing="none", level="none", output="magnitude")
dc_none = r_none.data[N // 2, N // 2]
# With mean subtraction — DC removed but gradient leaks
r_mean, = node.process(field, windowing="none", level="mean", output="magnitude")
dc_mean = r_mean.data[N // 2, N // 2]
# With plane subtraction — gradient removed
r_plane, = node.process(field, windowing="none", level="plane", output="magnitude")
dc_plane = r_plane.data[N // 2, N // 2]
# With plane subtraction, check the low-freq energy near DC is reduced
# (plane subtraction removes gradients that leak into low frequencies)
r = 3 # radius around DC to check
cy, cx = N // 2, N // 2
lowfreq_none = r_none.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
lowfreq_plane = r_plane.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
print(f" DC magnitude (no leveling): {dc_none:.2e}")
print(f" DC magnitude (mean subtract): {dc_mean:.2e}")
print(f" DC magnitude (plane subtract): {dc_plane:.2e}")
print(f" Low-freq energy (no level): {lowfreq_none:.2e}")
print(f" Low-freq energy (plane sub): {lowfreq_plane:.2e}")
assert dc_mean < dc_none, "Mean subtraction should reduce DC"
assert lowfreq_plane < lowfreq_none * 0.01, "Plane subtraction should reduce low-freq energy"
print(" PASS\n")
def test_non_square():
"""FFT should work on non-square, non-power-of-2 images."""
print("=== Test: Non-square image ===")
data = np.random.default_rng(99).standard_normal((100, 150))
field = make_field(data, xreal=1.5e-6, yreal=1.0e-6)
node = FFT2D()
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}"
assert np.all(np.isfinite(result.data)), "Non-finite values in output"
print(f" Shape: {result.data.shape}")
print(f" Output range: [{result.data.min():.4f}, {result.data.max():.4f}]")
print(" PASS\n")
def test_log_magnitude_visual_range():
"""Log magnitude should produce a reasonable dynamic range for display."""
print("=== Test: Log magnitude visual range ===")
N = 128
x = np.linspace(0, 1, N, endpoint=False)
# Multi-frequency test image
y, x = np.mgrid[0:N, 0:N] / N
data = (np.sin(2 * np.pi * 5 * x) +
0.5 * np.sin(2 * np.pi * 15 * x + 2 * np.pi * 10 * y) +
0.1 * np.random.default_rng(7).standard_normal((N, N)))
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
vmin, vmax = result.data.min(), result.data.max()
dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30)
print(f" Log magnitude range: [{vmin:.4f}, {vmax:.4f}]")
print(f" Dynamic range: {dynamic_range:.2f}")
assert vmax > vmin, "Log magnitude should have nonzero range"
assert np.all(np.isfinite(result.data)), "Non-finite values in log magnitude"
print(" PASS\n")
if __name__ == "__main__":
test_dc_removal()
test_single_frequency()
test_2d_frequency()
test_psdf_normalization()
test_windowing_reduces_leakage()
test_plane_subtraction()
test_non_square()
test_log_magnitude_visual_range()
print("All tests passed!")

97
tests/test_fft_visual.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Generate test images and their FFT outputs for visual comparison with Gwyddion.
Saves PNG files to tests/output/.
Run: .venv/bin/python -m tests.test_fft_visual
"""
import sys
import os
import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.analysis import FFT2D
OUT_DIR = os.path.join(os.path.dirname(__file__), "output")
os.makedirs(OUT_DIR, exist_ok=True)
def save_field(field, name, colormap="viridis"):
"""Save a DataField as a PNG for visual inspection."""
from PIL import Image
arr = datafield_to_uint8(field, colormap)
img = Image.fromarray(arr)
path = os.path.join(OUT_DIR, f"{name}.png")
img.save(path)
print(f" Saved {path} (range: [{field.data.min():.4g}, {field.data.max():.4g}])")
def make_field(data, xreal=1e-6, yreal=1e-6):
return DataField(data=data, xreal=xreal, yreal=yreal)
def main():
node = FFT2D()
N = 256
# --- Test 1: Multi-frequency sine waves ---
print("Test 1: Multi-frequency sine waves")
y, x = np.mgrid[0:N, 0:N] / N
data = (np.sin(2 * np.pi * 10 * x)
+ 0.7 * np.sin(2 * np.pi * 25 * y)
+ 0.3 * np.sin(2 * np.pi * (15 * x + 8 * y)))
field = make_field(data)
save_field(field, "01_sines_input")
for output_mode in ["log_magnitude", "magnitude", "psdf"]:
result, = node.process(field, windowing="hann", level="mean", output=output_mode)
save_field(result, f"01_sines_{output_mode}")
# --- Test 2: Real-world-like surface with noise + tilt ---
print("\nTest 2: Tilted surface with features")
rng = np.random.default_rng(42)
data = (50 * x + 30 * y # tilt
+ np.sin(2 * np.pi * 20 * x) # periodic feature
+ 0.5 * rng.standard_normal((N, N))) # noise
field = make_field(data)
save_field(field, "02_surface_input")
for level_mode in ["none", "mean", "plane"]:
result, = node.process(field, windowing="hann", level=level_mode, output="log_magnitude")
save_field(result, f"02_surface_fft_level_{level_mode}")
# --- Test 3: Checkerboard pattern ---
print("\nTest 3: Checkerboard")
freq = 16
data = np.sign(np.sin(2 * np.pi * freq * x) * np.sin(2 * np.pi * freq * y))
field = make_field(data)
save_field(field, "03_checker_input")
result, = node.process(field, windowing="none", level="mean", output="log_magnitude")
save_field(result, "03_checker_fft")
# --- Test 4: Concentric rings (radial frequency) ---
print("\nTest 4: Concentric rings")
r = np.sqrt((x - 0.5)**2 + (y - 0.5)**2)
data = np.sin(2 * np.pi * 30 * r)
field = make_field(data)
save_field(field, "04_rings_input")
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
save_field(result, "04_rings_fft")
# --- Test 5: Compare windowing effects ---
print("\nTest 5: Windowing comparison")
data = np.sin(2 * np.pi * 10.5 * x) + 0.5 * np.sin(2 * np.pi * 30.3 * y)
field = make_field(data)
save_field(field, "05_window_input")
for win in ["none", "hann", "hamming", "blackman"]:
result, = node.process(field, windowing=win, level="mean", output="log_magnitude")
save_field(result, f"05_window_{win}")
print(f"\nAll outputs saved to {OUT_DIR}/")
if __name__ == "__main__":
main()

488
tests/test_nodes.py Normal file
View File

@@ -0,0 +1,488 @@
"""
Tests for all argonode backend nodes (excluding FFT2D which has its own test file).
Run from project root:
.venv/bin/python -m tests.test_nodes
"""
import sys
import os
import tempfile
import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField
def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
"""Create a DataField, optionally from given data or a random field."""
if data is None:
data = np.random.default_rng(42).standard_normal(shape)
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
# =========================================================================
# Filters
# =========================================================================
def test_gaussian_filter():
print("=== Test: GaussianFilter ===")
from backend.nodes.filters import GaussianFilter
node = GaussianFilter()
field = make_field()
result, = node.process(field, sigma=2.0)
assert result.data.shape == field.data.shape
assert result.xreal == field.xreal
assert result.si_unit_z == field.si_unit_z
# Gaussian blur should reduce variance
assert result.data.std() < field.data.std()
# With very small sigma, output should be nearly unchanged
result_tiny, = node.process(field, sigma=0.01)
assert np.allclose(result_tiny.data, field.data, atol=1e-6)
print(" PASS\n")
def test_median_filter():
print("=== Test: MedianFilter ===")
from backend.nodes.filters import MedianFilter
node = MedianFilter()
# Median filter should remove salt-and-pepper noise
data = np.zeros((64, 64))
rng = np.random.default_rng(7)
noise_idx = rng.choice(64 * 64, size=100, replace=False)
data.ravel()[noise_idx] = 1.0
field = make_field(data=data)
result, = node.process(field, size=3)
assert result.data.shape == field.data.shape
# Should remove most impulse noise
assert result.data.sum() < field.data.sum()
# Size=1 should be identity
result_1, = node.process(field, size=1)
assert np.array_equal(result_1.data, field.data)
print(" PASS\n")
def test_edge_detect():
print("=== Test: EdgeDetect ===")
from backend.nodes.filters import EdgeDetect
node = EdgeDetect()
# Create an image with a sharp vertical edge
data = np.zeros((64, 64))
data[:, 32:] = 1.0
field = make_field(data=data)
for method in ["sobel", "prewitt", "laplacian", "log"]:
result, = node.process(field, method=method, sigma=1.0)
assert result.data.shape == field.data.shape
# Edge response should be strongest near column 32
col_energy = np.abs(result.data).sum(axis=0)
peak_col = np.argmax(col_energy)
assert abs(peak_col - 32) <= 2, f"{method}: peak at col {peak_col}, expected ~32"
print(" PASS\n")
# =========================================================================
# Level
# =========================================================================
def test_plane_level():
print("=== Test: PlaneLevelField ===")
from backend.nodes.level import PlaneLevelField
node = PlaneLevelField()
# Create a tilted plane + small signal
N = 64
y, x = np.mgrid[0:N, 0:N] / N
signal = np.sin(2 * np.pi * 5 * x)
data = 100 * x + 50 * y + signal
field = make_field(data=data)
result, = node.process(field)
assert result.data.shape == field.data.shape
# After plane leveling, mean should be near zero
assert abs(result.data.mean()) < 1e-10
# The signal should remain (correlation with original sine)
corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1]
assert corr > 0.98, f"Signal correlation after leveling: {corr}"
print(" PASS\n")
def test_poly_level():
print("=== Test: PolyLevelField ===")
from backend.nodes.level import PolyLevelField
node = PolyLevelField()
N = 64
y, x = np.mgrid[0:N, 0:N] / N
# Quadratic background + signal
background = 50 * x**2 + 30 * y**2 + 10 * x * y
signal = np.sin(2 * np.pi * 8 * x)
data = background + signal
field = make_field(data=data)
leveled, bg = node.process(field, degree_x=2, degree_y=2)
assert leveled.data.shape == field.data.shape
assert bg.data.shape == field.data.shape
# leveled + bg should reconstruct original
assert np.allclose(leveled.data + bg.data, field.data, atol=1e-10)
# Signal should be preserved after leveling
corr = np.corrcoef(leveled.data.ravel(), signal.ravel())[0, 1]
assert corr > 0.95, f"Signal correlation after poly leveling: {corr}"
# Degree 0 should just subtract the mean
leveled_0, bg_0 = node.process(field, degree_x=0, degree_y=0)
assert abs(leveled_0.data.mean()) < 1e-10
print(" PASS\n")
def test_fix_zero():
print("=== Test: FixZero ===")
from backend.nodes.level import FixZero
node = FixZero()
field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64))
result_min, = node.process(field, method="min")
assert result_min.data.min() == 0.0
assert result_min.data.max() == 30.0
result_mean, = node.process(field, method="mean")
assert abs(result_mean.data.mean()) < 1e-10
result_median, = node.process(field, method="median")
assert abs(np.median(result_median.data)) < 1e-10
print(" PASS\n")
# =========================================================================
# Analysis (non-FFT)
# =========================================================================
def test_statistics():
print("=== Test: StatisticsNode ===")
from backend.nodes.analysis import StatisticsNode
node = StatisticsNode()
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
field = make_field(data=data)
table, = node.process(field)
stats = {row["quantity"]: row["value"] for row in table}
assert stats["min"] == 1.0
assert stats["max"] == 4.0
assert stats["mean"] == 2.5
assert stats["median"] == 2.5
assert stats["range"] == 3.0
# RMS = sqrt(mean((x - mean)^2))
expected_rms = np.sqrt(np.mean((data - 2.5) ** 2))
assert abs(stats["RMS"] - expected_rms) < 1e-10
# Constant data should have RMS=0, skewness=0, kurtosis=0
const_field = make_field(data=np.ones((4, 4)) * 5.0)
table_const, = node.process(const_field)
const_stats = {row["quantity"]: row["value"] for row in table_const}
assert const_stats["RMS"] == 0.0
assert const_stats["skewness"] == 0.0
assert const_stats["kurtosis"] == 0.0
print(" PASS\n")
def test_height_histogram():
print("=== Test: HeightHistogram ===")
from backend.nodes.analysis import HeightHistogram
node = HeightHistogram()
# Uniform data should give a roughly flat histogram
data = np.linspace(0, 1, 1000).reshape(25, 40)
field = make_field(data=data)
counts, bin_centers = node.process(field, n_bins=10)
assert len(counts) == 10
assert len(bin_centers) == 10
assert counts.dtype == np.float64
# Total counts should equal number of pixels
assert counts.sum() == 1000
# For uniform data, each bin should have ~100 counts
assert np.std(counts) < 10, f"Histogram not flat enough: std={np.std(counts)}"
# Bin centers should span the data range
assert bin_centers[0] > 0.0
assert bin_centers[-1] < 1.0
print(" PASS\n")
def test_cross_section():
print("=== Test: CrossSection ===")
from backend.nodes.analysis import CrossSection
node = CrossSection()
# Create a field with a known horizontal gradient
N = 100
y, x = np.mgrid[0:N, 0:N] / N
data = x * 10.0 # value = 10 * x_fraction
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
# Horizontal cross section at y=0.5
(profile,) = node.process(
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
extend="none", n_samples=100,
)
assert len(profile) == 100
# Profile should be a linear ramp from ~0 to ~10
assert profile[0] < 0.5, f"Start of profile: {profile[0]}"
assert profile[-1] > 9.5, f"End of profile: {profile[-1]}"
# n_samples=0 should auto-calculate
(profile_auto,) = node.process(
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
extend="none", n_samples=0,
)
assert len(profile_auto) >= 2
# Test extend to edges — a short segment should be extended
(profile_ext,) = node.process(
field, x1=0.3, y1=0.5, x2=0.7, y2=0.5,
extend="to_edges", n_samples=100,
)
# Extended profile should start near 0 and end near 10
assert profile_ext[0] < 0.5
assert profile_ext[-1] > 9.5
# Diagonal cross section
(profile_diag,) = node.process(
field, x1=0.0, y1=0.0, x2=1.0, y2=1.0,
extend="none", n_samples=50,
)
assert len(profile_diag) == 50
print(" PASS\n")
# =========================================================================
# Grains
# =========================================================================
def test_threshold_mask():
print("=== Test: ThresholdMask ===")
from backend.nodes.grains import ThresholdMask
node = ThresholdMask()
# Clear bimodal data: left half = 0, right half = 1
data = np.zeros((64, 64))
data[:, 32:] = 1.0
field = make_field(data=data)
# Absolute threshold at 0.5
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
assert mask.dtype == np.uint8
assert mask.shape == (64, 64)
assert np.all(mask[:, :32] == 0)
assert np.all(mask[:, 32:] == 255)
# Direction "below"
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
assert np.all(mask_below[:, :32] == 255)
assert np.all(mask_below[:, 32:] == 0)
# Relative threshold at 0.5 (midpoint of range)
mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above")
assert np.all(mask_rel[:, 32:] == 255)
# Otsu should find the bimodal threshold
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
print(" PASS\n")
def test_grain_analysis():
print("=== Test: GrainAnalysis ===")
from backend.nodes.grains import GrainAnalysis
node = GrainAnalysis()
# Create a field with two distinct "grains"
N = 64
data = np.zeros((N, N))
# Grain 1: 10x10 block at top-left with height 5
data[5:15, 5:15] = 5.0
# Grain 2: 8x8 block at bottom-right with height 3
data[45:53, 45:53] = 3.0
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
# Create matching mask
mask = np.zeros((N, N), dtype=np.uint8)
mask[5:15, 5:15] = 255
mask[45:53, 45:53] = 255
table, = node.process(field, mask=mask, min_size=10)
assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
# Sort by area descending
table.sort(key=lambda r: r["area_px"], reverse=True)
assert table[0]["area_px"] == 100 # 10x10
assert table[1]["area_px"] == 64 # 8x8
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
# min_size filtering: only keep grains >= 80 px
table_filtered, = node.process(field, mask=mask, min_size=80)
assert len(table_filtered) == 1
assert table_filtered[0]["area_px"] == 100
print(" PASS\n")
# =========================================================================
# I/O
# =========================================================================
def test_load_image():
print("=== Test: LoadImage ===")
from backend.nodes.io import LoadImage
from PIL import Image
node = LoadImage()
with tempfile.TemporaryDirectory() as tmpdir:
# Test loading a grayscale PNG
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
img = Image.fromarray(arr, mode="L")
path = os.path.join(tmpdir, "test_gray.png")
img.save(path)
image, field = node.load(filename=path)
assert image.shape == (48, 64)
assert field.data.shape == (48, 64)
assert field.data.dtype == np.float64
# Test loading an RGB PNG (should average to grayscale for field)
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
path_rgb = os.path.join(tmpdir, "test_rgb.png")
img_rgb.save(path_rgb)
image_rgb, field_rgb = node.load(filename=path_rgb)
assert image_rgb.shape == (32, 32, 3)
assert field_rgb.data.shape == (32, 32)
# Test loading a .npy file
data_npy = np.random.default_rng(3).standard_normal((50, 60))
path_npy = os.path.join(tmpdir, "test.npy")
np.save(path_npy, data_npy)
image_npy, field_npy = node.load(filename=path_npy)
assert np.allclose(field_npy.data, data_npy)
print(" PASS\n")
def test_save_image():
print("=== Test: SaveImage ===")
from backend.nodes.io import SaveImage
node = SaveImage()
with tempfile.TemporaryDirectory() as tmpdir:
# Monkey-patch OUTPUT_DIR for testing
from pathlib import Path
import backend.nodes.io as io_mod
orig_dir = io_mod.OUTPUT_DIR
io_mod.OUTPUT_DIR = Path(tmpdir)
try:
arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8)
# Save as PNG
node.save(image=arr, filename_prefix="test", format="PNG")
saved = os.listdir(tmpdir)
assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}"
# Save as NPY
node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY")
saved = os.listdir(tmpdir)
assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}"
finally:
io_mod.OUTPUT_DIR = orig_dir
print(" PASS\n")
# =========================================================================
# Display (limited testing — these are output nodes with WS callbacks)
# =========================================================================
def test_preview_image():
print("=== Test: PreviewImage ===")
from backend.nodes.display import PreviewImage
node = PreviewImage()
# Set up a capture for the broadcast
captured = []
PreviewImage._broadcast_fn = lambda node_id, data_uri: captured.append(data_uri)
PreviewImage._current_node_id = "test"
# Preview with a DataField
field = make_field()
node.preview(colormap="viridis", field=field)
assert len(captured) == 1
assert captured[0].startswith("data:image/png;base64,")
# Preview with an IMAGE array
captured.clear()
arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8)
node.preview(colormap="gray", image=arr)
assert len(captured) == 1
# Clean up
PreviewImage._broadcast_fn = None
print(" PASS\n")
def test_print_table():
print("=== Test: PrintTable ===")
from backend.nodes.display import PrintTable
node = PrintTable()
captured = []
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
PrintTable._current_node_id = "test"
table = [{"quantity": "test", "value": 42.0, "unit": "m"}]
node.print_table(table=table)
assert len(captured) == 1
assert captured[0] == table
PrintTable._broadcast_table_fn = None
print(" PASS\n")
# =========================================================================
# Run all tests
# =========================================================================
if __name__ == "__main__":
# Filters
test_gaussian_filter()
test_median_filter()
test_edge_detect()
# Level
test_plane_level()
test_poly_level()
test_fix_zero()
# Analysis
test_statistics()
test_height_histogram()
test_cross_section()
# Grains
test_threshold_mask()
test_grain_analysis()
# I/O
test_load_image()
test_save_image()
# Display
test_preview_image()
test_print_table()
print("All tests passed!")