Files
tono/tests/node_tests/mask_shift.py
2026-04-04 00:25:53 -07:00

75 lines
2.6 KiB
Python

import numpy as np
import pytest
def _make_mask():
"""Create a simple test mask: 10x10 block of 255 in a 64x64 field."""
mask = np.zeros((64, 64), dtype=np.uint8)
mask[10:20, 10:20] = 255
return mask
def test_output_shape():
from backend.nodes.mask_shift import MaskShift
node = MaskShift()
mask = _make_mask()
result, = node.process(mask, shift_x=5, shift_y=3, border_mode="zero")
assert result.shape == mask.shape
assert result.dtype == np.uint8
result_wrap, = node.process(mask, shift_x=-10, shift_y=7, border_mode="wrap")
assert result_wrap.shape == mask.shape
result_mirror, = node.process(mask, shift_x=2, shift_y=-4, border_mode="mirror")
assert result_mirror.shape == mask.shape
def test_zero_shift_unchanged():
from backend.nodes.mask_shift import MaskShift
node = MaskShift()
mask = _make_mask()
result_zero, = node.process(mask, shift_x=0, shift_y=0, border_mode="zero")
assert np.array_equal(result_zero, mask)
result_wrap, = node.process(mask, shift_x=0, shift_y=0, border_mode="wrap")
assert np.array_equal(result_wrap, mask)
result_mirror, = node.process(mask, shift_x=0, shift_y=0, border_mode="mirror")
assert np.array_equal(result_mirror, mask)
def test_wrap_mode():
from backend.nodes.mask_shift import MaskShift
node = MaskShift()
mask = _make_mask()
# Shift block right by 60 pixels — the block at cols 10:20 should wrap
# and appear at cols 70%64=6 to 80%64=16, spanning the boundary.
result, = node.process(mask, shift_x=60, shift_y=0, border_mode="wrap")
assert result.dtype == np.uint8
# The total number of masked pixels should be preserved in wrap mode
assert np.count_nonzero(result) == np.count_nonzero(mask)
# Original location should not all still be set
# (shift is large enough to move block away from original position)
assert not np.array_equal(result, mask)
def test_zero_mode_fills():
from backend.nodes.mask_shift import MaskShift
node = MaskShift()
mask = _make_mask()
# Shift right by 5 — left 5 columns should be zeroed
result, = node.process(mask, shift_x=5, shift_y=0, border_mode="zero")
assert np.all(result[:, :5] == 0)
# Block should now be at cols 15:25, rows 10:20
assert np.all(result[10:20, 15:25] == 255)
# Shift down by 5 — top 5 rows should be zeroed
result2, = node.process(mask, shift_x=0, shift_y=5, border_mode="zero")
assert np.all(result2[:5, :] == 0)
# Block should now be at rows 15:25, cols 10:20
assert np.all(result2[15:25, 10:20] == 255)