clean tests

This commit is contained in:
2026-03-28 00:21:37 -07:00
parent 240a2529eb
commit 4baadd4c3e
14 changed files with 330 additions and 211 deletions

View File

@@ -523,6 +523,31 @@ def test_plane_level():
# The signal should remain (correlation with original sine)
corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1]
assert corr > 0.98, f"Signal correlation after leveling: {corr}"
yy_px, xx_px = np.mgrid[0:N, 0:N]
def fit_pixel_plane(data_in: np.ndarray, region: np.ndarray) -> tuple[float, float, float]:
A = np.column_stack([
np.ones(int(np.count_nonzero(region)), dtype=np.float64),
xx_px[region].astype(np.float64),
yy_px[region].astype(np.float64),
])
coeffs, _, _, _ = np.linalg.lstsq(A, data_in[region].ravel().astype(np.float64), rcond=None)
return float(coeffs[0]), float(coeffs[1]), float(coeffs[2])
mask = np.zeros((N, N), dtype=np.uint8)
mask[20:44, 22:46] = 255
feature = np.zeros((N, N), dtype=np.float64)
feature[mask > 0] = 35.0
masked_field = make_field(data=100 * x + 50 * y + feature)
unmasked, = node.process(masked_field)
masked, = node.process(masked_field, masking="exclude", mask=mask)
outside = mask == 0
_, unmasked_bx, unmasked_by = fit_pixel_plane(unmasked.data, outside)
_, masked_bx, masked_by = fit_pixel_plane(masked.data, outside)
assert np.hypot(masked_bx, masked_by) < np.hypot(unmasked_bx, unmasked_by) * 1e-3
print(" PASS\n")
@@ -1261,10 +1286,10 @@ def test_mask_invert():
print(" PASS\n")
def test_mask_combine():
print("=== Test: MaskCombine ===")
from backend.nodes.mask_combine import MaskCombine
node = MaskCombine()
def test_mask_operations():
print("=== Test: MaskOperations ===")
from backend.nodes.mask_operations import MaskOperations
node = MaskOperations()
# Two overlapping squares
a = np.zeros((64, 64), dtype=np.uint8)
@@ -1291,12 +1316,26 @@ def test_mask_combine():
assert result_xor[35, 35] == 255 # b-only
assert result_xor[25, 25] == 0 # overlap excluded
# Subtract — a minus b
result_sub, = node.process(a, b, operation="subtract")
# A minus B
result_sub, = node.process(a, b, operation="a_minus_b")
assert result_sub[15, 15] == 255 # a-only kept
assert result_sub[25, 25] == 0 # overlap removed
assert result_sub[35, 35] == 0 # b-only not included
# NAND — everything except overlap
result_nand, = node.process(a, b, operation="nand")
assert result_nand[15, 15] == 255
assert result_nand[35, 35] == 255
assert result_nand[25, 25] == 0
assert result_nand[5, 5] == 255
# XNOR — overlap plus shared background
result_xnor, = node.process(a, b, operation="xnor")
assert result_xnor[25, 25] == 255
assert result_xnor[5, 5] == 255
assert result_xnor[15, 15] == 0
assert result_xnor[35, 35] == 0
print(" PASS\n")
@@ -1347,17 +1386,17 @@ def test_draw_mask():
print(" PASS\n")
def test_particle_analysis():
print("=== Test: ParticleAnalysis ===")
from backend.nodes.particle_analysis import ParticleAnalysis
node = ParticleAnalysis()
def test_grain_analysis():
print("=== Test: GrainAnalysis ===")
from backend.nodes.grain_analysis import GrainAnalysis
node = GrainAnalysis()
# Create a field with two distinct particles
# Create a field with two distinct grains
N = 64
data = np.zeros((N, N))
# Particle 1: 10x10 block at top-left with height 5
# Grain 1: 10x10 block at top-left with height 5
data[5:15, 5:15] = 5.0
# Particle 2: 8x8 block at bottom-right with height 3
# Grain 2: 8x8 block at bottom-right with height 3
data[45:53, 45:53] = 3.0
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
@@ -1367,7 +1406,7 @@ def test_particle_analysis():
mask[45:53, 45:53] = 255
table, = node.process(field, mask=mask, min_size=10)
assert len(table) == 2, f"Expected 2 particles, got {len(table)}"
assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
# Sort by area descending
table.sort(key=lambda r: r["area_px"], reverse=True)
@@ -1381,7 +1420,7 @@ def test_particle_analysis():
assert table[0]["mean_height_unit"] == "m"
assert table[0]["max_height_unit"] == "m"
# min_size filtering: only keep particles >= 80 px
# min_size filtering: only keep grains >= 80 px
table_filtered, = node.process(field, mask=mask, min_size=80)
assert len(table_filtered) == 1
assert table_filtered[0]["area_px"] == 100
@@ -3140,11 +3179,11 @@ if __name__ == "__main__":
test_threshold_mask()
test_mask_morphology()
test_mask_invert()
test_mask_combine()
test_mask_operations()
test_draw_mask()
# Grains
test_particle_analysis()
test_grain_analysis()
# I/O
test_load_file()