improve coverage

This commit is contained in:
2026-03-29 19:13:59 -07:00
parent 2b17a2594f
commit 29eee8a42c
12 changed files with 570 additions and 35 deletions

BIN
.coverage

Binary file not shown.

2
demo

Submodule demo updated: 0e24a1eb54...7621b48a68

View File

@@ -0,0 +1,72 @@
import numpy as np
from backend.data_types import LineData, RecordTable
def test_acf_1d():
from backend.nodes.acf_1d import ACF1D
node = ACF1D()
# Periodic signal — ACF should show a peak at the period
n = 256
period = 32
t = np.arange(n, dtype=np.float64)
signal = np.sin(2 * np.pi * t / period)
profile = LineData(
data=signal,
x_axis=t * 1e-9,
x_unit="m",
y_unit="V",
)
acf, measurement = node.process(profile, level="mean")
assert isinstance(acf, LineData)
assert isinstance(measurement, RecordTable)
# ACF should be symmetric about zero lag
center = len(acf) // 2
assert np.allclose(acf.data, acf.data[::-1], atol=1e-10)
# Peak period should be close to the input period in metres
expected_period_m = period * 1e-9
assert len(measurement) == 1
assert measurement[0]["quantity"] == "Peak period"
assert abs(measurement[0]["value"] - expected_period_m) / expected_period_m < 0.1
assert measurement[0]["unit"] == "m"
# x_axis should be centred on zero
assert acf.x_axis is not None
assert acf.x_axis[center] == 0.0 or abs(acf.x_axis[center]) < 1e-15
# ACF at zero lag should equal variance (signal is mean-subtracted)
assert acf.data[center] > 0
def test_acf_1d_no_peak():
from backend.nodes.acf_1d import ACF1D
node = ACF1D()
# White noise — ACF should have no reliable peak, measurement table may be empty
rng = np.random.default_rng(0)
noise = rng.standard_normal(64)
profile = LineData(data=noise, x_axis=np.arange(64, dtype=np.float64), x_unit="m")
acf, measurement = node.process(profile, level="none")
assert isinstance(acf, LineData)
# measurement is either empty or has one row — no assertion on content
def test_acf_1d_level_none():
from backend.nodes.acf_1d import ACF1D
node = ACF1D()
# With level="none", a DC offset should not be removed
data = np.ones(32, dtype=np.float64) * 5.0
profile = LineData(data=data, x_axis=np.arange(32, dtype=np.float64))
acf, _ = node.process(profile, level="none")
# ACF of a constant is a constant
assert acf.data[len(acf) // 2] > 0

View File

@@ -18,8 +18,9 @@ def test_line_cursors():
table, coord_pair = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5) table, coord_pair = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
assert isinstance(coord_pair, tuple) and len(coord_pair) == 2 assert isinstance(coord_pair, tuple) and len(coord_pair) == 2
assert len(table) == 6 assert len(table) == 7
quantities = {row["quantity"] for row in table} quantities = {row["quantity"] for row in table}
assert "Length" in quantities
assert "A x" in quantities assert "A x" in quantities
assert "B x" in quantities assert "B x" in quantities
assert "dx" in quantities assert "dx" in quantities
@@ -41,7 +42,7 @@ def test_line_cursors():
line_data = LineData(data=line, x_axis=np.linspace(0, 1, 100)) line_data = LineData(data=line, x_axis=np.linspace(0, 1, 100))
table2, _ = node.process(line_data, x1=0.25, y1=0.5, x2=0.75, y2=0.5) table2, _ = node.process(line_data, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
assert len(table2) == 6 assert len(table2) == 7
field = DataField( field = DataField(
data=np.arange(100, dtype=np.float64).reshape(10, 10), data=np.arange(100, dtype=np.float64).reshape(10, 10),

View File

@@ -0,0 +1,136 @@
"""Tests for ExecutionEngine._auto_preview and _render_line_preview."""
import backend.nodes # noqa: F401
import numpy as np
from backend.execution import ExecutionEngine
from backend.node_registry import register_node
from backend.data_types import DataField, LineData
def test_auto_preview_data_field():
"""A node that outputs DATA_FIELD should trigger on_preview."""
engine = ExecutionEngine()
previews = []
prompt = {
"1": {"class_type": "Number", "inputs": {"value": 1.0}},
}
# Number outputs FLOAT, not DATA_FIELD — use GaussianFilter which outputs DATA_FIELD
from tests.node_tests._shared import make_field
@register_node(display_name="Test Preview Field Source")
class TestPreviewFieldSource:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
OUTPUTS = (('DATA_FIELD', 'out'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self):
return (make_field(),)
engine = ExecutionEngine()
previews = []
prompt = {"1": {"class_type": "TestPreviewFieldSource", "inputs": {}}}
engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p)))
assert len(previews) == 1
nid, payload = previews[0]
assert nid == "1"
assert isinstance(payload, str) and payload.startswith("data:image/png;base64,")
def test_auto_preview_line():
"""A node that outputs LINE should trigger on_preview with a line_plot dict."""
@register_node(display_name="Test Preview Line Source")
class TestPreviewLineSource:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
OUTPUTS = (('LINE', 'out'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self):
return (LineData(
data=np.sin(np.linspace(0, 2 * np.pi, 64)),
x_axis=np.linspace(0, 1e-6, 64),
x_unit="m",
),)
engine = ExecutionEngine()
previews = []
prompt = {"1": {"class_type": "TestPreviewLineSource", "inputs": {}}}
engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p)))
assert len(previews) == 1
_, payload = previews[0]
assert isinstance(payload, dict)
assert payload["kind"] == "line_plot"
assert "line" in payload and "x_axis" in payload
assert payload["x_unit"] == "m"
def test_auto_preview_table():
"""A node that outputs RECORD_TABLE should trigger on_table."""
from backend.data_types import RecordTable
@register_node(display_name="Test Preview Table Source")
class TestPreviewTableSource:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
OUTPUTS = (('RECORD_TABLE', 'out'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self):
return (RecordTable([{"quantity": "x", "value": 1.0, "unit": "m"}]),)
engine = ExecutionEngine()
tables = []
prompt = {"1": {"class_type": "TestPreviewTableSource", "inputs": {}}}
engine.execute(prompt, on_table=lambda nid, t: tables.append((nid, t)))
assert len(tables) == 1
nid, rows = tables[0]
assert nid == "1"
assert rows[0]["quantity"] == "x"
def test_auto_preview_polymorphic_field_output():
"""A polymorphic output (declared LINE, actual DataField) should preview as a field."""
from tests.node_tests._shared import make_field
@register_node(display_name="Test Polymorphic Field Out")
class TestPolymorphicFieldOut:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
OUTPUTS = (('LINE', 'out', {"accepted_types": ["DATA_FIELD"]}),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self):
return (make_field(),)
engine = ExecutionEngine()
previews = []
prompt = {"1": {"class_type": "TestPolymorphicFieldOut", "inputs": {}}}
engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p)))
assert len(previews) == 1
_, payload = previews[0]
# Should render as field preview (data URI), not line_plot dict
assert isinstance(payload, str) and payload.startswith("data:image/png;base64,")
def test_on_node_start_called():
"""on_node_start callback fires before each node executes."""
@register_node(display_name="Test Start Callback")
class TestStartCallback:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
OUTPUTS = (('FLOAT', 'v'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self):
return (1.0,)
started = []
engine = ExecutionEngine()
prompt = {"1": {"class_type": "TestStartCallback", "inputs": {}}}
engine.execute(prompt, on_node_start=lambda nid: started.append(nid))
assert started == ["1"]

View File

@@ -0,0 +1,68 @@
import numpy as np
from backend.data_types import LineData, RecordTable
def test_fft_1d_peak_period():
from backend.nodes.fft_1d import FFT1D
node = FFT1D()
n = 256
period = 32 # pixels
dx = 1e-9 # 1 nm per pixel
t = np.arange(n, dtype=np.float64)
signal = np.sin(2 * np.pi * t / period)
profile = LineData(
data=signal,
x_axis=t * dx,
x_unit="m",
y_unit="V",
)
freq_line, table = node.process(profile)
assert isinstance(freq_line, LineData)
assert isinstance(table, RecordTable)
assert len(table) == 1
assert table[0]["quantity"] == "Peak period"
assert table[0]["unit"] == "m"
# Peak period should be close to 32 nm
expected = period * dx
assert abs(table[0]["value"] - expected) / expected < 0.1
# Output axis is in metres (spatial units)
assert freq_line.x_unit == "m"
# Spectrum values are non-negative magnitudes
assert np.all(freq_line.data >= 0)
# Highest spectral value corresponds to peak period
peak_idx = np.argmax(freq_line.data)
assert abs(freq_line.x_axis[peak_idx] - expected) / expected < 0.1
def test_fft_1d_no_x_axis():
from backend.nodes.fft_1d import FFT1D
node = FFT1D()
# Plain numpy array without calibration — should fall back to d=1, unit="m"
signal = np.sin(2 * np.pi * np.arange(64) / 8)
freq_line, table = node.process(signal)
assert isinstance(freq_line, LineData)
assert len(freq_line.data) > 0
assert np.all(freq_line.data >= 0)
assert len(table) == 1
def test_fft_1d_output_length():
from backend.nodes.fft_1d import FFT1D
node = FFT1D()
for n in (32, 64, 128):
data = np.random.default_rng(n).standard_normal(n)
profile = LineData(data=data, x_axis=np.arange(n, dtype=np.float64) * 1e-9, x_unit="m")
freq_line, _ = node.process(profile)
# rfft gives n//2+1 bins; DC (index 0) is removed, leaving n//2 points
assert len(freq_line.data) == n // 2

View File

@@ -63,3 +63,129 @@ def test_list_channels():
folder_node = Folder() folder_node = Folder()
folder_result = folder_node.list_files(tmpdir) folder_result = folder_node.list_files(tmpdir)
assert folder_result == tuple(entry["path"] for entry in paths) assert folder_result == tuple(entry["path"] for entry in paths)
def test_measurement_helpers():
from backend.nodes.helpers import _measurement_names, _measurement_entry, _measurement_value
from backend.data_types import RecordTable
table = RecordTable([
{"quantity": "Rq", "value": 0.5, "unit": "nm"},
{"quantity": "Ra", "value": 0.3, "unit": "nm"},
{"quantity": "Rq", "value": 0.5, "unit": "nm"}, # duplicate — deduplicated in names
])
names = _measurement_names(table)
assert names == ["Rq", "Ra"]
row = _measurement_entry(table, "Ra")
assert row["value"] == 0.3
# falls back to first when selection not found
row_fallback = _measurement_entry(table, "nonexistent")
assert row_fallback["quantity"] == "Rq"
val = _measurement_value(table, "Ra")
assert val == 0.3
def test_measurement_value_errors():
from backend.nodes.helpers import _measurement_value
from backend.data_types import RecordTable
empty = RecordTable([])
try:
_measurement_value(empty, "anything")
assert False, "should raise"
except ValueError:
pass
bool_table = RecordTable([{"quantity": "flag", "value": True}])
try:
_measurement_value(bool_table, "flag")
assert False, "should raise"
except ValueError:
pass
def test_format_with_unit():
from backend.nodes.helpers import _format_with_unit, _format_numeric
assert _format_numeric(0.0) == "0"
assert not np.isfinite(float('inf')) or _format_numeric(float('inf')) is not None
# plain number no unit
result = _format_with_unit(1.5, "")
assert "1.5" in result
# prefixable unit gets SI prefix
result_nm = _format_with_unit(1e-9, "m")
assert "n" in result_nm or "1e" in result_nm
# non-prefixable unit is left as-is
result_bare = _format_with_unit(3.14, "rad")
assert "3.14" in result_bare and "rad" in result_bare
# zero value
result_zero = _format_with_unit(0.0, "m")
assert "0" in result_zero
def test_table_and_array_ops():
from backend.nodes.helpers import (
TABLE_OPS, ARRAY_OPS, extract_numeric_table_values,
resolve_table_column_name, _common_table_unit,
)
from backend.data_types import RecordTable
values = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
assert TABLE_OPS["min"](values) == 1.0
assert TABLE_OPS["max"](values) == 5.0
assert TABLE_OPS["mean"](values) == 3.0
assert TABLE_OPS["sum"](values) == 15.0
assert TABLE_OPS["range"](values) == 4.0
assert TABLE_OPS["count"](values) == 5.0
assert TABLE_OPS["median"](values) == 3.0
assert TABLE_OPS["std"](values) > 0
assert TABLE_OPS["variance"](values) > 0
assert ARRAY_OPS["rms"](values) > 0
assert ARRAY_OPS["std"](values) > 0
table = RecordTable([
{"quantity": "A", "value": 1.0, "unit": "m"},
{"quantity": "B", "value": 2.0, "unit": "m"},
{"not_a_dict": True},
{"quantity": "C", "value": "not_a_number"},
])
nums = extract_numeric_table_values(table, "value")
assert nums == [1.0, 2.0]
col = resolve_table_column_name(table, "")
assert col == "value"
unit = _common_table_unit(table, "value")
assert unit == "m"
def test_square_unit_and_apply():
from backend.nodes.helpers import _square_unit, _apply_scalar_unit
assert _square_unit("m") == "m^2"
assert _square_unit("m/s") == "(m/s)^2"
assert _square_unit("") == ""
assert _apply_scalar_unit("m", "variance") == "m^2"
assert _apply_scalar_unit("m", "count") == "count"
assert _apply_scalar_unit("m", "mean") == "m"
assert _apply_scalar_unit("", "mean") == ""
def test_nice_length():
from backend.nodes.helpers import _nice_length
assert _nice_length(0.0) == 0.0
assert _nice_length(float('inf')) == 0.0
assert _nice_length(7.3) == 5.0
assert _nice_length(1500.0) == 1000.0
assert _nice_length(0.003) == 0.002

View File

@@ -20,8 +20,9 @@ def test_load_file():
img.save(path) img.save(path)
result = node.load(filename=path) result = node.load(filename=path)
assert len(result) == 1 assert len(result) == 2
field = result[0] assert isinstance(result[0], str)
field = result[1]
assert field.data.shape == (48, 64) assert field.data.shape == (48, 64)
assert field.data.dtype == np.float64 assert field.data.dtype == np.float64
@@ -31,15 +32,15 @@ def test_load_file():
img_rgb.save(path_rgb) img_rgb.save(path_rgb)
result_rgb = node.load(filename=path_rgb) result_rgb = node.load(filename=path_rgb)
assert len(result_rgb) == 1 assert len(result_rgb) == 2
assert result_rgb[0].data.shape == (32, 32) assert result_rgb[1].data.shape == (32, 32)
data_npy = np.random.default_rng(3).standard_normal((50, 60)) data_npy = np.random.default_rng(3).standard_normal((50, 60))
path_npy = os.path.join(tmpdir, "test.npy") path_npy = os.path.join(tmpdir, "test.npy")
np.save(path_npy, data_npy) np.save(path_npy, data_npy)
result_npy = node.load(filename=path_npy) result_npy = node.load(filename=path_npy)
assert np.allclose(result_npy[0].data, data_npy) assert np.allclose(result_npy[1].data, data_npy)
custom_colormap = { custom_colormap = {
"mode": "custom", "mode": "custom",
@@ -50,13 +51,13 @@ def test_load_file():
], ],
} }
result_custom = node.load(filename=path, colormap_map=custom_colormap) result_custom = node.load(filename=path, colormap_map=custom_colormap)
assert isinstance(result_custom[0].colormap, dict) assert isinstance(result_custom[1].colormap, dict)
assert result_custom[0].colormap["mode"] == "custom" assert result_custom[1].colormap["mode"] == "custom"
assert len(result_custom[0].colormap["stops"]) == 3 assert len(result_custom[1].colormap["stops"]) == 3
result_from_path = node.load(filename="", path=path) result_from_path = node.load(filename="", path=path)
assert len(result_from_path) == 1 assert len(result_from_path) == 2
assert result_from_path[0].data.shape == (48, 64) assert result_from_path[1].data.shape == (48, 64)
def test_load_file_npz(): def test_load_file_npz():
@@ -68,8 +69,8 @@ def test_load_file_npz():
np.savez(path, my_array=data) np.savez(path, my_array=data)
result = node.load(filename=path) result = node.load(filename=path)
assert len(result) == 1 assert len(result) == 2
assert np.allclose(result[0].data, data) assert np.allclose(result[1].data, data)
def test_load_file_cache(): def test_load_file_cache():
@@ -83,8 +84,8 @@ def test_load_file_cache():
np.save(path, data) np.save(path, data)
with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader:
first, = node.load(filename=path) _, first = node.load(filename=path)
second, = node.load(filename=path) _, second = node.load(filename=path)
assert loader.call_count == 1 assert loader.call_count == 1
assert np.allclose(first.data, data) assert np.allclose(first.data, data)
@@ -92,7 +93,7 @@ def test_load_file_cache():
assert first is not second assert first is not second
first.data[0, 0] = -999.0 first.data[0, 0] = -999.0
third, = node.load(filename=path) _, third = node.load(filename=path)
assert third.data[0, 0] == data[0, 0] assert third.data[0, 0] == data[0, 0]
Image._load_fields_cached.cache_clear() Image._load_fields_cached.cache_clear()
@@ -136,7 +137,7 @@ def test_load_file_warning():
img.save(path) img.save(path)
result = node.load(filename=path) result = node.load(filename=path)
assert len(result) == 1 assert len(result) == 2
assert len(warnings) == 1 assert len(warnings) == 1
assert "Uncalibrated" in warnings[0] assert "Uncalibrated" in warnings[0]

View File

@@ -9,17 +9,23 @@ import backend.nodes # noqa: F401
def test_load_demo(): def test_load_demo():
from backend.nodes.image_demo import ImageDemo from backend.nodes.image_demo import ImageDemo
from backend.nodes.helpers import DEMO_DIR
node = ImageDemo() node = ImageDemo()
result = node.load(name="nanoparticles.npy") result = node.load(name="nanoparticles.npy")
assert len(result) >= 1 # result[0] is the FILE_PATH string, fields follow
assert isinstance(result[0], DataField) assert len(result) >= 2
assert result[0].data.ndim == 2 assert isinstance(result[0], str)
assert isinstance(result[1], DataField)
assert result[1].data.ndim == 2
result_ibw = node.load(name="whiskers.ibw") ibw_path = DEMO_DIR / "whiskers.ibw"
assert len(result_ibw) == 4 if ibw_path.exists():
for field in result_ibw: result_ibw = node.load(name="whiskers.ibw")
assert isinstance(field, DataField) fields = [v for v in result_ibw if isinstance(v, DataField)]
assert len(fields) == 4
for field in fields:
assert isinstance(field, DataField)
try: try:
node.load(name="nonexistent_file.png") node.load(name="nonexistent_file.png")
@@ -36,21 +42,26 @@ def test_load_demo_cache():
Image._load_fields_cached.cache_clear() Image._load_fields_cached.cache_clear()
with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader:
first, = node.load(name="nanoparticles.npy") _, first = node.load(name="nanoparticles.npy")
second, = node.load(name="nanoparticles.npy") _, second = node.load(name="nanoparticles.npy")
assert loader.call_count == 1 assert loader.call_count == 1
assert np.allclose(first.data, second.data) assert np.allclose(first.data, second.data)
assert first is not second assert first is not second
first.data[0, 0] = -999.0 first.data[0, 0] = -999.0
third, = node.load(name="nanoparticles.npy") _, third = node.load(name="nanoparticles.npy")
assert third.data[0, 0] != -999.0 assert third.data[0, 0] != -999.0
Image._load_fields_cached.cache_clear() Image._load_fields_cached.cache_clear()
def test_load_demo_multi_layer_preview_payload(): def test_load_demo_multi_layer_preview_payload():
from backend.nodes.helpers import DEMO_DIR
ibw_path = DEMO_DIR / "whiskers.ibw"
if not ibw_path.exists():
return
previews = [] previews = []
prompt = { prompt = {
"1": { "1": {

View File

@@ -52,3 +52,87 @@ def test_line_correction():
assert np.allclose(leveled.data + poly_bg.data, poly_field.data) assert np.allclose(leveled.data + poly_bg.data, poly_field.data)
assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995 assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995
assert len(poly_shifts) == rows assert len(poly_shifts) == rows
def test_line_correction_methods():
from backend.nodes.line_correction import LineCorrection
from tests.node_tests._shared import make_field
node = LineCorrection()
rows, cols = 64, 80
rng = np.random.default_rng(7)
signal = rng.standard_normal((rows, cols)) * 0.1
row_offsets = rng.standard_normal(rows) * 2.0
data = signal + row_offsets[:, None]
field = make_field(data=data)
# median_diff
c, b, s = node.process(field, method="median_diff", direction="horizontal",
masking="ignore", trim_fraction=0.05, polynomial_degree=1)
assert np.allclose(c.data + b.data, field.data)
assert len(s) == rows
# trimmed_mean
c, b, s = node.process(field, method="trimmed_mean", direction="horizontal",
masking="ignore", trim_fraction=0.2, polynomial_degree=1)
assert np.allclose(c.data + b.data, field.data)
# trimmed_diff
c, b, s = node.process(field, method="trimmed_diff", direction="horizontal",
masking="ignore", trim_fraction=0.2, polynomial_degree=1)
assert np.allclose(c.data + b.data, field.data)
# step
c, b, s = node.process(field, method="step", direction="horizontal",
masking="ignore", trim_fraction=0.05, polynomial_degree=1)
assert np.allclose(c.data + b.data, field.data)
assert len(s) == rows
def test_line_correction_vertical():
from backend.nodes.line_correction import LineCorrection
from tests.node_tests._shared import make_field
node = LineCorrection()
rows, cols = 48, 64
col_offsets = np.random.default_rng(3).standard_normal(cols) * 1.5
data = np.random.default_rng(3).standard_normal((rows, cols)) * 0.1 + col_offsets[None, :]
field = make_field(data=data)
c, b, s = node.process(field, method="median", direction="vertical",
masking="ignore", trim_fraction=0.05, polynomial_degree=1)
assert c.data.shape == field.data.shape
assert np.allclose(c.data + b.data, field.data)
# vertical shift line length = number of columns
assert len(s) == cols
assert s.x_axis is not None
assert np.isclose(s.x_axis[-1], field.xreal)
def test_line_correction_with_mask():
from backend.nodes.line_correction import LineCorrection
from tests.node_tests._shared import make_field
node = LineCorrection()
rows, cols = 32, 48
data = np.random.default_rng(9).standard_normal((rows, cols)) * 0.1
row_offsets = np.linspace(0, 3.0, rows)
data += row_offsets[:, None]
field = make_field(data=data)
# mask covers right half
mask = np.zeros((rows, cols), dtype=np.uint8)
mask[:, cols // 2:] = 255
c_excl, b_excl, _ = node.process(field, method="median", direction="horizontal",
masking="exclude", trim_fraction=0.05,
polynomial_degree=1, mask=mask)
assert np.allclose(c_excl.data + b_excl.data, field.data)
c_incl, b_incl, _ = node.process(field, method="median", direction="horizontal",
masking="include", trim_fraction=0.05,
polynomial_degree=1, mask=mask)
assert np.allclose(c_incl.data + b_incl.data, field.data)

View File

@@ -23,14 +23,14 @@ def test_threshold_mask():
assert len(previews) == 1 assert len(previews) == 1
assert previews[0].startswith("data:image/png;base64,") assert previews[0].startswith("data:image/png;base64,")
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below") mask_below, _ = node.process(field, method="absolute", threshold=0.5, direction="below")
assert np.all(mask_below[:, :32] == 255) assert np.all(mask_below[:, :32] == 255)
assert np.all(mask_below[:, 32:] == 0) assert np.all(mask_below[:, 32:] == 0)
mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above") mask_rel, _ = node.process(field, method="relative", threshold=0.5, direction="above")
assert np.all(mask_rel[:, 32:] == 255) assert np.all(mask_rel[:, 32:] == 255)
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above") mask_otsu, _ = node.process(field, method="otsu", threshold=0.0, direction="above")
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum() assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
ThresholdMask._broadcast_fn = None ThresholdMask._broadcast_fn = None

View File

@@ -1,11 +1,13 @@
import math
from backend.data_types import RecordTable from backend.data_types import RecordTable
from backend.execution_context import active_node, execution_callbacks
def test_value_display(): def test_value_display():
from backend.nodes.value_io import ValueIO from backend.nodes.value_io import ValueIO
node = ValueIO() node = ValueIO()
value_spec = ValueIO.INPUT_TYPES()["required"]["value"] value_spec = ValueIO.INPUT_TYPES()["optional"]["value"]
assert value_spec[0] == "FLOAT" assert value_spec[0] == "FLOAT"
assert value_spec[1]["accepted_types"] == ["RECORD_TABLE"] assert value_spec[1]["accepted_types"] == ["RECORD_TABLE"]
@@ -13,7 +15,7 @@ def test_value_display():
ValueIO._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) ValueIO._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
ValueIO._current_node_id = "test" ValueIO._current_node_id = "test"
result = node.display_value(3.25) result = node.display_value(value=3.25)
assert result == (3.25,) assert result == (3.25,)
assert captured == [("test", {"value": 3.25})] assert captured == [("test", {"value": 3.25})]
@@ -21,8 +23,42 @@ def test_value_display():
{"quantity": "delta X", "value": 1.7e-7, "unit": "m"}, {"quantity": "delta X", "value": 1.7e-7, "unit": "m"},
{"quantity": "delta Y", "value": 463, "unit": "count"}, {"quantity": "delta Y", "value": 463, "unit": "count"},
]) ])
result = node.display_value(measurements, measurement="delta X") result = node.display_value(value=measurements, measurement="delta X")
assert result == (1.7e-7,) assert result == (1.7e-7,)
assert captured[-1] == ("test", {"value": 1.7e-7, "unit": "m"}) assert captured[-1] == ("test", {"value": 1.7e-7, "unit": "m"})
ValueIO._broadcast_value_fn = None ValueIO._broadcast_value_fn = None
def test_value_display_string_input():
from backend.nodes.value_io import ValueIO
node = ValueIO()
values = []
with execution_callbacks(value=lambda nid, v: values.append(v)), active_node("n1"):
# plain number
result = node.display_value(number_input="42")
assert result == (42.0,)
assert values[-1]["value"] == 42.0
values.clear()
with execution_callbacks(value=lambda nid, v: values.append(v)), active_node("n1"):
# negative number
result = node.display_value(number_input="-3.14")
assert math.isclose(result[0], -3.14)
assert math.isclose(values[-1]["value"], -3.14)
def test_value_display_table_emits_table():
from backend.nodes.value_io import ValueIO
node = ValueIO()
tables = []
measurements = RecordTable([
{"quantity": "Rq", "value": 0.42, "unit": "nm"},
])
with execution_callbacks(table=lambda nid, t: tables.append(t)), active_node("n1"):
result = node.display_value(value=measurements, measurement="Rq")
assert result == (0.42,)
assert len(tables) == 1
assert tables[0][0]["quantity"] == "Rq"