Files
tono/tests/node_tests/cross_correlate.py

90 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_cross_correlate_same_field_peak_at_center():
"""Correlating a field with itself in 'same' mode peaks at the centre."""
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(0)
data = rng.standard_normal((32, 32))
field = make_field(data=data)
node = CrossCorrelate()
result, = node.process(field, field, mode="same", normalize=True)
peak_y, peak_x = np.unravel_index(np.argmax(result.data), result.data.shape)
cy, cx = result.data.shape[0] // 2, result.data.shape[1] // 2
# Peak should be within a few pixels of centre
assert abs(peak_y - cy) <= 2
assert abs(peak_x - cx) <= 2
def test_cross_correlate_same_mode_shape_equals_a():
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(1)
a = make_field(data=rng.standard_normal((32, 48)))
b = make_field(data=rng.standard_normal((32, 48)))
node = CrossCorrelate()
result, = node.process(a, b, mode="same", normalize=True)
assert result.data.shape == a.data.shape
def test_cross_correlate_full_mode_shape():
"""Full mode output shape should be Na+Nb-1 × Ma+Mb-1."""
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(2)
a = make_field(data=rng.standard_normal((20, 30)))
b = make_field(data=rng.standard_normal((20, 30)))
node = CrossCorrelate()
result, = node.process(a, b, mode="full", normalize=True)
assert result.data.shape == (20 + 20 - 1, 30 + 30 - 1)
def test_cross_correlate_normalized_peak_is_one():
"""Self-correlation normalised should give peak = 1."""
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(3)
data = rng.standard_normal((32, 32))
field = make_field(data=data)
node = CrossCorrelate()
result, = node.process(field, field, mode="same", normalize=True)
assert result.data.max() == pytest.approx(1.0, abs=1e-6)
def test_cross_correlate_unnormalized_runs():
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(4)
data = rng.standard_normal((16, 16))
field = make_field(data=data)
node = CrossCorrelate()
result, = node.process(field, field, mode="same", normalize=False)
assert result.data.shape == (16, 16)
def test_cross_correlate_valid_mode():
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(5)
a = make_field(data=rng.standard_normal((16, 16)))
b = make_field(data=rng.standard_normal((8, 8)))
node = CrossCorrelate()
result, = node.process(a, b, mode="valid", normalize=True)
# Valid mode output: (16-8+1, 16-8+1) = (9, 9)
assert result.data.shape == (9, 9)
def test_cross_correlate_preserves_metadata_same_mode():
from backend.nodes.cross_correlate import CrossCorrelate
rng = np.random.default_rng(6)
field = make_field(data=rng.standard_normal((16, 16)))
node = CrossCorrelate()
result, = node.process(field, field, mode="same", normalize=True)
assert result.xreal == field.xreal
assert result.yreal == field.yreal