tip modelling and deconvolution
This commit is contained in:
75
tests/node_tests/field_arithmetic.py
Normal file
75
tests/node_tests/field_arithmetic.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_field_arithmetic_basic():
|
||||
from backend.nodes.field_arithmetic import FieldArithmetic
|
||||
|
||||
node = FieldArithmetic()
|
||||
a = make_field(data=np.array([[1.0, 2.0], [3.0, 4.0]]))
|
||||
b = make_field(data=np.array([[1.0, 1.0], [1.0, 1.0]]))
|
||||
|
||||
result, = node.process(a, b, "add")
|
||||
assert np.allclose(result.data, [[2.0, 3.0], [4.0, 5.0]])
|
||||
|
||||
result, = node.process(a, b, "subtract")
|
||||
assert np.allclose(result.data, [[0.0, 1.0], [2.0, 3.0]])
|
||||
|
||||
result, = node.process(a, b, "multiply")
|
||||
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
result, = node.process(a, b, "divide")
|
||||
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
result, = node.process(a, b, "min")
|
||||
assert np.allclose(result.data, [[1.0, 1.0], [1.0, 1.0]])
|
||||
|
||||
result, = node.process(a, b, "max")
|
||||
assert np.allclose(result.data, [[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
|
||||
def test_field_arithmetic_hypot():
|
||||
from backend.nodes.field_arithmetic import FieldArithmetic
|
||||
|
||||
node = FieldArithmetic()
|
||||
a = make_field(data=np.array([[3.0, 0.0], [0.0, 5.0]]))
|
||||
b = make_field(data=np.array([[4.0, 5.0], [3.0, 12.0]]))
|
||||
|
||||
result, = node.process(a, b, "hypot")
|
||||
assert np.allclose(result.data, [[5.0, 5.0], [3.0, 13.0]])
|
||||
|
||||
|
||||
def test_field_arithmetic_metadata_inherited():
|
||||
from backend.nodes.field_arithmetic import FieldArithmetic
|
||||
|
||||
node = FieldArithmetic()
|
||||
a = make_field(data=np.ones((4, 4)))
|
||||
b = make_field(data=np.ones((4, 4)))
|
||||
|
||||
result, = node.process(a, b, "add")
|
||||
assert result.xreal == a.xreal
|
||||
assert result.si_unit_xy == a.si_unit_xy
|
||||
assert result.si_unit_z == a.si_unit_z
|
||||
|
||||
|
||||
def test_field_arithmetic_shape_mismatch():
|
||||
from backend.nodes.field_arithmetic import FieldArithmetic
|
||||
|
||||
node = FieldArithmetic()
|
||||
a = make_field(data=np.ones((4, 4)))
|
||||
b = make_field(data=np.ones((3, 3)))
|
||||
|
||||
with pytest.raises(ValueError, match="resolution"):
|
||||
node.process(a, b, "add")
|
||||
|
||||
|
||||
def test_field_arithmetic_unknown_operation():
|
||||
from backend.nodes.field_arithmetic import FieldArithmetic
|
||||
|
||||
node = FieldArithmetic()
|
||||
a = make_field(data=np.ones((4, 4)))
|
||||
b = make_field(data=np.ones((4, 4)))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
node.process(a, b, "power")
|
||||
Reference in New Issue
Block a user