Files
tono/backend/nodes/mutual_crop.py
2026-04-03 23:11:52 -07:00

80 lines
2.5 KiB
Python

"""Mutual crop — align and crop two images to their overlapping region."""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Mutual Crop")
class MutualCrop:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field_a": ("DATA_FIELD",),
"field_b": ("DATA_FIELD",),
}
}
OUTPUTS = (
('DATA_FIELD', 'cropped_a'),
('DATA_FIELD', 'cropped_b'),
)
FUNCTION = "process"
DESCRIPTION = (
"Align two images using cross-correlation and crop both to their "
"overlapping region. Useful for comparing images acquired at "
"different times or with slight position offsets. "
)
def process(self, field_a: DataField, field_b: DataField) -> tuple:
a = np.asarray(field_a.data, dtype=np.float64)
b = np.asarray(field_b.data, dtype=np.float64)
# Pad to common shape for cross-correlation
shape = (max(a.shape[0], b.shape[0]), max(a.shape[1], b.shape[1]))
a_pad = np.zeros(shape)
b_pad = np.zeros(shape)
a_pad[:a.shape[0], :a.shape[1]] = a - a.mean()
b_pad[:b.shape[0], :b.shape[1]] = b - b.mean()
# Cross-correlate to find shift
fa = np.fft.fft2(a_pad)
fb = np.fft.fft2(b_pad)
cc = np.abs(np.fft.ifft2(fa * np.conj(fb)))
cc = np.fft.fftshift(cc)
cy, cx = np.array(shape) // 2
peak = np.unravel_index(np.argmax(cc), shape)
dy = peak[0] - cy
dx = peak[1] - cx
# Compute overlap region
ay_start = max(0, dy)
ay_end = min(a.shape[0], b.shape[0] + dy)
ax_start = max(0, dx)
ax_end = min(a.shape[1], b.shape[1] + dx)
by_start = max(0, -dy)
by_end = by_start + (ay_end - ay_start)
bx_start = max(0, -dx)
bx_end = bx_start + (ax_end - ax_start)
if ay_end <= ay_start or ax_end <= ax_start:
# No overlap found, return originals
return (field_a, field_b)
crop_a = a[ay_start:ay_end, ax_start:ax_end]
crop_b = b[by_start:by_end, bx_start:bx_end]
xreal = crop_a.shape[1] * field_a.dx
yreal = crop_a.shape[0] * field_a.dy
return (
field_a.replace(data=crop_a, xreal=xreal, yreal=yreal),
field_b.replace(data=crop_b, xreal=xreal, yreal=yreal),
)