Files
tono/backend/nodes/image_stitch.py
2026-04-03 22:09:19 -07:00

132 lines
5.4 KiB
Python

"""Image stitching — combine two overlapping scans into one image."""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
def _find_overlap_shift(a: np.ndarray, b: np.ndarray) -> tuple[int, int]:
"""Estimate the (dy, dx) pixel shift of b relative to a via cross-correlation."""
fa = np.fft.fft2(a - a.mean())
fb = np.fft.fft2(b - b.mean())
cross = np.fft.ifft2(fa * np.conj(fb))
cc = np.abs(np.fft.fftshift(cross))
cy, cx = np.array(cc.shape) // 2
peak = np.unravel_index(np.argmax(cc), cc.shape)
dy = peak[0] - cy
dx = peak[1] - cx
return int(dy), int(dx)
def _blend_overlap(a: np.ndarray, b: np.ndarray, axis: int) -> np.ndarray:
"""Linear blend along the overlap axis."""
length = a.shape[axis]
weight = np.linspace(1.0, 0.0, length)
if axis == 0:
weight = weight[:, np.newaxis]
return a * weight + b * (1.0 - weight)
@register_node(display_name="Image Stitch")
class ImageStitch:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field_a": ("DATA_FIELD",),
"field_b": ("DATA_FIELD",),
"direction": (["right", "below", "auto"], {"default": "auto"}),
"blend": (["linear", "none"], {"default": "linear"}),
}
}
OUTPUTS = (
('DATA_FIELD', 'stitched'),
)
FUNCTION = "process"
DESCRIPTION = (
"Combine two overlapping scans into a single image. "
"Uses cross-correlation to align the images and blends the overlap region. "
"Direction specifies how field_b is positioned relative to field_a. "
"'auto' uses cross-correlation to determine the best placement. "
"Equivalent to Gwyddion's merge.c / stitch.c modules."
)
def process(self, field_a: DataField, field_b: DataField, direction: str, blend: str) -> tuple:
a = np.asarray(field_a.data, dtype=np.float64)
b = np.asarray(field_b.data, dtype=np.float64)
if direction == "auto":
# Pad b to match a's shape for cross-correlation
shape = (max(a.shape[0], b.shape[0]), max(a.shape[1], b.shape[1]))
a_pad = np.zeros(shape)
b_pad = np.zeros(shape)
a_pad[:a.shape[0], :a.shape[1]] = a
b_pad[:b.shape[0], :b.shape[1]] = b
dy, dx = _find_overlap_shift(a_pad, b_pad)
direction = "right" if abs(dx) >= abs(dy) else "below"
if direction == "right":
# b is to the right of a
dy, dx = _find_overlap_shift(
a[:, -min(a.shape[1], b.shape[1]):],
b[:, :min(a.shape[1], b.shape[1])],
)
overlap = max(0, min(a.shape[1], b.shape[1]) - abs(dx))
if overlap <= 0:
# No overlap, just concatenate
out_h = max(a.shape[0], b.shape[0])
out = np.zeros((out_h, a.shape[1] + b.shape[1]))
out[:a.shape[0], :a.shape[1]] = a
out[:b.shape[0], a.shape[1]:] = b
else:
out_h = max(a.shape[0], b.shape[0])
out_w = a.shape[1] + b.shape[1] - overlap
out = np.zeros((out_h, out_w))
out[:a.shape[0], :a.shape[1]] = a
b_start = a.shape[1] - overlap
if blend == "linear" and overlap > 1:
ov_a = a[:min(a.shape[0], b.shape[0]), a.shape[1] - overlap:]
ov_b = b[:min(a.shape[0], b.shape[0]), :overlap]
blended = _blend_overlap(ov_a, ov_b, axis=1)
out[:blended.shape[0], b_start:b_start + overlap] = blended
out[:b.shape[0], b_start + overlap:b_start + b.shape[1]] = b[:, overlap:]
else:
out[:b.shape[0], b_start:b_start + b.shape[1]] = b
elif direction == "below":
dy, dx = _find_overlap_shift(
a[-min(a.shape[0], b.shape[0]):, :],
b[:min(a.shape[0], b.shape[0]), :],
)
overlap = max(0, min(a.shape[0], b.shape[0]) - abs(dy))
if overlap <= 0:
out_w = max(a.shape[1], b.shape[1])
out = np.zeros((a.shape[0] + b.shape[0], out_w))
out[:a.shape[0], :a.shape[1]] = a
out[a.shape[0]:, :b.shape[1]] = b
else:
out_w = max(a.shape[1], b.shape[1])
out_h = a.shape[0] + b.shape[0] - overlap
out = np.zeros((out_h, out_w))
out[:a.shape[0], :a.shape[1]] = a
b_start = a.shape[0] - overlap
if blend == "linear" and overlap > 1:
ov_a = a[a.shape[0] - overlap:, :min(a.shape[1], b.shape[1])]
ov_b = b[:overlap, :min(a.shape[1], b.shape[1])]
blended = _blend_overlap(ov_a, ov_b, axis=0)
out[b_start:b_start + overlap, :blended.shape[1]] = blended
out[b_start + overlap:b_start + b.shape[0], :b.shape[1]] = b[overlap:, :]
else:
out[b_start:b_start + b.shape[0], :b.shape[1]] = b
else:
raise ValueError(f"Unknown direction: {direction!r}")
xreal = out.shape[1] * field_a.dx
yreal = out.shape[0] * field_a.dy
return (field_a.replace(data=out, xreal=xreal, yreal=yreal),)