add line correction and scar removal nodes

This commit is contained in:
2026-03-27 21:57:43 -07:00
parent fa9aeaa4a9
commit e10a30c08f
6 changed files with 742 additions and 2 deletions

View File

@@ -8,8 +8,8 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
| # | Feature | Gwyddion Source | Description |
|---|---------|---------------|-------------|
| 1 | Line Correction | linecorrect.c, linematch.c | Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts. |
| 2 | Scar Removal | scars.c | Detect and interpolate scan-line defects (horizontal streaks). |
| ~~1~~ | ~~Line Correction~~ | ~~linecorrect.c, linematch.c~~ | ~~Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts.~~ **DONE** |
| ~~2~~ | ~~Scar Removal~~ | ~~scars.c~~ | ~~Detect and interpolate scan-line defects (horizontal streaks).~~ **DONE** |
| 3 | Facet Leveling | facet-level.c | Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces. |
| ~~4~~ | ~~Morphological Mask Ops~~ | ~~mask_morph.c~~ | ~~Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks.~~ **DONE** |
| ~~5~~ | ~~1D FFT Filter~~ | ~~fft_filter_1d.c~~ | ~~Bandpass/lowpass/highpass filtering of LINE profiles.~~ **DONE** |
@@ -73,11 +73,13 @@ For reference, these Gwyddion equivalents are already covered:
| Plane Level | level | level.c |
| Polynomial Level | level | polylevel.c |
| Fix Zero | level | level.c (fix_zero) |
| Line Correction | level | linecorrect.c, linematch.c |
| Gaussian Filter | filters | filters.c (gaussian) |
| Median Filter | filters | filters.c (median) |
| Edge Detect | filters | edge.c (sobel, prewitt, laplacian, LoG) |
| 1D FFT Filter | filters | fft_filter_1d.c (lowpass, highpass, bandpass, notch) |
| 2D FFT Filter | filters | fft_filter_2d.c (lowpass, highpass, bandpass, notch) |
| Scar Removal | filters | scars.c |
| Statistics | analysis | stats.c |
| Height Histogram | analysis | linestats.c (dh) |
| 2D FFT | analysis | fft.c |

View File

@@ -47,6 +47,7 @@ MENU_LAYOUT: dict[str, list[str]] = {
"EdgeDetect",
"FFTFilter1D",
"FFTFilter2D",
"ScarRemoval",
],
"Frequency": [
"FFT2D",
@@ -56,6 +57,7 @@ MENU_LAYOUT: dict[str, list[str]] = {
"PlaneLevelField",
"PolyLevelField",
"FixZero",
"LineCorrection",
],
"Measure": [
"CrossSection",

View File

@@ -25,12 +25,15 @@ from backend.nodes import (
plane_level_field,
poly_level_field,
fix_zero,
line_correction,
# Mask
draw_mask,
threshold_mask,
mask_morphology,
mask_invert,
mask_combine,
# Correction
scar_removal,
# Display
color_map,
font_node,

View File

@@ -0,0 +1,388 @@
from __future__ import annotations
import numpy as np
from backend.data_types import DataField, LineData
from backend.node_registry import register_node
def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None:
if mask is None:
return None
mask_array = np.asarray(mask)
if mask_array.shape[:2] != shape:
raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.")
return mask_array > 127
def _trimmed_mean_or_median(values: np.ndarray, trim_fraction: float) -> float:
values = np.asarray(values, dtype=np.float64)
if values.size == 0:
return 0.0
sorted_values = np.sort(values, kind="mergesort")
count = sorted_values.size
nlowest = int(np.rint(trim_fraction * count))
nhighest = int(np.rint(trim_fraction * count))
if nlowest + nhighest + 1 >= count:
return float(np.median(sorted_values))
trimmed = sorted_values[nlowest:count - nhighest]
return float(trimmed.mean()) if trimmed.size else float(np.median(sorted_values))
def _masked_values(data: np.ndarray, mask: np.ndarray | None, masking: str) -> np.ndarray:
if mask is None or masking == "ignore":
return data
if masking == "include":
return data[mask]
if masking == "exclude":
return data[~mask]
raise ValueError(f"Unknown masking mode: {masking}")
def _global_masked_median(data: np.ndarray, mask: np.ndarray | None, masking: str) -> float:
selected = _masked_values(data, mask, masking)
if selected.size == 0:
selected = np.asarray(data, dtype=np.float64).ravel()
return float(np.median(selected))
def _find_row_shifts_trimmed_mean(
data: np.ndarray,
mask: np.ndarray | None,
masking: str,
trim_fraction: float,
mincount: int = 0,
) -> np.ndarray:
yres, xres = data.shape
if yres == 0:
return np.zeros(0, dtype=np.float64)
if mincount <= 0:
mincount = int(np.rint(np.log(max(xres, 1)) + 1.0))
total_median = _global_masked_median(data, mask, masking)
shifts = np.empty(yres, dtype=np.float64)
for i in range(yres):
row = data[i]
row_mask = None if mask is None else mask[i]
if row_mask is None or masking == "ignore":
shifts[i] = _trimmed_mean_or_median(row, trim_fraction)
continue
values = _masked_values(row, row_mask, masking)
if values.size >= mincount:
shifts[i] = _trimmed_mean_or_median(values, trim_fraction)
else:
shifts[i] = total_median
shifts -= shifts.mean()
return shifts
def _slope_level_row_shifts(shifts: np.ndarray) -> np.ndarray:
shifts = np.asarray(shifts, dtype=np.float64).copy()
if shifts.size <= 1:
shifts -= shifts.mean() if shifts.size else 0.0
return shifts
x = np.arange(shifts.size, dtype=np.float64)
A = np.column_stack((np.ones_like(x), x))
coeffs, _, _, _ = np.linalg.lstsq(A, shifts, rcond=None)
intercept, slope = coeffs
shifts -= intercept + slope * x
return shifts
def _find_row_shifts_trimmed_diff(
data: np.ndarray,
mask: np.ndarray | None,
masking: str,
trim_fraction: float,
mincount: int = 0,
) -> np.ndarray:
yres, xres = data.shape
shifts = np.zeros(yres, dtype=np.float64)
if yres <= 1:
return shifts
if mincount <= 0:
mincount = int(np.rint(np.log(max(xres, 1)) + 1.0))
for i in range(yres - 1):
upper = data[i]
lower = data[i + 1]
if mask is None or masking == "ignore":
diffs = lower - upper
else:
upper_mask = mask[i]
lower_mask = mask[i + 1]
valid = upper_mask & lower_mask if masking == "include" else (~upper_mask & ~lower_mask)
diffs = (lower - upper)[valid]
if diffs.size >= mincount:
shifts[i + 1] = _trimmed_mean_or_median(diffs, trim_fraction)
else:
shifts[i + 1] = 0.0
shifts = np.cumsum(shifts)
return _slope_level_row_shifts(shifts)
def _vandermonde(x: np.ndarray, degree: int) -> np.ndarray:
return np.vander(np.asarray(x, dtype=np.float64), N=degree + 1, increasing=True)
def _row_level_poly(
data: np.ndarray,
mask: np.ndarray | None,
masking: str,
degree: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
yres, xres = data.shape
corrected = data.copy()
background = np.zeros_like(corrected)
shifts = np.zeros(yres, dtype=np.float64)
if yres == 0 or xres == 0:
return corrected, background, shifts
xc = 0.5 * (xres - 1)
avg = float(data.mean())
x_all = np.arange(xres, dtype=np.float64) - xc
design_all = _vandermonde(x_all, degree)
for i in range(yres):
row = data[i]
row_mask = None if mask is None else mask[i]
if row_mask is None or masking == "ignore":
valid = np.ones(xres, dtype=bool)
elif masking == "include":
valid = row_mask
else:
valid = ~row_mask
coeffs = np.zeros(degree + 1, dtype=np.float64)
if np.count_nonzero(valid) > degree:
design = design_all[valid]
coeffs, _, _, _ = np.linalg.lstsq(design, row[valid], rcond=None)
coeffs[0] -= avg
row_background = design_all @ coeffs
corrected[i] = row - row_background
background[i] = row_background
shifts[i] = coeffs[0]
return corrected, background, shifts
def _calculate_segment_correction(upper: np.ndarray, middle: np.ndarray, lower: np.ndarray) -> np.ndarray:
length = upper.size
if length < 4:
return np.zeros(length, dtype=np.float64)
corr = float(np.mean((upper + lower) / 2.0 - middle))
return (3.0 * corr + (upper + lower) / 2.0 - middle) / 4.0
def _line_correct_step_iter(data: np.ndarray) -> np.ndarray:
yres, xres = data.shape
if yres < 3 or xres == 0:
return data.copy()
corrections = np.zeros_like(data)
vertical_diff_energy = float(np.mean((data[1:] - data[:-1]) ** 2))
if vertical_diff_energy <= 0.0:
return data.copy()
threshold = 3.0
for i in range(yres - 2):
upper = data[i]
middle = data[i + 1]
lower = data[i + 2]
marker_row = corrections[i + 1]
candidates = (middle - upper) * (middle - lower) > threshold * vertical_diff_energy
if np.any(candidates):
signs = np.where(2.0 * middle[candidates] - upper[candidates] - lower[candidates] > 0.0, 1.0, -1.0)
marker_row[candidates] = signs
segment_start = 0
while segment_start < xres:
sign = marker_row[segment_start]
if sign == 0.0:
segment_start += 1
continue
segment_end = segment_start + 1
while segment_end < xres and marker_row[segment_end] == sign:
segment_end += 1
marker_row[segment_start:segment_end] = _calculate_segment_correction(
upper[segment_start:segment_end],
middle[segment_start:segment_end],
lower[segment_start:segment_end],
)
segment_start = segment_end
return data + corrections
def _conservative_filter(data: np.ndarray, size: int) -> np.ndarray:
if size <= 1:
return data.copy()
from scipy.ndimage import maximum_filter, minimum_filter
footprint = np.ones((size, size), dtype=bool)
footprint[size // 2, size // 2] = False
min_neighbours = minimum_filter(data, footprint=footprint, mode="nearest")
max_neighbours = maximum_filter(data, footprint=footprint, mode="nearest")
return np.clip(data, min_neighbours, max_neighbours)
def _line_correct_step(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
corrected = data.copy()
avg = float(corrected.mean())
shifts = _find_row_shifts_trimmed_mean(corrected, None, "ignore", 0.5, 0)
corrected -= shifts[:, np.newaxis]
corrected = _line_correct_step_iter(corrected)
corrected = _line_correct_step_iter(corrected)
corrected = _conservative_filter(corrected, 5)
corrected += avg - float(corrected.mean())
background = data - corrected
step_shifts = background.mean(axis=1) if background.size else np.zeros(data.shape[0], dtype=np.float64)
return corrected, step_shifts
def _line_axis(length: int, real_extent: float) -> np.ndarray | None:
if length <= 0:
return None
return np.linspace(0.0, float(real_extent), int(length), dtype=np.float64)
@register_node(display_name="Line Correction")
class LineCorrection:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": ([
"median",
"median_diff",
"trimmed_mean",
"trimmed_diff",
"polynomial",
"step",
], {"default": "median"}),
"direction": (["horizontal", "vertical"], {"default": "horizontal"}),
"masking": (["ignore", "include", "exclude"], {"default": "ignore"}),
"trim_fraction": ("FLOAT", {
"default": 0.05,
"min": 0.0,
"max": 0.5,
"step": 0.01,
"show_when_widget_value": {"method": ["trimmed_mean", "trimmed_diff"]},
}),
"polynomial_degree": ("INT", {
"default": 1,
"min": 0,
"max": 5,
"step": 1,
"show_when_widget_value": {"method": ["polynomial"]},
}),
},
"optional": {
"mask": ("IMAGE",),
},
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "LINE")
RETURN_NAMES = ("corrected", "background", "row shifts")
FUNCTION = "process"
DESCRIPTION = (
"Correct scan-line mismatches using Gwyddion-derived row alignment methods. "
"Supports median and trimmed row alignment, difference-based alignment, polynomial row leveling, "
"and the step-line correction path from Gwyddion's linecorrect/linematch modules."
)
def process(
self,
field: DataField,
method: str,
direction: str,
masking: str,
trim_fraction: float,
polynomial_degree: int,
mask: np.ndarray | None = None,
) -> tuple:
data = np.asarray(field.data, dtype=np.float64)
mask_array = _normalize_mask(mask, data.shape)
if direction not in {"horizontal", "vertical"}:
raise ValueError(f"Unknown direction: {direction}")
working = data.copy()
working_mask = None if mask_array is None else mask_array.copy()
if direction == "vertical":
working = working.T
if working_mask is not None:
working_mask = working_mask.T
if method == "median":
shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, 0.5, 0)
corrected = working - shifts[:, np.newaxis]
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
elif method == "median_diff":
shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, 0.5, 0)
corrected = working - shifts[:, np.newaxis]
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
elif method == "trimmed_mean":
shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, float(trim_fraction), 0)
corrected = working - shifts[:, np.newaxis]
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
elif method == "trimmed_diff":
shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, float(trim_fraction), 0)
corrected = working - shifts[:, np.newaxis]
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
elif method == "polynomial":
corrected, background, shifts = _row_level_poly(
working,
working_mask,
masking,
int(polynomial_degree),
)
elif method == "step":
corrected, shifts = _line_correct_step(working)
background = working - corrected
else:
raise ValueError(f"Unknown line correction method: {method}")
if direction == "vertical":
corrected = corrected.T
background = background.T
line_axis = _line_axis(field.xres, field.xreal)
else:
line_axis = _line_axis(field.yres, field.yreal)
corrected_field = field.replace(data=corrected)
background_field = field.replace(data=background)
shift_line = LineData(
data=np.asarray(shifts, dtype=np.float64),
x_axis=line_axis,
x_unit=field.si_unit_xy,
y_unit=field.si_unit_z,
)
return (corrected_field, background_field, shift_line)

View File

@@ -0,0 +1,219 @@
from __future__ import annotations
import warnings
import numpy as np
from backend.data_types import DataField
from backend.node_registry import register_node
def _mark_scars_one_sign(
data: np.ndarray,
threshold_high: float,
threshold_low: float,
min_length: int,
max_width: int,
negative: bool,
) -> np.ndarray:
yres, xres = data.shape
marks = np.zeros_like(data, dtype=np.float64)
min_length = max(int(min_length), 1)
max_width = min(int(max_width), yres - 2)
threshold_high = max(float(threshold_high), float(threshold_low))
threshold_low = float(threshold_low)
if min_length > xres or max_width < 1 or threshold_low <= 0.0:
return marks
vertical_rms = float(np.sqrt(np.sum((data[:-1] - data[1:]) ** 2) / max(xres * yres, 1)))
if vertical_rms == 0.0:
return marks
for i in range(yres - (max_width + 1)):
for j in range(xres):
if negative:
top = data[i, j]
bottom = data[i + 1, j]
width = 0
for candidate in range(1, max_width + 1):
top = min(data[i, j], data[i + candidate + 1, j])
bottom = max(bottom, data[i + candidate, j])
if top - bottom >= threshold_low * vertical_rms:
width = candidate
break
if width:
for candidate in range(width, 0, -1):
marks[i + candidate, j] = max(
marks[i + candidate, j],
(top - data[i + candidate, j]) / vertical_rms,
)
else:
bottom = data[i, j]
top = data[i + 1, j]
width = 0
for candidate in range(1, max_width + 1):
bottom = max(data[i, j], data[i + candidate + 1, j])
top = min(top, data[i + candidate, j])
if top - bottom >= threshold_low * vertical_rms:
width = candidate
break
if width:
for candidate in range(width, 0, -1):
marks[i + candidate, j] = max(
marks[i + candidate, j],
(data[i + candidate, j] - bottom) / vertical_rms,
)
for i in range(yres):
row = marks[i]
for j in range(1, xres):
if row[j] >= threshold_low and row[j - 1] >= threshold_high:
row[j] = threshold_high
for j in range(xres - 1, 0, -1):
if row[j - 1] >= threshold_low and row[j] >= threshold_high:
row[j - 1] = threshold_high
for i in range(yres):
row = marks[i]
run_length = 0
for j in range(xres):
if row[j] >= threshold_high:
row[j] = 1.0
run_length += 1
continue
if 0 < run_length < min_length:
row[j - run_length:j] = 0.0
row[j] = 0.0
run_length = 0
if 0 < run_length < min_length:
row[xres - run_length:xres] = 0.0
return marks
def _mark_scars(
data: np.ndarray,
scar_type: str,
threshold_high: float,
threshold_low: float,
min_length: int,
max_width: int,
) -> np.ndarray:
if scar_type == "positive":
return _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=False)
if scar_type == "negative":
return _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=True)
if scar_type == "both":
positive = _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=False)
negative = _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=True)
return np.maximum(positive, negative)
raise ValueError(f"Unknown scar type: {scar_type}")
def _laplace_inpaint(data: np.ndarray, mask: np.ndarray) -> np.ndarray:
mask = np.asarray(mask, dtype=bool)
if not np.any(mask):
return data.copy()
if np.all(mask):
return np.zeros_like(data, dtype=np.float64)
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import MatrixRankWarning, spsolve
from skimage.restoration import inpaint_biharmonic
yres, xres = data.shape
unknown_indices = -np.ones((yres, xres), dtype=np.int64)
unknown_coords = np.argwhere(mask)
unknown_indices[mask] = np.arange(unknown_coords.shape[0], dtype=np.int64)
rows: list[int] = []
cols: list[int] = []
values: list[float] = []
rhs = np.zeros(unknown_coords.shape[0], dtype=np.float64)
for row_index, (y, x) in enumerate(unknown_coords):
degree = 0
for ny, nx in ((y - 1, x), (y + 1, x), (y, x - 1), (y, x + 1)):
if ny < 0 or ny >= yres or nx < 0 or nx >= xres:
continue
degree += 1
if mask[ny, nx]:
rows.append(row_index)
cols.append(int(unknown_indices[ny, nx]))
values.append(-1.0)
else:
rhs[row_index] += float(data[ny, nx])
rows.append(row_index)
cols.append(row_index)
values.append(float(degree))
matrix = csr_matrix((values, (rows, cols)), shape=(unknown_coords.shape[0], unknown_coords.shape[0]))
restored = data.copy()
try:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=MatrixRankWarning)
solved = spsolve(matrix, rhs)
except (MatrixRankWarning, RuntimeError, ValueError):
return np.asarray(inpaint_biharmonic(data, mask, channel_axis=None), dtype=np.float64)
restored[mask] = np.asarray(solved, dtype=np.float64)
return restored
@register_node(display_name="Scar Removal")
class ScarRemoval:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"scar_type": (["both", "positive", "negative"], {"default": "both"}),
"threshold_high": ("FLOAT", {"default": 0.666, "min": 0.0, "max": 2.0, "step": 0.01}),
"threshold_low": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 2.0, "step": 0.01}),
"min_length": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
"max_width": ("INT", {"default": 4, "min": 1, "max": 32, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD", "IMAGE")
RETURN_NAMES = ("corrected", "scar mask")
FUNCTION = "process"
DESCRIPTION = (
"Detect and remove horizontal scan scars using Gwyddion-derived scar marking thresholds, "
"then interpolate over the detected mask with a Laplace-style inpaint."
)
def process(
self,
field: DataField,
scar_type: str,
threshold_high: float,
threshold_low: float,
min_length: int,
max_width: int,
) -> tuple:
threshold_high = float(max(threshold_high, threshold_low))
threshold_low = float(min(threshold_high, threshold_low))
marks = _mark_scars(
np.asarray(field.data, dtype=np.float64),
scar_type,
threshold_high,
threshold_low,
int(min_length),
int(max_width),
)
scar_mask = marks > 0.0
corrected = _laplace_inpaint(np.asarray(field.data, dtype=np.float64), scar_mask)
return (field.replace(data=corrected), scar_mask.astype(np.uint8) * 255)

View File

@@ -571,6 +571,130 @@ def test_fix_zero():
print(" PASS\n")
def test_line_correction():
print("=== Test: LineCorrection ===")
from backend.node_registry import get_node_info
from backend.nodes.line_correction import LineCorrection
node = LineCorrection()
assert get_node_info("LineCorrection")["category"] == "Flatten"
rows = 96
cols = 128
y = np.linspace(0.0, 1.0, rows, dtype=np.float64)
x = np.linspace(-1.0, 1.0, cols, dtype=np.float64)
signal = (
0.15 * np.sin(8.0 * np.pi * x)[None, :]
+ 0.05 * np.cos(4.0 * np.pi * y)[:, None]
)
row_offsets = 1.5 * np.sin(3.0 * np.pi * y) + 0.25 * np.cos(7.0 * np.pi * y)
field = make_field(
data=signal + row_offsets[:, None],
xreal=2.5e-6,
yreal=1.5e-6,
)
corrected, background, shifts = node.process(
field,
method="median",
direction="horizontal",
masking="ignore",
trim_fraction=0.05,
polynomial_degree=1,
)
expected_shifts = row_offsets - row_offsets.mean()
assert corrected.data.shape == field.data.shape
assert background.data.shape == field.data.shape
assert np.allclose(corrected.data + background.data, field.data)
assert isinstance(shifts, LineData)
assert shifts.x_unit == field.si_unit_xy
assert shifts.y_unit == field.si_unit_z
assert np.isclose(shifts.x_axis[0], 0.0)
assert np.isclose(shifts.x_axis[-1], field.yreal)
assert np.corrcoef(shifts.data, expected_shifts)[0, 1] > 0.999
assert corrected.data.mean(axis=1).std() < field.data.mean(axis=1).std() * 0.03
poly_background = (
row_offsets[:, None]
+ (0.35 * y - 0.15)[:, None] * x[None, :]
+ (0.10 + 0.05 * y)[:, None] * (x[None, :] ** 2)
)
poly_signal = 0.08 * np.sin(10.0 * np.pi * x)[None, :] * (1.0 + 0.15 * np.cos(2.0 * np.pi * y)[:, None])
poly_field = make_field(data=poly_signal + poly_background)
leveled, poly_bg, poly_shifts = node.process(
poly_field,
method="polynomial",
direction="horizontal",
masking="ignore",
trim_fraction=0.05,
polynomial_degree=2,
)
assert np.allclose(leveled.data + poly_bg.data, poly_field.data)
assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995
assert len(poly_shifts) == rows
print(" PASS\n")
def test_scar_removal():
print("=== Test: ScarRemoval ===")
from backend.node_registry import get_node_info
from backend.nodes.scar_removal import ScarRemoval
node = ScarRemoval()
assert get_node_info("ScarRemoval")["category"] == "Filter"
rows = 96
cols = 128
yy, xx = np.mgrid[0:rows, 0:cols]
base = (
0.005 * xx
+ 0.01 * yy
+ 0.12 * np.sin(2.0 * np.pi * xx / cols)
+ 0.07 * np.cos(2.0 * np.pi * yy / rows)
)
scarred = base.copy()
scarred[24, 20:92] += 1.8
scarred[25, 20:92] += 1.6
scarred[60, 12:116] -= 1.7
field = make_field(data=scarred)
corrected, scar_mask = node.process(
field,
scar_type="both",
threshold_high=0.6,
threshold_low=0.2,
min_length=12,
max_width=4,
)
mask_bool = scar_mask > 127
assert scar_mask.dtype == np.uint8
assert scar_mask.shape == field.data.shape
assert np.count_nonzero(mask_bool) > 0
assert np.count_nonzero(mask_bool[24:26, 20:92]) > 0
assert np.count_nonzero(mask_bool[60:61, 12:116]) > 0
assert np.allclose(corrected.data[~mask_bool], field.data[~mask_bool])
before_rmse = np.sqrt(np.mean((field.data[mask_bool] - base[mask_bool]) ** 2))
after_rmse = np.sqrt(np.mean((corrected.data[mask_bool] - base[mask_bool]) ** 2))
assert after_rmse < before_rmse * 0.35
clean_corrected, clean_mask = node.process(
make_field(data=base),
scar_type="both",
threshold_high=0.6,
threshold_low=0.2,
min_length=12,
max_width=4,
)
assert np.count_nonzero(clean_mask) == 0
assert np.allclose(clean_corrected.data, base)
print(" PASS\n")
# =========================================================================
# Analysis (non-FFT)
# =========================================================================
@@ -2522,6 +2646,8 @@ if __name__ == "__main__":
test_plane_level()
test_poly_level()
test_fix_zero()
test_line_correction()
test_scar_removal()
# Analysis
test_statistics()