from __future__ import annotations import warnings import numpy as np from backend.data_types import DataField from backend.node_registry import register_node from backend.nodes.helpers import bool_to_mask def _mark_scars_one_sign( data: np.ndarray, threshold_high: float, threshold_low: float, min_length: int, max_width: int, negative: bool, ) -> np.ndarray: yres, xres = data.shape marks = np.zeros_like(data, dtype=np.float64) min_length = max(int(min_length), 1) max_width = min(int(max_width), yres - 2) threshold_high = max(float(threshold_high), float(threshold_low)) threshold_low = float(threshold_low) if min_length > xres or max_width < 1 or threshold_low <= 0.0: return marks vertical_rms = float(np.sqrt(np.sum((data[:-1] - data[1:]) ** 2) / max(xres * yres, 1))) if vertical_rms == 0.0: return marks for i in range(yres - (max_width + 1)): for j in range(xres): if negative: top = data[i, j] bottom = data[i + 1, j] width = 0 for candidate in range(1, max_width + 1): top = min(data[i, j], data[i + candidate + 1, j]) bottom = max(bottom, data[i + candidate, j]) if top - bottom >= threshold_low * vertical_rms: width = candidate break if width: for candidate in range(width, 0, -1): marks[i + candidate, j] = max( marks[i + candidate, j], (top - data[i + candidate, j]) / vertical_rms, ) else: bottom = data[i, j] top = data[i + 1, j] width = 0 for candidate in range(1, max_width + 1): bottom = max(data[i, j], data[i + candidate + 1, j]) top = min(top, data[i + candidate, j]) if top - bottom >= threshold_low * vertical_rms: width = candidate break if width: for candidate in range(width, 0, -1): marks[i + candidate, j] = max( marks[i + candidate, j], (data[i + candidate, j] - bottom) / vertical_rms, ) for i in range(yres): row = marks[i] for j in range(1, xres): if row[j] >= threshold_low and row[j - 1] >= threshold_high: row[j] = threshold_high for j in range(xres - 1, 0, -1): if row[j - 1] >= threshold_low and row[j] >= threshold_high: row[j - 1] = threshold_high for i in range(yres): row = marks[i] run_length = 0 for j in range(xres): if row[j] >= threshold_high: row[j] = 1.0 run_length += 1 continue if 0 < run_length < min_length: row[j - run_length:j] = 0.0 row[j] = 0.0 run_length = 0 if 0 < run_length < min_length: row[xres - run_length:xres] = 0.0 return marks def _mark_scars( data: np.ndarray, scar_type: str, threshold_high: float, threshold_low: float, min_length: int, max_width: int, ) -> np.ndarray: if scar_type == "positive": return _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=False) if scar_type == "negative": return _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=True) if scar_type == "both": positive = _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=False) negative = _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=True) return np.maximum(positive, negative) raise ValueError(f"Unknown scar type: {scar_type}") def _laplace_inpaint(data: np.ndarray, mask: np.ndarray) -> np.ndarray: mask = np.asarray(mask, dtype=bool) if not np.any(mask): return data.copy() if np.all(mask): return np.zeros_like(data, dtype=np.float64) from scipy.sparse import csr_matrix from scipy.sparse.linalg import MatrixRankWarning, spsolve from skimage.restoration import inpaint_biharmonic yres, xres = data.shape unknown_indices = -np.ones((yres, xres), dtype=np.int64) unknown_coords = np.argwhere(mask) unknown_indices[mask] = np.arange(unknown_coords.shape[0], dtype=np.int64) rows: list[int] = [] cols: list[int] = [] values: list[float] = [] rhs = np.zeros(unknown_coords.shape[0], dtype=np.float64) for row_index, (y, x) in enumerate(unknown_coords): degree = 0 for ny, nx in ((y - 1, x), (y + 1, x), (y, x - 1), (y, x + 1)): if ny < 0 or ny >= yres or nx < 0 or nx >= xres: continue degree += 1 if mask[ny, nx]: rows.append(row_index) cols.append(int(unknown_indices[ny, nx])) values.append(-1.0) else: rhs[row_index] += float(data[ny, nx]) rows.append(row_index) cols.append(row_index) values.append(float(degree)) matrix = csr_matrix((values, (rows, cols)), shape=(unknown_coords.shape[0], unknown_coords.shape[0])) restored = data.copy() try: with warnings.catch_warnings(): warnings.filterwarnings("error", category=MatrixRankWarning) solved = spsolve(matrix, rhs) except (MatrixRankWarning, RuntimeError, ValueError): return np.asarray(inpaint_biharmonic(data, mask, channel_axis=None), dtype=np.float64) restored[mask] = np.asarray(solved, dtype=np.float64) return restored @register_node(display_name="Scar Removal") class ScarRemoval: @classmethod def INPUT_TYPES(cls): return { "required": { "field": ("DATA_FIELD",), "scar_type": (["both", "positive", "negative"], {"default": "both"}), "threshold_high": ("FLOAT", {"default": 0.666, "min": 0.0, "max": 2.0, "step": 0.01}), "threshold_low": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 2.0, "step": 0.01}), "min_length": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}), "max_width": ("INT", {"default": 4, "min": 1, "max": 32, "step": 1}), } } OUTPUTS = ( ('DATA_FIELD', 'corrected'), ('IMAGE', 'scar_mask'), ) FUNCTION = "process" DESCRIPTION = ( "Detect and remove horizontal scan scars using Gwyddion-derived scar marking thresholds, " "then interpolate over the detected mask with a Laplace-style inpaint." ) KEYWORDS = ("stripe", "streak", "glitch", "artifact", "inpaint", "destripe") def process( self, field: DataField, scar_type: str, threshold_high: float, threshold_low: float, min_length: int, max_width: int, ) -> tuple: threshold_high = float(max(threshold_high, threshold_low)) threshold_low = float(min(threshold_high, threshold_low)) marks = _mark_scars( np.asarray(field.data, dtype=np.float64), scar_type, threshold_high, threshold_low, int(min_length), int(max_width), ) scar_mask = marks > 0.0 corrected = _laplace_inpaint(np.asarray(field.data, dtype=np.float64), scar_mask) return (field.replace(data=corrected), bool_to_mask(scar_mask))