48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
import numpy as np
|
|
from backend.node_registry import get_node_info
|
|
from tests.node_tests._shared import make_field
|
|
|
|
|
|
def test_scar_removal():
|
|
from backend.nodes.scar_removal import ScarRemoval
|
|
|
|
node = ScarRemoval()
|
|
info = get_node_info("ScarRemoval")
|
|
assert info["category"] == "Level & Correct"
|
|
|
|
rows = 96
|
|
cols = 128
|
|
yy, xx = np.mgrid[0:rows, 0:cols]
|
|
base = (
|
|
0.005 * xx + 0.01 * yy
|
|
+ 0.12 * np.sin(2.0 * np.pi * xx / cols)
|
|
+ 0.07 * np.cos(2.0 * np.pi * yy / rows)
|
|
)
|
|
scarred = base.copy()
|
|
scarred[24, 20:92] += 1.8
|
|
scarred[25, 20:92] += 1.6
|
|
scarred[60, 12:116] -= 1.7
|
|
|
|
field = make_field(data=scarred)
|
|
corrected, scar_mask = node.process(
|
|
field, scar_type="both", threshold_high=0.6, threshold_low=0.2, min_length=12, max_width=4,
|
|
)
|
|
|
|
mask_bool = scar_mask > 127
|
|
assert scar_mask.dtype == np.uint8
|
|
assert scar_mask.shape == field.data.shape
|
|
assert np.count_nonzero(mask_bool) > 0
|
|
assert np.count_nonzero(mask_bool[24:26, 20:92]) > 0
|
|
assert np.count_nonzero(mask_bool[60:61, 12:116]) > 0
|
|
assert np.allclose(corrected.data[~mask_bool], field.data[~mask_bool])
|
|
|
|
before_rmse = np.sqrt(np.mean((field.data[mask_bool] - base[mask_bool]) ** 2))
|
|
after_rmse = np.sqrt(np.mean((corrected.data[mask_bool] - base[mask_bool]) ** 2))
|
|
assert after_rmse < before_rmse * 0.35
|
|
|
|
clean_corrected, clean_mask = node.process(
|
|
make_field(data=base), scar_type="both", threshold_high=0.6, threshold_low=0.2, min_length=12, max_width=4,
|
|
)
|
|
assert np.count_nonzero(clean_mask) == 0
|
|
assert np.allclose(clean_corrected.data, base)
|