Files
tono/backend/nodes/line_correction.py

389 lines
13 KiB
Python

from __future__ import annotations
import numpy as np
from backend.data_types import DataField, LineData
from backend.node_registry import register_node
def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None:
if mask is None:
return None
mask_array = np.asarray(mask)
if mask_array.shape[:2] != shape:
raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.")
return mask_array > 127
def _trimmed_mean_or_median(values: np.ndarray, trim_fraction: float) -> float:
values = np.asarray(values, dtype=np.float64)
if values.size == 0:
return 0.0
sorted_values = np.sort(values, kind="mergesort")
count = sorted_values.size
nlowest = int(np.rint(trim_fraction * count))
nhighest = int(np.rint(trim_fraction * count))
if nlowest + nhighest + 1 >= count:
return float(np.median(sorted_values))
trimmed = sorted_values[nlowest:count - nhighest]
return float(trimmed.mean()) if trimmed.size else float(np.median(sorted_values))
def _masked_values(data: np.ndarray, mask: np.ndarray | None, masking: str) -> np.ndarray:
if mask is None or masking == "ignore":
return data
if masking == "include":
return data[mask]
if masking == "exclude":
return data[~mask]
raise ValueError(f"Unknown masking mode: {masking}")
def _global_masked_median(data: np.ndarray, mask: np.ndarray | None, masking: str) -> float:
selected = _masked_values(data, mask, masking)
if selected.size == 0:
selected = np.asarray(data, dtype=np.float64).ravel()
return float(np.median(selected))
def _find_row_shifts_trimmed_mean(
data: np.ndarray,
mask: np.ndarray | None,
masking: str,
trim_fraction: float,
mincount: int = 0,
) -> np.ndarray:
yres, xres = data.shape
if yres == 0:
return np.zeros(0, dtype=np.float64)
if mincount <= 0:
mincount = int(np.rint(np.log(max(xres, 1)) + 1.0))
total_median = _global_masked_median(data, mask, masking)
shifts = np.empty(yres, dtype=np.float64)
for i in range(yres):
row = data[i]
row_mask = None if mask is None else mask[i]
if row_mask is None or masking == "ignore":
shifts[i] = _trimmed_mean_or_median(row, trim_fraction)
continue
values = _masked_values(row, row_mask, masking)
if values.size >= mincount:
shifts[i] = _trimmed_mean_or_median(values, trim_fraction)
else:
shifts[i] = total_median
shifts -= shifts.mean()
return shifts
def _slope_level_row_shifts(shifts: np.ndarray) -> np.ndarray:
shifts = np.asarray(shifts, dtype=np.float64).copy()
if shifts.size <= 1:
shifts -= shifts.mean() if shifts.size else 0.0
return shifts
x = np.arange(shifts.size, dtype=np.float64)
A = np.column_stack((np.ones_like(x), x))
coeffs, _, _, _ = np.linalg.lstsq(A, shifts, rcond=None)
intercept, slope = coeffs
shifts -= intercept + slope * x
return shifts
def _find_row_shifts_trimmed_diff(
data: np.ndarray,
mask: np.ndarray | None,
masking: str,
trim_fraction: float,
mincount: int = 0,
) -> np.ndarray:
yres, xres = data.shape
shifts = np.zeros(yres, dtype=np.float64)
if yres <= 1:
return shifts
if mincount <= 0:
mincount = int(np.rint(np.log(max(xres, 1)) + 1.0))
for i in range(yres - 1):
upper = data[i]
lower = data[i + 1]
if mask is None or masking == "ignore":
diffs = lower - upper
else:
upper_mask = mask[i]
lower_mask = mask[i + 1]
valid = upper_mask & lower_mask if masking == "include" else (~upper_mask & ~lower_mask)
diffs = (lower - upper)[valid]
if diffs.size >= mincount:
shifts[i + 1] = _trimmed_mean_or_median(diffs, trim_fraction)
else:
shifts[i + 1] = 0.0
shifts = np.cumsum(shifts)
return _slope_level_row_shifts(shifts)
def _vandermonde(x: np.ndarray, degree: int) -> np.ndarray:
return np.vander(np.asarray(x, dtype=np.float64), N=degree + 1, increasing=True)
def _row_level_poly(
data: np.ndarray,
mask: np.ndarray | None,
masking: str,
degree: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
yres, xres = data.shape
corrected = data.copy()
background = np.zeros_like(corrected)
shifts = np.zeros(yres, dtype=np.float64)
if yres == 0 or xres == 0:
return corrected, background, shifts
xc = 0.5 * (xres - 1)
avg = float(data.mean())
x_all = np.arange(xres, dtype=np.float64) - xc
design_all = _vandermonde(x_all, degree)
for i in range(yres):
row = data[i]
row_mask = None if mask is None else mask[i]
if row_mask is None or masking == "ignore":
valid = np.ones(xres, dtype=bool)
elif masking == "include":
valid = row_mask
else:
valid = ~row_mask
coeffs = np.zeros(degree + 1, dtype=np.float64)
if np.count_nonzero(valid) > degree:
design = design_all[valid]
coeffs, _, _, _ = np.linalg.lstsq(design, row[valid], rcond=None)
coeffs[0] -= avg
row_background = design_all @ coeffs
corrected[i] = row - row_background
background[i] = row_background
shifts[i] = coeffs[0]
return corrected, background, shifts
def _calculate_segment_correction(upper: np.ndarray, middle: np.ndarray, lower: np.ndarray) -> np.ndarray:
length = upper.size
if length < 4:
return np.zeros(length, dtype=np.float64)
corr = float(np.mean((upper + lower) / 2.0 - middle))
return (3.0 * corr + (upper + lower) / 2.0 - middle) / 4.0
def _line_correct_step_iter(data: np.ndarray) -> np.ndarray:
yres, xres = data.shape
if yres < 3 or xres == 0:
return data.copy()
corrections = np.zeros_like(data)
vertical_diff_energy = float(np.mean((data[1:] - data[:-1]) ** 2))
if vertical_diff_energy <= 0.0:
return data.copy()
threshold = 3.0
for i in range(yres - 2):
upper = data[i]
middle = data[i + 1]
lower = data[i + 2]
marker_row = corrections[i + 1]
candidates = (middle - upper) * (middle - lower) > threshold * vertical_diff_energy
if np.any(candidates):
signs = np.where(2.0 * middle[candidates] - upper[candidates] - lower[candidates] > 0.0, 1.0, -1.0)
marker_row[candidates] = signs
segment_start = 0
while segment_start < xres:
sign = marker_row[segment_start]
if sign == 0.0:
segment_start += 1
continue
segment_end = segment_start + 1
while segment_end < xres and marker_row[segment_end] == sign:
segment_end += 1
marker_row[segment_start:segment_end] = _calculate_segment_correction(
upper[segment_start:segment_end],
middle[segment_start:segment_end],
lower[segment_start:segment_end],
)
segment_start = segment_end
return data + corrections
def _conservative_filter(data: np.ndarray, size: int) -> np.ndarray:
if size <= 1:
return data.copy()
from scipy.ndimage import maximum_filter, minimum_filter
footprint = np.ones((size, size), dtype=bool)
footprint[size // 2, size // 2] = False
min_neighbours = minimum_filter(data, footprint=footprint, mode="nearest")
max_neighbours = maximum_filter(data, footprint=footprint, mode="nearest")
return np.clip(data, min_neighbours, max_neighbours)
def _line_correct_step(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
corrected = data.copy()
avg = float(corrected.mean())
shifts = _find_row_shifts_trimmed_mean(corrected, None, "ignore", 0.5, 0)
corrected -= shifts[:, np.newaxis]
corrected = _line_correct_step_iter(corrected)
corrected = _line_correct_step_iter(corrected)
corrected = _conservative_filter(corrected, 5)
corrected += avg - float(corrected.mean())
background = data - corrected
step_shifts = background.mean(axis=1) if background.size else np.zeros(data.shape[0], dtype=np.float64)
return corrected, step_shifts
def _line_axis(length: int, real_extent: float) -> np.ndarray | None:
if length <= 0:
return None
return np.linspace(0.0, float(real_extent), int(length), dtype=np.float64)
@register_node(display_name="Line Correction")
class LineCorrection:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": ([
"median",
"median_diff",
"trimmed_mean",
"trimmed_diff",
"polynomial",
"step",
], {"default": "median"}),
"direction": (["horizontal", "vertical"], {"default": "horizontal"}),
"masking": (["ignore", "include", "exclude"], {"default": "ignore"}),
"trim_fraction": ("FLOAT", {
"default": 0.05,
"min": 0.0,
"max": 0.5,
"step": 0.01,
"show_when_widget_value": {"method": ["trimmed_mean", "trimmed_diff"]},
}),
"polynomial_degree": ("INT", {
"default": 1,
"min": 0,
"max": 5,
"step": 1,
"show_when_widget_value": {"method": ["polynomial"]},
}),
},
"optional": {
"mask": ("IMAGE",),
},
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "LINE")
RETURN_NAMES = ("corrected", "background", "row shifts")
FUNCTION = "process"
DESCRIPTION = (
"Correct scan-line mismatches using Gwyddion-derived row alignment methods. "
"Supports median and trimmed row alignment, difference-based alignment, polynomial row leveling, "
"and the step-line correction path from Gwyddion's linecorrect/linematch modules."
)
def process(
self,
field: DataField,
method: str,
direction: str,
masking: str,
trim_fraction: float,
polynomial_degree: int,
mask: np.ndarray | None = None,
) -> tuple:
data = np.asarray(field.data, dtype=np.float64)
mask_array = _normalize_mask(mask, data.shape)
if direction not in {"horizontal", "vertical"}:
raise ValueError(f"Unknown direction: {direction}")
working = data.copy()
working_mask = None if mask_array is None else mask_array.copy()
if direction == "vertical":
working = working.T
if working_mask is not None:
working_mask = working_mask.T
if method == "median":
shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, 0.5, 0)
corrected = working - shifts[:, np.newaxis]
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
elif method == "median_diff":
shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, 0.5, 0)
corrected = working - shifts[:, np.newaxis]
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
elif method == "trimmed_mean":
shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, float(trim_fraction), 0)
corrected = working - shifts[:, np.newaxis]
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
elif method == "trimmed_diff":
shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, float(trim_fraction), 0)
corrected = working - shifts[:, np.newaxis]
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
elif method == "polynomial":
corrected, background, shifts = _row_level_poly(
working,
working_mask,
masking,
int(polynomial_degree),
)
elif method == "step":
corrected, shifts = _line_correct_step(working)
background = working - corrected
else:
raise ValueError(f"Unknown line correction method: {method}")
if direction == "vertical":
corrected = corrected.T
background = background.T
line_axis = _line_axis(field.xres, field.xreal)
else:
line_axis = _line_axis(field.yres, field.yreal)
corrected_field = field.replace(data=corrected)
background_field = field.replace(data=background)
shift_line = LineData(
data=np.asarray(shifts, dtype=np.float64),
x_axis=line_axis,
x_unit=field.si_unit_xy,
y_unit=field.si_unit_z,
)
return (corrected_field, background_field, shift_line)