add snapshot tool, masks, and build for mac

This commit is contained in:
2026-03-23 21:52:17 -07:00
parent 080eefbef6
commit a34b1c980d
29 changed files with 2016 additions and 170 deletions

View File

@@ -85,6 +85,88 @@ def test_edge_detect():
print(" PASS\n")
def test_fft_filter_1d():
print("=== Test: FFTFilter1D ===")
from backend.nodes.filters import FFTFilter1D
node = FFTFilter1D()
# Signal: low-frequency sine + high-frequency sine
n = 256
t = np.arange(n, dtype=np.float64) / n
low = np.sin(2 * np.pi * 3 * t) # 3 cycles — low freq
high = np.sin(2 * np.pi * 80 * t) # 80 cycles — high freq
line = low + high
# Lowpass should keep low, suppress high
filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
assert len(filtered_lp) == n
corr_low = np.corrcoef(filtered_lp, low)[0, 1]
corr_high = np.corrcoef(filtered_lp, high)[0, 1]
assert corr_low > 0.95, f"Lowpass: correlation with low={corr_low}"
assert abs(corr_high) < 0.3, f"Lowpass: correlation with high={corr_high}"
# Highpass should keep high, suppress low
filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1]
corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1]
assert abs(corr_low_hp) < 0.3, f"Highpass: correlation with low={corr_low_hp}"
assert corr_high_hp > 0.95, f"Highpass: correlation with high={corr_high_hp}"
# Bandpass centred on the high frequency
filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4)
corr_low_bp = np.corrcoef(filtered_bp, low)[0, 1]
corr_high_bp = np.corrcoef(filtered_bp, high)[0, 1]
assert abs(corr_low_bp) < 0.3, f"Bandpass: correlation with low={corr_low_bp}"
assert corr_high_bp > 0.9, f"Bandpass: correlation with high={corr_high_bp}"
# Notch (band-reject) centred on the high frequency — should remove it
filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4)
corr_low_notch = np.corrcoef(filtered_notch, low)[0, 1]
corr_high_notch = np.corrcoef(filtered_notch, high)[0, 1]
assert corr_low_notch > 0.95, f"Notch: correlation with low={corr_low_notch}"
assert abs(corr_high_notch) < 0.3, f"Notch: correlation with high={corr_high_notch}"
print(" PASS\n")
def test_fft_filter_2d():
print("=== Test: FFTFilter2D ===")
from backend.nodes.filters import FFTFilter2D
node = FFTFilter2D()
N = 128
y, x = np.mgrid[0:N, 0:N] / N
# Low-frequency 2D pattern + high-frequency pattern
low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y)
high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y)
data = low_2d + high_2d
field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6)
# Lowpass — should preserve low, remove high
result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
assert result_lp.data.shape == (N, N)
assert result_lp.xreal == field.xreal
assert result_lp.si_unit_z == field.si_unit_z
corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1]
corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1]
assert corr_low > 0.9, f"2D lowpass: correlation with low={corr_low}"
assert abs(corr_high) < 0.3, f"2D lowpass: correlation with high={corr_high}"
# Highpass — should preserve high, remove low
result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
corr_low_hp = np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]
corr_high_hp = np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1]
assert abs(corr_low_hp) < 0.3, f"2D highpass: correlation with low={corr_low_hp}"
assert corr_high_hp > 0.9, f"2D highpass: correlation with high={corr_high_hp}"
# Constant field should be unchanged by lowpass (DC preservation)
const = make_field(data=np.ones((32, 32)) * 7.0)
result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2)
assert np.allclose(result_const.data, 7.0, atol=1e-10), "Lowpass should preserve constant field"
print(" PASS\n")
# =========================================================================
# Level
# =========================================================================
@@ -199,7 +281,7 @@ def test_height_histogram():
data = np.linspace(0, 1, 1000).reshape(25, 40)
field = make_field(data=data)
counts, bin_centers = node.process(field, n_bins=10)
counts, bin_centers = node.process(field, n_bins=10, y_scale="linear")
assert len(counts) == 10
assert len(bin_centers) == 10
assert counts.dtype == np.float64
@@ -265,7 +347,7 @@ def test_cross_section():
def test_threshold_mask():
print("=== Test: ThresholdMask ===")
from backend.nodes.grains import ThresholdMask
from backend.nodes.mask import ThresholdMask
node = ThresholdMask()
# Clear bimodal data: left half = 0, right half = 1
@@ -273,6 +355,11 @@ def test_threshold_mask():
data[:, 32:] = 1.0
field = make_field(data=data)
# Capture overlay preview
previews = []
ThresholdMask._broadcast_fn = lambda nid, uri: previews.append(uri)
ThresholdMask._current_node_id = "test"
# Absolute threshold at 0.5
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
assert mask.dtype == np.uint8
@@ -280,6 +367,10 @@ def test_threshold_mask():
assert np.all(mask[:, :32] == 0)
assert np.all(mask[:, 32:] == 255)
# Verify overlay preview was broadcast
assert len(previews) == 1
assert previews[0].startswith("data:image/png;base64,")
# Direction "below"
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
assert np.all(mask_below[:, :32] == 255)
@@ -292,20 +383,117 @@ def test_threshold_mask():
# Otsu should find the bimodal threshold
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
ThresholdMask._broadcast_fn = None
print(" PASS\n")
def test_grain_analysis():
print("=== Test: GrainAnalysis ===")
from backend.nodes.grains import GrainAnalysis
node = GrainAnalysis()
def test_mask_morphology():
print("=== Test: MaskMorphology ===")
from backend.nodes.mask import MaskMorphology
node = MaskMorphology()
# Create a field with two distinct "grains"
# Small square blob in the centre
mask = np.zeros((64, 64), dtype=np.uint8)
mask[28:36, 28:36] = 255 # 8x8 block
orig_count = np.count_nonzero(mask)
# Dilate should grow the region
dilated, = node.process(mask, operation="dilate", radius=1, shape="square")
assert dilated.dtype == np.uint8
assert np.count_nonzero(dilated) > orig_count
# Erode should shrink it
eroded, = node.process(mask, operation="erode", radius=1, shape="square")
assert np.count_nonzero(eroded) < orig_count
# Open on a clean block should give back roughly the same block
opened, = node.process(mask, operation="open", radius=1, shape="square")
assert np.count_nonzero(opened) <= orig_count
# Close on a mask with a 1-pixel hole should fill the hole
mask_hole = mask.copy()
mask_hole[32, 32] = 0 # poke a hole
assert np.count_nonzero(mask_hole) == orig_count - 1
closed, = node.process(mask_hole, operation="close", radius=1, shape="square")
assert closed[32, 32] == 255, "Close should fill the 1-pixel hole"
# Disk structuring element should also work
dilated_disk, = node.process(mask, operation="dilate", radius=2, shape="disk")
assert np.count_nonzero(dilated_disk) > orig_count
print(" PASS\n")
def test_mask_invert():
print("=== Test: MaskInvert ===")
from backend.nodes.mask import MaskInvert
node = MaskInvert()
mask = np.zeros((64, 64), dtype=np.uint8)
mask[10:20, 10:20] = 255
inverted, = node.process(mask)
assert inverted.dtype == np.uint8
assert np.all(inverted[10:20, 10:20] == 0)
assert np.all(inverted[0:10, 0:10] == 255)
# Double-invert should return to original
double, = node.process(inverted)
assert np.array_equal(double, mask)
print(" PASS\n")
def test_mask_combine():
print("=== Test: MaskCombine ===")
from backend.nodes.mask import MaskCombine
node = MaskCombine()
# Two overlapping squares
a = np.zeros((64, 64), dtype=np.uint8)
a[10:30, 10:30] = 255 # 20x20
b = np.zeros((64, 64), dtype=np.uint8)
b[20:40, 20:40] = 255 # 20x20, overlaps 10x10
# AND — only the overlap
result_and, = node.process(a, b, operation="and")
assert np.all(result_and[20:30, 20:30] == 255)
assert result_and[15, 15] == 0 # a-only region
assert result_and[35, 35] == 0 # b-only region
# OR — union
result_or, = node.process(a, b, operation="or")
assert result_or[15, 15] == 255
assert result_or[35, 35] == 255
assert result_or[25, 25] == 255
assert result_or[5, 5] == 0
# XOR — symmetric difference
result_xor, = node.process(a, b, operation="xor")
assert result_xor[15, 15] == 255 # a-only
assert result_xor[35, 35] == 255 # b-only
assert result_xor[25, 25] == 0 # overlap excluded
# Subtract — a minus b
result_sub, = node.process(a, b, operation="subtract")
assert result_sub[15, 15] == 255 # a-only kept
assert result_sub[25, 25] == 0 # overlap removed
assert result_sub[35, 35] == 0 # b-only not included
print(" PASS\n")
def test_particle_analysis():
print("=== Test: ParticleAnalysis ===")
from backend.nodes.grains import ParticleAnalysis
node = ParticleAnalysis()
# Create a field with two distinct particles
N = 64
data = np.zeros((N, N))
# Grain 1: 10x10 block at top-left with height 5
# Particle 1: 10x10 block at top-left with height 5
data[5:15, 5:15] = 5.0
# Grain 2: 8x8 block at bottom-right with height 3
# Particle 2: 8x8 block at bottom-right with height 3
data[45:53, 45:53] = 3.0
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
@@ -315,7 +503,7 @@ def test_grain_analysis():
mask[45:53, 45:53] = 255
table, = node.process(field, mask=mask, min_size=10)
assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
assert len(table) == 2, f"Expected 2 particles, got {len(table)}"
# Sort by area descending
table.sort(key=lambda r: r["area_px"], reverse=True)
@@ -324,7 +512,7 @@ def test_grain_analysis():
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
# min_size filtering: only keep grains >= 80 px
# min_size filtering: only keep particles >= 80 px
table_filtered, = node.process(field, mask=mask, min_size=80)
assert len(table_filtered) == 1
assert table_filtered[0]["area_px"] == 100
@@ -462,6 +650,8 @@ if __name__ == "__main__":
test_gaussian_filter()
test_median_filter()
test_edge_detect()
test_fft_filter_1d()
test_fft_filter_2d()
# Level
test_plane_level()
@@ -473,9 +663,14 @@ if __name__ == "__main__":
test_height_histogram()
test_cross_section()
# Grains
# Mask
test_threshold_mask()
test_grain_analysis()
test_mask_morphology()
test_mask_invert()
test_mask_combine()
# Grains
test_particle_analysis()
# I/O
test_load_image()