81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
"""Drift correction — compensate for thermal/piezo drift between scan lines."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
from scipy.ndimage import shift as ndimage_shift
|
|
|
|
from backend.node_registry import register_node
|
|
from backend.data_types import DataField
|
|
|
|
|
|
def _estimate_drift(data: np.ndarray, reference: str) -> np.ndarray:
|
|
"""Estimate per-row horizontal drift via cross-correlation with a reference.
|
|
|
|
Returns an array of shape (yres,) with the estimated x-shift (in pixels)
|
|
for each row.
|
|
"""
|
|
yres, xres = data.shape
|
|
shifts = np.zeros(yres)
|
|
|
|
if reference == "previous_row":
|
|
for i in range(1, yres):
|
|
corr = np.correlate(data[i - 1], data[i], mode="full")
|
|
peak = np.argmax(corr) - (xres - 1)
|
|
shifts[i] = shifts[i - 1] + peak
|
|
elif reference == "mean_row":
|
|
mean_row = data.mean(axis=0)
|
|
for i in range(yres):
|
|
corr = np.correlate(mean_row, data[i], mode="full")
|
|
peak = np.argmax(corr) - (xres - 1)
|
|
shifts[i] = peak
|
|
else:
|
|
raise ValueError(f"Unknown reference: {reference!r}")
|
|
|
|
# Remove the overall mean to centre the correction
|
|
shifts -= shifts.mean()
|
|
return shifts
|
|
|
|
|
|
@register_node(display_name="Drift Correction")
|
|
class DriftCorrection:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"field": ("DATA_FIELD",),
|
|
"reference": (["previous_row", "mean_row"], {"default": "previous_row"}),
|
|
"direction": (["horizontal", "vertical"], {"default": "horizontal"}),
|
|
}
|
|
}
|
|
|
|
OUTPUTS = (
|
|
('DATA_FIELD', 'corrected'),
|
|
)
|
|
FUNCTION = "process"
|
|
|
|
DESCRIPTION = (
|
|
"Compensate for thermal or piezo drift between scan lines. "
|
|
"Cross-correlates each row (or column) against a reference to estimate "
|
|
"the drift offset, then shifts lines to correct. "
|
|
)
|
|
|
|
def process(self, field: DataField, reference: str, direction: str) -> tuple:
|
|
data = np.asarray(field.data, dtype=np.float64)
|
|
|
|
if direction == "vertical":
|
|
data = data.T
|
|
|
|
shifts = _estimate_drift(data, reference)
|
|
corrected = np.empty_like(data)
|
|
for i, s in enumerate(shifts):
|
|
if abs(s) < 0.01:
|
|
corrected[i] = data[i]
|
|
else:
|
|
ndimage_shift(data[i], -s, output=corrected[i], mode="reflect")
|
|
|
|
if direction == "vertical":
|
|
corrected = corrected.T
|
|
|
|
return (field.replace(data=corrected),)
|