add line correction and scar removal nodes
This commit is contained in:
388
backend/nodes/line_correction.py
Normal file
388
backend/nodes/line_correction.py
Normal file
@@ -0,0 +1,388 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import DataField, LineData
|
||||
from backend.node_registry import register_node
|
||||
|
||||
|
||||
def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None:
|
||||
if mask is None:
|
||||
return None
|
||||
|
||||
mask_array = np.asarray(mask)
|
||||
if mask_array.shape[:2] != shape:
|
||||
raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.")
|
||||
return mask_array > 127
|
||||
|
||||
|
||||
def _trimmed_mean_or_median(values: np.ndarray, trim_fraction: float) -> float:
|
||||
values = np.asarray(values, dtype=np.float64)
|
||||
if values.size == 0:
|
||||
return 0.0
|
||||
|
||||
sorted_values = np.sort(values, kind="mergesort")
|
||||
count = sorted_values.size
|
||||
nlowest = int(np.rint(trim_fraction * count))
|
||||
nhighest = int(np.rint(trim_fraction * count))
|
||||
|
||||
if nlowest + nhighest + 1 >= count:
|
||||
return float(np.median(sorted_values))
|
||||
|
||||
trimmed = sorted_values[nlowest:count - nhighest]
|
||||
return float(trimmed.mean()) if trimmed.size else float(np.median(sorted_values))
|
||||
|
||||
|
||||
def _masked_values(data: np.ndarray, mask: np.ndarray | None, masking: str) -> np.ndarray:
|
||||
if mask is None or masking == "ignore":
|
||||
return data
|
||||
if masking == "include":
|
||||
return data[mask]
|
||||
if masking == "exclude":
|
||||
return data[~mask]
|
||||
raise ValueError(f"Unknown masking mode: {masking}")
|
||||
|
||||
|
||||
def _global_masked_median(data: np.ndarray, mask: np.ndarray | None, masking: str) -> float:
|
||||
selected = _masked_values(data, mask, masking)
|
||||
if selected.size == 0:
|
||||
selected = np.asarray(data, dtype=np.float64).ravel()
|
||||
return float(np.median(selected))
|
||||
|
||||
|
||||
def _find_row_shifts_trimmed_mean(
|
||||
data: np.ndarray,
|
||||
mask: np.ndarray | None,
|
||||
masking: str,
|
||||
trim_fraction: float,
|
||||
mincount: int = 0,
|
||||
) -> np.ndarray:
|
||||
yres, xres = data.shape
|
||||
if yres == 0:
|
||||
return np.zeros(0, dtype=np.float64)
|
||||
|
||||
if mincount <= 0:
|
||||
mincount = int(np.rint(np.log(max(xres, 1)) + 1.0))
|
||||
|
||||
total_median = _global_masked_median(data, mask, masking)
|
||||
shifts = np.empty(yres, dtype=np.float64)
|
||||
|
||||
for i in range(yres):
|
||||
row = data[i]
|
||||
row_mask = None if mask is None else mask[i]
|
||||
|
||||
if row_mask is None or masking == "ignore":
|
||||
shifts[i] = _trimmed_mean_or_median(row, trim_fraction)
|
||||
continue
|
||||
|
||||
values = _masked_values(row, row_mask, masking)
|
||||
if values.size >= mincount:
|
||||
shifts[i] = _trimmed_mean_or_median(values, trim_fraction)
|
||||
else:
|
||||
shifts[i] = total_median
|
||||
|
||||
shifts -= shifts.mean()
|
||||
return shifts
|
||||
|
||||
|
||||
def _slope_level_row_shifts(shifts: np.ndarray) -> np.ndarray:
|
||||
shifts = np.asarray(shifts, dtype=np.float64).copy()
|
||||
if shifts.size <= 1:
|
||||
shifts -= shifts.mean() if shifts.size else 0.0
|
||||
return shifts
|
||||
|
||||
x = np.arange(shifts.size, dtype=np.float64)
|
||||
A = np.column_stack((np.ones_like(x), x))
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, shifts, rcond=None)
|
||||
intercept, slope = coeffs
|
||||
shifts -= intercept + slope * x
|
||||
return shifts
|
||||
|
||||
|
||||
def _find_row_shifts_trimmed_diff(
|
||||
data: np.ndarray,
|
||||
mask: np.ndarray | None,
|
||||
masking: str,
|
||||
trim_fraction: float,
|
||||
mincount: int = 0,
|
||||
) -> np.ndarray:
|
||||
yres, xres = data.shape
|
||||
shifts = np.zeros(yres, dtype=np.float64)
|
||||
if yres <= 1:
|
||||
return shifts
|
||||
|
||||
if mincount <= 0:
|
||||
mincount = int(np.rint(np.log(max(xres, 1)) + 1.0))
|
||||
|
||||
for i in range(yres - 1):
|
||||
upper = data[i]
|
||||
lower = data[i + 1]
|
||||
|
||||
if mask is None or masking == "ignore":
|
||||
diffs = lower - upper
|
||||
else:
|
||||
upper_mask = mask[i]
|
||||
lower_mask = mask[i + 1]
|
||||
valid = upper_mask & lower_mask if masking == "include" else (~upper_mask & ~lower_mask)
|
||||
diffs = (lower - upper)[valid]
|
||||
|
||||
if diffs.size >= mincount:
|
||||
shifts[i + 1] = _trimmed_mean_or_median(diffs, trim_fraction)
|
||||
else:
|
||||
shifts[i + 1] = 0.0
|
||||
|
||||
shifts = np.cumsum(shifts)
|
||||
return _slope_level_row_shifts(shifts)
|
||||
|
||||
|
||||
def _vandermonde(x: np.ndarray, degree: int) -> np.ndarray:
|
||||
return np.vander(np.asarray(x, dtype=np.float64), N=degree + 1, increasing=True)
|
||||
|
||||
|
||||
def _row_level_poly(
|
||||
data: np.ndarray,
|
||||
mask: np.ndarray | None,
|
||||
masking: str,
|
||||
degree: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
yres, xres = data.shape
|
||||
corrected = data.copy()
|
||||
background = np.zeros_like(corrected)
|
||||
shifts = np.zeros(yres, dtype=np.float64)
|
||||
|
||||
if yres == 0 or xres == 0:
|
||||
return corrected, background, shifts
|
||||
|
||||
xc = 0.5 * (xres - 1)
|
||||
avg = float(data.mean())
|
||||
x_all = np.arange(xres, dtype=np.float64) - xc
|
||||
design_all = _vandermonde(x_all, degree)
|
||||
|
||||
for i in range(yres):
|
||||
row = data[i]
|
||||
row_mask = None if mask is None else mask[i]
|
||||
|
||||
if row_mask is None or masking == "ignore":
|
||||
valid = np.ones(xres, dtype=bool)
|
||||
elif masking == "include":
|
||||
valid = row_mask
|
||||
else:
|
||||
valid = ~row_mask
|
||||
|
||||
coeffs = np.zeros(degree + 1, dtype=np.float64)
|
||||
if np.count_nonzero(valid) > degree:
|
||||
design = design_all[valid]
|
||||
coeffs, _, _, _ = np.linalg.lstsq(design, row[valid], rcond=None)
|
||||
|
||||
coeffs[0] -= avg
|
||||
row_background = design_all @ coeffs
|
||||
corrected[i] = row - row_background
|
||||
background[i] = row_background
|
||||
shifts[i] = coeffs[0]
|
||||
|
||||
return corrected, background, shifts
|
||||
|
||||
|
||||
def _calculate_segment_correction(upper: np.ndarray, middle: np.ndarray, lower: np.ndarray) -> np.ndarray:
|
||||
length = upper.size
|
||||
if length < 4:
|
||||
return np.zeros(length, dtype=np.float64)
|
||||
|
||||
corr = float(np.mean((upper + lower) / 2.0 - middle))
|
||||
return (3.0 * corr + (upper + lower) / 2.0 - middle) / 4.0
|
||||
|
||||
|
||||
def _line_correct_step_iter(data: np.ndarray) -> np.ndarray:
|
||||
yres, xres = data.shape
|
||||
if yres < 3 or xres == 0:
|
||||
return data.copy()
|
||||
|
||||
corrections = np.zeros_like(data)
|
||||
vertical_diff_energy = float(np.mean((data[1:] - data[:-1]) ** 2))
|
||||
if vertical_diff_energy <= 0.0:
|
||||
return data.copy()
|
||||
|
||||
threshold = 3.0
|
||||
for i in range(yres - 2):
|
||||
upper = data[i]
|
||||
middle = data[i + 1]
|
||||
lower = data[i + 2]
|
||||
marker_row = corrections[i + 1]
|
||||
|
||||
candidates = (middle - upper) * (middle - lower) > threshold * vertical_diff_energy
|
||||
if np.any(candidates):
|
||||
signs = np.where(2.0 * middle[candidates] - upper[candidates] - lower[candidates] > 0.0, 1.0, -1.0)
|
||||
marker_row[candidates] = signs
|
||||
|
||||
segment_start = 0
|
||||
while segment_start < xres:
|
||||
sign = marker_row[segment_start]
|
||||
if sign == 0.0:
|
||||
segment_start += 1
|
||||
continue
|
||||
|
||||
segment_end = segment_start + 1
|
||||
while segment_end < xres and marker_row[segment_end] == sign:
|
||||
segment_end += 1
|
||||
|
||||
marker_row[segment_start:segment_end] = _calculate_segment_correction(
|
||||
upper[segment_start:segment_end],
|
||||
middle[segment_start:segment_end],
|
||||
lower[segment_start:segment_end],
|
||||
)
|
||||
segment_start = segment_end
|
||||
|
||||
return data + corrections
|
||||
|
||||
|
||||
def _conservative_filter(data: np.ndarray, size: int) -> np.ndarray:
|
||||
if size <= 1:
|
||||
return data.copy()
|
||||
|
||||
from scipy.ndimage import maximum_filter, minimum_filter
|
||||
|
||||
footprint = np.ones((size, size), dtype=bool)
|
||||
footprint[size // 2, size // 2] = False
|
||||
min_neighbours = minimum_filter(data, footprint=footprint, mode="nearest")
|
||||
max_neighbours = maximum_filter(data, footprint=footprint, mode="nearest")
|
||||
return np.clip(data, min_neighbours, max_neighbours)
|
||||
|
||||
|
||||
def _line_correct_step(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
corrected = data.copy()
|
||||
avg = float(corrected.mean())
|
||||
|
||||
shifts = _find_row_shifts_trimmed_mean(corrected, None, "ignore", 0.5, 0)
|
||||
corrected -= shifts[:, np.newaxis]
|
||||
corrected = _line_correct_step_iter(corrected)
|
||||
corrected = _line_correct_step_iter(corrected)
|
||||
corrected = _conservative_filter(corrected, 5)
|
||||
corrected += avg - float(corrected.mean())
|
||||
|
||||
background = data - corrected
|
||||
step_shifts = background.mean(axis=1) if background.size else np.zeros(data.shape[0], dtype=np.float64)
|
||||
return corrected, step_shifts
|
||||
|
||||
|
||||
def _line_axis(length: int, real_extent: float) -> np.ndarray | None:
|
||||
if length <= 0:
|
||||
return None
|
||||
return np.linspace(0.0, float(real_extent), int(length), dtype=np.float64)
|
||||
|
||||
|
||||
@register_node(display_name="Line Correction")
|
||||
class LineCorrection:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": ([
|
||||
"median",
|
||||
"median_diff",
|
||||
"trimmed_mean",
|
||||
"trimmed_diff",
|
||||
"polynomial",
|
||||
"step",
|
||||
], {"default": "median"}),
|
||||
"direction": (["horizontal", "vertical"], {"default": "horizontal"}),
|
||||
"masking": (["ignore", "include", "exclude"], {"default": "ignore"}),
|
||||
"trim_fraction": ("FLOAT", {
|
||||
"default": 0.05,
|
||||
"min": 0.0,
|
||||
"max": 0.5,
|
||||
"step": 0.01,
|
||||
"show_when_widget_value": {"method": ["trimmed_mean", "trimmed_diff"]},
|
||||
}),
|
||||
"polynomial_degree": ("INT", {
|
||||
"default": 1,
|
||||
"min": 0,
|
||||
"max": 5,
|
||||
"step": 1,
|
||||
"show_when_widget_value": {"method": ["polynomial"]},
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "LINE")
|
||||
RETURN_NAMES = ("corrected", "background", "row shifts")
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Correct scan-line mismatches using Gwyddion-derived row alignment methods. "
|
||||
"Supports median and trimmed row alignment, difference-based alignment, polynomial row leveling, "
|
||||
"and the step-line correction path from Gwyddion's linecorrect/linematch modules."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
method: str,
|
||||
direction: str,
|
||||
masking: str,
|
||||
trim_fraction: float,
|
||||
polynomial_degree: int,
|
||||
mask: np.ndarray | None = None,
|
||||
) -> tuple:
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
mask_array = _normalize_mask(mask, data.shape)
|
||||
|
||||
if direction not in {"horizontal", "vertical"}:
|
||||
raise ValueError(f"Unknown direction: {direction}")
|
||||
|
||||
working = data.copy()
|
||||
working_mask = None if mask_array is None else mask_array.copy()
|
||||
if direction == "vertical":
|
||||
working = working.T
|
||||
if working_mask is not None:
|
||||
working_mask = working_mask.T
|
||||
|
||||
if method == "median":
|
||||
shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, 0.5, 0)
|
||||
corrected = working - shifts[:, np.newaxis]
|
||||
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
|
||||
elif method == "median_diff":
|
||||
shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, 0.5, 0)
|
||||
corrected = working - shifts[:, np.newaxis]
|
||||
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
|
||||
elif method == "trimmed_mean":
|
||||
shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, float(trim_fraction), 0)
|
||||
corrected = working - shifts[:, np.newaxis]
|
||||
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
|
||||
elif method == "trimmed_diff":
|
||||
shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, float(trim_fraction), 0)
|
||||
corrected = working - shifts[:, np.newaxis]
|
||||
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
|
||||
elif method == "polynomial":
|
||||
corrected, background, shifts = _row_level_poly(
|
||||
working,
|
||||
working_mask,
|
||||
masking,
|
||||
int(polynomial_degree),
|
||||
)
|
||||
elif method == "step":
|
||||
corrected, shifts = _line_correct_step(working)
|
||||
background = working - corrected
|
||||
else:
|
||||
raise ValueError(f"Unknown line correction method: {method}")
|
||||
|
||||
if direction == "vertical":
|
||||
corrected = corrected.T
|
||||
background = background.T
|
||||
line_axis = _line_axis(field.xres, field.xreal)
|
||||
else:
|
||||
line_axis = _line_axis(field.yres, field.yreal)
|
||||
|
||||
corrected_field = field.replace(data=corrected)
|
||||
background_field = field.replace(data=background)
|
||||
shift_line = LineData(
|
||||
data=np.asarray(shifts, dtype=np.float64),
|
||||
x_axis=line_axis,
|
||||
x_unit=field.si_unit_xy,
|
||||
y_unit=field.si_unit_z,
|
||||
)
|
||||
|
||||
return (corrected_field, background_field, shift_line)
|
||||
Reference in New Issue
Block a user