Files
tono/backend/nodes/drift_correction.py
2026-04-03 22:09:19 -07:00

82 lines
2.6 KiB
Python

"""Drift correction — compensate for thermal/piezo drift between scan lines."""
from __future__ import annotations
import numpy as np
from scipy.ndimage import shift as ndimage_shift
from backend.node_registry import register_node
from backend.data_types import DataField
def _estimate_drift(data: np.ndarray, reference: str) -> np.ndarray:
"""Estimate per-row horizontal drift via cross-correlation with a reference.
Returns an array of shape (yres,) with the estimated x-shift (in pixels)
for each row.
"""
yres, xres = data.shape
shifts = np.zeros(yres)
if reference == "previous_row":
for i in range(1, yres):
corr = np.correlate(data[i - 1], data[i], mode="full")
peak = np.argmax(corr) - (xres - 1)
shifts[i] = shifts[i - 1] + peak
elif reference == "mean_row":
mean_row = data.mean(axis=0)
for i in range(yres):
corr = np.correlate(mean_row, data[i], mode="full")
peak = np.argmax(corr) - (xres - 1)
shifts[i] = peak
else:
raise ValueError(f"Unknown reference: {reference!r}")
# Remove the overall mean to centre the correction
shifts -= shifts.mean()
return shifts
@register_node(display_name="Drift Correction")
class DriftCorrection:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"reference": (["previous_row", "mean_row"], {"default": "previous_row"}),
"direction": (["horizontal", "vertical"], {"default": "horizontal"}),
}
}
OUTPUTS = (
('DATA_FIELD', 'corrected'),
)
FUNCTION = "process"
DESCRIPTION = (
"Compensate for thermal or piezo drift between scan lines. "
"Cross-correlates each row (or column) against a reference to estimate "
"the drift offset, then shifts lines to correct. "
"Equivalent to Gwyddion's drift.c module."
)
def process(self, field: DataField, reference: str, direction: str) -> tuple:
data = np.asarray(field.data, dtype=np.float64)
if direction == "vertical":
data = data.T
shifts = _estimate_drift(data, reference)
corrected = np.empty_like(data)
for i, s in enumerate(shifts):
if abs(s) < 0.01:
corrected[i] = data[i]
else:
ndimage_shift(data[i], -s, output=corrected[i], mode="reflect")
if direction == "vertical":
corrected = corrected.T
return (field.replace(data=corrected),)