improve coverage
This commit is contained in:
72
tests/node_tests/acf_1d.py
Normal file
72
tests/node_tests/acf_1d.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
from backend.data_types import LineData, RecordTable
|
||||
|
||||
|
||||
def test_acf_1d():
|
||||
from backend.nodes.acf_1d import ACF1D
|
||||
|
||||
node = ACF1D()
|
||||
|
||||
# Periodic signal — ACF should show a peak at the period
|
||||
n = 256
|
||||
period = 32
|
||||
t = np.arange(n, dtype=np.float64)
|
||||
signal = np.sin(2 * np.pi * t / period)
|
||||
profile = LineData(
|
||||
data=signal,
|
||||
x_axis=t * 1e-9,
|
||||
x_unit="m",
|
||||
y_unit="V",
|
||||
)
|
||||
|
||||
acf, measurement = node.process(profile, level="mean")
|
||||
|
||||
assert isinstance(acf, LineData)
|
||||
assert isinstance(measurement, RecordTable)
|
||||
|
||||
# ACF should be symmetric about zero lag
|
||||
center = len(acf) // 2
|
||||
assert np.allclose(acf.data, acf.data[::-1], atol=1e-10)
|
||||
|
||||
# Peak period should be close to the input period in metres
|
||||
expected_period_m = period * 1e-9
|
||||
assert len(measurement) == 1
|
||||
assert measurement[0]["quantity"] == "Peak period"
|
||||
assert abs(measurement[0]["value"] - expected_period_m) / expected_period_m < 0.1
|
||||
assert measurement[0]["unit"] == "m"
|
||||
|
||||
# x_axis should be centred on zero
|
||||
assert acf.x_axis is not None
|
||||
assert acf.x_axis[center] == 0.0 or abs(acf.x_axis[center]) < 1e-15
|
||||
|
||||
# ACF at zero lag should equal variance (signal is mean-subtracted)
|
||||
assert acf.data[center] > 0
|
||||
|
||||
|
||||
def test_acf_1d_no_peak():
|
||||
from backend.nodes.acf_1d import ACF1D
|
||||
|
||||
node = ACF1D()
|
||||
|
||||
# White noise — ACF should have no reliable peak, measurement table may be empty
|
||||
rng = np.random.default_rng(0)
|
||||
noise = rng.standard_normal(64)
|
||||
profile = LineData(data=noise, x_axis=np.arange(64, dtype=np.float64), x_unit="m")
|
||||
|
||||
acf, measurement = node.process(profile, level="none")
|
||||
assert isinstance(acf, LineData)
|
||||
# measurement is either empty or has one row — no assertion on content
|
||||
|
||||
|
||||
def test_acf_1d_level_none():
|
||||
from backend.nodes.acf_1d import ACF1D
|
||||
|
||||
node = ACF1D()
|
||||
|
||||
# With level="none", a DC offset should not be removed
|
||||
data = np.ones(32, dtype=np.float64) * 5.0
|
||||
profile = LineData(data=data, x_axis=np.arange(32, dtype=np.float64))
|
||||
|
||||
acf, _ = node.process(profile, level="none")
|
||||
# ACF of a constant is a constant
|
||||
assert acf.data[len(acf) // 2] > 0
|
||||
@@ -18,8 +18,9 @@ def test_line_cursors():
|
||||
|
||||
table, coord_pair = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
|
||||
assert isinstance(coord_pair, tuple) and len(coord_pair) == 2
|
||||
assert len(table) == 6
|
||||
assert len(table) == 7
|
||||
quantities = {row["quantity"] for row in table}
|
||||
assert "Length" in quantities
|
||||
assert "A x" in quantities
|
||||
assert "B x" in quantities
|
||||
assert "dx" in quantities
|
||||
@@ -41,7 +42,7 @@ def test_line_cursors():
|
||||
|
||||
line_data = LineData(data=line, x_axis=np.linspace(0, 1, 100))
|
||||
table2, _ = node.process(line_data, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
|
||||
assert len(table2) == 6
|
||||
assert len(table2) == 7
|
||||
|
||||
field = DataField(
|
||||
data=np.arange(100, dtype=np.float64).reshape(10, 10),
|
||||
|
||||
136
tests/node_tests/execution_preview.py
Normal file
136
tests/node_tests/execution_preview.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Tests for ExecutionEngine._auto_preview and _render_line_preview."""
|
||||
import backend.nodes # noqa: F401
|
||||
import numpy as np
|
||||
from backend.execution import ExecutionEngine
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, LineData
|
||||
|
||||
|
||||
def test_auto_preview_data_field():
|
||||
"""A node that outputs DATA_FIELD should trigger on_preview."""
|
||||
engine = ExecutionEngine()
|
||||
previews = []
|
||||
prompt = {
|
||||
"1": {"class_type": "Number", "inputs": {"value": 1.0}},
|
||||
}
|
||||
# Number outputs FLOAT, not DATA_FIELD — use GaussianFilter which outputs DATA_FIELD
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
@register_node(display_name="Test Preview Field Source")
|
||||
class TestPreviewFieldSource:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
OUTPUTS = (('DATA_FIELD', 'out'),)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
def process(self):
|
||||
return (make_field(),)
|
||||
|
||||
engine = ExecutionEngine()
|
||||
previews = []
|
||||
prompt = {"1": {"class_type": "TestPreviewFieldSource", "inputs": {}}}
|
||||
engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p)))
|
||||
assert len(previews) == 1
|
||||
nid, payload = previews[0]
|
||||
assert nid == "1"
|
||||
assert isinstance(payload, str) and payload.startswith("data:image/png;base64,")
|
||||
|
||||
|
||||
def test_auto_preview_line():
|
||||
"""A node that outputs LINE should trigger on_preview with a line_plot dict."""
|
||||
@register_node(display_name="Test Preview Line Source")
|
||||
class TestPreviewLineSource:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
OUTPUTS = (('LINE', 'out'),)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
def process(self):
|
||||
return (LineData(
|
||||
data=np.sin(np.linspace(0, 2 * np.pi, 64)),
|
||||
x_axis=np.linspace(0, 1e-6, 64),
|
||||
x_unit="m",
|
||||
),)
|
||||
|
||||
engine = ExecutionEngine()
|
||||
previews = []
|
||||
prompt = {"1": {"class_type": "TestPreviewLineSource", "inputs": {}}}
|
||||
engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p)))
|
||||
assert len(previews) == 1
|
||||
_, payload = previews[0]
|
||||
assert isinstance(payload, dict)
|
||||
assert payload["kind"] == "line_plot"
|
||||
assert "line" in payload and "x_axis" in payload
|
||||
assert payload["x_unit"] == "m"
|
||||
|
||||
|
||||
def test_auto_preview_table():
|
||||
"""A node that outputs RECORD_TABLE should trigger on_table."""
|
||||
from backend.data_types import RecordTable
|
||||
|
||||
@register_node(display_name="Test Preview Table Source")
|
||||
class TestPreviewTableSource:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
OUTPUTS = (('RECORD_TABLE', 'out'),)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
def process(self):
|
||||
return (RecordTable([{"quantity": "x", "value": 1.0, "unit": "m"}]),)
|
||||
|
||||
engine = ExecutionEngine()
|
||||
tables = []
|
||||
prompt = {"1": {"class_type": "TestPreviewTableSource", "inputs": {}}}
|
||||
engine.execute(prompt, on_table=lambda nid, t: tables.append((nid, t)))
|
||||
assert len(tables) == 1
|
||||
nid, rows = tables[0]
|
||||
assert nid == "1"
|
||||
assert rows[0]["quantity"] == "x"
|
||||
|
||||
|
||||
def test_auto_preview_polymorphic_field_output():
|
||||
"""A polymorphic output (declared LINE, actual DataField) should preview as a field."""
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
@register_node(display_name="Test Polymorphic Field Out")
|
||||
class TestPolymorphicFieldOut:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
OUTPUTS = (('LINE', 'out', {"accepted_types": ["DATA_FIELD"]}),)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
def process(self):
|
||||
return (make_field(),)
|
||||
|
||||
engine = ExecutionEngine()
|
||||
previews = []
|
||||
prompt = {"1": {"class_type": "TestPolymorphicFieldOut", "inputs": {}}}
|
||||
engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p)))
|
||||
assert len(previews) == 1
|
||||
_, payload = previews[0]
|
||||
# Should render as field preview (data URI), not line_plot dict
|
||||
assert isinstance(payload, str) and payload.startswith("data:image/png;base64,")
|
||||
|
||||
|
||||
def test_on_node_start_called():
|
||||
"""on_node_start callback fires before each node executes."""
|
||||
@register_node(display_name="Test Start Callback")
|
||||
class TestStartCallback:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
OUTPUTS = (('FLOAT', 'v'),)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
def process(self):
|
||||
return (1.0,)
|
||||
|
||||
started = []
|
||||
engine = ExecutionEngine()
|
||||
prompt = {"1": {"class_type": "TestStartCallback", "inputs": {}}}
|
||||
engine.execute(prompt, on_node_start=lambda nid: started.append(nid))
|
||||
assert started == ["1"]
|
||||
68
tests/node_tests/fft_1d.py
Normal file
68
tests/node_tests/fft_1d.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import numpy as np
|
||||
from backend.data_types import LineData, RecordTable
|
||||
|
||||
|
||||
def test_fft_1d_peak_period():
|
||||
from backend.nodes.fft_1d import FFT1D
|
||||
|
||||
node = FFT1D()
|
||||
|
||||
n = 256
|
||||
period = 32 # pixels
|
||||
dx = 1e-9 # 1 nm per pixel
|
||||
t = np.arange(n, dtype=np.float64)
|
||||
signal = np.sin(2 * np.pi * t / period)
|
||||
profile = LineData(
|
||||
data=signal,
|
||||
x_axis=t * dx,
|
||||
x_unit="m",
|
||||
y_unit="V",
|
||||
)
|
||||
|
||||
freq_line, table = node.process(profile)
|
||||
|
||||
assert isinstance(freq_line, LineData)
|
||||
assert isinstance(table, RecordTable)
|
||||
assert len(table) == 1
|
||||
assert table[0]["quantity"] == "Peak period"
|
||||
assert table[0]["unit"] == "m"
|
||||
|
||||
# Peak period should be close to 32 nm
|
||||
expected = period * dx
|
||||
assert abs(table[0]["value"] - expected) / expected < 0.1
|
||||
|
||||
# Output axis is in metres (spatial units)
|
||||
assert freq_line.x_unit == "m"
|
||||
# Spectrum values are non-negative magnitudes
|
||||
assert np.all(freq_line.data >= 0)
|
||||
# Highest spectral value corresponds to peak period
|
||||
peak_idx = np.argmax(freq_line.data)
|
||||
assert abs(freq_line.x_axis[peak_idx] - expected) / expected < 0.1
|
||||
|
||||
|
||||
def test_fft_1d_no_x_axis():
|
||||
from backend.nodes.fft_1d import FFT1D
|
||||
|
||||
node = FFT1D()
|
||||
|
||||
# Plain numpy array without calibration — should fall back to d=1, unit="m"
|
||||
signal = np.sin(2 * np.pi * np.arange(64) / 8)
|
||||
freq_line, table = node.process(signal)
|
||||
|
||||
assert isinstance(freq_line, LineData)
|
||||
assert len(freq_line.data) > 0
|
||||
assert np.all(freq_line.data >= 0)
|
||||
assert len(table) == 1
|
||||
|
||||
|
||||
def test_fft_1d_output_length():
|
||||
from backend.nodes.fft_1d import FFT1D
|
||||
|
||||
node = FFT1D()
|
||||
|
||||
for n in (32, 64, 128):
|
||||
data = np.random.default_rng(n).standard_normal(n)
|
||||
profile = LineData(data=data, x_axis=np.arange(n, dtype=np.float64) * 1e-9, x_unit="m")
|
||||
freq_line, _ = node.process(profile)
|
||||
# rfft gives n//2+1 bins; DC (index 0) is removed, leaving n//2 points
|
||||
assert len(freq_line.data) == n // 2
|
||||
@@ -63,3 +63,129 @@ def test_list_channels():
|
||||
folder_node = Folder()
|
||||
folder_result = folder_node.list_files(tmpdir)
|
||||
assert folder_result == tuple(entry["path"] for entry in paths)
|
||||
|
||||
|
||||
def test_measurement_helpers():
|
||||
from backend.nodes.helpers import _measurement_names, _measurement_entry, _measurement_value
|
||||
from backend.data_types import RecordTable
|
||||
|
||||
table = RecordTable([
|
||||
{"quantity": "Rq", "value": 0.5, "unit": "nm"},
|
||||
{"quantity": "Ra", "value": 0.3, "unit": "nm"},
|
||||
{"quantity": "Rq", "value": 0.5, "unit": "nm"}, # duplicate — deduplicated in names
|
||||
])
|
||||
|
||||
names = _measurement_names(table)
|
||||
assert names == ["Rq", "Ra"]
|
||||
|
||||
row = _measurement_entry(table, "Ra")
|
||||
assert row["value"] == 0.3
|
||||
|
||||
# falls back to first when selection not found
|
||||
row_fallback = _measurement_entry(table, "nonexistent")
|
||||
assert row_fallback["quantity"] == "Rq"
|
||||
|
||||
val = _measurement_value(table, "Ra")
|
||||
assert val == 0.3
|
||||
|
||||
|
||||
def test_measurement_value_errors():
|
||||
from backend.nodes.helpers import _measurement_value
|
||||
from backend.data_types import RecordTable
|
||||
|
||||
empty = RecordTable([])
|
||||
try:
|
||||
_measurement_value(empty, "anything")
|
||||
assert False, "should raise"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
bool_table = RecordTable([{"quantity": "flag", "value": True}])
|
||||
try:
|
||||
_measurement_value(bool_table, "flag")
|
||||
assert False, "should raise"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_format_with_unit():
|
||||
from backend.nodes.helpers import _format_with_unit, _format_numeric
|
||||
|
||||
assert _format_numeric(0.0) == "0"
|
||||
assert not np.isfinite(float('inf')) or _format_numeric(float('inf')) is not None
|
||||
|
||||
# plain number no unit
|
||||
result = _format_with_unit(1.5, "")
|
||||
assert "1.5" in result
|
||||
|
||||
# prefixable unit gets SI prefix
|
||||
result_nm = _format_with_unit(1e-9, "m")
|
||||
assert "n" in result_nm or "1e" in result_nm
|
||||
|
||||
# non-prefixable unit is left as-is
|
||||
result_bare = _format_with_unit(3.14, "rad")
|
||||
assert "3.14" in result_bare and "rad" in result_bare
|
||||
|
||||
# zero value
|
||||
result_zero = _format_with_unit(0.0, "m")
|
||||
assert "0" in result_zero
|
||||
|
||||
|
||||
def test_table_and_array_ops():
|
||||
from backend.nodes.helpers import (
|
||||
TABLE_OPS, ARRAY_OPS, extract_numeric_table_values,
|
||||
resolve_table_column_name, _common_table_unit,
|
||||
)
|
||||
from backend.data_types import RecordTable
|
||||
|
||||
values = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
assert TABLE_OPS["min"](values) == 1.0
|
||||
assert TABLE_OPS["max"](values) == 5.0
|
||||
assert TABLE_OPS["mean"](values) == 3.0
|
||||
assert TABLE_OPS["sum"](values) == 15.0
|
||||
assert TABLE_OPS["range"](values) == 4.0
|
||||
assert TABLE_OPS["count"](values) == 5.0
|
||||
assert TABLE_OPS["median"](values) == 3.0
|
||||
assert TABLE_OPS["std"](values) > 0
|
||||
assert TABLE_OPS["variance"](values) > 0
|
||||
|
||||
assert ARRAY_OPS["rms"](values) > 0
|
||||
assert ARRAY_OPS["std"](values) > 0
|
||||
|
||||
table = RecordTable([
|
||||
{"quantity": "A", "value": 1.0, "unit": "m"},
|
||||
{"quantity": "B", "value": 2.0, "unit": "m"},
|
||||
{"not_a_dict": True},
|
||||
{"quantity": "C", "value": "not_a_number"},
|
||||
])
|
||||
nums = extract_numeric_table_values(table, "value")
|
||||
assert nums == [1.0, 2.0]
|
||||
|
||||
col = resolve_table_column_name(table, "")
|
||||
assert col == "value"
|
||||
|
||||
unit = _common_table_unit(table, "value")
|
||||
assert unit == "m"
|
||||
|
||||
|
||||
def test_square_unit_and_apply():
|
||||
from backend.nodes.helpers import _square_unit, _apply_scalar_unit
|
||||
|
||||
assert _square_unit("m") == "m^2"
|
||||
assert _square_unit("m/s") == "(m/s)^2"
|
||||
assert _square_unit("") == ""
|
||||
|
||||
assert _apply_scalar_unit("m", "variance") == "m^2"
|
||||
assert _apply_scalar_unit("m", "count") == "count"
|
||||
assert _apply_scalar_unit("m", "mean") == "m"
|
||||
assert _apply_scalar_unit("", "mean") == ""
|
||||
|
||||
|
||||
def test_nice_length():
|
||||
from backend.nodes.helpers import _nice_length
|
||||
|
||||
assert _nice_length(0.0) == 0.0
|
||||
assert _nice_length(float('inf')) == 0.0
|
||||
assert _nice_length(7.3) == 5.0
|
||||
assert _nice_length(1500.0) == 1000.0
|
||||
assert _nice_length(0.003) == 0.002
|
||||
|
||||
@@ -20,8 +20,9 @@ def test_load_file():
|
||||
img.save(path)
|
||||
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
field = result[0]
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], str)
|
||||
field = result[1]
|
||||
assert field.data.shape == (48, 64)
|
||||
assert field.data.dtype == np.float64
|
||||
|
||||
@@ -31,15 +32,15 @@ def test_load_file():
|
||||
img_rgb.save(path_rgb)
|
||||
|
||||
result_rgb = node.load(filename=path_rgb)
|
||||
assert len(result_rgb) == 1
|
||||
assert result_rgb[0].data.shape == (32, 32)
|
||||
assert len(result_rgb) == 2
|
||||
assert result_rgb[1].data.shape == (32, 32)
|
||||
|
||||
data_npy = np.random.default_rng(3).standard_normal((50, 60))
|
||||
path_npy = os.path.join(tmpdir, "test.npy")
|
||||
np.save(path_npy, data_npy)
|
||||
|
||||
result_npy = node.load(filename=path_npy)
|
||||
assert np.allclose(result_npy[0].data, data_npy)
|
||||
assert np.allclose(result_npy[1].data, data_npy)
|
||||
|
||||
custom_colormap = {
|
||||
"mode": "custom",
|
||||
@@ -50,13 +51,13 @@ def test_load_file():
|
||||
],
|
||||
}
|
||||
result_custom = node.load(filename=path, colormap_map=custom_colormap)
|
||||
assert isinstance(result_custom[0].colormap, dict)
|
||||
assert result_custom[0].colormap["mode"] == "custom"
|
||||
assert len(result_custom[0].colormap["stops"]) == 3
|
||||
assert isinstance(result_custom[1].colormap, dict)
|
||||
assert result_custom[1].colormap["mode"] == "custom"
|
||||
assert len(result_custom[1].colormap["stops"]) == 3
|
||||
|
||||
result_from_path = node.load(filename="", path=path)
|
||||
assert len(result_from_path) == 1
|
||||
assert result_from_path[0].data.shape == (48, 64)
|
||||
assert len(result_from_path) == 2
|
||||
assert result_from_path[1].data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_load_file_npz():
|
||||
@@ -68,8 +69,8 @@ def test_load_file_npz():
|
||||
np.savez(path, my_array=data)
|
||||
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
assert np.allclose(result[0].data, data)
|
||||
assert len(result) == 2
|
||||
assert np.allclose(result[1].data, data)
|
||||
|
||||
|
||||
def test_load_file_cache():
|
||||
@@ -83,8 +84,8 @@ def test_load_file_cache():
|
||||
np.save(path, data)
|
||||
|
||||
with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader:
|
||||
first, = node.load(filename=path)
|
||||
second, = node.load(filename=path)
|
||||
_, first = node.load(filename=path)
|
||||
_, second = node.load(filename=path)
|
||||
assert loader.call_count == 1
|
||||
|
||||
assert np.allclose(first.data, data)
|
||||
@@ -92,7 +93,7 @@ def test_load_file_cache():
|
||||
assert first is not second
|
||||
first.data[0, 0] = -999.0
|
||||
|
||||
third, = node.load(filename=path)
|
||||
_, third = node.load(filename=path)
|
||||
assert third.data[0, 0] == data[0, 0]
|
||||
|
||||
Image._load_fields_cached.cache_clear()
|
||||
@@ -136,7 +137,7 @@ def test_load_file_warning():
|
||||
img.save(path)
|
||||
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
assert len(result) == 2
|
||||
assert len(warnings) == 1
|
||||
assert "Uncalibrated" in warnings[0]
|
||||
|
||||
|
||||
@@ -9,17 +9,23 @@ import backend.nodes # noqa: F401
|
||||
|
||||
def test_load_demo():
|
||||
from backend.nodes.image_demo import ImageDemo
|
||||
from backend.nodes.helpers import DEMO_DIR
|
||||
node = ImageDemo()
|
||||
|
||||
result = node.load(name="nanoparticles.npy")
|
||||
assert len(result) >= 1
|
||||
assert isinstance(result[0], DataField)
|
||||
assert result[0].data.ndim == 2
|
||||
# result[0] is the FILE_PATH string, fields follow
|
||||
assert len(result) >= 2
|
||||
assert isinstance(result[0], str)
|
||||
assert isinstance(result[1], DataField)
|
||||
assert result[1].data.ndim == 2
|
||||
|
||||
result_ibw = node.load(name="whiskers.ibw")
|
||||
assert len(result_ibw) == 4
|
||||
for field in result_ibw:
|
||||
assert isinstance(field, DataField)
|
||||
ibw_path = DEMO_DIR / "whiskers.ibw"
|
||||
if ibw_path.exists():
|
||||
result_ibw = node.load(name="whiskers.ibw")
|
||||
fields = [v for v in result_ibw if isinstance(v, DataField)]
|
||||
assert len(fields) == 4
|
||||
for field in fields:
|
||||
assert isinstance(field, DataField)
|
||||
|
||||
try:
|
||||
node.load(name="nonexistent_file.png")
|
||||
@@ -36,21 +42,26 @@ def test_load_demo_cache():
|
||||
Image._load_fields_cached.cache_clear()
|
||||
|
||||
with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader:
|
||||
first, = node.load(name="nanoparticles.npy")
|
||||
second, = node.load(name="nanoparticles.npy")
|
||||
_, first = node.load(name="nanoparticles.npy")
|
||||
_, second = node.load(name="nanoparticles.npy")
|
||||
assert loader.call_count == 1
|
||||
|
||||
assert np.allclose(first.data, second.data)
|
||||
assert first is not second
|
||||
first.data[0, 0] = -999.0
|
||||
|
||||
third, = node.load(name="nanoparticles.npy")
|
||||
_, third = node.load(name="nanoparticles.npy")
|
||||
assert third.data[0, 0] != -999.0
|
||||
|
||||
Image._load_fields_cached.cache_clear()
|
||||
|
||||
|
||||
def test_load_demo_multi_layer_preview_payload():
|
||||
from backend.nodes.helpers import DEMO_DIR
|
||||
ibw_path = DEMO_DIR / "whiskers.ibw"
|
||||
if not ibw_path.exists():
|
||||
return
|
||||
|
||||
previews = []
|
||||
prompt = {
|
||||
"1": {
|
||||
|
||||
@@ -52,3 +52,87 @@ def test_line_correction():
|
||||
assert np.allclose(leveled.data + poly_bg.data, poly_field.data)
|
||||
assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995
|
||||
assert len(poly_shifts) == rows
|
||||
|
||||
|
||||
def test_line_correction_methods():
|
||||
from backend.nodes.line_correction import LineCorrection
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
node = LineCorrection()
|
||||
|
||||
rows, cols = 64, 80
|
||||
rng = np.random.default_rng(7)
|
||||
signal = rng.standard_normal((rows, cols)) * 0.1
|
||||
row_offsets = rng.standard_normal(rows) * 2.0
|
||||
data = signal + row_offsets[:, None]
|
||||
field = make_field(data=data)
|
||||
|
||||
# median_diff
|
||||
c, b, s = node.process(field, method="median_diff", direction="horizontal",
|
||||
masking="ignore", trim_fraction=0.05, polynomial_degree=1)
|
||||
assert np.allclose(c.data + b.data, field.data)
|
||||
assert len(s) == rows
|
||||
|
||||
# trimmed_mean
|
||||
c, b, s = node.process(field, method="trimmed_mean", direction="horizontal",
|
||||
masking="ignore", trim_fraction=0.2, polynomial_degree=1)
|
||||
assert np.allclose(c.data + b.data, field.data)
|
||||
|
||||
# trimmed_diff
|
||||
c, b, s = node.process(field, method="trimmed_diff", direction="horizontal",
|
||||
masking="ignore", trim_fraction=0.2, polynomial_degree=1)
|
||||
assert np.allclose(c.data + b.data, field.data)
|
||||
|
||||
# step
|
||||
c, b, s = node.process(field, method="step", direction="horizontal",
|
||||
masking="ignore", trim_fraction=0.05, polynomial_degree=1)
|
||||
assert np.allclose(c.data + b.data, field.data)
|
||||
assert len(s) == rows
|
||||
|
||||
|
||||
def test_line_correction_vertical():
|
||||
from backend.nodes.line_correction import LineCorrection
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
node = LineCorrection()
|
||||
|
||||
rows, cols = 48, 64
|
||||
col_offsets = np.random.default_rng(3).standard_normal(cols) * 1.5
|
||||
data = np.random.default_rng(3).standard_normal((rows, cols)) * 0.1 + col_offsets[None, :]
|
||||
field = make_field(data=data)
|
||||
|
||||
c, b, s = node.process(field, method="median", direction="vertical",
|
||||
masking="ignore", trim_fraction=0.05, polynomial_degree=1)
|
||||
assert c.data.shape == field.data.shape
|
||||
assert np.allclose(c.data + b.data, field.data)
|
||||
# vertical shift line length = number of columns
|
||||
assert len(s) == cols
|
||||
assert s.x_axis is not None
|
||||
assert np.isclose(s.x_axis[-1], field.xreal)
|
||||
|
||||
|
||||
def test_line_correction_with_mask():
|
||||
from backend.nodes.line_correction import LineCorrection
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
node = LineCorrection()
|
||||
|
||||
rows, cols = 32, 48
|
||||
data = np.random.default_rng(9).standard_normal((rows, cols)) * 0.1
|
||||
row_offsets = np.linspace(0, 3.0, rows)
|
||||
data += row_offsets[:, None]
|
||||
field = make_field(data=data)
|
||||
|
||||
# mask covers right half
|
||||
mask = np.zeros((rows, cols), dtype=np.uint8)
|
||||
mask[:, cols // 2:] = 255
|
||||
|
||||
c_excl, b_excl, _ = node.process(field, method="median", direction="horizontal",
|
||||
masking="exclude", trim_fraction=0.05,
|
||||
polynomial_degree=1, mask=mask)
|
||||
assert np.allclose(c_excl.data + b_excl.data, field.data)
|
||||
|
||||
c_incl, b_incl, _ = node.process(field, method="median", direction="horizontal",
|
||||
masking="include", trim_fraction=0.05,
|
||||
polynomial_degree=1, mask=mask)
|
||||
assert np.allclose(c_incl.data + b_incl.data, field.data)
|
||||
|
||||
@@ -23,14 +23,14 @@ def test_threshold_mask():
|
||||
assert len(previews) == 1
|
||||
assert previews[0].startswith("data:image/png;base64,")
|
||||
|
||||
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
|
||||
mask_below, _ = node.process(field, method="absolute", threshold=0.5, direction="below")
|
||||
assert np.all(mask_below[:, :32] == 255)
|
||||
assert np.all(mask_below[:, 32:] == 0)
|
||||
|
||||
mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above")
|
||||
mask_rel, _ = node.process(field, method="relative", threshold=0.5, direction="above")
|
||||
assert np.all(mask_rel[:, 32:] == 255)
|
||||
|
||||
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
mask_otsu, _ = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
|
||||
|
||||
ThresholdMask._broadcast_fn = None
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import math
|
||||
from backend.data_types import RecordTable
|
||||
from backend.execution_context import active_node, execution_callbacks
|
||||
|
||||
|
||||
def test_value_display():
|
||||
from backend.nodes.value_io import ValueIO
|
||||
|
||||
node = ValueIO()
|
||||
value_spec = ValueIO.INPUT_TYPES()["required"]["value"]
|
||||
value_spec = ValueIO.INPUT_TYPES()["optional"]["value"]
|
||||
assert value_spec[0] == "FLOAT"
|
||||
assert value_spec[1]["accepted_types"] == ["RECORD_TABLE"]
|
||||
|
||||
@@ -13,7 +15,7 @@ def test_value_display():
|
||||
ValueIO._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
|
||||
ValueIO._current_node_id = "test"
|
||||
|
||||
result = node.display_value(3.25)
|
||||
result = node.display_value(value=3.25)
|
||||
assert result == (3.25,)
|
||||
assert captured == [("test", {"value": 3.25})]
|
||||
|
||||
@@ -21,8 +23,42 @@ def test_value_display():
|
||||
{"quantity": "delta X", "value": 1.7e-7, "unit": "m"},
|
||||
{"quantity": "delta Y", "value": 463, "unit": "count"},
|
||||
])
|
||||
result = node.display_value(measurements, measurement="delta X")
|
||||
result = node.display_value(value=measurements, measurement="delta X")
|
||||
assert result == (1.7e-7,)
|
||||
assert captured[-1] == ("test", {"value": 1.7e-7, "unit": "m"})
|
||||
|
||||
ValueIO._broadcast_value_fn = None
|
||||
|
||||
|
||||
def test_value_display_string_input():
|
||||
from backend.nodes.value_io import ValueIO
|
||||
|
||||
node = ValueIO()
|
||||
values = []
|
||||
with execution_callbacks(value=lambda nid, v: values.append(v)), active_node("n1"):
|
||||
# plain number
|
||||
result = node.display_value(number_input="42")
|
||||
assert result == (42.0,)
|
||||
assert values[-1]["value"] == 42.0
|
||||
|
||||
values.clear()
|
||||
with execution_callbacks(value=lambda nid, v: values.append(v)), active_node("n1"):
|
||||
# negative number
|
||||
result = node.display_value(number_input="-3.14")
|
||||
assert math.isclose(result[0], -3.14)
|
||||
assert math.isclose(values[-1]["value"], -3.14)
|
||||
|
||||
|
||||
def test_value_display_table_emits_table():
|
||||
from backend.nodes.value_io import ValueIO
|
||||
|
||||
node = ValueIO()
|
||||
tables = []
|
||||
measurements = RecordTable([
|
||||
{"quantity": "Rq", "value": 0.42, "unit": "nm"},
|
||||
])
|
||||
with execution_callbacks(table=lambda nid, t: tables.append(t)), active_node("n1"):
|
||||
result = node.display_value(value=measurements, measurement="Rq")
|
||||
assert result == (0.42,)
|
||||
assert len(tables) == 1
|
||||
assert tables[0][0]["quantity"] == "Rq"
|
||||
|
||||
Reference in New Issue
Block a user