low pri features
This commit is contained in:
81
backend/nodes/drift_correction.py
Normal file
81
backend/nodes/drift_correction.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Drift correction — compensate for thermal/piezo drift between scan lines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage import shift as ndimage_shift
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
def _estimate_drift(data: np.ndarray, reference: str) -> np.ndarray:
|
||||
"""Estimate per-row horizontal drift via cross-correlation with a reference.
|
||||
|
||||
Returns an array of shape (yres,) with the estimated x-shift (in pixels)
|
||||
for each row.
|
||||
"""
|
||||
yres, xres = data.shape
|
||||
shifts = np.zeros(yres)
|
||||
|
||||
if reference == "previous_row":
|
||||
for i in range(1, yres):
|
||||
corr = np.correlate(data[i - 1], data[i], mode="full")
|
||||
peak = np.argmax(corr) - (xres - 1)
|
||||
shifts[i] = shifts[i - 1] + peak
|
||||
elif reference == "mean_row":
|
||||
mean_row = data.mean(axis=0)
|
||||
for i in range(yres):
|
||||
corr = np.correlate(mean_row, data[i], mode="full")
|
||||
peak = np.argmax(corr) - (xres - 1)
|
||||
shifts[i] = peak
|
||||
else:
|
||||
raise ValueError(f"Unknown reference: {reference!r}")
|
||||
|
||||
# Remove the overall mean to centre the correction
|
||||
shifts -= shifts.mean()
|
||||
return shifts
|
||||
|
||||
|
||||
@register_node(display_name="Drift Correction")
|
||||
class DriftCorrection:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"reference": (["previous_row", "mean_row"], {"default": "previous_row"}),
|
||||
"direction": (["horizontal", "vertical"], {"default": "horizontal"}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'corrected'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Compensate for thermal or piezo drift between scan lines. "
|
||||
"Cross-correlates each row (or column) against a reference to estimate "
|
||||
"the drift offset, then shifts lines to correct. "
|
||||
"Equivalent to Gwyddion's drift.c module."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, reference: str, direction: str) -> tuple:
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
|
||||
if direction == "vertical":
|
||||
data = data.T
|
||||
|
||||
shifts = _estimate_drift(data, reference)
|
||||
corrected = np.empty_like(data)
|
||||
for i, s in enumerate(shifts):
|
||||
if abs(s) < 0.01:
|
||||
corrected[i] = data[i]
|
||||
else:
|
||||
ndimage_shift(data[i], -s, output=corrected[i], mode="reflect")
|
||||
|
||||
if direction == "vertical":
|
||||
corrected = corrected.T
|
||||
|
||||
return (field.replace(data=corrected),)
|
||||
Reference in New Issue
Block a user