Files
tono/tests/node_tests/field_arithmetic.py

76 lines
2.3 KiB
Python

import numpy as np
import pytest
from tests.node_tests._shared import make_field
def test_field_arithmetic_basic():
from backend.nodes.field_arithmetic import FieldArithmetic
node = FieldArithmetic()
a = make_field(data=np.array([[1.0, 2.0], [3.0, 4.0]]))
b = make_field(data=np.array([[1.0, 1.0], [1.0, 1.0]]))
result, = node.process(a, b, "add")
assert np.allclose(result.data, [[2.0, 3.0], [4.0, 5.0]])
result, = node.process(a, b, "subtract")
assert np.allclose(result.data, [[0.0, 1.0], [2.0, 3.0]])
result, = node.process(a, b, "multiply")
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
result, = node.process(a, b, "divide")
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
result, = node.process(a, b, "min")
assert np.allclose(result.data, [[1.0, 1.0], [1.0, 1.0]])
result, = node.process(a, b, "max")
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
def test_field_arithmetic_hypot():
from backend.nodes.field_arithmetic import FieldArithmetic
node = FieldArithmetic()
a = make_field(data=np.array([[3.0, 0.0], [0.0, 5.0]]))
b = make_field(data=np.array([[4.0, 5.0], [3.0, 12.0]]))
result, = node.process(a, b, "hypot")
assert np.allclose(result.data, [[5.0, 5.0], [3.0, 13.0]])
def test_field_arithmetic_metadata_inherited():
from backend.nodes.field_arithmetic import FieldArithmetic
node = FieldArithmetic()
a = make_field(data=np.ones((4, 4)))
b = make_field(data=np.ones((4, 4)))
result, = node.process(a, b, "add")
assert result.xreal == a.xreal
assert result.si_unit_xy == a.si_unit_xy
assert result.si_unit_z == a.si_unit_z
def test_field_arithmetic_shape_mismatch():
from backend.nodes.field_arithmetic import FieldArithmetic
node = FieldArithmetic()
a = make_field(data=np.ones((4, 4)))
b = make_field(data=np.ones((3, 3)))
with pytest.raises(ValueError, match="resolution"):
node.process(a, b, "add")
def test_field_arithmetic_unknown_operation():
from backend.nodes.field_arithmetic import FieldArithmetic
node = FieldArithmetic()
a = make_field(data=np.ones((4, 4)))
b = make_field(data=np.ones((4, 4)))
with pytest.raises(ValueError):
node.process(a, b, "power")