Files
tono/backend/nodes/super_resolution.py

129 lines
4.0 KiB
Python

"""Super-resolution -- combine multiple aligned scans for resolution enhancement."""
from __future__ import annotations
import numpy as np
from scipy.ndimage import shift as ndimage_shift, zoom as ndimage_zoom
from backend.node_registry import register_node
from backend.data_types import DataField
def _find_subpixel_shift(ref: np.ndarray, img: np.ndarray) -> tuple[float, float]:
"""Estimate the (dy, dx) sub-pixel shift of *img* relative to *ref* via cross-correlation.
Uses FFT-based cross-correlation with parabolic peak refinement.
"""
fa = np.fft.fft2(ref - ref.mean())
fb = np.fft.fft2(img - img.mean())
cross = np.fft.ifft2(fa * np.conj(fb))
cc = np.abs(np.fft.fftshift(cross))
cy, cx = np.array(cc.shape) // 2
peak_y, peak_x = np.unravel_index(np.argmax(cc), cc.shape)
# Integer shift relative to centre
dy = peak_y - cy
dx = peak_x - cx
# Parabolic sub-pixel refinement around peak
h, w = cc.shape
if 1 <= peak_y <= h - 2:
num = float(cc[peak_y - 1, peak_x] - cc[peak_y + 1, peak_x])
den = float(
cc[peak_y - 1, peak_x] - 2.0 * cc[peak_y, peak_x] + cc[peak_y + 1, peak_x]
)
if abs(den) > 1e-12:
dy += 0.5 * num / den
if 1 <= peak_x <= w - 2:
num = float(cc[peak_y, peak_x - 1] - cc[peak_y, peak_x + 1])
den = float(
cc[peak_y, peak_x - 1] - 2.0 * cc[peak_y, peak_x] + cc[peak_y, peak_x + 1]
)
if abs(den) > 1e-12:
dx += 0.5 * num / den
return float(dy), float(dx)
@register_node(display_name="Super Resolution")
class SuperResolution:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field1": ("DATA_FIELD",),
"upscale": ("INT", {"default": 2, "min": 2, "max": 4, "step": 1}),
},
"optional": {
"field2": ("DATA_FIELD",),
"field3": ("DATA_FIELD",),
"field4": ("DATA_FIELD",),
},
}
OUTPUTS = (
('DATA_FIELD', 'result'),
)
FUNCTION = "process"
DESCRIPTION = (
"Combine multiple aligned scans to produce a super-resolved image with higher "
"spatial resolution. Sub-pixel shifts between inputs are estimated via FFT "
"cross-correlation and used to reconstruct a finer grid. When only one field "
"is provided the image is upsampled using cubic interpolation."
)
KEYWORDS = ("upscale", "upsample", "multiframe", "stack", "subpixel", "enhance")
def process(
self,
field1: DataField,
upscale: int,
field2: DataField | None = None,
field3: DataField | None = None,
field4: DataField | None = None,
) -> tuple:
fields = [field1]
for f in (field2, field3, field4):
if f is not None:
fields.append(f)
ref = np.asarray(field1.data, dtype=np.float64)
# Upsample reference to target resolution
high_res = ndimage_zoom(ref, upscale, order=3)
weight = np.ones_like(high_res)
if len(fields) == 1:
# Single input -- just return the upsampled reference
return (field1.replace(
data=high_res,
xreal=field1.xreal,
yreal=field1.yreal,
),)
# Multiple inputs -- align, upsample, and average
for extra in fields[1:]:
img = np.asarray(extra.data, dtype=np.float64)
# Find sub-pixel shift relative to reference
dy, dx = _find_subpixel_shift(ref, img)
# Shift in high-res coordinates
shifted = ndimage_shift(img.astype(np.float64), (-dy, -dx), order=3)
upsampled = ndimage_zoom(shifted, upscale, order=3)
# Accumulate
high_res += upsampled
weight += 1.0
high_res /= weight
return (field1.replace(
data=high_res,
xreal=field1.xreal,
yreal=field1.yreal,
),)