remaining med value features

This commit is contained in:
2026-03-30 22:31:04 -07:00
parent ea749938bb
commit ced43bec4f
21 changed files with 4257 additions and 9 deletions

View File

@@ -0,0 +1,89 @@
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_cross_correlate_same_field_peak_at_center():
"""Correlating a field with itself in 'same' mode peaks at the centre."""
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(0)
data = rng.standard_normal((32, 32))
field = make_field(data=data)
node = CrossCorrelate()
result, = node.process(field, field, mode="same", normalize=True)
peak_y, peak_x = np.unravel_index(np.argmax(result.data), result.data.shape)
cy, cx = result.data.shape[0] // 2, result.data.shape[1] // 2
# Peak should be within a few pixels of centre
assert abs(peak_y - cy) <= 2
assert abs(peak_x - cx) <= 2
def test_cross_correlate_same_mode_shape_equals_a():
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(1)
a = make_field(data=rng.standard_normal((32, 48)))
b = make_field(data=rng.standard_normal((32, 48)))
node = CrossCorrelate()
result, = node.process(a, b, mode="same", normalize=True)
assert result.data.shape == a.data.shape
def test_cross_correlate_full_mode_shape():
"""Full mode output shape should be Na+Nb-1 × Ma+Mb-1."""
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(2)
a = make_field(data=rng.standard_normal((20, 30)))
b = make_field(data=rng.standard_normal((20, 30)))
node = CrossCorrelate()
result, = node.process(a, b, mode="full", normalize=True)
assert result.data.shape == (20 + 20 - 1, 30 + 30 - 1)
def test_cross_correlate_normalized_peak_is_one():
"""Self-correlation normalised should give peak = 1."""
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(3)
data = rng.standard_normal((32, 32))
field = make_field(data=data)
node = CrossCorrelate()
result, = node.process(field, field, mode="same", normalize=True)
assert result.data.max() == pytest.approx(1.0, abs=1e-6)
def test_cross_correlate_unnormalized_runs():
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(4)
data = rng.standard_normal((16, 16))
field = make_field(data=data)
node = CrossCorrelate()
result, = node.process(field, field, mode="same", normalize=False)
assert result.data.shape == (16, 16)
def test_cross_correlate_valid_mode():
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(5)
a = make_field(data=rng.standard_normal((16, 16)))
b = make_field(data=rng.standard_normal((8, 8)))
node = CrossCorrelate()
result, = node.process(a, b, mode="valid", normalize=True)
# Valid mode output: (16-8+1, 16-8+1) = (9, 9)
assert result.data.shape == (9, 9)
def test_cross_correlate_preserves_metadata_same_mode():
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(6)
field = make_field(data=rng.standard_normal((16, 16)))
node = CrossCorrelate()
result, = node.process(field, field, mode="same", normalize=True)
assert result.xreal == field.xreal
assert result.yreal == field.yreal

View File

@@ -0,0 +1,75 @@
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_entropy_uniform_field_low():
"""A field with a single unique value has zero entropy."""
from backend.nodes.entropy import Entropy
node = Entropy()
field = make_field(data=np.full((32, 32), 3.14))
h, h_norm = node.process(field, mode="height values", n_bins=256)
# All values fall in one bin → p=1 → H = 0
assert h == pytest.approx(0.0, abs=1e-10)
assert h_norm == pytest.approx(0.0, abs=1e-10)
def test_entropy_random_field_positive():
"""A random field should have positive entropy."""
from backend.nodes.entropy import Entropy
rng = np.random.default_rng(0)
field = make_field(data=rng.standard_normal((64, 64)))
node = Entropy()
h, h_norm = node.process(field, mode="height values", n_bins=256)
assert h > 0.0
assert 0.0 < h_norm <= 1.0
def test_entropy_normalised_leq_one():
"""Normalised entropy should never exceed 1."""
from backend.nodes.entropy import Entropy
rng = np.random.default_rng(2)
field = make_field(data=rng.uniform(0, 1, (64, 64)))
node = Entropy()
_, h_norm = node.process(field, mode="height values", n_bins=64)
assert h_norm <= 1.0 + 1e-12
def test_entropy_slope_mode():
"""Slope mode should work and return valid entropy values."""
from backend.nodes.entropy import Entropy
rng = np.random.default_rng(3)
field = make_field(data=rng.standard_normal((32, 32)))
node = Entropy()
h, h_norm = node.process(field, mode="slope magnitude", n_bins=128)
assert h > 0.0
assert 0.0 <= h_norm <= 1.0
def test_entropy_more_uniform_is_higher():
"""Uniformly distributed values have higher entropy than a spiked distribution."""
from backend.nodes.entropy import Entropy
rng = np.random.default_rng(4)
uniform = rng.uniform(0, 1, (64, 64))
spiked = np.zeros((64, 64))
spiked[0, 0] = 1.0
node = Entropy()
h_uniform, _ = node.process(make_field(data=uniform), mode="height values", n_bins=64)
h_spiked, _ = node.process(make_field(data=spiked), mode="height values", n_bins=64)
assert h_uniform > h_spiked
def test_entropy_returns_floats():
from backend.nodes.entropy import Entropy
field = make_field()
node = Entropy()
h, h_norm = node.process(field, mode="height values", n_bins=256)
assert isinstance(h, float)
assert isinstance(h_norm, float)

View File

@@ -0,0 +1,95 @@
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_custom_convolution_identity_kernel():
"""[[1]] with normalize=True (abs_sum=1) should return input unchanged."""
from backend.nodes.filter_custom import CustomConvolution
node = CustomConvolution()
data = np.random.default_rng(0).standard_normal((32, 32))
field = make_field(data=data)
result, = node.process(field, kernel="1", normalize=False, boundary="reflect")
assert np.allclose(result.data, data)
def test_custom_convolution_uniform_kernel_normalized():
"""An all-ones kernel with normalize=True is a box filter (mean filter)."""
from backend.nodes.filter_custom import CustomConvolution
node = CustomConvolution()
data = np.random.default_rng(1).standard_normal((32, 32))
field = make_field(data=data)
# 3x3 all-ones kernel, normalized → each pixel becomes mean of its neighbourhood
kernel = "1 1 1\n1 1 1\n1 1 1"
result, = node.process(field, kernel=kernel, normalize=True, boundary="reflect")
# Output std should be less than input std (smoothing)
assert result.data.std() < data.std()
def test_custom_convolution_sharpen_increases_variation():
"""A sharpening kernel should increase local variation on a smooth field."""
from backend.nodes.filter_custom import CustomConvolution
node = CustomConvolution()
# Smooth ramp field — very low frequency content
data = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
field = make_field(data=data)
sharpen = "0 -1 0\n-1 5 -1\n0 -1 0"
result, = node.process(field, kernel=sharpen, normalize=False, boundary="reflect")
# Sharpening without normalisation keeps the ramp intact plus adds edges
# The std of the sharpened field should differ from input
assert result.data.std() != pytest.approx(data.std(), rel=0.0)
def test_custom_convolution_shape_preserved():
from backend.nodes.filter_custom import CustomConvolution
node = CustomConvolution()
field = make_field(shape=(48, 64))
result, = node.process(field, kernel="0 1 0\n1 1 1\n0 1 0", normalize=True, boundary="reflect")
assert result.data.shape == (48, 64)
def test_custom_convolution_invalid_kernel_fallback():
"""An invalid kernel string should return the input field unchanged."""
from backend.nodes.filter_custom import CustomConvolution
node = CustomConvolution()
data = np.random.default_rng(2).standard_normal((16, 16))
field = make_field(data=data)
result, = node.process(field, kernel="", normalize=True, boundary="reflect")
assert np.allclose(result.data, data)
def test_custom_convolution_ragged_kernel_fallback():
"""A ragged (non-rectangular) kernel should be rejected gracefully."""
from backend.nodes.filter_custom import CustomConvolution
node = CustomConvolution()
data = np.random.default_rng(3).standard_normal((16, 16))
field = make_field(data=data)
result, = node.process(field, kernel="1 2\n1 2 3", normalize=True, boundary="reflect")
assert np.allclose(result.data, data)
def test_custom_convolution_boundary_modes():
"""All boundary modes should produce valid output without error."""
from backend.nodes.filter_custom import CustomConvolution
node = CustomConvolution()
field = make_field()
for mode in ("reflect", "nearest", "wrap"):
result, = node.process(field, kernel="1 1 1\n1 1 1\n1 1 1", normalize=True, boundary=mode)
assert result.data.shape == field.data.shape
def test_custom_convolution_preserves_metadata():
from backend.nodes.filter_custom import CustomConvolution
node = CustomConvolution()
field = make_field()
result, = node.process(field, kernel="1", normalize=False, boundary="reflect")
assert result.xreal == field.xreal
assert result.si_unit_z == field.si_unit_z

View File

@@ -0,0 +1,75 @@
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_kuwahara_shape_preserved():
from backend.nodes.filter_kuwahara import KuwaharaFilter
node = KuwaharaFilter()
field = make_field(shape=(48, 64))
result, = node.process(field, iterations=1)
assert result.data.shape == (48, 64)
def test_kuwahara_flat_field_unchanged():
"""A constant field should pass through the Kuwahara filter unchanged."""
from backend.nodes.filter_kuwahara import KuwaharaFilter
node = KuwaharaFilter()
field = make_field(data=np.full((32, 32), 7.5))
result, = node.process(field, iterations=1)
assert np.allclose(result.data, 7.5)
def test_kuwahara_reduces_noise():
"""Applying the filter to a noisy field should reduce standard deviation."""
from backend.nodes.filter_kuwahara import KuwaharaFilter
rng = np.random.default_rng(0)
noisy = rng.standard_normal((64, 64))
node = KuwaharaFilter()
field = make_field(data=noisy)
result, = node.process(field, iterations=1)
assert result.data.std() < noisy.std()
def test_kuwahara_preserves_step_edge():
"""The Kuwahara filter should preserve a sharp step edge better than a blur."""
from backend.nodes.filter_kuwahara import KuwaharaFilter
# Left half = 0, right half = 1
data = np.zeros((32, 64))
data[:, 32:] = 1.0
node = KuwaharaFilter()
field = make_field(data=data)
result, = node.process(field, iterations=1)
# The edge column should have a large jump (edge preserved)
col_before = result.data[:, 30].mean()
col_after = result.data[:, 34].mean()
assert col_after - col_before > 0.5
def test_kuwahara_multiple_iterations():
"""Running multiple iterations should further reduce noise."""
from backend.nodes.filter_kuwahara import KuwaharaFilter
rng = np.random.default_rng(1)
noisy = rng.standard_normal((32, 32))
node = KuwaharaFilter()
field = make_field(data=noisy)
result1, = node.process(field, iterations=1)
result3, = node.process(field, iterations=3)
assert result3.data.std() <= result1.data.std()
def test_kuwahara_preserves_metadata():
from backend.nodes.filter_kuwahara import KuwaharaFilter
node = KuwaharaFilter()
field = make_field()
result, = node.process(field, iterations=1)
assert result.xreal == field.xreal
assert result.yreal == field.yreal
assert result.si_unit_z == field.si_unit_z

View File

@@ -0,0 +1,67 @@
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_local_contrast_shape_preserved():
from backend.nodes.local_contrast import LocalContrast
node = LocalContrast()
field = make_field(shape=(48, 64))
result, = node.process(field, kernel_size=10, weight=0.5)
assert result.data.shape == (48, 64)
def test_local_contrast_weight_zero_unchanged():
"""weight=0 blends 100% original → result equals input."""
from backend.nodes.local_contrast import LocalContrast
node = LocalContrast()
data = np.random.default_rng(0).standard_normal((32, 32))
field = make_field(data=data)
result, = node.process(field, kernel_size=5, weight=0.0)
assert np.allclose(result.data, data)
def test_local_contrast_uniform_field_unchanged():
"""A flat field has nothing to enhance; it should be returned as-is."""
from backend.nodes.local_contrast import LocalContrast
node = LocalContrast()
field = make_field(data=np.full((32, 32), 2.0))
result, = node.process(field, kernel_size=5, weight=1.0)
assert np.allclose(result.data, 2.0)
def test_local_contrast_increases_dynamic_range():
"""Weight=1 full enhancement should not compress global range beyond input."""
from backend.nodes.local_contrast import LocalContrast
rng = np.random.default_rng(1)
data = rng.standard_normal((64, 64))
field = make_field(data=data)
node = LocalContrast()
result, = node.process(field, kernel_size=8, weight=1.0)
# Global min/max should be preserved (by construction of the algorithm)
assert np.isclose(result.data.min(), data.min(), atol=1e-6)
assert np.isclose(result.data.max(), data.max(), atol=1e-6)
def test_local_contrast_preserves_metadata():
from backend.nodes.local_contrast import LocalContrast
node = LocalContrast()
field = make_field()
result, = node.process(field, kernel_size=10, weight=0.5)
assert result.xreal == field.xreal
assert result.si_unit_z == field.si_unit_z
def test_local_contrast_weight_clipped():
"""Values outside [0,1] should be clipped without error."""
from backend.nodes.local_contrast import LocalContrast
node = LocalContrast()
field = make_field()
result, = node.process(field, kernel_size=5, weight=2.0)
assert result.data.shape == field.data.shape

View File

@@ -156,10 +156,16 @@ def test_save_generic():
except ValueError:
pass
# LINE as plot image (PNG / TIFF)
node.save(filename="line_plot_png", directory_path=tmpdir, format="PNG", value=line)
assert Path(tmpdir, "line_plot_png.png").exists()
node.save(filename="line_plot_tiff", directory_path=tmpdir, format="TIFF", value=line)
assert Path(tmpdir, "line_plot_tiff.tiff").exists()
# Unsupported LINE format
try:
node.save(filename="line_bad", directory_path=tmpdir, format="TIFF", value=line)
assert False, "Expected ValueError for LINE + TIFF"
node.save(filename="line_bad", directory_path=tmpdir, format="OBJ", value=line)
assert False, "Expected ValueError for LINE + OBJ"
except ValueError:
pass

View File

@@ -0,0 +1,110 @@
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def _make_mask(shape, defect_positions):
"""Create a uint8 mask array with 255 at given (row, col) positions."""
mask = np.zeros(shape, dtype=np.uint8)
for r, c in defect_positions:
mask[r, c] = 255
return mask
def test_spot_removal_no_mask_returns_field_unchanged():
"""Without a mask input the field should be returned as-is."""
from backend.nodes.spot_removal import SpotRemoval
node = SpotRemoval()
field = make_field()
result, = node.process(field, method="laplace", max_iter=50)
# Should be the identical object (short-circuit path)
assert result is field
def test_spot_removal_zero_fill():
"""method='zero' sets defect pixels to exactly 0."""
from backend.nodes.spot_removal import SpotRemoval
node = SpotRemoval()
data = np.ones((16, 16)) * 5.0
field = make_field(data=data)
mask = _make_mask((16, 16), [(4, 4), (8, 8)])
result, = node.process(field, method="zero", max_iter=1, mask=mask)
assert result.data[4, 4] == pytest.approx(0.0)
assert result.data[8, 8] == pytest.approx(0.0)
# Non-defect pixels should stay 5.0
assert result.data[0, 0] == pytest.approx(5.0)
def test_spot_removal_mean_fill_surrounded_by_constant():
"""On a constant field, mean fill should give back the constant."""
from backend.nodes.spot_removal import SpotRemoval
node = SpotRemoval()
data = np.full((16, 16), 3.0)
field = make_field(data=data)
mask = _make_mask((16, 16), [(7, 7)])
result, = node.process(field, method="mean", max_iter=1, mask=mask)
assert result.data[7, 7] == pytest.approx(3.0, abs=1e-10)
def test_spot_removal_laplace_fill_surrounded_by_constant():
"""Laplace fill on a constant field should recover the constant at the defect."""
from backend.nodes.spot_removal import SpotRemoval
node = SpotRemoval()
data = np.full((16, 16), 2.5)
field = make_field(data=data)
mask = _make_mask((16, 16), [(8, 8)])
result, = node.process(field, method="laplace", max_iter=200, mask=mask)
assert result.data[8, 8] == pytest.approx(2.5, abs=1e-3)
def test_spot_removal_laplace_smooth_interpolation():
"""Laplace should interpolate between boundary values smoothly."""
from backend.nodes.spot_removal import SpotRemoval
node = SpotRemoval()
# Left half = 0, right half = 10; single defect in the middle
data = np.zeros((16, 16))
data[:, 8:] = 10.0
field = make_field(data=data)
# Defect at the boundary column
mask = _make_mask((16, 16), [(8, 7)])
result, = node.process(field, method="laplace", max_iter=500, mask=mask)
# The filled value should be between 0 and 10
filled = result.data[8, 7]
assert 0.0 <= filled <= 10.0
def test_spot_removal_shape_preserved():
from backend.nodes.spot_removal import SpotRemoval
node = SpotRemoval()
field = make_field(shape=(48, 64))
mask = _make_mask((48, 64), [(10, 20)])
result, = node.process(field, method="mean", max_iter=10, mask=mask)
assert result.data.shape == (48, 64)
def test_spot_removal_mask_shape_mismatch_raises():
from backend.nodes.spot_removal import SpotRemoval
node = SpotRemoval()
field = make_field(shape=(16, 16))
bad_mask = np.ones((32, 32), dtype=np.uint8)
with pytest.raises(ValueError, match="Mask shape"):
node.process(field, method="zero", max_iter=1, mask=bad_mask)
def test_spot_removal_empty_mask_unchanged():
"""An all-zero mask means no defects — field returned unchanged."""
from backend.nodes.spot_removal import SpotRemoval
node = SpotRemoval()
data = np.random.default_rng(0).standard_normal((16, 16))
field = make_field(data=data)
mask = np.zeros((16, 16), dtype=np.uint8)
result, = node.process(field, method="laplace", max_iter=50, mask=mask)
assert result is field

View File

@@ -0,0 +1,93 @@
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_template_match_exact_match_score_one():
"""When template equals the image, the peak score should be 1."""
from backend.nodes.template_match import TemplateMatch
rng = np.random.default_rng(0)
data = rng.standard_normal((32, 32))
image_field = make_field(data=data)
# Template is the full image → perfect correlation everywhere → peak = 1
template_field = make_field(data=data)
node = TemplateMatch()
score_field, detections = node.process(image_field, template_field, threshold=0.9)
assert score_field.data.max() == pytest.approx(1.0, abs=1e-6)
def test_template_match_output_shape_matches_image():
from backend.nodes.template_match import TemplateMatch
rng = np.random.default_rng(1)
image_field = make_field(data=rng.standard_normal((32, 32)))
template_field = make_field(data=rng.standard_normal((8, 8)))
node = TemplateMatch()
score_field, detections = node.process(image_field, template_field, threshold=0.5)
assert score_field.data.shape == image_field.data.shape
assert detections.shape == image_field.data.shape
def test_template_match_score_in_range():
"""Score values should be clipped to [0, 1]."""
from backend.nodes.template_match import TemplateMatch
rng = np.random.default_rng(2)
image_field = make_field(data=rng.standard_normal((32, 32)))
template_field = make_field(data=rng.standard_normal((6, 6)))
node = TemplateMatch()
score_field, _ = node.process(image_field, template_field, threshold=0.5)
assert score_field.data.min() >= 0.0 - 1e-10
assert score_field.data.max() <= 1.0 + 1e-10
def test_template_match_detections_binary():
"""Detection mask values should be 0 or 255 only."""
from backend.nodes.template_match import TemplateMatch
rng = np.random.default_rng(3)
image_field = make_field(data=rng.standard_normal((32, 32)))
template_field = make_field(data=rng.standard_normal((8, 8)))
node = TemplateMatch()
_, detections = node.process(image_field, template_field, threshold=0.5)
unique_values = set(np.unique(detections))
assert unique_values <= {0, 255}
def test_template_match_threshold_zero_all_detected():
"""threshold=0 should mark all pixels as detections (score always >= 0)."""
from backend.nodes.template_match import TemplateMatch
rng = np.random.default_rng(4)
image_field = make_field(data=rng.standard_normal((16, 16)))
template_field = make_field(data=rng.standard_normal((4, 4)))
node = TemplateMatch()
_, detections = node.process(image_field, template_field, threshold=0.0)
assert np.all(detections == 255)
def test_template_match_threshold_one_sparse_detections():
"""threshold=1.0 should detect very few (or no) positions."""
from backend.nodes.template_match import TemplateMatch
rng = np.random.default_rng(5)
image_field = make_field(data=rng.standard_normal((32, 32)))
template_field = make_field(data=rng.standard_normal((8, 8)))
node = TemplateMatch()
_, detections = node.process(image_field, template_field, threshold=1.0)
# At threshold=1.0, only perfect matches count (rare for random data)
detected_count = int((detections == 255).sum())
assert detected_count < 10 # very few or none
def test_template_match_preserves_metadata():
from backend.nodes.template_match import TemplateMatch
rng = np.random.default_rng(6)
image_field = make_field(data=rng.standard_normal((32, 32)))
template_field = make_field(data=rng.standard_normal((8, 8)))
node = TemplateMatch()
score_field, _ = node.process(image_field, template_field, threshold=0.5)
assert score_field.xreal == image_field.xreal
assert score_field.yreal == image_field.yreal

View File

@@ -0,0 +1,85 @@
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_wavelet_denoise_shape_preserved():
from backend.nodes.wavelet_denoise import WaveletDenoise
node = WaveletDenoise()
field = make_field(shape=(64, 64))
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
assert result.data.shape == (64, 64)
def test_wavelet_denoise_reduces_noise():
"""Denoising noisy data should reduce standard deviation."""
from backend.nodes.wavelet_denoise import WaveletDenoise
rng = np.random.default_rng(0)
clean = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
noisy = clean + rng.normal(0, 0.1, clean.shape)
field = make_field(data=noisy)
node = WaveletDenoise()
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
# Denoised should be closer to clean than the noisy input
denoised_err = np.std(result.data - clean)
noisy_err = np.std(noisy - clean)
assert denoised_err < noisy_err
def test_wavelet_denoise_uniform_field_unchanged():
"""A flat field (no variation) is returned as-is."""
from backend.nodes.wavelet_denoise import WaveletDenoise
node = WaveletDenoise()
field = make_field(data=np.full((32, 32), 5.0))
result, = node.process(field, wavelet="db1", method="VisuShrink", sigma=0.0, mode="hard")
# The short-circuit path returns the original field object
assert result is field
def test_wavelet_denoise_preserves_range():
"""Output values should stay within the input data range (approx)."""
from backend.nodes.wavelet_denoise import WaveletDenoise
rng = np.random.default_rng(1)
data = rng.standard_normal((32, 32))
field = make_field(data=data)
node = WaveletDenoise()
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
# The normalisation ensures output is within [data.min(), data.max()]
assert result.data.min() >= data.min() - 1e-10
assert result.data.max() <= data.max() + 1e-10
def test_wavelet_denoise_all_wavelets():
"""All supported wavelets should run without error."""
from backend.nodes.wavelet_denoise import WaveletDenoise
rng = np.random.default_rng(2)
field = make_field(data=rng.standard_normal((32, 32)))
node = WaveletDenoise()
for wavelet in ("db1", "db2", "db4", "db8", "sym4", "coif1", "bior1.3"):
result, = node.process(field, wavelet=wavelet, method="BayesShrink", sigma=0.0, mode="soft")
assert result.data.shape == field.data.shape
def test_wavelet_denoise_visu_shrink():
from backend.nodes.wavelet_denoise import WaveletDenoise
rng = np.random.default_rng(3)
field = make_field(data=rng.standard_normal((32, 32)))
node = WaveletDenoise()
result, = node.process(field, wavelet="db4", method="VisuShrink", sigma=0.0, mode="soft")
assert result.data.shape == field.data.shape
def test_wavelet_denoise_preserves_metadata():
from backend.nodes.wavelet_denoise import WaveletDenoise
node = WaveletDenoise()
field = make_field()
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
assert result.xreal == field.xreal
assert result.si_unit_z == field.si_unit_z