remaining med value features
This commit is contained in:
89
tests/node_tests/cross_correlate.py
Normal file
89
tests/node_tests/cross_correlate.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_cross_correlate_same_field_peak_at_center():
|
||||
"""Correlating a field with itself in 'same' mode peaks at the centre."""
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
data = rng.standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(field, field, mode="same", normalize=True)
|
||||
|
||||
peak_y, peak_x = np.unravel_index(np.argmax(result.data), result.data.shape)
|
||||
cy, cx = result.data.shape[0] // 2, result.data.shape[1] // 2
|
||||
# Peak should be within a few pixels of centre
|
||||
assert abs(peak_y - cy) <= 2
|
||||
assert abs(peak_x - cx) <= 2
|
||||
|
||||
|
||||
def test_cross_correlate_same_mode_shape_equals_a():
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
a = make_field(data=rng.standard_normal((32, 48)))
|
||||
b = make_field(data=rng.standard_normal((32, 48)))
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(a, b, mode="same", normalize=True)
|
||||
assert result.data.shape == a.data.shape
|
||||
|
||||
|
||||
def test_cross_correlate_full_mode_shape():
|
||||
"""Full mode output shape should be Na+Nb-1 × Ma+Mb-1."""
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
a = make_field(data=rng.standard_normal((20, 30)))
|
||||
b = make_field(data=rng.standard_normal((20, 30)))
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(a, b, mode="full", normalize=True)
|
||||
assert result.data.shape == (20 + 20 - 1, 30 + 30 - 1)
|
||||
|
||||
|
||||
def test_cross_correlate_normalized_peak_is_one():
|
||||
"""Self-correlation normalised should give peak = 1."""
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
data = rng.standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(field, field, mode="same", normalize=True)
|
||||
assert result.data.max() == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_cross_correlate_unnormalized_runs():
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(4)
|
||||
data = rng.standard_normal((16, 16))
|
||||
field = make_field(data=data)
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(field, field, mode="same", normalize=False)
|
||||
assert result.data.shape == (16, 16)
|
||||
|
||||
|
||||
def test_cross_correlate_valid_mode():
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(5)
|
||||
a = make_field(data=rng.standard_normal((16, 16)))
|
||||
b = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(a, b, mode="valid", normalize=True)
|
||||
# Valid mode output: (16-8+1, 16-8+1) = (9, 9)
|
||||
assert result.data.shape == (9, 9)
|
||||
|
||||
|
||||
def test_cross_correlate_preserves_metadata_same_mode():
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(6)
|
||||
field = make_field(data=rng.standard_normal((16, 16)))
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(field, field, mode="same", normalize=True)
|
||||
assert result.xreal == field.xreal
|
||||
assert result.yreal == field.yreal
|
||||
75
tests/node_tests/entropy.py
Normal file
75
tests/node_tests/entropy.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_entropy_uniform_field_low():
|
||||
"""A field with a single unique value has zero entropy."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
node = Entropy()
|
||||
field = make_field(data=np.full((32, 32), 3.14))
|
||||
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||
# All values fall in one bin → p=1 → H = 0
|
||||
assert h == pytest.approx(0.0, abs=1e-10)
|
||||
assert h_norm == pytest.approx(0.0, abs=1e-10)
|
||||
|
||||
|
||||
def test_entropy_random_field_positive():
|
||||
"""A random field should have positive entropy."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
field = make_field(data=rng.standard_normal((64, 64)))
|
||||
node = Entropy()
|
||||
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||
assert h > 0.0
|
||||
assert 0.0 < h_norm <= 1.0
|
||||
|
||||
|
||||
def test_entropy_normalised_leq_one():
|
||||
"""Normalised entropy should never exceed 1."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
field = make_field(data=rng.uniform(0, 1, (64, 64)))
|
||||
node = Entropy()
|
||||
_, h_norm = node.process(field, mode="height values", n_bins=64)
|
||||
assert h_norm <= 1.0 + 1e-12
|
||||
|
||||
|
||||
def test_entropy_slope_mode():
|
||||
"""Slope mode should work and return valid entropy values."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
field = make_field(data=rng.standard_normal((32, 32)))
|
||||
node = Entropy()
|
||||
h, h_norm = node.process(field, mode="slope magnitude", n_bins=128)
|
||||
assert h > 0.0
|
||||
assert 0.0 <= h_norm <= 1.0
|
||||
|
||||
|
||||
def test_entropy_more_uniform_is_higher():
|
||||
"""Uniformly distributed values have higher entropy than a spiked distribution."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
rng = np.random.default_rng(4)
|
||||
uniform = rng.uniform(0, 1, (64, 64))
|
||||
spiked = np.zeros((64, 64))
|
||||
spiked[0, 0] = 1.0
|
||||
|
||||
node = Entropy()
|
||||
h_uniform, _ = node.process(make_field(data=uniform), mode="height values", n_bins=64)
|
||||
h_spiked, _ = node.process(make_field(data=spiked), mode="height values", n_bins=64)
|
||||
assert h_uniform > h_spiked
|
||||
|
||||
|
||||
def test_entropy_returns_floats():
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
field = make_field()
|
||||
node = Entropy()
|
||||
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||
assert isinstance(h, float)
|
||||
assert isinstance(h_norm, float)
|
||||
95
tests/node_tests/filter_custom.py
Normal file
95
tests/node_tests/filter_custom.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_custom_convolution_identity_kernel():
|
||||
"""[[1]] with normalize=True (abs_sum=1) should return input unchanged."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
data = np.random.default_rng(0).standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, kernel="1", normalize=False, boundary="reflect")
|
||||
assert np.allclose(result.data, data)
|
||||
|
||||
|
||||
def test_custom_convolution_uniform_kernel_normalized():
|
||||
"""An all-ones kernel with normalize=True is a box filter (mean filter)."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
data = np.random.default_rng(1).standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
# 3x3 all-ones kernel, normalized → each pixel becomes mean of its neighbourhood
|
||||
kernel = "1 1 1\n1 1 1\n1 1 1"
|
||||
result, = node.process(field, kernel=kernel, normalize=True, boundary="reflect")
|
||||
# Output std should be less than input std (smoothing)
|
||||
assert result.data.std() < data.std()
|
||||
|
||||
|
||||
def test_custom_convolution_sharpen_increases_variation():
|
||||
"""A sharpening kernel should increase local variation on a smooth field."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
# Smooth ramp field — very low frequency content
|
||||
data = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
|
||||
field = make_field(data=data)
|
||||
sharpen = "0 -1 0\n-1 5 -1\n0 -1 0"
|
||||
result, = node.process(field, kernel=sharpen, normalize=False, boundary="reflect")
|
||||
# Sharpening without normalisation keeps the ramp intact plus adds edges
|
||||
# The std of the sharpened field should differ from input
|
||||
assert result.data.std() != pytest.approx(data.std(), rel=0.0)
|
||||
|
||||
|
||||
def test_custom_convolution_shape_preserved():
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
field = make_field(shape=(48, 64))
|
||||
result, = node.process(field, kernel="0 1 0\n1 1 1\n0 1 0", normalize=True, boundary="reflect")
|
||||
assert result.data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_custom_convolution_invalid_kernel_fallback():
|
||||
"""An invalid kernel string should return the input field unchanged."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
data = np.random.default_rng(2).standard_normal((16, 16))
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, kernel="", normalize=True, boundary="reflect")
|
||||
assert np.allclose(result.data, data)
|
||||
|
||||
|
||||
def test_custom_convolution_ragged_kernel_fallback():
|
||||
"""A ragged (non-rectangular) kernel should be rejected gracefully."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
data = np.random.default_rng(3).standard_normal((16, 16))
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, kernel="1 2\n1 2 3", normalize=True, boundary="reflect")
|
||||
assert np.allclose(result.data, data)
|
||||
|
||||
|
||||
def test_custom_convolution_boundary_modes():
|
||||
"""All boundary modes should produce valid output without error."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
field = make_field()
|
||||
for mode in ("reflect", "nearest", "wrap"):
|
||||
result, = node.process(field, kernel="1 1 1\n1 1 1\n1 1 1", normalize=True, boundary=mode)
|
||||
assert result.data.shape == field.data.shape
|
||||
|
||||
|
||||
def test_custom_convolution_preserves_metadata():
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
field = make_field()
|
||||
result, = node.process(field, kernel="1", normalize=False, boundary="reflect")
|
||||
assert result.xreal == field.xreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
75
tests/node_tests/filter_kuwahara.py
Normal file
75
tests/node_tests/filter_kuwahara.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_kuwahara_shape_preserved():
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(shape=(48, 64))
|
||||
result, = node.process(field, iterations=1)
|
||||
assert result.data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_kuwahara_flat_field_unchanged():
|
||||
"""A constant field should pass through the Kuwahara filter unchanged."""
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(data=np.full((32, 32), 7.5))
|
||||
result, = node.process(field, iterations=1)
|
||||
assert np.allclose(result.data, 7.5)
|
||||
|
||||
|
||||
def test_kuwahara_reduces_noise():
|
||||
"""Applying the filter to a noisy field should reduce standard deviation."""
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
noisy = rng.standard_normal((64, 64))
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(data=noisy)
|
||||
result, = node.process(field, iterations=1)
|
||||
assert result.data.std() < noisy.std()
|
||||
|
||||
|
||||
def test_kuwahara_preserves_step_edge():
|
||||
"""The Kuwahara filter should preserve a sharp step edge better than a blur."""
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
# Left half = 0, right half = 1
|
||||
data = np.zeros((32, 64))
|
||||
data[:, 32:] = 1.0
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, iterations=1)
|
||||
|
||||
# The edge column should have a large jump (edge preserved)
|
||||
col_before = result.data[:, 30].mean()
|
||||
col_after = result.data[:, 34].mean()
|
||||
assert col_after - col_before > 0.5
|
||||
|
||||
|
||||
def test_kuwahara_multiple_iterations():
|
||||
"""Running multiple iterations should further reduce noise."""
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
noisy = rng.standard_normal((32, 32))
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(data=noisy)
|
||||
result1, = node.process(field, iterations=1)
|
||||
result3, = node.process(field, iterations=3)
|
||||
assert result3.data.std() <= result1.data.std()
|
||||
|
||||
|
||||
def test_kuwahara_preserves_metadata():
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
node = KuwaharaFilter()
|
||||
field = make_field()
|
||||
result, = node.process(field, iterations=1)
|
||||
assert result.xreal == field.xreal
|
||||
assert result.yreal == field.yreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
67
tests/node_tests/local_contrast.py
Normal file
67
tests/node_tests/local_contrast.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_local_contrast_shape_preserved():
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
field = make_field(shape=(48, 64))
|
||||
result, = node.process(field, kernel_size=10, weight=0.5)
|
||||
assert result.data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_local_contrast_weight_zero_unchanged():
|
||||
"""weight=0 blends 100% original → result equals input."""
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
data = np.random.default_rng(0).standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, kernel_size=5, weight=0.0)
|
||||
assert np.allclose(result.data, data)
|
||||
|
||||
|
||||
def test_local_contrast_uniform_field_unchanged():
|
||||
"""A flat field has nothing to enhance; it should be returned as-is."""
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
field = make_field(data=np.full((32, 32), 2.0))
|
||||
result, = node.process(field, kernel_size=5, weight=1.0)
|
||||
assert np.allclose(result.data, 2.0)
|
||||
|
||||
|
||||
def test_local_contrast_increases_dynamic_range():
|
||||
"""Weight=1 full enhancement should not compress global range beyond input."""
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
data = rng.standard_normal((64, 64))
|
||||
field = make_field(data=data)
|
||||
node = LocalContrast()
|
||||
result, = node.process(field, kernel_size=8, weight=1.0)
|
||||
# Global min/max should be preserved (by construction of the algorithm)
|
||||
assert np.isclose(result.data.min(), data.min(), atol=1e-6)
|
||||
assert np.isclose(result.data.max(), data.max(), atol=1e-6)
|
||||
|
||||
|
||||
def test_local_contrast_preserves_metadata():
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
field = make_field()
|
||||
result, = node.process(field, kernel_size=10, weight=0.5)
|
||||
assert result.xreal == field.xreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
|
||||
|
||||
def test_local_contrast_weight_clipped():
|
||||
"""Values outside [0,1] should be clipped without error."""
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
field = make_field()
|
||||
result, = node.process(field, kernel_size=5, weight=2.0)
|
||||
assert result.data.shape == field.data.shape
|
||||
@@ -156,10 +156,16 @@ def test_save_generic():
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# LINE as plot image (PNG / TIFF)
|
||||
node.save(filename="line_plot_png", directory_path=tmpdir, format="PNG", value=line)
|
||||
assert Path(tmpdir, "line_plot_png.png").exists()
|
||||
node.save(filename="line_plot_tiff", directory_path=tmpdir, format="TIFF", value=line)
|
||||
assert Path(tmpdir, "line_plot_tiff.tiff").exists()
|
||||
|
||||
# Unsupported LINE format
|
||||
try:
|
||||
node.save(filename="line_bad", directory_path=tmpdir, format="TIFF", value=line)
|
||||
assert False, "Expected ValueError for LINE + TIFF"
|
||||
node.save(filename="line_bad", directory_path=tmpdir, format="OBJ", value=line)
|
||||
assert False, "Expected ValueError for LINE + OBJ"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
110
tests/node_tests/spot_removal.py
Normal file
110
tests/node_tests/spot_removal.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def _make_mask(shape, defect_positions):
|
||||
"""Create a uint8 mask array with 255 at given (row, col) positions."""
|
||||
mask = np.zeros(shape, dtype=np.uint8)
|
||||
for r, c in defect_positions:
|
||||
mask[r, c] = 255
|
||||
return mask
|
||||
|
||||
|
||||
def test_spot_removal_no_mask_returns_field_unchanged():
|
||||
"""Without a mask input the field should be returned as-is."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
field = make_field()
|
||||
result, = node.process(field, method="laplace", max_iter=50)
|
||||
# Should be the identical object (short-circuit path)
|
||||
assert result is field
|
||||
|
||||
|
||||
def test_spot_removal_zero_fill():
|
||||
"""method='zero' sets defect pixels to exactly 0."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
data = np.ones((16, 16)) * 5.0
|
||||
field = make_field(data=data)
|
||||
mask = _make_mask((16, 16), [(4, 4), (8, 8)])
|
||||
result, = node.process(field, method="zero", max_iter=1, mask=mask)
|
||||
assert result.data[4, 4] == pytest.approx(0.0)
|
||||
assert result.data[8, 8] == pytest.approx(0.0)
|
||||
# Non-defect pixels should stay 5.0
|
||||
assert result.data[0, 0] == pytest.approx(5.0)
|
||||
|
||||
|
||||
def test_spot_removal_mean_fill_surrounded_by_constant():
|
||||
"""On a constant field, mean fill should give back the constant."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
data = np.full((16, 16), 3.0)
|
||||
field = make_field(data=data)
|
||||
mask = _make_mask((16, 16), [(7, 7)])
|
||||
result, = node.process(field, method="mean", max_iter=1, mask=mask)
|
||||
assert result.data[7, 7] == pytest.approx(3.0, abs=1e-10)
|
||||
|
||||
|
||||
def test_spot_removal_laplace_fill_surrounded_by_constant():
|
||||
"""Laplace fill on a constant field should recover the constant at the defect."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
data = np.full((16, 16), 2.5)
|
||||
field = make_field(data=data)
|
||||
mask = _make_mask((16, 16), [(8, 8)])
|
||||
result, = node.process(field, method="laplace", max_iter=200, mask=mask)
|
||||
assert result.data[8, 8] == pytest.approx(2.5, abs=1e-3)
|
||||
|
||||
|
||||
def test_spot_removal_laplace_smooth_interpolation():
|
||||
"""Laplace should interpolate between boundary values smoothly."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
# Left half = 0, right half = 10; single defect in the middle
|
||||
data = np.zeros((16, 16))
|
||||
data[:, 8:] = 10.0
|
||||
field = make_field(data=data)
|
||||
# Defect at the boundary column
|
||||
mask = _make_mask((16, 16), [(8, 7)])
|
||||
result, = node.process(field, method="laplace", max_iter=500, mask=mask)
|
||||
# The filled value should be between 0 and 10
|
||||
filled = result.data[8, 7]
|
||||
assert 0.0 <= filled <= 10.0
|
||||
|
||||
|
||||
def test_spot_removal_shape_preserved():
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
field = make_field(shape=(48, 64))
|
||||
mask = _make_mask((48, 64), [(10, 20)])
|
||||
result, = node.process(field, method="mean", max_iter=10, mask=mask)
|
||||
assert result.data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_spot_removal_mask_shape_mismatch_raises():
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
field = make_field(shape=(16, 16))
|
||||
bad_mask = np.ones((32, 32), dtype=np.uint8)
|
||||
with pytest.raises(ValueError, match="Mask shape"):
|
||||
node.process(field, method="zero", max_iter=1, mask=bad_mask)
|
||||
|
||||
|
||||
def test_spot_removal_empty_mask_unchanged():
|
||||
"""An all-zero mask means no defects — field returned unchanged."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
data = np.random.default_rng(0).standard_normal((16, 16))
|
||||
field = make_field(data=data)
|
||||
mask = np.zeros((16, 16), dtype=np.uint8)
|
||||
result, = node.process(field, method="laplace", max_iter=50, mask=mask)
|
||||
assert result is field
|
||||
93
tests/node_tests/template_match.py
Normal file
93
tests/node_tests/template_match.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_template_match_exact_match_score_one():
|
||||
"""When template equals the image, the peak score should be 1."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
data = rng.standard_normal((32, 32))
|
||||
image_field = make_field(data=data)
|
||||
# Template is the full image → perfect correlation everywhere → peak = 1
|
||||
template_field = make_field(data=data)
|
||||
node = TemplateMatch()
|
||||
score_field, detections = node.process(image_field, template_field, threshold=0.9)
|
||||
assert score_field.data.max() == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_template_match_output_shape_matches_image():
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = TemplateMatch()
|
||||
score_field, detections = node.process(image_field, template_field, threshold=0.5)
|
||||
assert score_field.data.shape == image_field.data.shape
|
||||
assert detections.shape == image_field.data.shape
|
||||
|
||||
|
||||
def test_template_match_score_in_range():
|
||||
"""Score values should be clipped to [0, 1]."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((6, 6)))
|
||||
node = TemplateMatch()
|
||||
score_field, _ = node.process(image_field, template_field, threshold=0.5)
|
||||
assert score_field.data.min() >= 0.0 - 1e-10
|
||||
assert score_field.data.max() <= 1.0 + 1e-10
|
||||
|
||||
|
||||
def test_template_match_detections_binary():
|
||||
"""Detection mask values should be 0 or 255 only."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = TemplateMatch()
|
||||
_, detections = node.process(image_field, template_field, threshold=0.5)
|
||||
unique_values = set(np.unique(detections))
|
||||
assert unique_values <= {0, 255}
|
||||
|
||||
|
||||
def test_template_match_threshold_zero_all_detected():
|
||||
"""threshold=0 should mark all pixels as detections (score always >= 0)."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(4)
|
||||
image_field = make_field(data=rng.standard_normal((16, 16)))
|
||||
template_field = make_field(data=rng.standard_normal((4, 4)))
|
||||
node = TemplateMatch()
|
||||
_, detections = node.process(image_field, template_field, threshold=0.0)
|
||||
assert np.all(detections == 255)
|
||||
|
||||
|
||||
def test_template_match_threshold_one_sparse_detections():
|
||||
"""threshold=1.0 should detect very few (or no) positions."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(5)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = TemplateMatch()
|
||||
_, detections = node.process(image_field, template_field, threshold=1.0)
|
||||
# At threshold=1.0, only perfect matches count (rare for random data)
|
||||
detected_count = int((detections == 255).sum())
|
||||
assert detected_count < 10 # very few or none
|
||||
|
||||
|
||||
def test_template_match_preserves_metadata():
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(6)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = TemplateMatch()
|
||||
score_field, _ = node.process(image_field, template_field, threshold=0.5)
|
||||
assert score_field.xreal == image_field.xreal
|
||||
assert score_field.yreal == image_field.yreal
|
||||
85
tests/node_tests/wavelet_denoise.py
Normal file
85
tests/node_tests/wavelet_denoise.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_wavelet_denoise_shape_preserved():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field(shape=(64, 64))
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == (64, 64)
|
||||
|
||||
|
||||
def test_wavelet_denoise_reduces_noise():
|
||||
"""Denoising noisy data should reduce standard deviation."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
clean = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
|
||||
noisy = clean + rng.normal(0, 0.1, clean.shape)
|
||||
field = make_field(data=noisy)
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
# Denoised should be closer to clean than the noisy input
|
||||
denoised_err = np.std(result.data - clean)
|
||||
noisy_err = np.std(noisy - clean)
|
||||
assert denoised_err < noisy_err
|
||||
|
||||
|
||||
def test_wavelet_denoise_uniform_field_unchanged():
|
||||
"""A flat field (no variation) is returned as-is."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field(data=np.full((32, 32), 5.0))
|
||||
result, = node.process(field, wavelet="db1", method="VisuShrink", sigma=0.0, mode="hard")
|
||||
# The short-circuit path returns the original field object
|
||||
assert result is field
|
||||
|
||||
|
||||
def test_wavelet_denoise_preserves_range():
|
||||
"""Output values should stay within the input data range (approx)."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
data = rng.standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
# The normalisation ensures output is within [data.min(), data.max()]
|
||||
assert result.data.min() >= data.min() - 1e-10
|
||||
assert result.data.max() <= data.max() + 1e-10
|
||||
|
||||
|
||||
def test_wavelet_denoise_all_wavelets():
|
||||
"""All supported wavelets should run without error."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
field = make_field(data=rng.standard_normal((32, 32)))
|
||||
node = WaveletDenoise()
|
||||
for wavelet in ("db1", "db2", "db4", "db8", "sym4", "coif1", "bior1.3"):
|
||||
result, = node.process(field, wavelet=wavelet, method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == field.data.shape
|
||||
|
||||
|
||||
def test_wavelet_denoise_visu_shrink():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
field = make_field(data=rng.standard_normal((32, 32)))
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="VisuShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == field.data.shape
|
||||
|
||||
|
||||
def test_wavelet_denoise_preserves_metadata():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.xreal == field.xreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
Reference in New Issue
Block a user