129 lines
4.0 KiB
Python
129 lines
4.0 KiB
Python
"""Super-resolution -- combine multiple aligned scans for resolution enhancement."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
from scipy.ndimage import shift as ndimage_shift, zoom as ndimage_zoom
|
|
|
|
from backend.node_registry import register_node
|
|
from backend.data_types import DataField
|
|
|
|
|
|
def _find_subpixel_shift(ref: np.ndarray, img: np.ndarray) -> tuple[float, float]:
|
|
"""Estimate the (dy, dx) sub-pixel shift of *img* relative to *ref* via cross-correlation.
|
|
|
|
Uses FFT-based cross-correlation with parabolic peak refinement.
|
|
"""
|
|
fa = np.fft.fft2(ref - ref.mean())
|
|
fb = np.fft.fft2(img - img.mean())
|
|
cross = np.fft.ifft2(fa * np.conj(fb))
|
|
cc = np.abs(np.fft.fftshift(cross))
|
|
|
|
cy, cx = np.array(cc.shape) // 2
|
|
peak_y, peak_x = np.unravel_index(np.argmax(cc), cc.shape)
|
|
|
|
# Integer shift relative to centre
|
|
dy = peak_y - cy
|
|
dx = peak_x - cx
|
|
|
|
# Parabolic sub-pixel refinement around peak
|
|
h, w = cc.shape
|
|
if 1 <= peak_y <= h - 2:
|
|
num = float(cc[peak_y - 1, peak_x] - cc[peak_y + 1, peak_x])
|
|
den = float(
|
|
cc[peak_y - 1, peak_x] - 2.0 * cc[peak_y, peak_x] + cc[peak_y + 1, peak_x]
|
|
)
|
|
if abs(den) > 1e-12:
|
|
dy += 0.5 * num / den
|
|
|
|
if 1 <= peak_x <= w - 2:
|
|
num = float(cc[peak_y, peak_x - 1] - cc[peak_y, peak_x + 1])
|
|
den = float(
|
|
cc[peak_y, peak_x - 1] - 2.0 * cc[peak_y, peak_x] + cc[peak_y, peak_x + 1]
|
|
)
|
|
if abs(den) > 1e-12:
|
|
dx += 0.5 * num / den
|
|
|
|
return float(dy), float(dx)
|
|
|
|
|
|
@register_node(display_name="Super Resolution")
|
|
class SuperResolution:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"field1": ("DATA_FIELD",),
|
|
"upscale": ("INT", {"default": 2, "min": 2, "max": 4, "step": 1}),
|
|
},
|
|
"optional": {
|
|
"field2": ("DATA_FIELD",),
|
|
"field3": ("DATA_FIELD",),
|
|
"field4": ("DATA_FIELD",),
|
|
},
|
|
}
|
|
|
|
OUTPUTS = (
|
|
('DATA_FIELD', 'result'),
|
|
)
|
|
FUNCTION = "process"
|
|
|
|
DESCRIPTION = (
|
|
"Combine multiple aligned scans to produce a super-resolved image with higher "
|
|
"spatial resolution. Sub-pixel shifts between inputs are estimated via FFT "
|
|
"cross-correlation and used to reconstruct a finer grid. When only one field "
|
|
"is provided the image is upsampled using cubic interpolation."
|
|
)
|
|
|
|
KEYWORDS = ("upscale", "upsample", "multiframe", "stack", "subpixel", "enhance")
|
|
|
|
def process(
|
|
self,
|
|
field1: DataField,
|
|
upscale: int,
|
|
field2: DataField | None = None,
|
|
field3: DataField | None = None,
|
|
field4: DataField | None = None,
|
|
) -> tuple:
|
|
fields = [field1]
|
|
for f in (field2, field3, field4):
|
|
if f is not None:
|
|
fields.append(f)
|
|
|
|
ref = np.asarray(field1.data, dtype=np.float64)
|
|
|
|
# Upsample reference to target resolution
|
|
high_res = ndimage_zoom(ref, upscale, order=3)
|
|
weight = np.ones_like(high_res)
|
|
|
|
if len(fields) == 1:
|
|
# Single input -- just return the upsampled reference
|
|
return (field1.replace(
|
|
data=high_res,
|
|
xreal=field1.xreal,
|
|
yreal=field1.yreal,
|
|
),)
|
|
|
|
# Multiple inputs -- align, upsample, and average
|
|
for extra in fields[1:]:
|
|
img = np.asarray(extra.data, dtype=np.float64)
|
|
|
|
# Find sub-pixel shift relative to reference
|
|
dy, dx = _find_subpixel_shift(ref, img)
|
|
|
|
# Shift in high-res coordinates
|
|
shifted = ndimage_shift(img.astype(np.float64), (-dy, -dx), order=3)
|
|
upsampled = ndimage_zoom(shifted, upscale, order=3)
|
|
|
|
# Accumulate
|
|
high_res += upsampled
|
|
weight += 1.0
|
|
|
|
high_res /= weight
|
|
|
|
return (field1.replace(
|
|
data=high_res,
|
|
xreal=field1.xreal,
|
|
yreal=field1.yreal,
|
|
),)
|