remaining med value features
This commit is contained in:
85
tests/node_tests/wavelet_denoise.py
Normal file
85
tests/node_tests/wavelet_denoise.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_wavelet_denoise_shape_preserved():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field(shape=(64, 64))
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == (64, 64)
|
||||
|
||||
|
||||
def test_wavelet_denoise_reduces_noise():
|
||||
"""Denoising noisy data should reduce standard deviation."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
clean = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
|
||||
noisy = clean + rng.normal(0, 0.1, clean.shape)
|
||||
field = make_field(data=noisy)
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
# Denoised should be closer to clean than the noisy input
|
||||
denoised_err = np.std(result.data - clean)
|
||||
noisy_err = np.std(noisy - clean)
|
||||
assert denoised_err < noisy_err
|
||||
|
||||
|
||||
def test_wavelet_denoise_uniform_field_unchanged():
|
||||
"""A flat field (no variation) is returned as-is."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field(data=np.full((32, 32), 5.0))
|
||||
result, = node.process(field, wavelet="db1", method="VisuShrink", sigma=0.0, mode="hard")
|
||||
# The short-circuit path returns the original field object
|
||||
assert result is field
|
||||
|
||||
|
||||
def test_wavelet_denoise_preserves_range():
|
||||
"""Output values should stay within the input data range (approx)."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
data = rng.standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
# The normalisation ensures output is within [data.min(), data.max()]
|
||||
assert result.data.min() >= data.min() - 1e-10
|
||||
assert result.data.max() <= data.max() + 1e-10
|
||||
|
||||
|
||||
def test_wavelet_denoise_all_wavelets():
|
||||
"""All supported wavelets should run without error."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
field = make_field(data=rng.standard_normal((32, 32)))
|
||||
node = WaveletDenoise()
|
||||
for wavelet in ("db1", "db2", "db4", "db8", "sym4", "coif1", "bior1.3"):
|
||||
result, = node.process(field, wavelet=wavelet, method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == field.data.shape
|
||||
|
||||
|
||||
def test_wavelet_denoise_visu_shrink():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
field = make_field(data=rng.standard_normal((32, 32)))
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="VisuShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == field.data.shape
|
||||
|
||||
|
||||
def test_wavelet_denoise_preserves_metadata():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.xreal == field.xreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
Reference in New Issue
Block a user