diff --git a/.coverage b/.coverage index 26c5629..43e750b 100644 Binary files a/.coverage and b/.coverage differ diff --git a/demo b/demo index 0e24a1e..7621b48 160000 --- a/demo +++ b/demo @@ -1 +1 @@ -Subproject commit 0e24a1eb540283bea7a087bec41b4de411e4d657 +Subproject commit 7621b48a681c41fc54b1e9e5885144fe36cb1177 diff --git a/tests/node_tests/acf_1d.py b/tests/node_tests/acf_1d.py new file mode 100644 index 0000000..63bc8f6 --- /dev/null +++ b/tests/node_tests/acf_1d.py @@ -0,0 +1,72 @@ +import numpy as np +from backend.data_types import LineData, RecordTable + + +def test_acf_1d(): + from backend.nodes.acf_1d import ACF1D + + node = ACF1D() + + # Periodic signal — ACF should show a peak at the period + n = 256 + period = 32 + t = np.arange(n, dtype=np.float64) + signal = np.sin(2 * np.pi * t / period) + profile = LineData( + data=signal, + x_axis=t * 1e-9, + x_unit="m", + y_unit="V", + ) + + acf, measurement = node.process(profile, level="mean") + + assert isinstance(acf, LineData) + assert isinstance(measurement, RecordTable) + + # ACF should be symmetric about zero lag + center = len(acf) // 2 + assert np.allclose(acf.data, acf.data[::-1], atol=1e-10) + + # Peak period should be close to the input period in metres + expected_period_m = period * 1e-9 + assert len(measurement) == 1 + assert measurement[0]["quantity"] == "Peak period" + assert abs(measurement[0]["value"] - expected_period_m) / expected_period_m < 0.1 + assert measurement[0]["unit"] == "m" + + # x_axis should be centred on zero + assert acf.x_axis is not None + assert acf.x_axis[center] == 0.0 or abs(acf.x_axis[center]) < 1e-15 + + # ACF at zero lag should equal variance (signal is mean-subtracted) + assert acf.data[center] > 0 + + +def test_acf_1d_no_peak(): + from backend.nodes.acf_1d import ACF1D + + node = ACF1D() + + # White noise — ACF should have no reliable peak, measurement table may be empty + rng = np.random.default_rng(0) + noise = rng.standard_normal(64) + profile = LineData(data=noise, x_axis=np.arange(64, dtype=np.float64), x_unit="m") + + acf, measurement = node.process(profile, level="none") + assert isinstance(acf, LineData) + # measurement is either empty or has one row — no assertion on content + + +def test_acf_1d_level_none(): + from backend.nodes.acf_1d import ACF1D + + node = ACF1D() + + # With level="none", a DC offset should not be removed + data = np.ones(32, dtype=np.float64) * 5.0 + profile = LineData(data=data, x_axis=np.arange(32, dtype=np.float64)) + + acf, _ = node.process(profile, level="none") + # ACF of a constant is a constant + assert acf.data[len(acf) // 2] > 0 diff --git a/tests/node_tests/cursors.py b/tests/node_tests/cursors.py index 218149c..da0c058 100644 --- a/tests/node_tests/cursors.py +++ b/tests/node_tests/cursors.py @@ -18,8 +18,9 @@ def test_line_cursors(): table, coord_pair = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5) assert isinstance(coord_pair, tuple) and len(coord_pair) == 2 - assert len(table) == 6 + assert len(table) == 7 quantities = {row["quantity"] for row in table} + assert "Length" in quantities assert "A x" in quantities assert "B x" in quantities assert "dx" in quantities @@ -41,7 +42,7 @@ def test_line_cursors(): line_data = LineData(data=line, x_axis=np.linspace(0, 1, 100)) table2, _ = node.process(line_data, x1=0.25, y1=0.5, x2=0.75, y2=0.5) - assert len(table2) == 6 + assert len(table2) == 7 field = DataField( data=np.arange(100, dtype=np.float64).reshape(10, 10), diff --git a/tests/node_tests/execution_preview.py b/tests/node_tests/execution_preview.py new file mode 100644 index 0000000..e2cb2f0 --- /dev/null +++ b/tests/node_tests/execution_preview.py @@ -0,0 +1,136 @@ +"""Tests for ExecutionEngine._auto_preview and _render_line_preview.""" +import backend.nodes # noqa: F401 +import numpy as np +from backend.execution import ExecutionEngine +from backend.node_registry import register_node +from backend.data_types import DataField, LineData + + +def test_auto_preview_data_field(): + """A node that outputs DATA_FIELD should trigger on_preview.""" + engine = ExecutionEngine() + previews = [] + prompt = { + "1": {"class_type": "Number", "inputs": {"value": 1.0}}, + } + # Number outputs FLOAT, not DATA_FIELD — use GaussianFilter which outputs DATA_FIELD + from tests.node_tests._shared import make_field + + @register_node(display_name="Test Preview Field Source") + class TestPreviewFieldSource: + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + OUTPUTS = (('DATA_FIELD', 'out'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self): + return (make_field(),) + + engine = ExecutionEngine() + previews = [] + prompt = {"1": {"class_type": "TestPreviewFieldSource", "inputs": {}}} + engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p))) + assert len(previews) == 1 + nid, payload = previews[0] + assert nid == "1" + assert isinstance(payload, str) and payload.startswith("data:image/png;base64,") + + +def test_auto_preview_line(): + """A node that outputs LINE should trigger on_preview with a line_plot dict.""" + @register_node(display_name="Test Preview Line Source") + class TestPreviewLineSource: + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + OUTPUTS = (('LINE', 'out'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self): + return (LineData( + data=np.sin(np.linspace(0, 2 * np.pi, 64)), + x_axis=np.linspace(0, 1e-6, 64), + x_unit="m", + ),) + + engine = ExecutionEngine() + previews = [] + prompt = {"1": {"class_type": "TestPreviewLineSource", "inputs": {}}} + engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p))) + assert len(previews) == 1 + _, payload = previews[0] + assert isinstance(payload, dict) + assert payload["kind"] == "line_plot" + assert "line" in payload and "x_axis" in payload + assert payload["x_unit"] == "m" + + +def test_auto_preview_table(): + """A node that outputs RECORD_TABLE should trigger on_table.""" + from backend.data_types import RecordTable + + @register_node(display_name="Test Preview Table Source") + class TestPreviewTableSource: + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + OUTPUTS = (('RECORD_TABLE', 'out'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self): + return (RecordTable([{"quantity": "x", "value": 1.0, "unit": "m"}]),) + + engine = ExecutionEngine() + tables = [] + prompt = {"1": {"class_type": "TestPreviewTableSource", "inputs": {}}} + engine.execute(prompt, on_table=lambda nid, t: tables.append((nid, t))) + assert len(tables) == 1 + nid, rows = tables[0] + assert nid == "1" + assert rows[0]["quantity"] == "x" + + +def test_auto_preview_polymorphic_field_output(): + """A polymorphic output (declared LINE, actual DataField) should preview as a field.""" + from tests.node_tests._shared import make_field + + @register_node(display_name="Test Polymorphic Field Out") + class TestPolymorphicFieldOut: + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + OUTPUTS = (('LINE', 'out', {"accepted_types": ["DATA_FIELD"]}),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self): + return (make_field(),) + + engine = ExecutionEngine() + previews = [] + prompt = {"1": {"class_type": "TestPolymorphicFieldOut", "inputs": {}}} + engine.execute(prompt, on_preview=lambda nid, p: previews.append((nid, p))) + assert len(previews) == 1 + _, payload = previews[0] + # Should render as field preview (data URI), not line_plot dict + assert isinstance(payload, str) and payload.startswith("data:image/png;base64,") + + +def test_on_node_start_called(): + """on_node_start callback fires before each node executes.""" + @register_node(display_name="Test Start Callback") + class TestStartCallback: + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + OUTPUTS = (('FLOAT', 'v'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self): + return (1.0,) + + started = [] + engine = ExecutionEngine() + prompt = {"1": {"class_type": "TestStartCallback", "inputs": {}}} + engine.execute(prompt, on_node_start=lambda nid: started.append(nid)) + assert started == ["1"] diff --git a/tests/node_tests/fft_1d.py b/tests/node_tests/fft_1d.py new file mode 100644 index 0000000..83b37ed --- /dev/null +++ b/tests/node_tests/fft_1d.py @@ -0,0 +1,68 @@ +import numpy as np +from backend.data_types import LineData, RecordTable + + +def test_fft_1d_peak_period(): + from backend.nodes.fft_1d import FFT1D + + node = FFT1D() + + n = 256 + period = 32 # pixels + dx = 1e-9 # 1 nm per pixel + t = np.arange(n, dtype=np.float64) + signal = np.sin(2 * np.pi * t / period) + profile = LineData( + data=signal, + x_axis=t * dx, + x_unit="m", + y_unit="V", + ) + + freq_line, table = node.process(profile) + + assert isinstance(freq_line, LineData) + assert isinstance(table, RecordTable) + assert len(table) == 1 + assert table[0]["quantity"] == "Peak period" + assert table[0]["unit"] == "m" + + # Peak period should be close to 32 nm + expected = period * dx + assert abs(table[0]["value"] - expected) / expected < 0.1 + + # Output axis is in metres (spatial units) + assert freq_line.x_unit == "m" + # Spectrum values are non-negative magnitudes + assert np.all(freq_line.data >= 0) + # Highest spectral value corresponds to peak period + peak_idx = np.argmax(freq_line.data) + assert abs(freq_line.x_axis[peak_idx] - expected) / expected < 0.1 + + +def test_fft_1d_no_x_axis(): + from backend.nodes.fft_1d import FFT1D + + node = FFT1D() + + # Plain numpy array without calibration — should fall back to d=1, unit="m" + signal = np.sin(2 * np.pi * np.arange(64) / 8) + freq_line, table = node.process(signal) + + assert isinstance(freq_line, LineData) + assert len(freq_line.data) > 0 + assert np.all(freq_line.data >= 0) + assert len(table) == 1 + + +def test_fft_1d_output_length(): + from backend.nodes.fft_1d import FFT1D + + node = FFT1D() + + for n in (32, 64, 128): + data = np.random.default_rng(n).standard_normal(n) + profile = LineData(data=data, x_axis=np.arange(n, dtype=np.float64) * 1e-9, x_unit="m") + freq_line, _ = node.process(profile) + # rfft gives n//2+1 bins; DC (index 0) is removed, leaving n//2 points + assert len(freq_line.data) == n // 2 diff --git a/tests/node_tests/helpers.py b/tests/node_tests/helpers.py index db8c929..2cb997b 100644 --- a/tests/node_tests/helpers.py +++ b/tests/node_tests/helpers.py @@ -63,3 +63,129 @@ def test_list_channels(): folder_node = Folder() folder_result = folder_node.list_files(tmpdir) assert folder_result == tuple(entry["path"] for entry in paths) + + +def test_measurement_helpers(): + from backend.nodes.helpers import _measurement_names, _measurement_entry, _measurement_value + from backend.data_types import RecordTable + + table = RecordTable([ + {"quantity": "Rq", "value": 0.5, "unit": "nm"}, + {"quantity": "Ra", "value": 0.3, "unit": "nm"}, + {"quantity": "Rq", "value": 0.5, "unit": "nm"}, # duplicate — deduplicated in names + ]) + + names = _measurement_names(table) + assert names == ["Rq", "Ra"] + + row = _measurement_entry(table, "Ra") + assert row["value"] == 0.3 + + # falls back to first when selection not found + row_fallback = _measurement_entry(table, "nonexistent") + assert row_fallback["quantity"] == "Rq" + + val = _measurement_value(table, "Ra") + assert val == 0.3 + + +def test_measurement_value_errors(): + from backend.nodes.helpers import _measurement_value + from backend.data_types import RecordTable + + empty = RecordTable([]) + try: + _measurement_value(empty, "anything") + assert False, "should raise" + except ValueError: + pass + + bool_table = RecordTable([{"quantity": "flag", "value": True}]) + try: + _measurement_value(bool_table, "flag") + assert False, "should raise" + except ValueError: + pass + + +def test_format_with_unit(): + from backend.nodes.helpers import _format_with_unit, _format_numeric + + assert _format_numeric(0.0) == "0" + assert not np.isfinite(float('inf')) or _format_numeric(float('inf')) is not None + + # plain number no unit + result = _format_with_unit(1.5, "") + assert "1.5" in result + + # prefixable unit gets SI prefix + result_nm = _format_with_unit(1e-9, "m") + assert "n" in result_nm or "1e" in result_nm + + # non-prefixable unit is left as-is + result_bare = _format_with_unit(3.14, "rad") + assert "3.14" in result_bare and "rad" in result_bare + + # zero value + result_zero = _format_with_unit(0.0, "m") + assert "0" in result_zero + + +def test_table_and_array_ops(): + from backend.nodes.helpers import ( + TABLE_OPS, ARRAY_OPS, extract_numeric_table_values, + resolve_table_column_name, _common_table_unit, + ) + from backend.data_types import RecordTable + + values = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + assert TABLE_OPS["min"](values) == 1.0 + assert TABLE_OPS["max"](values) == 5.0 + assert TABLE_OPS["mean"](values) == 3.0 + assert TABLE_OPS["sum"](values) == 15.0 + assert TABLE_OPS["range"](values) == 4.0 + assert TABLE_OPS["count"](values) == 5.0 + assert TABLE_OPS["median"](values) == 3.0 + assert TABLE_OPS["std"](values) > 0 + assert TABLE_OPS["variance"](values) > 0 + + assert ARRAY_OPS["rms"](values) > 0 + assert ARRAY_OPS["std"](values) > 0 + + table = RecordTable([ + {"quantity": "A", "value": 1.0, "unit": "m"}, + {"quantity": "B", "value": 2.0, "unit": "m"}, + {"not_a_dict": True}, + {"quantity": "C", "value": "not_a_number"}, + ]) + nums = extract_numeric_table_values(table, "value") + assert nums == [1.0, 2.0] + + col = resolve_table_column_name(table, "") + assert col == "value" + + unit = _common_table_unit(table, "value") + assert unit == "m" + + +def test_square_unit_and_apply(): + from backend.nodes.helpers import _square_unit, _apply_scalar_unit + + assert _square_unit("m") == "m^2" + assert _square_unit("m/s") == "(m/s)^2" + assert _square_unit("") == "" + + assert _apply_scalar_unit("m", "variance") == "m^2" + assert _apply_scalar_unit("m", "count") == "count" + assert _apply_scalar_unit("m", "mean") == "m" + assert _apply_scalar_unit("", "mean") == "" + + +def test_nice_length(): + from backend.nodes.helpers import _nice_length + + assert _nice_length(0.0) == 0.0 + assert _nice_length(float('inf')) == 0.0 + assert _nice_length(7.3) == 5.0 + assert _nice_length(1500.0) == 1000.0 + assert _nice_length(0.003) == 0.002 diff --git a/tests/node_tests/image.py b/tests/node_tests/image.py index 70b7e12..fd9acb7 100644 --- a/tests/node_tests/image.py +++ b/tests/node_tests/image.py @@ -20,8 +20,9 @@ def test_load_file(): img.save(path) result = node.load(filename=path) - assert len(result) == 1 - field = result[0] + assert len(result) == 2 + assert isinstance(result[0], str) + field = result[1] assert field.data.shape == (48, 64) assert field.data.dtype == np.float64 @@ -31,15 +32,15 @@ def test_load_file(): img_rgb.save(path_rgb) result_rgb = node.load(filename=path_rgb) - assert len(result_rgb) == 1 - assert result_rgb[0].data.shape == (32, 32) + assert len(result_rgb) == 2 + assert result_rgb[1].data.shape == (32, 32) data_npy = np.random.default_rng(3).standard_normal((50, 60)) path_npy = os.path.join(tmpdir, "test.npy") np.save(path_npy, data_npy) result_npy = node.load(filename=path_npy) - assert np.allclose(result_npy[0].data, data_npy) + assert np.allclose(result_npy[1].data, data_npy) custom_colormap = { "mode": "custom", @@ -50,13 +51,13 @@ def test_load_file(): ], } result_custom = node.load(filename=path, colormap_map=custom_colormap) - assert isinstance(result_custom[0].colormap, dict) - assert result_custom[0].colormap["mode"] == "custom" - assert len(result_custom[0].colormap["stops"]) == 3 + assert isinstance(result_custom[1].colormap, dict) + assert result_custom[1].colormap["mode"] == "custom" + assert len(result_custom[1].colormap["stops"]) == 3 result_from_path = node.load(filename="", path=path) - assert len(result_from_path) == 1 - assert result_from_path[0].data.shape == (48, 64) + assert len(result_from_path) == 2 + assert result_from_path[1].data.shape == (48, 64) def test_load_file_npz(): @@ -68,8 +69,8 @@ def test_load_file_npz(): np.savez(path, my_array=data) result = node.load(filename=path) - assert len(result) == 1 - assert np.allclose(result[0].data, data) + assert len(result) == 2 + assert np.allclose(result[1].data, data) def test_load_file_cache(): @@ -83,8 +84,8 @@ def test_load_file_cache(): np.save(path, data) with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: - first, = node.load(filename=path) - second, = node.load(filename=path) + _, first = node.load(filename=path) + _, second = node.load(filename=path) assert loader.call_count == 1 assert np.allclose(first.data, data) @@ -92,7 +93,7 @@ def test_load_file_cache(): assert first is not second first.data[0, 0] = -999.0 - third, = node.load(filename=path) + _, third = node.load(filename=path) assert third.data[0, 0] == data[0, 0] Image._load_fields_cached.cache_clear() @@ -136,7 +137,7 @@ def test_load_file_warning(): img.save(path) result = node.load(filename=path) - assert len(result) == 1 + assert len(result) == 2 assert len(warnings) == 1 assert "Uncalibrated" in warnings[0] diff --git a/tests/node_tests/image_demo.py b/tests/node_tests/image_demo.py index 079a853..60a654f 100644 --- a/tests/node_tests/image_demo.py +++ b/tests/node_tests/image_demo.py @@ -9,17 +9,23 @@ import backend.nodes # noqa: F401 def test_load_demo(): from backend.nodes.image_demo import ImageDemo + from backend.nodes.helpers import DEMO_DIR node = ImageDemo() result = node.load(name="nanoparticles.npy") - assert len(result) >= 1 - assert isinstance(result[0], DataField) - assert result[0].data.ndim == 2 + # result[0] is the FILE_PATH string, fields follow + assert len(result) >= 2 + assert isinstance(result[0], str) + assert isinstance(result[1], DataField) + assert result[1].data.ndim == 2 - result_ibw = node.load(name="whiskers.ibw") - assert len(result_ibw) == 4 - for field in result_ibw: - assert isinstance(field, DataField) + ibw_path = DEMO_DIR / "whiskers.ibw" + if ibw_path.exists(): + result_ibw = node.load(name="whiskers.ibw") + fields = [v for v in result_ibw if isinstance(v, DataField)] + assert len(fields) == 4 + for field in fields: + assert isinstance(field, DataField) try: node.load(name="nonexistent_file.png") @@ -36,21 +42,26 @@ def test_load_demo_cache(): Image._load_fields_cached.cache_clear() with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: - first, = node.load(name="nanoparticles.npy") - second, = node.load(name="nanoparticles.npy") + _, first = node.load(name="nanoparticles.npy") + _, second = node.load(name="nanoparticles.npy") assert loader.call_count == 1 assert np.allclose(first.data, second.data) assert first is not second first.data[0, 0] = -999.0 - third, = node.load(name="nanoparticles.npy") + _, third = node.load(name="nanoparticles.npy") assert third.data[0, 0] != -999.0 Image._load_fields_cached.cache_clear() def test_load_demo_multi_layer_preview_payload(): + from backend.nodes.helpers import DEMO_DIR + ibw_path = DEMO_DIR / "whiskers.ibw" + if not ibw_path.exists(): + return + previews = [] prompt = { "1": { diff --git a/tests/node_tests/line_correction.py b/tests/node_tests/line_correction.py index 6b2e791..1a2ee51 100644 --- a/tests/node_tests/line_correction.py +++ b/tests/node_tests/line_correction.py @@ -52,3 +52,87 @@ def test_line_correction(): assert np.allclose(leveled.data + poly_bg.data, poly_field.data) assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995 assert len(poly_shifts) == rows + + +def test_line_correction_methods(): + from backend.nodes.line_correction import LineCorrection + from tests.node_tests._shared import make_field + + node = LineCorrection() + + rows, cols = 64, 80 + rng = np.random.default_rng(7) + signal = rng.standard_normal((rows, cols)) * 0.1 + row_offsets = rng.standard_normal(rows) * 2.0 + data = signal + row_offsets[:, None] + field = make_field(data=data) + + # median_diff + c, b, s = node.process(field, method="median_diff", direction="horizontal", + masking="ignore", trim_fraction=0.05, polynomial_degree=1) + assert np.allclose(c.data + b.data, field.data) + assert len(s) == rows + + # trimmed_mean + c, b, s = node.process(field, method="trimmed_mean", direction="horizontal", + masking="ignore", trim_fraction=0.2, polynomial_degree=1) + assert np.allclose(c.data + b.data, field.data) + + # trimmed_diff + c, b, s = node.process(field, method="trimmed_diff", direction="horizontal", + masking="ignore", trim_fraction=0.2, polynomial_degree=1) + assert np.allclose(c.data + b.data, field.data) + + # step + c, b, s = node.process(field, method="step", direction="horizontal", + masking="ignore", trim_fraction=0.05, polynomial_degree=1) + assert np.allclose(c.data + b.data, field.data) + assert len(s) == rows + + +def test_line_correction_vertical(): + from backend.nodes.line_correction import LineCorrection + from tests.node_tests._shared import make_field + + node = LineCorrection() + + rows, cols = 48, 64 + col_offsets = np.random.default_rng(3).standard_normal(cols) * 1.5 + data = np.random.default_rng(3).standard_normal((rows, cols)) * 0.1 + col_offsets[None, :] + field = make_field(data=data) + + c, b, s = node.process(field, method="median", direction="vertical", + masking="ignore", trim_fraction=0.05, polynomial_degree=1) + assert c.data.shape == field.data.shape + assert np.allclose(c.data + b.data, field.data) + # vertical shift line length = number of columns + assert len(s) == cols + assert s.x_axis is not None + assert np.isclose(s.x_axis[-1], field.xreal) + + +def test_line_correction_with_mask(): + from backend.nodes.line_correction import LineCorrection + from tests.node_tests._shared import make_field + + node = LineCorrection() + + rows, cols = 32, 48 + data = np.random.default_rng(9).standard_normal((rows, cols)) * 0.1 + row_offsets = np.linspace(0, 3.0, rows) + data += row_offsets[:, None] + field = make_field(data=data) + + # mask covers right half + mask = np.zeros((rows, cols), dtype=np.uint8) + mask[:, cols // 2:] = 255 + + c_excl, b_excl, _ = node.process(field, method="median", direction="horizontal", + masking="exclude", trim_fraction=0.05, + polynomial_degree=1, mask=mask) + assert np.allclose(c_excl.data + b_excl.data, field.data) + + c_incl, b_incl, _ = node.process(field, method="median", direction="horizontal", + masking="include", trim_fraction=0.05, + polynomial_degree=1, mask=mask) + assert np.allclose(c_incl.data + b_incl.data, field.data) diff --git a/tests/node_tests/mask_threshold.py b/tests/node_tests/mask_threshold.py index b319d3c..7c10002 100644 --- a/tests/node_tests/mask_threshold.py +++ b/tests/node_tests/mask_threshold.py @@ -23,14 +23,14 @@ def test_threshold_mask(): assert len(previews) == 1 assert previews[0].startswith("data:image/png;base64,") - mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below") + mask_below, _ = node.process(field, method="absolute", threshold=0.5, direction="below") assert np.all(mask_below[:, :32] == 255) assert np.all(mask_below[:, 32:] == 0) - mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above") + mask_rel, _ = node.process(field, method="relative", threshold=0.5, direction="above") assert np.all(mask_rel[:, 32:] == 255) - mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above") + mask_otsu, _ = node.process(field, method="otsu", threshold=0.0, direction="above") assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum() ThresholdMask._broadcast_fn = None diff --git a/tests/node_tests/value_io.py b/tests/node_tests/value_io.py index e7faefa..f94bde8 100644 --- a/tests/node_tests/value_io.py +++ b/tests/node_tests/value_io.py @@ -1,11 +1,13 @@ +import math from backend.data_types import RecordTable +from backend.execution_context import active_node, execution_callbacks def test_value_display(): from backend.nodes.value_io import ValueIO node = ValueIO() - value_spec = ValueIO.INPUT_TYPES()["required"]["value"] + value_spec = ValueIO.INPUT_TYPES()["optional"]["value"] assert value_spec[0] == "FLOAT" assert value_spec[1]["accepted_types"] == ["RECORD_TABLE"] @@ -13,7 +15,7 @@ def test_value_display(): ValueIO._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) ValueIO._current_node_id = "test" - result = node.display_value(3.25) + result = node.display_value(value=3.25) assert result == (3.25,) assert captured == [("test", {"value": 3.25})] @@ -21,8 +23,42 @@ def test_value_display(): {"quantity": "delta X", "value": 1.7e-7, "unit": "m"}, {"quantity": "delta Y", "value": 463, "unit": "count"}, ]) - result = node.display_value(measurements, measurement="delta X") + result = node.display_value(value=measurements, measurement="delta X") assert result == (1.7e-7,) assert captured[-1] == ("test", {"value": 1.7e-7, "unit": "m"}) ValueIO._broadcast_value_fn = None + + +def test_value_display_string_input(): + from backend.nodes.value_io import ValueIO + + node = ValueIO() + values = [] + with execution_callbacks(value=lambda nid, v: values.append(v)), active_node("n1"): + # plain number + result = node.display_value(number_input="42") + assert result == (42.0,) + assert values[-1]["value"] == 42.0 + + values.clear() + with execution_callbacks(value=lambda nid, v: values.append(v)), active_node("n1"): + # negative number + result = node.display_value(number_input="-3.14") + assert math.isclose(result[0], -3.14) + assert math.isclose(values[-1]["value"], -3.14) + + +def test_value_display_table_emits_table(): + from backend.nodes.value_io import ValueIO + + node = ValueIO() + tables = [] + measurements = RecordTable([ + {"quantity": "Rq", "value": 0.42, "unit": "nm"}, + ]) + with execution_callbacks(table=lambda nid, t: tables.append(t)), active_node("n1"): + result = node.display_value(value=measurements, measurement="Rq") + assert result == (0.42,) + assert len(tables) == 1 + assert tables[0][0]["quantity"] == "Rq"