diff --git a/GWYDDION_FEATURE_GAP.md b/GWYDDION_FEATURE_GAP.md index 5e7f29f..4ce3cb1 100644 --- a/GWYDDION_FEATURE_GAP.md +++ b/GWYDDION_FEATURE_GAP.md @@ -8,8 +8,8 @@ Reference for future implementation. Grouped by value to typical SPM workflows. | # | Feature | Gwyddion Source | Description | |---|---------|---------------|-------------| -| 1 | Line Correction | linecorrect.c, linematch.c | Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts. | -| 2 | Scar Removal | scars.c | Detect and interpolate scan-line defects (horizontal streaks). | +| ~~1~~ | ~~Line Correction~~ | ~~linecorrect.c, linematch.c~~ | ~~Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts.~~ **DONE** | +| ~~2~~ | ~~Scar Removal~~ | ~~scars.c~~ | ~~Detect and interpolate scan-line defects (horizontal streaks).~~ **DONE** | | 3 | Facet Leveling | facet-level.c | Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces. | | ~~4~~ | ~~Morphological Mask Ops~~ | ~~mask_morph.c~~ | ~~Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks.~~ **DONE** | | ~~5~~ | ~~1D FFT Filter~~ | ~~fft_filter_1d.c~~ | ~~Bandpass/lowpass/highpass filtering of LINE profiles.~~ **DONE** | @@ -73,11 +73,13 @@ For reference, these Gwyddion equivalents are already covered: | Plane Level | level | level.c | | Polynomial Level | level | polylevel.c | | Fix Zero | level | level.c (fix_zero) | +| Line Correction | level | linecorrect.c, linematch.c | | Gaussian Filter | filters | filters.c (gaussian) | | Median Filter | filters | filters.c (median) | | Edge Detect | filters | edge.c (sobel, prewitt, laplacian, LoG) | | 1D FFT Filter | filters | fft_filter_1d.c (lowpass, highpass, bandpass, notch) | | 2D FFT Filter | filters | fft_filter_2d.c (lowpass, highpass, bandpass, notch) | +| Scar Removal | filters | scars.c | | Statistics | analysis | stats.c | | Height Histogram | analysis | linestats.c (dh) | | 2D FFT | analysis | fft.c | diff --git a/backend/node_menu.py b/backend/node_menu.py index b661186..8a1cd46 100644 --- a/backend/node_menu.py +++ b/backend/node_menu.py @@ -47,6 +47,7 @@ MENU_LAYOUT: dict[str, list[str]] = { "EdgeDetect", "FFTFilter1D", "FFTFilter2D", + "ScarRemoval", ], "Frequency": [ "FFT2D", @@ -56,6 +57,7 @@ MENU_LAYOUT: dict[str, list[str]] = { "PlaneLevelField", "PolyLevelField", "FixZero", + "LineCorrection", ], "Measure": [ "CrossSection", diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index 7cd328d..4fae0ad 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -25,12 +25,15 @@ from backend.nodes import ( plane_level_field, poly_level_field, fix_zero, + line_correction, # Mask draw_mask, threshold_mask, mask_morphology, mask_invert, mask_combine, + # Correction + scar_removal, # Display color_map, font_node, diff --git a/backend/nodes/line_correction.py b/backend/nodes/line_correction.py new file mode 100644 index 0000000..041de89 --- /dev/null +++ b/backend/nodes/line_correction.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import numpy as np + +from backend.data_types import DataField, LineData +from backend.node_registry import register_node + + +def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None: + if mask is None: + return None + + mask_array = np.asarray(mask) + if mask_array.shape[:2] != shape: + raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.") + return mask_array > 127 + + +def _trimmed_mean_or_median(values: np.ndarray, trim_fraction: float) -> float: + values = np.asarray(values, dtype=np.float64) + if values.size == 0: + return 0.0 + + sorted_values = np.sort(values, kind="mergesort") + count = sorted_values.size + nlowest = int(np.rint(trim_fraction * count)) + nhighest = int(np.rint(trim_fraction * count)) + + if nlowest + nhighest + 1 >= count: + return float(np.median(sorted_values)) + + trimmed = sorted_values[nlowest:count - nhighest] + return float(trimmed.mean()) if trimmed.size else float(np.median(sorted_values)) + + +def _masked_values(data: np.ndarray, mask: np.ndarray | None, masking: str) -> np.ndarray: + if mask is None or masking == "ignore": + return data + if masking == "include": + return data[mask] + if masking == "exclude": + return data[~mask] + raise ValueError(f"Unknown masking mode: {masking}") + + +def _global_masked_median(data: np.ndarray, mask: np.ndarray | None, masking: str) -> float: + selected = _masked_values(data, mask, masking) + if selected.size == 0: + selected = np.asarray(data, dtype=np.float64).ravel() + return float(np.median(selected)) + + +def _find_row_shifts_trimmed_mean( + data: np.ndarray, + mask: np.ndarray | None, + masking: str, + trim_fraction: float, + mincount: int = 0, +) -> np.ndarray: + yres, xres = data.shape + if yres == 0: + return np.zeros(0, dtype=np.float64) + + if mincount <= 0: + mincount = int(np.rint(np.log(max(xres, 1)) + 1.0)) + + total_median = _global_masked_median(data, mask, masking) + shifts = np.empty(yres, dtype=np.float64) + + for i in range(yres): + row = data[i] + row_mask = None if mask is None else mask[i] + + if row_mask is None or masking == "ignore": + shifts[i] = _trimmed_mean_or_median(row, trim_fraction) + continue + + values = _masked_values(row, row_mask, masking) + if values.size >= mincount: + shifts[i] = _trimmed_mean_or_median(values, trim_fraction) + else: + shifts[i] = total_median + + shifts -= shifts.mean() + return shifts + + +def _slope_level_row_shifts(shifts: np.ndarray) -> np.ndarray: + shifts = np.asarray(shifts, dtype=np.float64).copy() + if shifts.size <= 1: + shifts -= shifts.mean() if shifts.size else 0.0 + return shifts + + x = np.arange(shifts.size, dtype=np.float64) + A = np.column_stack((np.ones_like(x), x)) + coeffs, _, _, _ = np.linalg.lstsq(A, shifts, rcond=None) + intercept, slope = coeffs + shifts -= intercept + slope * x + return shifts + + +def _find_row_shifts_trimmed_diff( + data: np.ndarray, + mask: np.ndarray | None, + masking: str, + trim_fraction: float, + mincount: int = 0, +) -> np.ndarray: + yres, xres = data.shape + shifts = np.zeros(yres, dtype=np.float64) + if yres <= 1: + return shifts + + if mincount <= 0: + mincount = int(np.rint(np.log(max(xres, 1)) + 1.0)) + + for i in range(yres - 1): + upper = data[i] + lower = data[i + 1] + + if mask is None or masking == "ignore": + diffs = lower - upper + else: + upper_mask = mask[i] + lower_mask = mask[i + 1] + valid = upper_mask & lower_mask if masking == "include" else (~upper_mask & ~lower_mask) + diffs = (lower - upper)[valid] + + if diffs.size >= mincount: + shifts[i + 1] = _trimmed_mean_or_median(diffs, trim_fraction) + else: + shifts[i + 1] = 0.0 + + shifts = np.cumsum(shifts) + return _slope_level_row_shifts(shifts) + + +def _vandermonde(x: np.ndarray, degree: int) -> np.ndarray: + return np.vander(np.asarray(x, dtype=np.float64), N=degree + 1, increasing=True) + + +def _row_level_poly( + data: np.ndarray, + mask: np.ndarray | None, + masking: str, + degree: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + yres, xres = data.shape + corrected = data.copy() + background = np.zeros_like(corrected) + shifts = np.zeros(yres, dtype=np.float64) + + if yres == 0 or xres == 0: + return corrected, background, shifts + + xc = 0.5 * (xres - 1) + avg = float(data.mean()) + x_all = np.arange(xres, dtype=np.float64) - xc + design_all = _vandermonde(x_all, degree) + + for i in range(yres): + row = data[i] + row_mask = None if mask is None else mask[i] + + if row_mask is None or masking == "ignore": + valid = np.ones(xres, dtype=bool) + elif masking == "include": + valid = row_mask + else: + valid = ~row_mask + + coeffs = np.zeros(degree + 1, dtype=np.float64) + if np.count_nonzero(valid) > degree: + design = design_all[valid] + coeffs, _, _, _ = np.linalg.lstsq(design, row[valid], rcond=None) + + coeffs[0] -= avg + row_background = design_all @ coeffs + corrected[i] = row - row_background + background[i] = row_background + shifts[i] = coeffs[0] + + return corrected, background, shifts + + +def _calculate_segment_correction(upper: np.ndarray, middle: np.ndarray, lower: np.ndarray) -> np.ndarray: + length = upper.size + if length < 4: + return np.zeros(length, dtype=np.float64) + + corr = float(np.mean((upper + lower) / 2.0 - middle)) + return (3.0 * corr + (upper + lower) / 2.0 - middle) / 4.0 + + +def _line_correct_step_iter(data: np.ndarray) -> np.ndarray: + yres, xres = data.shape + if yres < 3 or xres == 0: + return data.copy() + + corrections = np.zeros_like(data) + vertical_diff_energy = float(np.mean((data[1:] - data[:-1]) ** 2)) + if vertical_diff_energy <= 0.0: + return data.copy() + + threshold = 3.0 + for i in range(yres - 2): + upper = data[i] + middle = data[i + 1] + lower = data[i + 2] + marker_row = corrections[i + 1] + + candidates = (middle - upper) * (middle - lower) > threshold * vertical_diff_energy + if np.any(candidates): + signs = np.where(2.0 * middle[candidates] - upper[candidates] - lower[candidates] > 0.0, 1.0, -1.0) + marker_row[candidates] = signs + + segment_start = 0 + while segment_start < xres: + sign = marker_row[segment_start] + if sign == 0.0: + segment_start += 1 + continue + + segment_end = segment_start + 1 + while segment_end < xres and marker_row[segment_end] == sign: + segment_end += 1 + + marker_row[segment_start:segment_end] = _calculate_segment_correction( + upper[segment_start:segment_end], + middle[segment_start:segment_end], + lower[segment_start:segment_end], + ) + segment_start = segment_end + + return data + corrections + + +def _conservative_filter(data: np.ndarray, size: int) -> np.ndarray: + if size <= 1: + return data.copy() + + from scipy.ndimage import maximum_filter, minimum_filter + + footprint = np.ones((size, size), dtype=bool) + footprint[size // 2, size // 2] = False + min_neighbours = minimum_filter(data, footprint=footprint, mode="nearest") + max_neighbours = maximum_filter(data, footprint=footprint, mode="nearest") + return np.clip(data, min_neighbours, max_neighbours) + + +def _line_correct_step(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + corrected = data.copy() + avg = float(corrected.mean()) + + shifts = _find_row_shifts_trimmed_mean(corrected, None, "ignore", 0.5, 0) + corrected -= shifts[:, np.newaxis] + corrected = _line_correct_step_iter(corrected) + corrected = _line_correct_step_iter(corrected) + corrected = _conservative_filter(corrected, 5) + corrected += avg - float(corrected.mean()) + + background = data - corrected + step_shifts = background.mean(axis=1) if background.size else np.zeros(data.shape[0], dtype=np.float64) + return corrected, step_shifts + + +def _line_axis(length: int, real_extent: float) -> np.ndarray | None: + if length <= 0: + return None + return np.linspace(0.0, float(real_extent), int(length), dtype=np.float64) + + +@register_node(display_name="Line Correction") +class LineCorrection: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "method": ([ + "median", + "median_diff", + "trimmed_mean", + "trimmed_diff", + "polynomial", + "step", + ], {"default": "median"}), + "direction": (["horizontal", "vertical"], {"default": "horizontal"}), + "masking": (["ignore", "include", "exclude"], {"default": "ignore"}), + "trim_fraction": ("FLOAT", { + "default": 0.05, + "min": 0.0, + "max": 0.5, + "step": 0.01, + "show_when_widget_value": {"method": ["trimmed_mean", "trimmed_diff"]}, + }), + "polynomial_degree": ("INT", { + "default": 1, + "min": 0, + "max": 5, + "step": 1, + "show_when_widget_value": {"method": ["polynomial"]}, + }), + }, + "optional": { + "mask": ("IMAGE",), + }, + } + + RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "LINE") + RETURN_NAMES = ("corrected", "background", "row shifts") + FUNCTION = "process" + + DESCRIPTION = ( + "Correct scan-line mismatches using Gwyddion-derived row alignment methods. " + "Supports median and trimmed row alignment, difference-based alignment, polynomial row leveling, " + "and the step-line correction path from Gwyddion's linecorrect/linematch modules." + ) + + def process( + self, + field: DataField, + method: str, + direction: str, + masking: str, + trim_fraction: float, + polynomial_degree: int, + mask: np.ndarray | None = None, + ) -> tuple: + data = np.asarray(field.data, dtype=np.float64) + mask_array = _normalize_mask(mask, data.shape) + + if direction not in {"horizontal", "vertical"}: + raise ValueError(f"Unknown direction: {direction}") + + working = data.copy() + working_mask = None if mask_array is None else mask_array.copy() + if direction == "vertical": + working = working.T + if working_mask is not None: + working_mask = working_mask.T + + if method == "median": + shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, 0.5, 0) + corrected = working - shifts[:, np.newaxis] + background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy() + elif method == "median_diff": + shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, 0.5, 0) + corrected = working - shifts[:, np.newaxis] + background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy() + elif method == "trimmed_mean": + shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, float(trim_fraction), 0) + corrected = working - shifts[:, np.newaxis] + background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy() + elif method == "trimmed_diff": + shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, float(trim_fraction), 0) + corrected = working - shifts[:, np.newaxis] + background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy() + elif method == "polynomial": + corrected, background, shifts = _row_level_poly( + working, + working_mask, + masking, + int(polynomial_degree), + ) + elif method == "step": + corrected, shifts = _line_correct_step(working) + background = working - corrected + else: + raise ValueError(f"Unknown line correction method: {method}") + + if direction == "vertical": + corrected = corrected.T + background = background.T + line_axis = _line_axis(field.xres, field.xreal) + else: + line_axis = _line_axis(field.yres, field.yreal) + + corrected_field = field.replace(data=corrected) + background_field = field.replace(data=background) + shift_line = LineData( + data=np.asarray(shifts, dtype=np.float64), + x_axis=line_axis, + x_unit=field.si_unit_xy, + y_unit=field.si_unit_z, + ) + + return (corrected_field, background_field, shift_line) diff --git a/backend/nodes/scar_removal.py b/backend/nodes/scar_removal.py new file mode 100644 index 0000000..0011dfe --- /dev/null +++ b/backend/nodes/scar_removal.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import warnings + +import numpy as np + +from backend.data_types import DataField +from backend.node_registry import register_node + + +def _mark_scars_one_sign( + data: np.ndarray, + threshold_high: float, + threshold_low: float, + min_length: int, + max_width: int, + negative: bool, +) -> np.ndarray: + yres, xres = data.shape + marks = np.zeros_like(data, dtype=np.float64) + + min_length = max(int(min_length), 1) + max_width = min(int(max_width), yres - 2) + threshold_high = max(float(threshold_high), float(threshold_low)) + threshold_low = float(threshold_low) + + if min_length > xres or max_width < 1 or threshold_low <= 0.0: + return marks + + vertical_rms = float(np.sqrt(np.sum((data[:-1] - data[1:]) ** 2) / max(xres * yres, 1))) + if vertical_rms == 0.0: + return marks + + for i in range(yres - (max_width + 1)): + for j in range(xres): + if negative: + top = data[i, j] + bottom = data[i + 1, j] + width = 0 + for candidate in range(1, max_width + 1): + top = min(data[i, j], data[i + candidate + 1, j]) + bottom = max(bottom, data[i + candidate, j]) + if top - bottom >= threshold_low * vertical_rms: + width = candidate + break + + if width: + for candidate in range(width, 0, -1): + marks[i + candidate, j] = max( + marks[i + candidate, j], + (top - data[i + candidate, j]) / vertical_rms, + ) + else: + bottom = data[i, j] + top = data[i + 1, j] + width = 0 + for candidate in range(1, max_width + 1): + bottom = max(data[i, j], data[i + candidate + 1, j]) + top = min(top, data[i + candidate, j]) + if top - bottom >= threshold_low * vertical_rms: + width = candidate + break + + if width: + for candidate in range(width, 0, -1): + marks[i + candidate, j] = max( + marks[i + candidate, j], + (data[i + candidate, j] - bottom) / vertical_rms, + ) + + for i in range(yres): + row = marks[i] + for j in range(1, xres): + if row[j] >= threshold_low and row[j - 1] >= threshold_high: + row[j] = threshold_high + for j in range(xres - 1, 0, -1): + if row[j - 1] >= threshold_low and row[j] >= threshold_high: + row[j - 1] = threshold_high + + for i in range(yres): + row = marks[i] + run_length = 0 + for j in range(xres): + if row[j] >= threshold_high: + row[j] = 1.0 + run_length += 1 + continue + + if 0 < run_length < min_length: + row[j - run_length:j] = 0.0 + row[j] = 0.0 + run_length = 0 + + if 0 < run_length < min_length: + row[xres - run_length:xres] = 0.0 + + return marks + + +def _mark_scars( + data: np.ndarray, + scar_type: str, + threshold_high: float, + threshold_low: float, + min_length: int, + max_width: int, +) -> np.ndarray: + if scar_type == "positive": + return _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=False) + if scar_type == "negative": + return _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=True) + if scar_type == "both": + positive = _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=False) + negative = _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=True) + return np.maximum(positive, negative) + raise ValueError(f"Unknown scar type: {scar_type}") + + +def _laplace_inpaint(data: np.ndarray, mask: np.ndarray) -> np.ndarray: + mask = np.asarray(mask, dtype=bool) + if not np.any(mask): + return data.copy() + + if np.all(mask): + return np.zeros_like(data, dtype=np.float64) + + from scipy.sparse import csr_matrix + from scipy.sparse.linalg import MatrixRankWarning, spsolve + from skimage.restoration import inpaint_biharmonic + + yres, xres = data.shape + unknown_indices = -np.ones((yres, xres), dtype=np.int64) + unknown_coords = np.argwhere(mask) + unknown_indices[mask] = np.arange(unknown_coords.shape[0], dtype=np.int64) + + rows: list[int] = [] + cols: list[int] = [] + values: list[float] = [] + rhs = np.zeros(unknown_coords.shape[0], dtype=np.float64) + + for row_index, (y, x) in enumerate(unknown_coords): + degree = 0 + for ny, nx in ((y - 1, x), (y + 1, x), (y, x - 1), (y, x + 1)): + if ny < 0 or ny >= yres or nx < 0 or nx >= xres: + continue + + degree += 1 + if mask[ny, nx]: + rows.append(row_index) + cols.append(int(unknown_indices[ny, nx])) + values.append(-1.0) + else: + rhs[row_index] += float(data[ny, nx]) + + rows.append(row_index) + cols.append(row_index) + values.append(float(degree)) + + matrix = csr_matrix((values, (rows, cols)), shape=(unknown_coords.shape[0], unknown_coords.shape[0])) + restored = data.copy() + + try: + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=MatrixRankWarning) + solved = spsolve(matrix, rhs) + except (MatrixRankWarning, RuntimeError, ValueError): + return np.asarray(inpaint_biharmonic(data, mask, channel_axis=None), dtype=np.float64) + + restored[mask] = np.asarray(solved, dtype=np.float64) + return restored + + +@register_node(display_name="Scar Removal") +class ScarRemoval: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "scar_type": (["both", "positive", "negative"], {"default": "both"}), + "threshold_high": ("FLOAT", {"default": 0.666, "min": 0.0, "max": 2.0, "step": 0.01}), + "threshold_low": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 2.0, "step": 0.01}), + "min_length": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}), + "max_width": ("INT", {"default": 4, "min": 1, "max": 32, "step": 1}), + } + } + + RETURN_TYPES = ("DATA_FIELD", "IMAGE") + RETURN_NAMES = ("corrected", "scar mask") + FUNCTION = "process" + + DESCRIPTION = ( + "Detect and remove horizontal scan scars using Gwyddion-derived scar marking thresholds, " + "then interpolate over the detected mask with a Laplace-style inpaint." + ) + + def process( + self, + field: DataField, + scar_type: str, + threshold_high: float, + threshold_low: float, + min_length: int, + max_width: int, + ) -> tuple: + threshold_high = float(max(threshold_high, threshold_low)) + threshold_low = float(min(threshold_high, threshold_low)) + + marks = _mark_scars( + np.asarray(field.data, dtype=np.float64), + scar_type, + threshold_high, + threshold_low, + int(min_length), + int(max_width), + ) + scar_mask = marks > 0.0 + corrected = _laplace_inpaint(np.asarray(field.data, dtype=np.float64), scar_mask) + return (field.replace(data=corrected), scar_mask.astype(np.uint8) * 255) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 33e14c2..414897b 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -571,6 +571,130 @@ def test_fix_zero(): print(" PASS\n") +def test_line_correction(): + print("=== Test: LineCorrection ===") + from backend.node_registry import get_node_info + from backend.nodes.line_correction import LineCorrection + + node = LineCorrection() + assert get_node_info("LineCorrection")["category"] == "Flatten" + + rows = 96 + cols = 128 + y = np.linspace(0.0, 1.0, rows, dtype=np.float64) + x = np.linspace(-1.0, 1.0, cols, dtype=np.float64) + signal = ( + 0.15 * np.sin(8.0 * np.pi * x)[None, :] + + 0.05 * np.cos(4.0 * np.pi * y)[:, None] + ) + row_offsets = 1.5 * np.sin(3.0 * np.pi * y) + 0.25 * np.cos(7.0 * np.pi * y) + field = make_field( + data=signal + row_offsets[:, None], + xreal=2.5e-6, + yreal=1.5e-6, + ) + + corrected, background, shifts = node.process( + field, + method="median", + direction="horizontal", + masking="ignore", + trim_fraction=0.05, + polynomial_degree=1, + ) + expected_shifts = row_offsets - row_offsets.mean() + assert corrected.data.shape == field.data.shape + assert background.data.shape == field.data.shape + assert np.allclose(corrected.data + background.data, field.data) + assert isinstance(shifts, LineData) + assert shifts.x_unit == field.si_unit_xy + assert shifts.y_unit == field.si_unit_z + assert np.isclose(shifts.x_axis[0], 0.0) + assert np.isclose(shifts.x_axis[-1], field.yreal) + assert np.corrcoef(shifts.data, expected_shifts)[0, 1] > 0.999 + assert corrected.data.mean(axis=1).std() < field.data.mean(axis=1).std() * 0.03 + + poly_background = ( + row_offsets[:, None] + + (0.35 * y - 0.15)[:, None] * x[None, :] + + (0.10 + 0.05 * y)[:, None] * (x[None, :] ** 2) + ) + poly_signal = 0.08 * np.sin(10.0 * np.pi * x)[None, :] * (1.0 + 0.15 * np.cos(2.0 * np.pi * y)[:, None]) + poly_field = make_field(data=poly_signal + poly_background) + + leveled, poly_bg, poly_shifts = node.process( + poly_field, + method="polynomial", + direction="horizontal", + masking="ignore", + trim_fraction=0.05, + polynomial_degree=2, + ) + assert np.allclose(leveled.data + poly_bg.data, poly_field.data) + assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995 + assert len(poly_shifts) == rows + + print(" PASS\n") + + +def test_scar_removal(): + print("=== Test: ScarRemoval ===") + from backend.node_registry import get_node_info + from backend.nodes.scar_removal import ScarRemoval + + node = ScarRemoval() + assert get_node_info("ScarRemoval")["category"] == "Filter" + + rows = 96 + cols = 128 + yy, xx = np.mgrid[0:rows, 0:cols] + base = ( + 0.005 * xx + + 0.01 * yy + + 0.12 * np.sin(2.0 * np.pi * xx / cols) + + 0.07 * np.cos(2.0 * np.pi * yy / rows) + ) + scarred = base.copy() + scarred[24, 20:92] += 1.8 + scarred[25, 20:92] += 1.6 + scarred[60, 12:116] -= 1.7 + + field = make_field(data=scarred) + corrected, scar_mask = node.process( + field, + scar_type="both", + threshold_high=0.6, + threshold_low=0.2, + min_length=12, + max_width=4, + ) + + mask_bool = scar_mask > 127 + assert scar_mask.dtype == np.uint8 + assert scar_mask.shape == field.data.shape + assert np.count_nonzero(mask_bool) > 0 + assert np.count_nonzero(mask_bool[24:26, 20:92]) > 0 + assert np.count_nonzero(mask_bool[60:61, 12:116]) > 0 + assert np.allclose(corrected.data[~mask_bool], field.data[~mask_bool]) + + before_rmse = np.sqrt(np.mean((field.data[mask_bool] - base[mask_bool]) ** 2)) + after_rmse = np.sqrt(np.mean((corrected.data[mask_bool] - base[mask_bool]) ** 2)) + assert after_rmse < before_rmse * 0.35 + + clean_corrected, clean_mask = node.process( + make_field(data=base), + scar_type="both", + threshold_high=0.6, + threshold_low=0.2, + min_length=12, + max_width=4, + ) + assert np.count_nonzero(clean_mask) == 0 + assert np.allclose(clean_corrected.data, base) + + print(" PASS\n") + + # ========================================================================= # Analysis (non-FFT) # ========================================================================= @@ -2522,6 +2646,8 @@ if __name__ == "__main__": test_plane_level() test_poly_level() test_fix_zero() + test_line_correction() + test_scar_removal() # Analysis test_statistics()