489 lines
16 KiB
Python
489 lines
16 KiB
Python
"""
|
|
Tests for all argonode backend nodes (excluding FFT2D which has its own test file).
|
|
|
|
Run from project root:
|
|
.venv/bin/python -m tests.test_nodes
|
|
"""
|
|
import sys
|
|
import os
|
|
import tempfile
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, ".")
|
|
from backend.data_types import DataField
|
|
|
|
|
|
def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
|
|
"""Create a DataField, optionally from given data or a random field."""
|
|
if data is None:
|
|
data = np.random.default_rng(42).standard_normal(shape)
|
|
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
|
|
|
|
|
|
# =========================================================================
|
|
# Filters
|
|
# =========================================================================
|
|
|
|
def test_gaussian_filter():
|
|
print("=== Test: GaussianFilter ===")
|
|
from backend.nodes.filters import GaussianFilter
|
|
node = GaussianFilter()
|
|
field = make_field()
|
|
|
|
result, = node.process(field, sigma=2.0)
|
|
assert result.data.shape == field.data.shape
|
|
assert result.xreal == field.xreal
|
|
assert result.si_unit_z == field.si_unit_z
|
|
# Gaussian blur should reduce variance
|
|
assert result.data.std() < field.data.std()
|
|
# With very small sigma, output should be nearly unchanged
|
|
result_tiny, = node.process(field, sigma=0.01)
|
|
assert np.allclose(result_tiny.data, field.data, atol=1e-6)
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_median_filter():
|
|
print("=== Test: MedianFilter ===")
|
|
from backend.nodes.filters import MedianFilter
|
|
node = MedianFilter()
|
|
|
|
# Median filter should remove salt-and-pepper noise
|
|
data = np.zeros((64, 64))
|
|
rng = np.random.default_rng(7)
|
|
noise_idx = rng.choice(64 * 64, size=100, replace=False)
|
|
data.ravel()[noise_idx] = 1.0
|
|
field = make_field(data=data)
|
|
|
|
result, = node.process(field, size=3)
|
|
assert result.data.shape == field.data.shape
|
|
# Should remove most impulse noise
|
|
assert result.data.sum() < field.data.sum()
|
|
# Size=1 should be identity
|
|
result_1, = node.process(field, size=1)
|
|
assert np.array_equal(result_1.data, field.data)
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_edge_detect():
|
|
print("=== Test: EdgeDetect ===")
|
|
from backend.nodes.filters import EdgeDetect
|
|
node = EdgeDetect()
|
|
|
|
# Create an image with a sharp vertical edge
|
|
data = np.zeros((64, 64))
|
|
data[:, 32:] = 1.0
|
|
field = make_field(data=data)
|
|
|
|
for method in ["sobel", "prewitt", "laplacian", "log"]:
|
|
result, = node.process(field, method=method, sigma=1.0)
|
|
assert result.data.shape == field.data.shape
|
|
# Edge response should be strongest near column 32
|
|
col_energy = np.abs(result.data).sum(axis=0)
|
|
peak_col = np.argmax(col_energy)
|
|
assert abs(peak_col - 32) <= 2, f"{method}: peak at col {peak_col}, expected ~32"
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Level
|
|
# =========================================================================
|
|
|
|
def test_plane_level():
|
|
print("=== Test: PlaneLevelField ===")
|
|
from backend.nodes.level import PlaneLevelField
|
|
node = PlaneLevelField()
|
|
|
|
# Create a tilted plane + small signal
|
|
N = 64
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
signal = np.sin(2 * np.pi * 5 * x)
|
|
data = 100 * x + 50 * y + signal
|
|
field = make_field(data=data)
|
|
|
|
result, = node.process(field)
|
|
assert result.data.shape == field.data.shape
|
|
# After plane leveling, mean should be near zero
|
|
assert abs(result.data.mean()) < 1e-10
|
|
# The signal should remain (correlation with original sine)
|
|
corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1]
|
|
assert corr > 0.98, f"Signal correlation after leveling: {corr}"
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_poly_level():
|
|
print("=== Test: PolyLevelField ===")
|
|
from backend.nodes.level import PolyLevelField
|
|
node = PolyLevelField()
|
|
|
|
N = 64
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
# Quadratic background + signal
|
|
background = 50 * x**2 + 30 * y**2 + 10 * x * y
|
|
signal = np.sin(2 * np.pi * 8 * x)
|
|
data = background + signal
|
|
field = make_field(data=data)
|
|
|
|
leveled, bg = node.process(field, degree_x=2, degree_y=2)
|
|
assert leveled.data.shape == field.data.shape
|
|
assert bg.data.shape == field.data.shape
|
|
# leveled + bg should reconstruct original
|
|
assert np.allclose(leveled.data + bg.data, field.data, atol=1e-10)
|
|
# Signal should be preserved after leveling
|
|
corr = np.corrcoef(leveled.data.ravel(), signal.ravel())[0, 1]
|
|
assert corr > 0.95, f"Signal correlation after poly leveling: {corr}"
|
|
# Degree 0 should just subtract the mean
|
|
leveled_0, bg_0 = node.process(field, degree_x=0, degree_y=0)
|
|
assert abs(leveled_0.data.mean()) < 1e-10
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_fix_zero():
|
|
print("=== Test: FixZero ===")
|
|
from backend.nodes.level import FixZero
|
|
node = FixZero()
|
|
field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64))
|
|
|
|
result_min, = node.process(field, method="min")
|
|
assert result_min.data.min() == 0.0
|
|
assert result_min.data.max() == 30.0
|
|
|
|
result_mean, = node.process(field, method="mean")
|
|
assert abs(result_mean.data.mean()) < 1e-10
|
|
|
|
result_median, = node.process(field, method="median")
|
|
assert abs(np.median(result_median.data)) < 1e-10
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Analysis (non-FFT)
|
|
# =========================================================================
|
|
|
|
def test_statistics():
|
|
print("=== Test: StatisticsNode ===")
|
|
from backend.nodes.analysis import StatisticsNode
|
|
node = StatisticsNode()
|
|
|
|
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
|
|
field = make_field(data=data)
|
|
|
|
table, = node.process(field)
|
|
stats = {row["quantity"]: row["value"] for row in table}
|
|
|
|
assert stats["min"] == 1.0
|
|
assert stats["max"] == 4.0
|
|
assert stats["mean"] == 2.5
|
|
assert stats["median"] == 2.5
|
|
assert stats["range"] == 3.0
|
|
# RMS = sqrt(mean((x - mean)^2))
|
|
expected_rms = np.sqrt(np.mean((data - 2.5) ** 2))
|
|
assert abs(stats["RMS"] - expected_rms) < 1e-10
|
|
|
|
# Constant data should have RMS=0, skewness=0, kurtosis=0
|
|
const_field = make_field(data=np.ones((4, 4)) * 5.0)
|
|
table_const, = node.process(const_field)
|
|
const_stats = {row["quantity"]: row["value"] for row in table_const}
|
|
assert const_stats["RMS"] == 0.0
|
|
assert const_stats["skewness"] == 0.0
|
|
assert const_stats["kurtosis"] == 0.0
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_height_histogram():
|
|
print("=== Test: HeightHistogram ===")
|
|
from backend.nodes.analysis import HeightHistogram
|
|
node = HeightHistogram()
|
|
|
|
# Uniform data should give a roughly flat histogram
|
|
data = np.linspace(0, 1, 1000).reshape(25, 40)
|
|
field = make_field(data=data)
|
|
|
|
counts, bin_centers = node.process(field, n_bins=10)
|
|
assert len(counts) == 10
|
|
assert len(bin_centers) == 10
|
|
assert counts.dtype == np.float64
|
|
# Total counts should equal number of pixels
|
|
assert counts.sum() == 1000
|
|
# For uniform data, each bin should have ~100 counts
|
|
assert np.std(counts) < 10, f"Histogram not flat enough: std={np.std(counts)}"
|
|
# Bin centers should span the data range
|
|
assert bin_centers[0] > 0.0
|
|
assert bin_centers[-1] < 1.0
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_cross_section():
|
|
print("=== Test: CrossSection ===")
|
|
from backend.nodes.analysis import CrossSection
|
|
node = CrossSection()
|
|
|
|
# Create a field with a known horizontal gradient
|
|
N = 100
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
data = x * 10.0 # value = 10 * x_fraction
|
|
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
|
|
|
# Horizontal cross section at y=0.5
|
|
(profile,) = node.process(
|
|
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
|
|
extend="none", n_samples=100,
|
|
)
|
|
assert len(profile) == 100
|
|
# Profile should be a linear ramp from ~0 to ~10
|
|
assert profile[0] < 0.5, f"Start of profile: {profile[0]}"
|
|
assert profile[-1] > 9.5, f"End of profile: {profile[-1]}"
|
|
|
|
# n_samples=0 should auto-calculate
|
|
(profile_auto,) = node.process(
|
|
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
|
|
extend="none", n_samples=0,
|
|
)
|
|
assert len(profile_auto) >= 2
|
|
|
|
# Test extend to edges — a short segment should be extended
|
|
(profile_ext,) = node.process(
|
|
field, x1=0.3, y1=0.5, x2=0.7, y2=0.5,
|
|
extend="to_edges", n_samples=100,
|
|
)
|
|
# Extended profile should start near 0 and end near 10
|
|
assert profile_ext[0] < 0.5
|
|
assert profile_ext[-1] > 9.5
|
|
|
|
# Diagonal cross section
|
|
(profile_diag,) = node.process(
|
|
field, x1=0.0, y1=0.0, x2=1.0, y2=1.0,
|
|
extend="none", n_samples=50,
|
|
)
|
|
assert len(profile_diag) == 50
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Grains
|
|
# =========================================================================
|
|
|
|
def test_threshold_mask():
|
|
print("=== Test: ThresholdMask ===")
|
|
from backend.nodes.grains import ThresholdMask
|
|
node = ThresholdMask()
|
|
|
|
# Clear bimodal data: left half = 0, right half = 1
|
|
data = np.zeros((64, 64))
|
|
data[:, 32:] = 1.0
|
|
field = make_field(data=data)
|
|
|
|
# Absolute threshold at 0.5
|
|
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
|
|
assert mask.dtype == np.uint8
|
|
assert mask.shape == (64, 64)
|
|
assert np.all(mask[:, :32] == 0)
|
|
assert np.all(mask[:, 32:] == 255)
|
|
|
|
# Direction "below"
|
|
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
|
|
assert np.all(mask_below[:, :32] == 255)
|
|
assert np.all(mask_below[:, 32:] == 0)
|
|
|
|
# Relative threshold at 0.5 (midpoint of range)
|
|
mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above")
|
|
assert np.all(mask_rel[:, 32:] == 255)
|
|
|
|
# Otsu should find the bimodal threshold
|
|
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
|
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_grain_analysis():
|
|
print("=== Test: GrainAnalysis ===")
|
|
from backend.nodes.grains import GrainAnalysis
|
|
node = GrainAnalysis()
|
|
|
|
# Create a field with two distinct "grains"
|
|
N = 64
|
|
data = np.zeros((N, N))
|
|
# Grain 1: 10x10 block at top-left with height 5
|
|
data[5:15, 5:15] = 5.0
|
|
# Grain 2: 8x8 block at bottom-right with height 3
|
|
data[45:53, 45:53] = 3.0
|
|
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
|
|
|
# Create matching mask
|
|
mask = np.zeros((N, N), dtype=np.uint8)
|
|
mask[5:15, 5:15] = 255
|
|
mask[45:53, 45:53] = 255
|
|
|
|
table, = node.process(field, mask=mask, min_size=10)
|
|
assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
|
|
|
|
# Sort by area descending
|
|
table.sort(key=lambda r: r["area_px"], reverse=True)
|
|
assert table[0]["area_px"] == 100 # 10x10
|
|
assert table[1]["area_px"] == 64 # 8x8
|
|
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
|
|
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
|
|
|
|
# min_size filtering: only keep grains >= 80 px
|
|
table_filtered, = node.process(field, mask=mask, min_size=80)
|
|
assert len(table_filtered) == 1
|
|
assert table_filtered[0]["area_px"] == 100
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# I/O
|
|
# =========================================================================
|
|
|
|
def test_load_image():
|
|
print("=== Test: LoadImage ===")
|
|
from backend.nodes.io import LoadImage
|
|
from PIL import Image
|
|
node = LoadImage()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Test loading a grayscale PNG
|
|
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
|
|
img = Image.fromarray(arr, mode="L")
|
|
path = os.path.join(tmpdir, "test_gray.png")
|
|
img.save(path)
|
|
|
|
image, field = node.load(filename=path)
|
|
assert image.shape == (48, 64)
|
|
assert field.data.shape == (48, 64)
|
|
assert field.data.dtype == np.float64
|
|
|
|
# Test loading an RGB PNG (should average to grayscale for field)
|
|
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
|
|
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
|
|
path_rgb = os.path.join(tmpdir, "test_rgb.png")
|
|
img_rgb.save(path_rgb)
|
|
|
|
image_rgb, field_rgb = node.load(filename=path_rgb)
|
|
assert image_rgb.shape == (32, 32, 3)
|
|
assert field_rgb.data.shape == (32, 32)
|
|
|
|
# Test loading a .npy file
|
|
data_npy = np.random.default_rng(3).standard_normal((50, 60))
|
|
path_npy = os.path.join(tmpdir, "test.npy")
|
|
np.save(path_npy, data_npy)
|
|
|
|
image_npy, field_npy = node.load(filename=path_npy)
|
|
assert np.allclose(field_npy.data, data_npy)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_save_image():
|
|
print("=== Test: SaveImage ===")
|
|
from backend.nodes.io import SaveImage
|
|
node = SaveImage()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Monkey-patch OUTPUT_DIR for testing
|
|
from pathlib import Path
|
|
import backend.nodes.io as io_mod
|
|
orig_dir = io_mod.OUTPUT_DIR
|
|
io_mod.OUTPUT_DIR = Path(tmpdir)
|
|
|
|
try:
|
|
arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8)
|
|
|
|
# Save as PNG
|
|
node.save(image=arr, filename_prefix="test", format="PNG")
|
|
saved = os.listdir(tmpdir)
|
|
assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}"
|
|
|
|
# Save as NPY
|
|
node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY")
|
|
saved = os.listdir(tmpdir)
|
|
assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}"
|
|
finally:
|
|
io_mod.OUTPUT_DIR = orig_dir
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Display (limited testing — these are output nodes with WS callbacks)
|
|
# =========================================================================
|
|
|
|
def test_preview_image():
|
|
print("=== Test: PreviewImage ===")
|
|
from backend.nodes.display import PreviewImage
|
|
node = PreviewImage()
|
|
|
|
# Set up a capture for the broadcast
|
|
captured = []
|
|
PreviewImage._broadcast_fn = lambda node_id, data_uri: captured.append(data_uri)
|
|
PreviewImage._current_node_id = "test"
|
|
|
|
# Preview with a DataField
|
|
field = make_field()
|
|
node.preview(colormap="viridis", field=field)
|
|
assert len(captured) == 1
|
|
assert captured[0].startswith("data:image/png;base64,")
|
|
|
|
# Preview with an IMAGE array
|
|
captured.clear()
|
|
arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8)
|
|
node.preview(colormap="gray", image=arr)
|
|
assert len(captured) == 1
|
|
|
|
# Clean up
|
|
PreviewImage._broadcast_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_print_table():
|
|
print("=== Test: PrintTable ===")
|
|
from backend.nodes.display import PrintTable
|
|
node = PrintTable()
|
|
|
|
captured = []
|
|
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
|
|
PrintTable._current_node_id = "test"
|
|
|
|
table = [{"quantity": "test", "value": 42.0, "unit": "m"}]
|
|
node.print_table(table=table)
|
|
assert len(captured) == 1
|
|
assert captured[0] == table
|
|
|
|
PrintTable._broadcast_table_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Run all tests
|
|
# =========================================================================
|
|
|
|
if __name__ == "__main__":
|
|
# Filters
|
|
test_gaussian_filter()
|
|
test_median_filter()
|
|
test_edge_detect()
|
|
|
|
# Level
|
|
test_plane_level()
|
|
test_poly_level()
|
|
test_fix_zero()
|
|
|
|
# Analysis
|
|
test_statistics()
|
|
test_height_histogram()
|
|
test_cross_section()
|
|
|
|
# Grains
|
|
test_threshold_mask()
|
|
test_grain_analysis()
|
|
|
|
# I/O
|
|
test_load_image()
|
|
test_save_image()
|
|
|
|
# Display
|
|
test_preview_image()
|
|
test_print_table()
|
|
|
|
print("All tests passed!")
|