add line correction and scar removal nodes

This commit is contained in:
2026-03-27 21:57:43 -07:00
parent fa9aeaa4a9
commit e10a30c08f
6 changed files with 742 additions and 2 deletions

View File

@@ -571,6 +571,130 @@ def test_fix_zero():
print(" PASS\n")
def test_line_correction():
print("=== Test: LineCorrection ===")
from backend.node_registry import get_node_info
from backend.nodes.line_correction import LineCorrection
node = LineCorrection()
assert get_node_info("LineCorrection")["category"] == "Flatten"
rows = 96
cols = 128
y = np.linspace(0.0, 1.0, rows, dtype=np.float64)
x = np.linspace(-1.0, 1.0, cols, dtype=np.float64)
signal = (
0.15 * np.sin(8.0 * np.pi * x)[None, :]
+ 0.05 * np.cos(4.0 * np.pi * y)[:, None]
)
row_offsets = 1.5 * np.sin(3.0 * np.pi * y) + 0.25 * np.cos(7.0 * np.pi * y)
field = make_field(
data=signal + row_offsets[:, None],
xreal=2.5e-6,
yreal=1.5e-6,
)
corrected, background, shifts = node.process(
field,
method="median",
direction="horizontal",
masking="ignore",
trim_fraction=0.05,
polynomial_degree=1,
)
expected_shifts = row_offsets - row_offsets.mean()
assert corrected.data.shape == field.data.shape
assert background.data.shape == field.data.shape
assert np.allclose(corrected.data + background.data, field.data)
assert isinstance(shifts, LineData)
assert shifts.x_unit == field.si_unit_xy
assert shifts.y_unit == field.si_unit_z
assert np.isclose(shifts.x_axis[0], 0.0)
assert np.isclose(shifts.x_axis[-1], field.yreal)
assert np.corrcoef(shifts.data, expected_shifts)[0, 1] > 0.999
assert corrected.data.mean(axis=1).std() < field.data.mean(axis=1).std() * 0.03
poly_background = (
row_offsets[:, None]
+ (0.35 * y - 0.15)[:, None] * x[None, :]
+ (0.10 + 0.05 * y)[:, None] * (x[None, :] ** 2)
)
poly_signal = 0.08 * np.sin(10.0 * np.pi * x)[None, :] * (1.0 + 0.15 * np.cos(2.0 * np.pi * y)[:, None])
poly_field = make_field(data=poly_signal + poly_background)
leveled, poly_bg, poly_shifts = node.process(
poly_field,
method="polynomial",
direction="horizontal",
masking="ignore",
trim_fraction=0.05,
polynomial_degree=2,
)
assert np.allclose(leveled.data + poly_bg.data, poly_field.data)
assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995
assert len(poly_shifts) == rows
print(" PASS\n")
def test_scar_removal():
print("=== Test: ScarRemoval ===")
from backend.node_registry import get_node_info
from backend.nodes.scar_removal import ScarRemoval
node = ScarRemoval()
assert get_node_info("ScarRemoval")["category"] == "Filter"
rows = 96
cols = 128
yy, xx = np.mgrid[0:rows, 0:cols]
base = (
0.005 * xx
+ 0.01 * yy
+ 0.12 * np.sin(2.0 * np.pi * xx / cols)
+ 0.07 * np.cos(2.0 * np.pi * yy / rows)
)
scarred = base.copy()
scarred[24, 20:92] += 1.8
scarred[25, 20:92] += 1.6
scarred[60, 12:116] -= 1.7
field = make_field(data=scarred)
corrected, scar_mask = node.process(
field,
scar_type="both",
threshold_high=0.6,
threshold_low=0.2,
min_length=12,
max_width=4,
)
mask_bool = scar_mask > 127
assert scar_mask.dtype == np.uint8
assert scar_mask.shape == field.data.shape
assert np.count_nonzero(mask_bool) > 0
assert np.count_nonzero(mask_bool[24:26, 20:92]) > 0
assert np.count_nonzero(mask_bool[60:61, 12:116]) > 0
assert np.allclose(corrected.data[~mask_bool], field.data[~mask_bool])
before_rmse = np.sqrt(np.mean((field.data[mask_bool] - base[mask_bool]) ** 2))
after_rmse = np.sqrt(np.mean((corrected.data[mask_bool] - base[mask_bool]) ** 2))
assert after_rmse < before_rmse * 0.35
clean_corrected, clean_mask = node.process(
make_field(data=base),
scar_type="both",
threshold_high=0.6,
threshold_low=0.2,
min_length=12,
max_width=4,
)
assert np.count_nonzero(clean_mask) == 0
assert np.allclose(clean_corrected.data, base)
print(" PASS\n")
# =========================================================================
# Analysis (non-FFT)
# =========================================================================
@@ -2522,6 +2646,8 @@ if __name__ == "__main__":
test_plane_level()
test_poly_level()
test_fix_zero()
test_line_correction()
test_scar_removal()
# Analysis
test_statistics()