76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
import numpy as np
|
|
import pytest
|
|
from tests.node_tests._shared import make_field
|
|
|
|
|
|
def test_field_arithmetic_basic():
|
|
from backend.nodes.field_arithmetic import FieldArithmetic
|
|
|
|
node = FieldArithmetic()
|
|
a = make_field(data=np.array([[1.0, 2.0], [3.0, 4.0]]))
|
|
b = make_field(data=np.array([[1.0, 1.0], [1.0, 1.0]]))
|
|
|
|
result, = node.process(a, b, "add")
|
|
assert np.allclose(result.data, [[2.0, 3.0], [4.0, 5.0]])
|
|
|
|
result, = node.process(a, b, "subtract")
|
|
assert np.allclose(result.data, [[0.0, 1.0], [2.0, 3.0]])
|
|
|
|
result, = node.process(a, b, "multiply")
|
|
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
|
|
|
|
result, = node.process(a, b, "divide")
|
|
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
|
|
|
|
result, = node.process(a, b, "min")
|
|
assert np.allclose(result.data, [[1.0, 1.0], [1.0, 1.0]])
|
|
|
|
result, = node.process(a, b, "max")
|
|
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
|
|
|
|
|
|
def test_field_arithmetic_hypot():
|
|
from backend.nodes.field_arithmetic import FieldArithmetic
|
|
|
|
node = FieldArithmetic()
|
|
a = make_field(data=np.array([[3.0, 0.0], [0.0, 5.0]]))
|
|
b = make_field(data=np.array([[4.0, 5.0], [3.0, 12.0]]))
|
|
|
|
result, = node.process(a, b, "hypot")
|
|
assert np.allclose(result.data, [[5.0, 5.0], [3.0, 13.0]])
|
|
|
|
|
|
def test_field_arithmetic_metadata_inherited():
|
|
from backend.nodes.field_arithmetic import FieldArithmetic
|
|
|
|
node = FieldArithmetic()
|
|
a = make_field(data=np.ones((4, 4)))
|
|
b = make_field(data=np.ones((4, 4)))
|
|
|
|
result, = node.process(a, b, "add")
|
|
assert result.xreal == a.xreal
|
|
assert result.si_unit_xy == a.si_unit_xy
|
|
assert result.si_unit_z == a.si_unit_z
|
|
|
|
|
|
def test_field_arithmetic_shape_mismatch():
|
|
from backend.nodes.field_arithmetic import FieldArithmetic
|
|
|
|
node = FieldArithmetic()
|
|
a = make_field(data=np.ones((4, 4)))
|
|
b = make_field(data=np.ones((3, 3)))
|
|
|
|
with pytest.raises(ValueError, match="resolution"):
|
|
node.process(a, b, "add")
|
|
|
|
|
|
def test_field_arithmetic_unknown_operation():
|
|
from backend.nodes.field_arithmetic import FieldArithmetic
|
|
|
|
node = FieldArithmetic()
|
|
a = make_field(data=np.ones((4, 4)))
|
|
b = make_field(data=np.ones((4, 4)))
|
|
|
|
with pytest.raises(ValueError):
|
|
node.process(a, b, "power")
|