combine fft filter into a single node, fix tests
This commit is contained in:
63
tests/node_tests/filter_fft.py
Normal file
63
tests/node_tests/filter_fft.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_fft_filter_line():
|
||||
from backend.nodes.filter_fft import FFTFilter
|
||||
node = FFTFilter()
|
||||
|
||||
n = 256
|
||||
t = np.arange(n, dtype=np.float64) / n
|
||||
low = np.sin(2 * np.pi * 3 * t)
|
||||
high = np.sin(2 * np.pi * 80 * t)
|
||||
line = low + high
|
||||
|
||||
filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
||||
assert len(filtered_lp) == n
|
||||
corr_low = np.corrcoef(filtered_lp, low)[0, 1]
|
||||
corr_high = np.corrcoef(filtered_lp, high)[0, 1]
|
||||
assert corr_low > 0.95
|
||||
assert abs(corr_high) < 0.3
|
||||
|
||||
filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
||||
assert abs(np.corrcoef(filtered_hp, low)[0, 1]) < 0.3
|
||||
assert np.corrcoef(filtered_hp, high)[0, 1] > 0.95
|
||||
|
||||
filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4)
|
||||
assert abs(np.corrcoef(filtered_bp, low)[0, 1]) < 0.3
|
||||
assert np.corrcoef(filtered_bp, high)[0, 1] > 0.9
|
||||
|
||||
filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4)
|
||||
assert np.corrcoef(filtered_notch, low)[0, 1] > 0.95
|
||||
assert abs(np.corrcoef(filtered_notch, high)[0, 1]) < 0.3
|
||||
|
||||
|
||||
def test_fft_filter_field():
|
||||
from backend.nodes.filter_fft import FFTFilter
|
||||
from backend.data_types import DataField
|
||||
node = FFTFilter()
|
||||
|
||||
N = 128
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y)
|
||||
high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y)
|
||||
data = low_2d + high_2d
|
||||
field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6)
|
||||
|
||||
result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
||||
assert isinstance(result_lp, DataField)
|
||||
assert result_lp.data.shape == (N, N)
|
||||
assert result_lp.xreal == field.xreal
|
||||
assert result_lp.si_unit_z == field.si_unit_z
|
||||
corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1]
|
||||
corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1]
|
||||
assert corr_low > 0.9
|
||||
assert abs(corr_high) < 0.3
|
||||
|
||||
result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
||||
assert abs(np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]) < 0.3
|
||||
assert np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] > 0.9
|
||||
|
||||
const = make_field(data=np.ones((32, 32)) * 7.0)
|
||||
result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2)
|
||||
assert np.allclose(result_const.data, 7.0, atol=1e-10)
|
||||
@@ -1,33 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_fft_filter_1d():
|
||||
from backend.nodes.filter_fft_1d import FFTFilter1D
|
||||
node = FFTFilter1D()
|
||||
|
||||
n = 256
|
||||
t = np.arange(n, dtype=np.float64) / n
|
||||
low = np.sin(2 * np.pi * 3 * t)
|
||||
high = np.sin(2 * np.pi * 80 * t)
|
||||
line = low + high
|
||||
|
||||
filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
||||
assert len(filtered_lp) == n
|
||||
corr_low = np.corrcoef(filtered_lp, low)[0, 1]
|
||||
corr_high = np.corrcoef(filtered_lp, high)[0, 1]
|
||||
assert corr_low > 0.95
|
||||
assert abs(corr_high) < 0.3
|
||||
|
||||
filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
||||
corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1]
|
||||
corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1]
|
||||
assert abs(corr_low_hp) < 0.3
|
||||
assert corr_high_hp > 0.95
|
||||
|
||||
filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4)
|
||||
assert abs(np.corrcoef(filtered_bp, low)[0, 1]) < 0.3
|
||||
assert np.corrcoef(filtered_bp, high)[0, 1] > 0.9
|
||||
|
||||
filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4)
|
||||
assert np.corrcoef(filtered_notch, low)[0, 1] > 0.95
|
||||
assert abs(np.corrcoef(filtered_notch, high)[0, 1]) < 0.3
|
||||
@@ -1,31 +0,0 @@
|
||||
import numpy as np
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_fft_filter_2d():
|
||||
from backend.nodes.filter_fft_2d import FFTFilter2D
|
||||
node = FFTFilter2D()
|
||||
|
||||
N = 128
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y)
|
||||
high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y)
|
||||
data = low_2d + high_2d
|
||||
field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6)
|
||||
|
||||
result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
||||
assert result_lp.data.shape == (N, N)
|
||||
assert result_lp.xreal == field.xreal
|
||||
assert result_lp.si_unit_z == field.si_unit_z
|
||||
corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1]
|
||||
corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1]
|
||||
assert corr_low > 0.9
|
||||
assert abs(corr_high) < 0.3
|
||||
|
||||
result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
||||
assert abs(np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]) < 0.3
|
||||
assert np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] > 0.9
|
||||
|
||||
const = make_field(data=np.ones((32, 32)) * 7.0)
|
||||
result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2)
|
||||
assert np.allclose(result_const.data, 7.0, atol=1e-10)
|
||||
@@ -36,7 +36,7 @@ def test_threshold_otsu_bimodal():
|
||||
data[70:100, 80:110] = 10.0 # another bright region
|
||||
field = make_field(data)
|
||||
|
||||
mask, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
mask, table = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
bright_pixels = (mask == 255)
|
||||
# Should capture both bright regions
|
||||
assert bright_pixels[40, 40], "Otsu missed bright region 1"
|
||||
@@ -57,7 +57,7 @@ def test_threshold_relative_range():
|
||||
data[10:20, 10:20] = 8.0 # bright patch, range = [2, 8], midpoint = 5
|
||||
field = make_field(data)
|
||||
|
||||
mask, = node.process(field, method="relative", threshold=0.5, direction="above")
|
||||
mask, table = node.process(field, method="relative", threshold=0.5, direction="above")
|
||||
# Only the bright patch (value 8 >= 5) should be masked
|
||||
assert np.all(mask[10:20, 10:20] == 255)
|
||||
assert np.all(mask[0:10, :] == 0)
|
||||
@@ -74,7 +74,7 @@ def test_threshold_empty_mask():
|
||||
data = np.ones((64, 64))
|
||||
field = make_field(data)
|
||||
|
||||
mask, = node.process(field, method="absolute", threshold=999.0, direction="above")
|
||||
mask, table = node.process(field, method="absolute", threshold=999.0, direction="above")
|
||||
assert mask.sum() == 0, "Mask should be completely empty"
|
||||
print(" PASS\n")
|
||||
|
||||
@@ -88,7 +88,7 @@ def test_threshold_full_mask():
|
||||
data = np.ones((64, 64)) * 5.0
|
||||
field = make_field(data)
|
||||
|
||||
mask, = node.process(field, method="absolute", threshold=-1.0, direction="above")
|
||||
mask, table = node.process(field, method="absolute", threshold=-1.0, direction="above")
|
||||
assert np.all(mask == 255), "Mask should be all white"
|
||||
print(" PASS\n")
|
||||
|
||||
@@ -345,7 +345,7 @@ def test_pipeline_synthetic():
|
||||
|
||||
# Step 1: threshold
|
||||
thresh = ThresholdMask()
|
||||
mask, = thresh.process(field, method="absolute", threshold=1.0, direction="above")
|
||||
mask, table = thresh.process(field, method="absolute", threshold=1.0, direction="above")
|
||||
|
||||
# Grains are well above noise, so mask should capture all 5
|
||||
assert mask.max() == 255, "No grains detected"
|
||||
@@ -387,7 +387,7 @@ def test_pipeline_demo_image():
|
||||
|
||||
# Threshold to find grains (they are raised above background)
|
||||
thresh = ThresholdMask()
|
||||
mask, = thresh.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
mask, table = thresh.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
|
||||
# Should detect grains
|
||||
assert mask.max() == 255, "No grains found in demo image"
|
||||
|
||||
3377
tests/test_nodes.py
3377
tests/test_nodes.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user