add line correction and scar removal nodes
This commit is contained in:
@@ -571,6 +571,130 @@ def test_fix_zero():
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_line_correction():
|
||||
print("=== Test: LineCorrection ===")
|
||||
from backend.node_registry import get_node_info
|
||||
from backend.nodes.line_correction import LineCorrection
|
||||
|
||||
node = LineCorrection()
|
||||
assert get_node_info("LineCorrection")["category"] == "Flatten"
|
||||
|
||||
rows = 96
|
||||
cols = 128
|
||||
y = np.linspace(0.0, 1.0, rows, dtype=np.float64)
|
||||
x = np.linspace(-1.0, 1.0, cols, dtype=np.float64)
|
||||
signal = (
|
||||
0.15 * np.sin(8.0 * np.pi * x)[None, :]
|
||||
+ 0.05 * np.cos(4.0 * np.pi * y)[:, None]
|
||||
)
|
||||
row_offsets = 1.5 * np.sin(3.0 * np.pi * y) + 0.25 * np.cos(7.0 * np.pi * y)
|
||||
field = make_field(
|
||||
data=signal + row_offsets[:, None],
|
||||
xreal=2.5e-6,
|
||||
yreal=1.5e-6,
|
||||
)
|
||||
|
||||
corrected, background, shifts = node.process(
|
||||
field,
|
||||
method="median",
|
||||
direction="horizontal",
|
||||
masking="ignore",
|
||||
trim_fraction=0.05,
|
||||
polynomial_degree=1,
|
||||
)
|
||||
expected_shifts = row_offsets - row_offsets.mean()
|
||||
assert corrected.data.shape == field.data.shape
|
||||
assert background.data.shape == field.data.shape
|
||||
assert np.allclose(corrected.data + background.data, field.data)
|
||||
assert isinstance(shifts, LineData)
|
||||
assert shifts.x_unit == field.si_unit_xy
|
||||
assert shifts.y_unit == field.si_unit_z
|
||||
assert np.isclose(shifts.x_axis[0], 0.0)
|
||||
assert np.isclose(shifts.x_axis[-1], field.yreal)
|
||||
assert np.corrcoef(shifts.data, expected_shifts)[0, 1] > 0.999
|
||||
assert corrected.data.mean(axis=1).std() < field.data.mean(axis=1).std() * 0.03
|
||||
|
||||
poly_background = (
|
||||
row_offsets[:, None]
|
||||
+ (0.35 * y - 0.15)[:, None] * x[None, :]
|
||||
+ (0.10 + 0.05 * y)[:, None] * (x[None, :] ** 2)
|
||||
)
|
||||
poly_signal = 0.08 * np.sin(10.0 * np.pi * x)[None, :] * (1.0 + 0.15 * np.cos(2.0 * np.pi * y)[:, None])
|
||||
poly_field = make_field(data=poly_signal + poly_background)
|
||||
|
||||
leveled, poly_bg, poly_shifts = node.process(
|
||||
poly_field,
|
||||
method="polynomial",
|
||||
direction="horizontal",
|
||||
masking="ignore",
|
||||
trim_fraction=0.05,
|
||||
polynomial_degree=2,
|
||||
)
|
||||
assert np.allclose(leveled.data + poly_bg.data, poly_field.data)
|
||||
assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995
|
||||
assert len(poly_shifts) == rows
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_scar_removal():
|
||||
print("=== Test: ScarRemoval ===")
|
||||
from backend.node_registry import get_node_info
|
||||
from backend.nodes.scar_removal import ScarRemoval
|
||||
|
||||
node = ScarRemoval()
|
||||
assert get_node_info("ScarRemoval")["category"] == "Filter"
|
||||
|
||||
rows = 96
|
||||
cols = 128
|
||||
yy, xx = np.mgrid[0:rows, 0:cols]
|
||||
base = (
|
||||
0.005 * xx
|
||||
+ 0.01 * yy
|
||||
+ 0.12 * np.sin(2.0 * np.pi * xx / cols)
|
||||
+ 0.07 * np.cos(2.0 * np.pi * yy / rows)
|
||||
)
|
||||
scarred = base.copy()
|
||||
scarred[24, 20:92] += 1.8
|
||||
scarred[25, 20:92] += 1.6
|
||||
scarred[60, 12:116] -= 1.7
|
||||
|
||||
field = make_field(data=scarred)
|
||||
corrected, scar_mask = node.process(
|
||||
field,
|
||||
scar_type="both",
|
||||
threshold_high=0.6,
|
||||
threshold_low=0.2,
|
||||
min_length=12,
|
||||
max_width=4,
|
||||
)
|
||||
|
||||
mask_bool = scar_mask > 127
|
||||
assert scar_mask.dtype == np.uint8
|
||||
assert scar_mask.shape == field.data.shape
|
||||
assert np.count_nonzero(mask_bool) > 0
|
||||
assert np.count_nonzero(mask_bool[24:26, 20:92]) > 0
|
||||
assert np.count_nonzero(mask_bool[60:61, 12:116]) > 0
|
||||
assert np.allclose(corrected.data[~mask_bool], field.data[~mask_bool])
|
||||
|
||||
before_rmse = np.sqrt(np.mean((field.data[mask_bool] - base[mask_bool]) ** 2))
|
||||
after_rmse = np.sqrt(np.mean((corrected.data[mask_bool] - base[mask_bool]) ** 2))
|
||||
assert after_rmse < before_rmse * 0.35
|
||||
|
||||
clean_corrected, clean_mask = node.process(
|
||||
make_field(data=base),
|
||||
scar_type="both",
|
||||
threshold_high=0.6,
|
||||
threshold_low=0.2,
|
||||
min_length=12,
|
||||
max_width=4,
|
||||
)
|
||||
assert np.count_nonzero(clean_mask) == 0
|
||||
assert np.allclose(clean_corrected.data, base)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis (non-FFT)
|
||||
# =========================================================================
|
||||
@@ -2522,6 +2646,8 @@ if __name__ == "__main__":
|
||||
test_plane_level()
|
||||
test_poly_level()
|
||||
test_fix_zero()
|
||||
test_line_correction()
|
||||
test_scar_removal()
|
||||
|
||||
# Analysis
|
||||
test_statistics()
|
||||
|
||||
Reference in New Issue
Block a user