"""Image stitching — combine two overlapping scans into one image.""" from __future__ import annotations import numpy as np from backend.node_registry import register_node from backend.data_types import DataField def _find_overlap_shift(a: np.ndarray, b: np.ndarray) -> tuple[int, int]: """Estimate the (dy, dx) pixel shift of b relative to a via cross-correlation.""" fa = np.fft.fft2(a - a.mean()) fb = np.fft.fft2(b - b.mean()) cross = np.fft.ifft2(fa * np.conj(fb)) cc = np.abs(np.fft.fftshift(cross)) cy, cx = np.array(cc.shape) // 2 peak = np.unravel_index(np.argmax(cc), cc.shape) dy = peak[0] - cy dx = peak[1] - cx return int(dy), int(dx) def _blend_overlap(a: np.ndarray, b: np.ndarray, axis: int) -> np.ndarray: """Linear blend along the overlap axis.""" length = a.shape[axis] weight = np.linspace(1.0, 0.0, length) if axis == 0: weight = weight[:, np.newaxis] return a * weight + b * (1.0 - weight) @register_node(display_name="Image Stitch") class ImageStitch: @classmethod def INPUT_TYPES(cls): return { "required": { "field_a": ("DATA_FIELD",), "field_b": ("DATA_FIELD",), "direction": (["right", "below", "auto"], {"default": "auto"}), "blend": (["linear", "none"], {"default": "linear"}), } } OUTPUTS = ( ('DATA_FIELD', 'stitched'), ) FUNCTION = "process" DESCRIPTION = ( "Combine two overlapping scans into a single image. " "Uses cross-correlation to align the images and blends the overlap region. " "Direction specifies how field_b is positioned relative to field_a. " "'auto' uses cross-correlation to determine the best placement. " "Equivalent to Gwyddion's merge.c / stitch.c modules." ) def process(self, field_a: DataField, field_b: DataField, direction: str, blend: str) -> tuple: a = np.asarray(field_a.data, dtype=np.float64) b = np.asarray(field_b.data, dtype=np.float64) if direction == "auto": # Pad b to match a's shape for cross-correlation shape = (max(a.shape[0], b.shape[0]), max(a.shape[1], b.shape[1])) a_pad = np.zeros(shape) b_pad = np.zeros(shape) a_pad[:a.shape[0], :a.shape[1]] = a b_pad[:b.shape[0], :b.shape[1]] = b dy, dx = _find_overlap_shift(a_pad, b_pad) direction = "right" if abs(dx) >= abs(dy) else "below" if direction == "right": # b is to the right of a dy, dx = _find_overlap_shift( a[:, -min(a.shape[1], b.shape[1]):], b[:, :min(a.shape[1], b.shape[1])], ) overlap = max(0, min(a.shape[1], b.shape[1]) - abs(dx)) if overlap <= 0: # No overlap, just concatenate out_h = max(a.shape[0], b.shape[0]) out = np.zeros((out_h, a.shape[1] + b.shape[1])) out[:a.shape[0], :a.shape[1]] = a out[:b.shape[0], a.shape[1]:] = b else: out_h = max(a.shape[0], b.shape[0]) out_w = a.shape[1] + b.shape[1] - overlap out = np.zeros((out_h, out_w)) out[:a.shape[0], :a.shape[1]] = a b_start = a.shape[1] - overlap if blend == "linear" and overlap > 1: ov_a = a[:min(a.shape[0], b.shape[0]), a.shape[1] - overlap:] ov_b = b[:min(a.shape[0], b.shape[0]), :overlap] blended = _blend_overlap(ov_a, ov_b, axis=1) out[:blended.shape[0], b_start:b_start + overlap] = blended out[:b.shape[0], b_start + overlap:b_start + b.shape[1]] = b[:, overlap:] else: out[:b.shape[0], b_start:b_start + b.shape[1]] = b elif direction == "below": dy, dx = _find_overlap_shift( a[-min(a.shape[0], b.shape[0]):, :], b[:min(a.shape[0], b.shape[0]), :], ) overlap = max(0, min(a.shape[0], b.shape[0]) - abs(dy)) if overlap <= 0: out_w = max(a.shape[1], b.shape[1]) out = np.zeros((a.shape[0] + b.shape[0], out_w)) out[:a.shape[0], :a.shape[1]] = a out[a.shape[0]:, :b.shape[1]] = b else: out_w = max(a.shape[1], b.shape[1]) out_h = a.shape[0] + b.shape[0] - overlap out = np.zeros((out_h, out_w)) out[:a.shape[0], :a.shape[1]] = a b_start = a.shape[0] - overlap if blend == "linear" and overlap > 1: ov_a = a[a.shape[0] - overlap:, :min(a.shape[1], b.shape[1])] ov_b = b[:overlap, :min(a.shape[1], b.shape[1])] blended = _blend_overlap(ov_a, ov_b, axis=0) out[b_start:b_start + overlap, :blended.shape[1]] = blended out[b_start + overlap:b_start + b.shape[0], :b.shape[1]] = b[overlap:, :] else: out[b_start:b_start + b.shape[0], :b.shape[1]] = b else: raise ValueError(f"Unknown direction: {direction!r}") xreal = out.shape[1] * field_a.dx yreal = out.shape[0] * field_a.dy return (field_a.replace(data=out, xreal=xreal, yreal=yreal),)