From c3bb34d248b11e69584cbece13657bfd2361be65 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Sun, 29 Mar 2026 16:39:37 -0700 Subject: [PATCH] split node tests into standalone files --- tests/node_tests/__init__.py | 0 tests/node_tests/_shared.py | 9 + tests/node_tests/test_acf_2d.py | 39 ++++ tests/node_tests/test_angle_measure.py | 60 ++++++ tests/node_tests/test_annotations.py | 81 ++++++++ tests/node_tests/test_colormap.py | 25 +++ tests/node_tests/test_colormap_adjust.py | 35 ++++ tests/node_tests/test_coordinate.py | 14 ++ tests/node_tests/test_crop_resize.py | 60 ++++++ tests/node_tests/test_cross_section.py | 52 ++++++ tests/node_tests/test_cursors.py | 63 +++++++ tests/node_tests/test_curvature.py | 176 ++++++++++++++++++ tests/node_tests/test_edge_detect.py | 18 ++ tests/node_tests/test_execution.py | 138 ++++++++++++++ tests/node_tests/test_fft_2d.py | 46 +++++ tests/node_tests/test_filter_fft_1d.py | 33 ++++ tests/node_tests/test_filter_fft_2d.py | 31 +++ tests/node_tests/test_filter_gaussian.py | 16 ++ tests/node_tests/test_filter_median.py | 19 ++ tests/node_tests/test_fix_zero.py | 18 ++ tests/node_tests/test_flip.py | 52 ++++++ tests/node_tests/test_font.py | 14 ++ tests/node_tests/test_fractal_dimension.py | 60 ++++++ tests/node_tests/test_grain_analysis.py | 35 ++++ .../test_grain_distance_transform.py | 34 ++++ tests/node_tests/test_helpers.py | 65 +++++++ tests/node_tests/test_histogram.py | 40 ++++ tests/node_tests/test_image.py | 167 +++++++++++++++++ tests/node_tests/test_image_demo.py | 70 +++++++ tests/node_tests/test_level_facet.py | 68 +++++++ tests/node_tests/test_level_plane.py | 40 ++++ tests/node_tests/test_level_poly.py | 24 +++ tests/node_tests/test_line_correction.py | 54 ++++++ tests/node_tests/test_markup.py | 60 ++++++ tests/node_tests/test_mask_draw.py | 40 ++++ tests/node_tests/test_mask_invert.py | 17 ++ tests/node_tests/test_mask_morphology.py | 29 +++ tests/node_tests/test_mask_operations.py | 44 +++++ tests/node_tests/test_mask_threshold.py | 36 ++++ tests/node_tests/test_number.py | 10 + tests/node_tests/test_preview_image.py | 58 ++++++ tests/node_tests/test_print_table.py | 18 ++ tests/node_tests/test_psdf.py | 35 ++++ tests/node_tests/test_range_slider.py | 16 ++ tests/node_tests/test_rotate.py | 60 ++++++ tests/node_tests/test_save.py | 138 ++++++++++++++ tests/node_tests/test_save_layers.py | 77 ++++++++ tests/node_tests/test_scar_removal.py | 48 +++++ tests/node_tests/test_statistics.py | 28 +++ tests/node_tests/test_stats.py | 68 +++++++ tests/node_tests/test_value_io.py | 28 +++ tests/node_tests/test_view_3d.py | 106 +++++++++++ .../node_tests/test_watershed_segmentation.py | 53 ++++++ 53 files changed, 2625 insertions(+) create mode 100644 tests/node_tests/__init__.py create mode 100644 tests/node_tests/_shared.py create mode 100644 tests/node_tests/test_acf_2d.py create mode 100644 tests/node_tests/test_angle_measure.py create mode 100644 tests/node_tests/test_annotations.py create mode 100644 tests/node_tests/test_colormap.py create mode 100644 tests/node_tests/test_colormap_adjust.py create mode 100644 tests/node_tests/test_coordinate.py create mode 100644 tests/node_tests/test_crop_resize.py create mode 100644 tests/node_tests/test_cross_section.py create mode 100644 tests/node_tests/test_cursors.py create mode 100644 tests/node_tests/test_curvature.py create mode 100644 tests/node_tests/test_edge_detect.py create mode 100644 tests/node_tests/test_execution.py create mode 100644 tests/node_tests/test_fft_2d.py create mode 100644 tests/node_tests/test_filter_fft_1d.py create mode 100644 tests/node_tests/test_filter_fft_2d.py create mode 100644 tests/node_tests/test_filter_gaussian.py create mode 100644 tests/node_tests/test_filter_median.py create mode 100644 tests/node_tests/test_fix_zero.py create mode 100644 tests/node_tests/test_flip.py create mode 100644 tests/node_tests/test_font.py create mode 100644 tests/node_tests/test_fractal_dimension.py create mode 100644 tests/node_tests/test_grain_analysis.py create mode 100644 tests/node_tests/test_grain_distance_transform.py create mode 100644 tests/node_tests/test_helpers.py create mode 100644 tests/node_tests/test_histogram.py create mode 100644 tests/node_tests/test_image.py create mode 100644 tests/node_tests/test_image_demo.py create mode 100644 tests/node_tests/test_level_facet.py create mode 100644 tests/node_tests/test_level_plane.py create mode 100644 tests/node_tests/test_level_poly.py create mode 100644 tests/node_tests/test_line_correction.py create mode 100644 tests/node_tests/test_markup.py create mode 100644 tests/node_tests/test_mask_draw.py create mode 100644 tests/node_tests/test_mask_invert.py create mode 100644 tests/node_tests/test_mask_morphology.py create mode 100644 tests/node_tests/test_mask_operations.py create mode 100644 tests/node_tests/test_mask_threshold.py create mode 100644 tests/node_tests/test_number.py create mode 100644 tests/node_tests/test_preview_image.py create mode 100644 tests/node_tests/test_print_table.py create mode 100644 tests/node_tests/test_psdf.py create mode 100644 tests/node_tests/test_range_slider.py create mode 100644 tests/node_tests/test_rotate.py create mode 100644 tests/node_tests/test_save.py create mode 100644 tests/node_tests/test_save_layers.py create mode 100644 tests/node_tests/test_scar_removal.py create mode 100644 tests/node_tests/test_statistics.py create mode 100644 tests/node_tests/test_stats.py create mode 100644 tests/node_tests/test_value_io.py create mode 100644 tests/node_tests/test_view_3d.py create mode 100644 tests/node_tests/test_watershed_segmentation.py diff --git a/tests/node_tests/__init__.py b/tests/node_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/node_tests/_shared.py b/tests/node_tests/_shared.py new file mode 100644 index 0000000..206aeb8 --- /dev/null +++ b/tests/node_tests/_shared.py @@ -0,0 +1,9 @@ +import numpy as np +from backend.data_types import DataField + + +def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6): + """Create a DataField, optionally from given data or a random field.""" + if data is None: + data = np.random.default_rng(42).standard_normal(shape) + return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m") diff --git a/tests/node_tests/test_acf_2d.py b/tests/node_tests/test_acf_2d.py new file mode 100644 index 0000000..8217866 --- /dev/null +++ b/tests/node_tests/test_acf_2d.py @@ -0,0 +1,39 @@ +import numpy as np +from backend.data_types import DataField + + +def test_acf(): + from backend.nodes.acf_2d import ACF2D + + node = ACF2D() + data = np.array([ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [2.0, 1.0, 0.0, -1.0], + [0.0, 1.0, 2.0, 3.0], + ], dtype=np.float64) + field = DataField(data=data, xreal=8.0, yreal=4.0, si_unit_xy="nm", si_unit_z="V") + + acf, = node.process(field, level="none") + assert acf.data.shape == (3, 3) + assert acf.domain == "spatial" + assert acf.si_unit_xy == "nm" + assert acf.si_unit_z == "V^2" + assert np.isclose(acf.xreal, 6.0) + assert np.isclose(acf.yreal, 3.0) + assert np.isclose(acf.xoff, -3.0) + assert np.isclose(acf.yoff, -1.5) + + expected = np.zeros((3, 3), dtype=np.float64) + for iy, dy in enumerate(range(-1, 2)): + for ix, dx in enumerate(range(-1, 2)): + y0a = max(0, dy) + y1a = min(data.shape[0], data.shape[0] + dy) + x0a = max(0, dx) + x1a = min(data.shape[1], data.shape[1] + dx) + lhs = data[y0a:y1a, x0a:x1a] + rhs = data[y0a - dy:y1a - dy, x0a - dx:x1a - dx] + expected[iy, ix] = float(np.mean(lhs * rhs)) + + assert np.allclose(acf.data, expected) + assert np.allclose(acf.data, acf.data[::-1, ::-1]) diff --git a/tests/node_tests/test_angle_measure.py b/tests/node_tests/test_angle_measure.py new file mode 100644 index 0000000..576c995 --- /dev/null +++ b/tests/node_tests/test_angle_measure.py @@ -0,0 +1,60 @@ +import numpy as np +from backend.data_types import DataField +from backend.node_registry import get_node_info +from tests.node_tests._shared import make_field + + +def test_angle_measure(): + from backend.nodes.angle_measure import AngleMeasure + from backend.data_types import ImageData + + node = AngleMeasure() + info = get_node_info("AngleMeasure") + assert info["category"] == "Overlay" + assert {entry["category"] for entry in info["menu_categories"]} == {"Overlay", "Measure"} + required_inputs = AngleMeasure.INPUT_TYPES()["required"] + optional_inputs = AngleMeasure.INPUT_TYPES().get("optional", {}) + assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] + assert required_inputs["color"][1]["default"] == "#ff9800" + assert required_inputs["stroke_width"][1]["default"] == 1.35 + assert optional_inputs["line_thickness"][1]["hidden"] is True + assert optional_inputs["line_thickness_input"][1]["hidden"] is True + + field = make_field(data=np.zeros((32, 64), dtype=np.float64), xreal=4.0, yreal=2.0) + output, table = node.process( + field, color="#c62828", stroke_width=1.8, + x1=0.2, y1=0.5, xm=0.5, ym=0.5, x2=0.5, y2=0.2, label_dx=0.0, label_dy=0.0, + ) + rows = {row["quantity"]: row for row in table} + assert isinstance(output, DataField) + assert output is not field + assert len(output.overlays) == len(field.overlays) + 1 + assert output.overlays[-1]["kind"] == "angle_measure" + assert output.overlays[-1]["color"] == "#c62828" + assert np.isclose(output.overlays[-1]["stroke_width"], 1.8) + assert np.isclose(rows["Arm A length"]["value"], 1.2) + assert np.isclose(rows["Arm B length"]["value"], 0.6) + assert np.isclose(rows["Angle"]["value"], 90.0) + assert rows["Angle"]["unit"] == "deg" + assert rows["Vertex x"]["unit"] == field.si_unit_xy + + sanitized_output, _ = node.process( + field, color="not-a-color", stroke_width=-0.7, + x1=0.2, y1=0.5, xm=0.5, ym=0.5, x2=0.5, y2=0.2, label_dx=0.0, label_dy=0.0, + ) + assert sanitized_output.overlays[-1]["color"] == "#ff9800" + assert np.isclose(sanitized_output.overlays[-1]["stroke_width"], 0.35) + + image = np.zeros((50, 100, 3), dtype=np.uint8) + image_output, image_table = node.process( + image, color="#ff9800", stroke_width=1.25, + x1=0.25, y1=0.5, xm=0.5, ym=0.5, x2=0.5, y2=0.25, label_dx=0.0, label_dy=0.0, + ) + image_rows = {row["quantity"]: row for row in image_table} + assert isinstance(image_output, ImageData) + assert image_output.shape == image.shape + assert np.count_nonzero(np.asarray(image_output)) > 0 + assert np.isclose(image_rows["Arm A length"]["value"], 24.75) + assert np.isclose(image_rows["Arm B length"]["value"], 12.25) + assert np.isclose(image_rows["Angle"]["value"], 90.0) + assert image_rows["Arm A length"]["unit"] == "px" diff --git a/tests/node_tests/test_annotations.py b/tests/node_tests/test_annotations.py new file mode 100644 index 0000000..51bb0b6 --- /dev/null +++ b/tests/node_tests/test_annotations.py @@ -0,0 +1,81 @@ +import numpy as np +from backend.data_types import DataField, ImageData, datafield_to_uint8, render_datafield_preview +from backend.execution_context import active_node, execution_callbacks + + +def test_annotations(): + from backend.nodes.annotations import Annotations + from backend.nodes.font import Font + + node = Annotations() + font_node = Font() + annotation_input = Annotations.INPUT_TYPES()["required"]["input"] + assert annotation_input[0] == "ANNOTATION_SOURCE" + assert annotation_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] + + warnings = [] + field = DataField( + data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64), + xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="V", colormap="viridis", + ) + + base = datafield_to_uint8(field, "viridis") + plain_preview = render_datafield_preview(field, "viridis") + assert np.array_equal(plain_preview, base) + + with execution_callbacks(warning=lambda nid, msg: warnings.append(msg)), active_node("test"): + plain_field, = node.render(input=field, colormap="auto", show_scale_bar=False, show_color_map=False) + assert isinstance(plain_field, DataField) + assert np.array_equal(plain_field.data, field.data) + assert plain_field.colormap == "viridis" + assert plain_field.overlays[-1]["kind"] == "annotation" + plain = render_datafield_preview(plain_field, plain_field.colormap) + assert plain.shape == base.shape + assert np.array_equal(plain, base) + + with_scale_field, = node.render(input=field, colormap="auto", show_scale_bar=True, show_color_map=False) + with_scale = render_datafield_preview(with_scale_field, with_scale_field.colormap) + assert with_scale.shape == base.shape + assert not np.array_equal(with_scale, base) + + with_legend_field, = node.render(input=field, colormap="auto", show_scale_bar=False, show_color_map=True) + with_legend = render_datafield_preview(with_legend_field, with_legend_field.colormap) + assert with_legend.shape[0] == base.shape[0] + assert with_legend.shape[1] > base.shape[1] + assert with_legend.shape[2] == 3 + + larger_legend_field, = node.render(input=field, colormap="auto", show_scale_bar=False, show_color_map=True, text_size=28.0) + larger_legend_text = render_datafield_preview(larger_legend_field, larger_legend_field.colormap) + assert larger_legend_text.shape[0] == with_legend.shape[0] + assert larger_legend_text.shape[1] > with_legend.shape[1] + assert not np.array_equal(larger_legend_text, with_legend) + + annotation_font, = font_node.build("Arial") + with_font_field, = node.render(input=field, colormap="auto", show_scale_bar=False, show_color_map=True, text_size=28.0, font=annotation_font) + assert with_font_field.overlays[-1]["font"] == {"family": "Arial", "path": ""} + with_font = render_datafield_preview(with_font_field, with_font_field.colormap) + assert with_font.shape[1] > with_legend.shape[1] + + with_both_field, = node.render(input=field, colormap="auto", show_scale_bar=True, show_color_map=True) + with_both = render_datafield_preview(with_both_field, with_both_field.colormap) + assert with_both.shape == with_legend.shape + assert not np.array_equal(with_both[:, :base.shape[1]], base) + + viewport_image = ImageData( + np.zeros((48, 64, 3), dtype=np.uint8), + metadata={"annotation_context": {"xreal": 2e-6, "si_unit_xy": "m", "legend_min": -1.5, "legend_mid": 0.0, "legend_max": 1.5, "legend_unit": "V", "colormap": "viridis"}}, + ) + annotated_image, = node.render(input=viewport_image, colormap="auto", show_scale_bar=True, show_color_map=True, text_size=18.0) + assert isinstance(annotated_image, ImageData) + assert annotated_image.shape[0] == viewport_image.shape[0] + assert annotated_image.shape[1] > viewport_image.shape[1] + assert annotated_image.metadata["annotation_context"]["legend_unit"] == "V" + assert warnings == [] + + plain_image = ImageData(np.zeros((32, 40, 3), dtype=np.uint8)) + passthrough_image, = node.render(input=plain_image, colormap="auto", show_scale_bar=True, show_color_map=True, text_size=18.0) + assert isinstance(passthrough_image, ImageData) + assert passthrough_image.shape == plain_image.shape + assert np.array_equal(np.asarray(passthrough_image), np.asarray(plain_image)) + assert len(warnings) == 1 + assert "no scale metadata" in warnings[0] diff --git a/tests/node_tests/test_colormap.py b/tests/node_tests/test_colormap.py new file mode 100644 index 0000000..c9c0612 --- /dev/null +++ b/tests/node_tests/test_colormap.py @@ -0,0 +1,25 @@ +import json + + +def test_color_map_node(): + from backend.nodes.colormap import ColorMap + + node = ColorMap() + + preset, = node.build(mode="preset", preset="magma", stops_json="[]") + assert preset["mode"] == "preset" + assert preset["preset"] == "magma" + + custom, = node.build( + mode="custom", + preset="viridis", + stops_json=json.dumps([ + {"position": 0.0, "color": "#000000"}, + {"position": 0.4, "color": "#00ff00"}, + {"position": 1.0, "color": "#ffffff"}, + ]), + ) + assert custom["mode"] == "custom" + assert custom["stops"][0]["position"] == 0.0 + assert custom["stops"][-1]["position"] == 1.0 + assert len(custom["stops"]) == 3 diff --git a/tests/node_tests/test_colormap_adjust.py b/tests/node_tests/test_colormap_adjust.py new file mode 100644 index 0000000..0345a9f --- /dev/null +++ b/tests/node_tests/test_colormap_adjust.py @@ -0,0 +1,35 @@ +import numpy as np +from backend.data_types import DataField, datafield_to_uint8 + + +def test_colormap_adjust(): + from backend.nodes.colormap_adjust import ColormapAdjust + + node = ColormapAdjust() + field = DataField(data=np.array([[0.0, 0.25, 0.5, 0.75, 1.0]], dtype=np.float64), xreal=5.0, yreal=1.0, colormap="gray") + + adjusted, = node.process(field, offset=0.25, scale=0.5) + assert np.array_equal(adjusted.data, field.data) + assert adjusted.display_offset == 0.25 + assert adjusted.display_scale == 0.5 + assert adjusted.colormap == field.colormap + + rgb = datafield_to_uint8(adjusted, "gray") + intensities = rgb[0, :, 0] + assert intensities[0] == 0 + assert intensities[1] == 0 + assert 110 <= intensities[2] <= 145 + assert intensities[3] == 255 + assert intensities[4] == 255 + + auto_like, = node.process(field, offset=0.0, scale=1.0) + auto_rgb = datafield_to_uint8(auto_like, "gray") + auto_intensities = auto_rgb[0, :, 0] + assert auto_intensities[0] == 0 + assert auto_intensities[-1] == 255 + + try: + node.process(field, offset=0.0, scale=0.0) + raise AssertionError("Expected non-positive scale to raise ValueError") + except ValueError: + pass diff --git a/tests/node_tests/test_coordinate.py b/tests/node_tests/test_coordinate.py new file mode 100644 index 0000000..3aa021a --- /dev/null +++ b/tests/node_tests/test_coordinate.py @@ -0,0 +1,14 @@ +def test_coordinate(): + from backend.nodes.coordinate import Coordinate + + node = Coordinate() + + result = node.process(x=0.3, y=0.7) + assert len(result) == 1 + assert result[0] == (0.3, 0.7) + + result_zero = node.process(x=0.0, y=0.0) + assert result_zero[0] == (0.0, 0.0) + + result_one = node.process(x=1.0, y=1.0) + assert result_one[0] == (1.0, 1.0) diff --git a/tests/node_tests/test_crop_resize.py b/tests/node_tests/test_crop_resize.py new file mode 100644 index 0000000..b9a9580 --- /dev/null +++ b/tests/node_tests/test_crop_resize.py @@ -0,0 +1,60 @@ +import numpy as np +from backend.data_types import DataField + + +def test_crop_resize_field(): + from backend.nodes.crop_resize import CropResizeField + node = CropResizeField() + + data = np.arange(32, dtype=np.float64).reshape(4, 8) + field = DataField( + data=data, + xreal=8.0, + yreal=4.0, + xoff=10.0, + yoff=20.0, + si_unit_xy="nm", + si_unit_z="nm", + overlays=[{"kind": "markup", "shapes": [{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 2, "color": "#ffffff"}]}], + ) + + overlays = [] + CropResizeField._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + CropResizeField._current_node_id = "test" + + cropped, = node.process(field, x1=0.25, y1=0.25, x2=0.75, y2=1.0, target_width=0, target_height=0, interpolation="bilinear") + assert cropped.data.shape == (3, 4) + assert np.array_equal(cropped.data, data[1:4, 2:6]) + assert cropped.xreal == 4.0 + assert cropped.yreal == 3.0 + assert cropped.xoff == 12.0 + assert cropped.yoff == 21.0 + assert cropped.si_unit_xy == field.si_unit_xy + assert cropped.si_unit_z == field.si_unit_z + assert cropped.overlays == [] + assert len(overlays) == 1 + assert overlays[0]["kind"] == "crop_box" + assert overlays[0]["image"].startswith("data:image/png;base64,") + assert overlays[0]["a_locked"] is False + assert overlays[0]["b_locked"] is False + + resized, = node.process(field, x1=0.0, y1=0.0, x2=1.0, y2=1.0, target_width=8, target_height=0, interpolation="bilinear", corner_a=(0.25, 0.25), corner_b=(0.75, 1.0)) + assert resized.data.shape == (6, 8) + assert resized.xreal == cropped.xreal + assert resized.yreal == cropped.yreal + assert resized.xoff == cropped.xoff + assert resized.yoff == cropped.yoff + assert resized.domain == field.domain + assert overlays[-1]["a_locked"] is True + assert overlays[-1]["b_locked"] is True + + reversed_crop, = node.process(field, x1=0.75, y1=1.0, x2=0.25, y2=0.25, target_width=0, target_height=0, interpolation="nearest") + assert np.array_equal(reversed_crop.data, cropped.data) + + try: + node.process(field, x1=0.9, y1=0.0, x2=0.9, y2=1.0, target_width=0, target_height=0, interpolation="nearest") + raise AssertionError("Expected invalid crop bounds to raise ValueError") + except ValueError: + pass + + CropResizeField._broadcast_overlay_fn = None diff --git a/tests/node_tests/test_cross_section.py b/tests/node_tests/test_cross_section.py new file mode 100644 index 0000000..a7bb7f6 --- /dev/null +++ b/tests/node_tests/test_cross_section.py @@ -0,0 +1,52 @@ +import numpy as np +from backend.data_types import LineData +from tests.node_tests._shared import make_field + + +def test_cross_section(): + from backend.nodes.cross_section import CrossSection + node = CrossSection() + + N = 100 + y, x = np.mgrid[0:N, 0:N] / N + data = x * 10.0 + field = make_field(data=data, xreal=1e-6, yreal=1e-6) + + profile, marker_pair = node.process(field, x1=0.0, y1=0.5, x2=1.0, y2=0.5, extend="none", n_samples=100) + assert isinstance(marker_pair, tuple) and len(marker_pair) == 2 + assert isinstance(profile, LineData) + assert len(profile) == 100 + assert profile.x_unit == field.si_unit_xy + assert profile.y_unit == field.si_unit_z + assert np.isclose(profile.x_axis[0], 0.0) + assert np.isclose(profile.x_axis[-1], field.xreal) + assert profile[0] < 0.5 + assert profile[-1] > 9.5 + + profile_auto, _ = node.process(field, x1=0.0, y1=0.5, x2=1.0, y2=0.5, extend="none", n_samples=0) + assert len(profile_auto) >= 2 + + profile_ext, _ = node.process(field, x1=0.3, y1=0.5, x2=0.7, y2=0.5, extend="to_edges", n_samples=100) + assert profile_ext[0] < 0.5 + assert profile_ext[-1] > 9.5 + + profile_diag, _ = node.process(field, x1=0.0, y1=0.0, x2=1.0, y2=1.0, extend="none", n_samples=50) + assert len(profile_diag) == 50 + + from backend.nodes.cursors import Cursors + from backend.nodes.stats import Stats + + cursors = Cursors() + table, _ = cursors.process(profile, x1=0.25, y1=0.5, x2=0.75, y2=0.5) + rows = {row["quantity"]: row for row in table} + assert rows["dx"]["unit"] == field.si_unit_xy + assert rows["dy"]["unit"] == field.si_unit_z + + captured = [] + Stats._broadcast_value_fn = lambda nid, payload: captured.append(payload) + Stats._current_node_id = "test" + stats = Stats() + mean_value, = stats.process(profile, operation="mean", column="value") + assert mean_value > 0 + assert captured[-1]["unit"] == field.si_unit_z + Stats._broadcast_value_fn = None diff --git a/tests/node_tests/test_cursors.py b/tests/node_tests/test_cursors.py new file mode 100644 index 0000000..218149c --- /dev/null +++ b/tests/node_tests/test_cursors.py @@ -0,0 +1,63 @@ +import numpy as np +from backend.data_types import DataField, LineData + + +def test_line_cursors(): + from backend.nodes.cursors import Cursors + + node = Cursors() + line_spec = Cursors.INPUT_TYPES()["required"]["line"] + assert line_spec[0] == "LINE" + assert line_spec[1]["accepted_types"] == ["DATA_FIELD"] + + line = np.linspace(0, 10, 100).astype(np.float64) + + overlays = [] + Cursors._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + Cursors._current_node_id = "test" + + table, coord_pair = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5) + assert isinstance(coord_pair, tuple) and len(coord_pair) == 2 + assert len(table) == 6 + quantities = {row["quantity"] for row in table} + assert "A x" in quantities + assert "B x" in quantities + assert "dx" in quantities + assert "dy" in quantities + + a_pos = next(r["value"] for r in table if r["quantity"] == "A x") + b_pos = next(r["value"] for r in table if r["quantity"] == "B x") + assert b_pos > a_pos + + dy = next(r["value"] for r in table if r["quantity"] == "dy") + assert dy > 0 + + assert len(overlays) == 1 + assert overlays[0]["kind"] == "line_plot" + assert len(overlays[0]["line"]) == len(line) + assert len(overlays[0]["x_axis"]) == len(line) + assert 0.0 <= overlays[0]["x1"] <= 1.0 + assert 0.0 <= overlays[0]["x2"] <= 1.0 + + line_data = LineData(data=line, x_axis=np.linspace(0, 1, 100)) + table2, _ = node.process(line_data, x1=0.25, y1=0.5, x2=0.75, y2=0.5) + assert len(table2) == 6 + + field = DataField( + data=np.arange(100, dtype=np.float64).reshape(10, 10), + xreal=2.0, yreal=4.0, si_unit_xy="um", si_unit_z="nm", + ) + overlays.clear() + table3, _ = node.process(field, x1=0.2, y1=0.25, x2=0.7, y2=0.75) + assert len(table3) == 9 + field_rows = {row["quantity"]: row for row in table3} + assert field_rows["dx"]["unit"] == "um" + assert field_rows["dy"]["unit"] == "um" + assert field_rows["dz"]["unit"] == "nm" + assert np.isclose(field_rows["dx"]["value"], 1.0) + assert np.isclose(field_rows["dy"]["value"], 2.0) + assert len(overlays) == 1 + assert overlays[0]["kind"] == "cursor_points" + assert overlays[0]["image"].startswith("data:image/png;base64,") + + Cursors._broadcast_overlay_fn = None diff --git a/tests/node_tests/test_curvature.py b/tests/node_tests/test_curvature.py new file mode 100644 index 0000000..1490787 --- /dev/null +++ b/tests/node_tests/test_curvature.py @@ -0,0 +1,176 @@ +import json +import numpy as np +from backend.data_types import DataField, LineData + + +def test_curvature(): + from backend.node_registry import get_node_info + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.curvature import Curvature + + node = Curvature() + assert get_node_info("Curvature")["category"] == "Measure" + + xres, yres = 121, 101 + xreal, yreal = 8.0e-6, 6.0e-6 + xoff, yoff = 1.0e-6, -0.5e-6 + x = np.linspace(xoff, xoff + xreal, xres, dtype=np.float64) + y = np.linspace(yoff, yoff + yreal, yres, dtype=np.float64) + yy, xx = np.meshgrid(y, x, indexing="ij") + + x0 = xoff + 0.45 * xreal + y0 = yoff + 0.60 * yreal + rx = 1.2e-6 + ry = 2.4e-6 + z0 = 3.0e-9 + data = z0 + (xx - x0) ** 2 / (2.0 * rx) + (yy - y0) ** 2 / (2.0 * ry) + field = DataField(data=data, xreal=xreal, yreal=yreal, xoff=xoff, yoff=yoff, si_unit_xy="m", si_unit_z="m") + + previews = [] + tables = [] + with execution_callbacks(preview=lambda nid, uri: previews.append(uri), table=lambda nid, rows: tables.append(rows)), active_node("test"): + output, table, profile1, profile2 = node.process(field, masking="ignore") + + rows = {row["quantity"]: row for row in table} + recovered_radii = sorted([rows["Curvature radius 1"]["value"], rows["Curvature radius 2"]["value"]]) + expected_radii = sorted([rx, ry]) + assert len(previews) == 1 + assert isinstance(previews[0], dict) and previews[0].get("kind") == "panels" + assert len(tables) == 1 + assert abs(rows["Center x position"]["value"] - x0) < xreal * 0.02 + assert abs(rows["Center y position"]["value"] - y0) < yreal * 0.02 + assert abs(rows["Center value"]["value"] - z0) < 5e-11 + assert np.allclose(recovered_radii, expected_radii, rtol=0.08, atol=5e-8) + assert output.overlays[-1]["kind"] == "markup" + assert len(output.overlays[-1]["shapes"]) == 3 + assert isinstance(profile1, LineData) + assert isinstance(profile2, LineData) + assert profile1.x_unit == field.si_unit_xy + assert profile1.y_unit == field.si_unit_z + assert len(profile1) > 10 + assert len(profile2) > 10 + + mask = np.zeros((yres, xres), dtype=np.uint8) + mask[:, :xres // 2] = 255 + left = 1.0e-9 + (xx - (xoff + 0.25 * xreal)) ** 2 / (2.0 * 0.9e-6) + (yy - (yoff + 0.5 * yreal)) ** 2 / (2.0 * 1.8e-6) + right = 2.0e-9 + (xx - (xoff + 0.75 * xreal)) ** 2 / (2.0 * 1.6e-6) + (yy - (yoff + 0.5 * yreal)) ** 2 / (2.0 * 3.2e-6) + split_field = DataField(data=np.where(mask > 0, left, right), xreal=xreal, yreal=yreal, xoff=xoff, yoff=yoff, si_unit_xy="m", si_unit_z="m") + _, include_table, _, _ = node.process(split_field, masking="include", mask=mask) + _, exclude_table, _, _ = node.process(split_field, masking="exclude", mask=mask) + include_radii = sorted([row["value"] for row in include_table if row["quantity"].startswith("Curvature radius")]) + exclude_radii = sorted([row["value"] for row in exclude_table if row["quantity"].startswith("Curvature radius")]) + assert np.allclose(include_radii, [0.9e-6, 1.8e-6], rtol=0.12, atol=5e-8) + assert np.allclose(exclude_radii, [1.6e-6, 3.2e-6], rtol=0.12, atol=5e-8) + + bad_units = DataField(data=np.zeros((16, 16), dtype=np.float64), xreal=1e-6, yreal=1e-6, si_unit_xy="nm", si_unit_z="V") + try: + node.process(bad_units, masking="ignore") + except ValueError as exc: + assert "compatible XY and Z units" in str(exc) + else: + assert False, "Curvature should reject incompatible XY/Z units." + + +def test_curvature_flat_surface(): + """A perfectly flat surface has zero curvature — both radii must be float('inf').""" + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.curvature import Curvature + + node = Curvature() + data = np.zeros((64, 64), dtype=np.float64) + field = DataField(data=data, xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m") + + warnings = [] + tables = [] + with execution_callbacks( + preview=lambda nid, v: None, + table=lambda nid, rows: tables.append(rows), + warning=lambda nid, msg: warnings.append(msg), + ), active_node("test"): + _, table, _, _ = node.process(field, masking="ignore") + + rows = {row["quantity"]: row for row in table} + assert rows["Curvature radius 1"]["value"] == float("inf") + assert rows["Curvature radius 2"]["value"] == float("inf") + assert len(warnings) == 0 + + +def test_curvature_cylindrical(): + """A cylindrical surface is curved in one direction only — one radius finite, one inf.""" + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.curvature import Curvature + + node = Curvature() + N = 64 + xreal = yreal = 1e-6 + x = np.linspace(-xreal / 2, xreal / 2, N, dtype=np.float64) + xx = np.broadcast_to(x, (N, N)) + r_x = 0.8e-6 + data = xx**2 / (2.0 * r_x) + field = DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m") + + tables = [] + with execution_callbacks( + preview=lambda nid, v: None, + table=lambda nid, rows: tables.append(rows), + ), active_node("test"): + _, table, _, _ = node.process(field, masking="ignore") + + rows = {row["quantity"]: row for row in table} + radii = sorted([rows["Curvature radius 1"]["value"], rows["Curvature radius 2"]["value"]]) + finite = [r for r in radii if np.isfinite(r)] + infinite = [r for r in radii if not np.isfinite(r)] + assert len(finite) == 1, f"Expected 1 finite radius, got {radii}" + assert len(infinite) == 1, f"Expected 1 inf radius, got {radii}" + assert abs(finite[0] - r_x) < r_x * 0.1, f"Finite radius {finite[0]} far from expected {r_x}" + + +def test_curvature_too_few_pixels(): + """Curvature with fewer than 6 valid pixels emits a warning and returns an empty table.""" + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.curvature import Curvature + + node = Curvature() + N = 16 + data = np.random.default_rng(0).standard_normal((N, N)) + field = DataField(data=data, xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m") + + mask = np.zeros((N, N), dtype=np.uint8) + mask[N // 2, N // 2:N // 2 + 4] = 255 + + warnings = [] + tables = [] + with execution_callbacks( + preview=lambda nid, v: None, + table=lambda nid, rows: tables.append(rows), + warning=lambda nid, msg: warnings.append(msg), + ), active_node("test"): + _, table, profile1, profile2 = node.process(field, masking="include", mask=mask) + + assert len(warnings) == 1 + assert "six" in warnings[0].lower() or "6" in warnings[0] + assert len(list(table)) == 0 + assert len(profile1.data) == 0 + assert len(profile2.data) == 0 + + +def test_curvature_inf_json_safe(): + """inf radii from curvature must not produce invalid JSON when sent over the wire.""" + from backend.server import _sanitize_non_finite, _dumps + + rows = [ + {"quantity": "Curvature radius 1", "value": float("inf"), "unit": "m"}, + {"quantity": "Curvature radius 2", "value": float("-inf"), "unit": "m"}, + {"quantity": "Center value", "value": float("nan"), "unit": "m"}, + {"quantity": "Center x position", "value": 1.5e-7, "unit": "m"}, + ] + + sanitized = _sanitize_non_finite(rows) + assert sanitized[0]["value"] == "∞" + assert sanitized[1]["value"] == "-∞" + assert sanitized[2]["value"] == "NaN" + assert sanitized[3]["value"] == 1.5e-7 + + payload = _dumps({"type": "table", "data": {"node_id": "n1", "rows": sanitized}}) + decoded = json.loads(payload) + assert decoded["data"]["rows"][0]["value"] == "∞" diff --git a/tests/node_tests/test_edge_detect.py b/tests/node_tests/test_edge_detect.py new file mode 100644 index 0000000..8f97727 --- /dev/null +++ b/tests/node_tests/test_edge_detect.py @@ -0,0 +1,18 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_edge_detect(): + from backend.nodes.edge_detect import EdgeDetect + node = EdgeDetect() + + data = np.zeros((64, 64)) + data[:, 32:] = 1.0 + field = make_field(data=data) + + for method in ["sobel", "prewitt", "laplacian", "log"]: + result, = node.process(field, method=method, sigma=1.0) + assert result.data.shape == field.data.shape + col_energy = np.abs(result.data).sum(axis=0) + peak_col = np.argmax(col_energy) + assert abs(peak_col - 32) <= 2, f"{method}: peak at col {peak_col}, expected ~32" diff --git a/tests/node_tests/test_execution.py b/tests/node_tests/test_execution.py new file mode 100644 index 0000000..babdb00 --- /dev/null +++ b/tests/node_tests/test_execution.py @@ -0,0 +1,138 @@ +import backend.nodes # noqa: F401 — registers all nodes +from backend.execution import ExecutionEngine +from backend.node_registry import register_node + + +def test_execution_engine_numeric_socket_coercion(): + @register_node(display_name="Test Echo Int") + class TestEchoInt: + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("INT",)}} + OUTPUTS = (('INT', 'value'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self, value): + return (value,) + + @register_node(display_name="Test Echo Float") + class TestEchoFloat: + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("FLOAT",)}} + OUTPUTS = (('FLOAT', 'value'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self, value): + return (value,) + + engine = ExecutionEngine() + prompt = { + "1": {"class_type": "Number", "inputs": {"value": 3.6}}, + "2": {"class_type": "TestEchoInt", "inputs": {"value": ["1", 0]}}, + "3": {"class_type": "TestEchoFloat", "inputs": {"value": ["1", 0]}}, + } + + outputs = engine.execute(prompt) + assert outputs["2"] == (4,) + assert outputs["3"] == (3.6,) + + +def test_execution_engine_caches_unchanged_nodes(): + @register_node(display_name="Test Cache Source") + class TestCacheSource: + calls = 0 + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("FLOAT",)}} + OUTPUTS = (('FLOAT', 'value'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self, value): + TestCacheSource.calls += 1 + return (float(value),) + + @register_node(display_name="Test Cache Downstream") + class TestCacheDownstream: + calls = 0 + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("FLOAT",)}} + OUTPUTS = (('FLOAT', 'value'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self, value): + TestCacheDownstream.calls += 1 + return (float(value) * 2.0,) + + TestCacheSource.calls = 0 + TestCacheDownstream.calls = 0 + + engine = ExecutionEngine() + prompt = { + "1": {"class_type": "TestCacheSource", "inputs": {"value": 2.5}}, + "2": {"class_type": "TestCacheDownstream", "inputs": {"value": ["1", 0]}}, + } + + first_timings = [] + first_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: first_timings.append((node_id, elapsed_ms))) + second_timings = [] + second_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: second_timings.append((node_id, elapsed_ms))) + + assert first_outputs["2"] == (5.0,) + assert second_outputs["2"] == (5.0,) + assert TestCacheSource.calls == 1 + assert TestCacheDownstream.calls == 1 + assert {node_id for node_id, _ in second_timings} == {"1", "2"} + assert all(elapsed_ms == 0.0 for _, elapsed_ms in second_timings) + + +def test_execution_engine_only_propagates_real_output_changes(): + @register_node(display_name="Test Quantized Source") + class TestQuantizedSource: + calls = 0 + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("FLOAT",)}} + OUTPUTS = (('INT', 'value'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self, value): + TestQuantizedSource.calls += 1 + return (int(round(float(value))),) + + @register_node(display_name="Test Quantized Downstream") + class TestQuantizedDownstream: + calls = 0 + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("INT",)}} + OUTPUTS = (('FLOAT', 'value'),) + FUNCTION = "process" + CATEGORY = "tests" + def process(self, value): + TestQuantizedDownstream.calls += 1 + return (float(value) + 0.5,) + + TestQuantizedSource.calls = 0 + TestQuantizedDownstream.calls = 0 + + engine = ExecutionEngine() + prompt = { + "1": {"class_type": "TestQuantizedSource", "inputs": {"value": 1.2}}, + "2": {"class_type": "TestQuantizedDownstream", "inputs": {"value": ["1", 0]}}, + } + + outputs_first = engine.execute(prompt) + assert outputs_first["2"] == (1.5,) + + prompt["1"]["inputs"]["value"] = 1.3 + outputs_second = engine.execute(prompt) + assert outputs_second["2"] == (1.5,) + + prompt["1"]["inputs"]["value"] = 2.2 + outputs_third = engine.execute(prompt) + assert outputs_third["2"] == (2.5,) + + assert TestQuantizedSource.calls == 3 + assert TestQuantizedDownstream.calls == 2 diff --git a/tests/node_tests/test_fft_2d.py b/tests/node_tests/test_fft_2d.py new file mode 100644 index 0000000..6886e5c --- /dev/null +++ b/tests/node_tests/test_fft_2d.py @@ -0,0 +1,46 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_fft2d(): + from backend.nodes.fft_2d import FFT2D + + node = FFT2D() + + N = 64 + y, x = np.mgrid[0:N, 0:N] / N + freq = 5 + data = np.sin(2 * np.pi * freq * x) + field = make_field(data=data, xreal=1e-6, yreal=1e-6) + + spectrum, spec_mag, spec_phase, spec_psdf = node.process(field, windowing="none", level="none") + assert spectrum.data.shape == (N, N) + assert spectrum.domain == "frequency" + assert spectrum.si_unit_xy == "1/m" + centre = N // 2 + row = spectrum.data[centre, :] + peak_idx = np.argmax(row[centre + 1:]) + centre + 1 + assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}" + + _, spec_mag, _, _ = node.process(field, windowing="hann", level="mean") + assert spec_mag.data.shape == (N, N) + assert np.all(spec_mag.data >= 0) + + _, _, spec_phase, _ = node.process(field, windowing="none", level="none") + assert spec_phase.data.shape == (N, N) + assert spec_phase.data.min() >= -np.pi - 0.01 + assert spec_phase.data.max() <= np.pi + 0.01 + + _, _, _, spec_psdf = node.process(field, windowing="hamming", level="plane") + assert spec_psdf.data.shape == (N, N) + assert np.all(spec_psdf.data >= 0) + assert "^2" in spec_psdf.si_unit_z + + const_field = make_field(data=np.ones((32, 32)) * 3.0) + _, spec_const, _, _ = node.process(const_field, windowing="none", level="none") + centre32 = 16 + dc_val = spec_const.data[centre32, centre32] + assert dc_val == spec_const.data.max() + + spec_bk, _, _, _ = node.process(field, windowing="blackman", level="none") + assert spec_bk.data.shape == (N, N) diff --git a/tests/node_tests/test_filter_fft_1d.py b/tests/node_tests/test_filter_fft_1d.py new file mode 100644 index 0000000..8b67a2c --- /dev/null +++ b/tests/node_tests/test_filter_fft_1d.py @@ -0,0 +1,33 @@ +import numpy as np + + +def test_fft_filter_1d(): + from backend.nodes.filter_fft_1d import FFTFilter1D + node = FFTFilter1D() + + n = 256 + t = np.arange(n, dtype=np.float64) / n + low = np.sin(2 * np.pi * 3 * t) + high = np.sin(2 * np.pi * 80 * t) + line = low + high + + filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) + assert len(filtered_lp) == n + corr_low = np.corrcoef(filtered_lp, low)[0, 1] + corr_high = np.corrcoef(filtered_lp, high)[0, 1] + assert corr_low > 0.95 + assert abs(corr_high) < 0.3 + + filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) + corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1] + corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1] + assert abs(corr_low_hp) < 0.3 + assert corr_high_hp > 0.95 + + filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4) + assert abs(np.corrcoef(filtered_bp, low)[0, 1]) < 0.3 + assert np.corrcoef(filtered_bp, high)[0, 1] > 0.9 + + filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4) + assert np.corrcoef(filtered_notch, low)[0, 1] > 0.95 + assert abs(np.corrcoef(filtered_notch, high)[0, 1]) < 0.3 diff --git a/tests/node_tests/test_filter_fft_2d.py b/tests/node_tests/test_filter_fft_2d.py new file mode 100644 index 0000000..ada9e2d --- /dev/null +++ b/tests/node_tests/test_filter_fft_2d.py @@ -0,0 +1,31 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_fft_filter_2d(): + from backend.nodes.filter_fft_2d import FFTFilter2D + node = FFTFilter2D() + + N = 128 + y, x = np.mgrid[0:N, 0:N] / N + low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y) + high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y) + data = low_2d + high_2d + field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6) + + result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) + assert result_lp.data.shape == (N, N) + assert result_lp.xreal == field.xreal + assert result_lp.si_unit_z == field.si_unit_z + corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1] + corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1] + assert corr_low > 0.9 + assert abs(corr_high) < 0.3 + + result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) + assert abs(np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]) < 0.3 + assert np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] > 0.9 + + const = make_field(data=np.ones((32, 32)) * 7.0) + result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2) + assert np.allclose(result_const.data, 7.0, atol=1e-10) diff --git a/tests/node_tests/test_filter_gaussian.py b/tests/node_tests/test_filter_gaussian.py new file mode 100644 index 0000000..5e17690 --- /dev/null +++ b/tests/node_tests/test_filter_gaussian.py @@ -0,0 +1,16 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_gaussian_filter(): + from backend.nodes.filter_gaussian import GaussianFilter + node = GaussianFilter() + field = make_field() + + result, = node.process(field, sigma=2.0) + assert result.data.shape == field.data.shape + assert result.xreal == field.xreal + assert result.si_unit_z == field.si_unit_z + assert result.data.std() < field.data.std() + result_tiny, = node.process(field, sigma=0.01) + assert np.allclose(result_tiny.data, field.data, atol=1e-6) diff --git a/tests/node_tests/test_filter_median.py b/tests/node_tests/test_filter_median.py new file mode 100644 index 0000000..50b0aa6 --- /dev/null +++ b/tests/node_tests/test_filter_median.py @@ -0,0 +1,19 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_median_filter(): + from backend.nodes.filter_median import MedianFilter + node = MedianFilter() + + data = np.zeros((64, 64)) + rng = np.random.default_rng(7) + noise_idx = rng.choice(64 * 64, size=100, replace=False) + data.ravel()[noise_idx] = 1.0 + field = make_field(data=data) + + result, = node.process(field, size=3) + assert result.data.shape == field.data.shape + assert result.data.sum() < field.data.sum() + result_1, = node.process(field, size=1) + assert np.array_equal(result_1.data, field.data) diff --git a/tests/node_tests/test_fix_zero.py b/tests/node_tests/test_fix_zero.py new file mode 100644 index 0000000..17435a8 --- /dev/null +++ b/tests/node_tests/test_fix_zero.py @@ -0,0 +1,18 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_fix_zero(): + from backend.nodes.fix_zero import FixZero + node = FixZero() + field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64)) + + result_min, = node.process(field, method="min") + assert result_min.data.min() == 0.0 + assert result_min.data.max() == 30.0 + + result_mean, = node.process(field, method="mean") + assert abs(result_mean.data.mean()) < 1e-10 + + result_median, = node.process(field, method="median") + assert abs(np.median(result_median.data)) < 1e-10 diff --git a/tests/node_tests/test_flip.py b/tests/node_tests/test_flip.py new file mode 100644 index 0000000..24eaef6 --- /dev/null +++ b/tests/node_tests/test_flip.py @@ -0,0 +1,52 @@ +import numpy as np +from backend.data_types import DataField + + +def test_flip_field(): + from backend.nodes.flip import FlipField + from backend.node_registry import get_node_info + + node = FlipField() + data = np.arange(1, 10, dtype=np.float64).reshape(3, 3) + markup_overlay = { + "kind": "markup", + "shapes": [ + {"kind": "line", "x1": 0.1, "y1": 0.2, "x2": 0.9, "y2": 0.8, "width": 2, "color": "#ffffff"}, + {"kind": "rectangle", "x1": 0.15, "y1": 0.1, "x2": 0.45, "y2": 0.6, "width": 3, "color": "#ff0000"}, + ], + } + annotation_overlay = {"kind": "annotation", "show_scale_bar": True, "show_color_map": False, "text_size": 14.0} + field = DataField(data=data, xreal=3.0, yreal=4.0, xoff=10.0, yoff=20.0, si_unit_xy="nm", si_unit_z="nm", overlays=[markup_overlay, annotation_overlay]) + + assert get_node_info("FlipField")["category"] == "Geometry" + + flipped_x, = node.process(field, axis="x") + assert np.array_equal(flipped_x.data, np.flipud(data)) + assert flipped_x.xreal == field.xreal + assert flipped_x.yreal == field.yreal + assert flipped_x.xoff == field.xoff + assert flipped_x.yoff == field.yoff + assert np.isclose(flipped_x.overlays[0]["shapes"][0]["x1"], 0.1) + assert np.isclose(flipped_x.overlays[0]["shapes"][0]["y1"], 0.8) + assert np.isclose(flipped_x.overlays[0]["shapes"][0]["x2"], 0.9) + assert np.isclose(flipped_x.overlays[0]["shapes"][0]["y2"], 0.2) + assert np.isclose(flipped_x.overlays[0]["shapes"][1]["x1"], 0.15) + assert np.isclose(flipped_x.overlays[0]["shapes"][1]["y1"], 0.4) + assert np.isclose(flipped_x.overlays[0]["shapes"][1]["x2"], 0.45) + assert np.isclose(flipped_x.overlays[0]["shapes"][1]["y2"], 0.9) + assert flipped_x.overlays[1] == annotation_overlay + + flipped_y, = node.process(field, axis="y") + assert np.array_equal(flipped_y.data, np.fliplr(data)) + assert np.isclose(flipped_y.overlays[0]["shapes"][0]["x1"], 0.9) + assert np.isclose(flipped_y.overlays[0]["shapes"][0]["y1"], 0.2) + assert np.isclose(flipped_y.overlays[0]["shapes"][0]["x2"], 0.1) + assert np.isclose(flipped_y.overlays[0]["shapes"][0]["y2"], 0.8) + + assert field.overlays[0]["shapes"][0]["x1"] == markup_overlay["shapes"][0]["x1"] + + try: + node.process(field, axis="diagonal") + raise AssertionError("Expected invalid flip axis to raise ValueError") + except ValueError: + pass diff --git a/tests/node_tests/test_font.py b/tests/node_tests/test_font.py new file mode 100644 index 0000000..ccca273 --- /dev/null +++ b/tests/node_tests/test_font.py @@ -0,0 +1,14 @@ +def test_font_node(): + from backend.nodes.font import Font + from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT + + node = Font() + + system_default, = node.build(SYSTEM_DEFAULT_FONT) + assert system_default is None + + named, = node.build("Arial") + assert named == {"family": "Arial", "path": ""} + + custom, = node.build(CUSTOM_FILE_FONT, "/tmp/example-font.ttf") + assert custom == {"family": "", "path": "/tmp/example-font.ttf"} diff --git a/tests/node_tests/test_fractal_dimension.py b/tests/node_tests/test_fractal_dimension.py new file mode 100644 index 0000000..841f0f6 --- /dev/null +++ b/tests/node_tests/test_fractal_dimension.py @@ -0,0 +1,60 @@ +import numpy as np +from backend.data_types import LineData +from backend.node_registry import get_node_info +from backend.execution_context import active_node, execution_callbacks +from tests.node_tests._shared import make_field + + +def test_fractal_dimension(): + from backend.nodes.fractal_dimension import FractalDimension + + node = FractalDimension() + assert get_node_info("FractalDimension")["category"] == "Measure" + + N = 129 + yy, xx = np.mgrid[0:N, 0:N] / (N - 1) + data = 0.25 * xx + 0.12 * yy + 0.03 * np.sin(6.0 * np.pi * xx) + 0.02 * np.cos(4.0 * np.pi * yy) + field = make_field(data=data, xreal=4.0e-6, yreal=4.0e-6) + + overlays = [] + tables = [] + with execution_callbacks( + overlay=lambda nid, payload: overlays.append(payload), + table=lambda nid, rows: tables.append(rows), + ), active_node("test"): + dimension, curve, table = node.process( + field, method="partitioning", interpolation="linear", x1=0.0, y1=0.5, x2=1.0, y2=0.5, + ) + + assert np.isfinite(dimension) + assert 1.5 < dimension < 2.5 + assert isinstance(curve, LineData) + assert len(curve) > 3 + assert curve.x_axis is not None + assert np.all(np.diff(curve.x_axis) > 0.0) + assert len(overlays) == 1 + assert overlays[0]["kind"] == "line_plot" + assert len(tables) == 1 + assert table[0]["quantity"] == "Dimension" + + methods = ["partitioning", "cube_counting", "triangulation", "psdf", "hhcf"] + for method in methods: + dim, line, measurements = node.process( + field, method=method, interpolation="linear", x1=0.0, y1=0.5, x2=1.0, y2=0.5, + ) + assert np.isfinite(dim), f"{method} should produce a finite fractal dimension" + if method == "psdf": + assert -1.0 < dim < 3.2 + else: + assert 1.2 < dim < 3.2 + assert isinstance(line, LineData) + assert len(line) >= 2 + assert measurements[0]["quantity"] == "Dimension" + + narrowed_dim, _, narrowed_table = node.process( + field, method="partitioning", interpolation="linear", x1=0.15, y1=0.5, x2=0.55, y2=0.5, + ) + assert np.isfinite(narrowed_dim) + fit_from = next(row["value"] for row in narrowed_table if row["quantity"] == "Fit from") + fit_to = next(row["value"] for row in narrowed_table if row["quantity"] == "Fit to") + assert fit_to > fit_from diff --git a/tests/node_tests/test_grain_analysis.py b/tests/node_tests/test_grain_analysis.py new file mode 100644 index 0000000..6d00e7f --- /dev/null +++ b/tests/node_tests/test_grain_analysis.py @@ -0,0 +1,35 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_grain_analysis(): + from backend.nodes.grain_analysis import GrainAnalysis + node = GrainAnalysis() + + N = 64 + data = np.zeros((N, N)) + data[5:15, 5:15] = 5.0 + data[45:53, 45:53] = 3.0 + field = make_field(data=data, xreal=1e-6, yreal=1e-6) + + mask = np.zeros((N, N), dtype=np.uint8) + mask[5:15, 5:15] = 255 + mask[45:53, 45:53] = 255 + + table, = node.process(field, mask=mask, min_size=10) + assert len(table) == 2, f"Expected 2 grains, got {len(table)}" + + table.sort(key=lambda r: r["area_px"], reverse=True) + assert table[0]["area_px"] == 100 + assert table[1]["area_px"] == 64 + assert abs(table[0]["mean_height"] - 5.0) < 1e-10 + assert abs(table[1]["mean_height"] - 3.0) < 1e-10 + assert table[0]["area_px_unit"] == "px^2" + assert table[0]["area_m2_unit"] == "m^2" + assert table[0]["equiv_diam_m_unit"] == "m" + assert table[0]["mean_height_unit"] == "m" + assert table[0]["max_height_unit"] == "m" + + table_filtered, = node.process(field, mask=mask, min_size=80) + assert len(table_filtered) == 1 + assert table_filtered[0]["area_px"] == 100 diff --git a/tests/node_tests/test_grain_distance_transform.py b/tests/node_tests/test_grain_distance_transform.py new file mode 100644 index 0000000..d444f29 --- /dev/null +++ b/tests/node_tests/test_grain_distance_transform.py @@ -0,0 +1,34 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_grain_distance_transform(): + from backend.nodes.grain_distance_transform import GrainDistanceTransform + + node = GrainDistanceTransform() + field = make_field(data=np.zeros((7, 7), dtype=np.float64), xreal=7.0, yreal=7.0) + mask = np.zeros((7, 7), dtype=np.uint8) + mask[2:5, 2:5] = 255 + + interior, = node.process(field, mask, distance_type="euclidean", output_type="interior", from_border=True) + assert interior.data.shape == field.data.shape + assert interior.si_unit_z == field.si_unit_xy + assert np.isclose(interior.data[3, 3], 2.0) + assert np.isclose(interior.data[2, 2], 1.0) + assert np.isclose(interior.data[0, 0], 0.0) + + exterior, = node.process(field, mask, distance_type="cityblock", output_type="exterior", from_border=True) + assert np.isclose(exterior.data[1, 1], 2.0) + assert np.isclose(exterior.data[2, 1], 1.0) + assert np.isclose(exterior.data[3, 3], 0.0) + + signed, = node.process(field, mask, distance_type="chess", output_type="signed", from_border=True) + assert signed.data[3, 3] > 0.0 + assert signed.data[0, 0] < 0.0 + + edge_field = make_field(data=np.zeros((5, 5), dtype=np.float64), xreal=5.0, yreal=5.0) + edge_mask = np.zeros((5, 5), dtype=np.uint8) + edge_mask[:, :2] = 255 + from_edge, = node.process(edge_field, edge_mask, distance_type="euclidean", output_type="interior", from_border=True) + not_from_edge, = node.process(edge_field, edge_mask, distance_type="euclidean", output_type="interior", from_border=False) + assert not_from_edge.data[2, 0] > from_edge.data[2, 0] diff --git a/tests/node_tests/test_helpers.py b/tests/node_tests/test_helpers.py new file mode 100644 index 0000000..db8c929 --- /dev/null +++ b/tests/node_tests/test_helpers.py @@ -0,0 +1,65 @@ +import os +import tempfile +from pathlib import Path + +import numpy as np +from PIL import Image + + +def test_list_channels(): + from backend.nodes.helpers import list_channels, list_folder_paths + from backend.nodes.folder import Folder + + ch = list_channels("/nonexistent/file.ibw") + assert len(ch) == 1 + assert ch[0]["name"] == "field" + + ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "demo", "BR_New20012.ibw")) + if os.path.exists(ibw_path): + ch = list_channels(ibw_path) + assert len(ch) == 4 + names = [c["name"] for c in ch] + assert "HeightRetrace" in names + assert "AmplitudeRetrace" in names + assert all(c["type"] == "DATA_FIELD" for c in ch) + + with tempfile.TemporaryDirectory() as tmpdir: + img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8)) + path = os.path.join(tmpdir, "test.png") + img.save(path) + + ch = list_channels(path) + assert len(ch) == 1 + assert ch[0]["name"] == "field" + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "test.npy") + np.save(path, np.zeros((4, 4))) + ch = list_channels(path) + assert len(ch) == 1 + + with tempfile.TemporaryDirectory() as tmpdir: + img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8)) + png_path = os.path.join(tmpdir, "a.png") + npy_path = os.path.join(tmpdir, "b.npy") + gwy_path = os.path.join(tmpdir, "c.gwy") + sxm_path = os.path.join(tmpdir, "d.sxm") + ibw_path2 = os.path.join(tmpdir, "e.ibw") + txt_path = os.path.join(tmpdir, "notes.txt") + img.save(png_path) + np.save(npy_path, np.zeros((4, 4))) + Path(gwy_path).write_bytes(b"gwy") + Path(sxm_path).write_bytes(b"sxm") + Path(ibw_path2).write_bytes(b"ibw") + with open(txt_path, "w", encoding="utf-8") as fh: + fh.write("ignore me") + + paths = list_folder_paths(tmpdir) + assert [entry["name"] for entry in paths] == ["directory", "a.png", "b.npy", "c.gwy", "d.sxm", "e.ibw"] + assert Path(paths[0]["path"]).resolve() == Path(tmpdir).resolve() + assert paths[0]["type"] == "DIRECTORY" + assert all(entry["type"] == "FILE_PATH" for entry in paths[1:]) + + folder_node = Folder() + folder_result = folder_node.list_files(tmpdir) + assert folder_result == tuple(entry["path"] for entry in paths) diff --git a/tests/node_tests/test_histogram.py b/tests/node_tests/test_histogram.py new file mode 100644 index 0000000..4539c5c --- /dev/null +++ b/tests/node_tests/test_histogram.py @@ -0,0 +1,40 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_height_histogram(): + from backend.nodes.histogram import Histogram + node = Histogram() + + data = np.linspace(0, 1, 1000).reshape(25, 40) + field = make_field(data=data) + + overlays = [] + Histogram._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + Histogram._current_node_id = "test" + + table, coord_pair = node.process(field, n_bins=10, y_scale="linear", x1=0.2, y1=0.5, x2=0.8, y2=0.5) + assert isinstance(coord_pair, tuple) and len(coord_pair) == 2 + measurements = {row["quantity"]: row for row in table} + assert "A position" in measurements + assert "A count" in measurements + assert "B position" in measurements + assert "B count" in measurements + assert "delta X" in measurements + assert "delta Y" in measurements + assert measurements["A count"]["unit"] == "count" + assert measurements["B count"]["unit"] == "count" + assert measurements["B position"]["value"] > measurements["A position"]["value"] + assert len(overlays) == 1 + assert overlays[0]["kind"] == "line_plot" + assert overlays[0]["section_title"] == "Histogram" + assert len(overlays[0]["line"]) == 10 + assert len(overlays[0]["x_axis"]) == 10 + assert np.isclose(overlays[0]["x1"], 0.2) + assert np.isclose(overlays[0]["x2"], 0.8) + assert np.isclose( + measurements["delta Y"]["value"], + measurements["B count"]["value"] - measurements["A count"]["value"], + ) + + Histogram._broadcast_overlay_fn = None diff --git a/tests/node_tests/test_image.py b/tests/node_tests/test_image.py new file mode 100644 index 0000000..70b7e12 --- /dev/null +++ b/tests/node_tests/test_image.py @@ -0,0 +1,167 @@ +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import numpy as np +from PIL import Image as PILImage + +from backend.data_types import DataField + + +def test_load_file(): + from backend.nodes.image import Image as ImageNode + node = ImageNode() + + with tempfile.TemporaryDirectory() as tmpdir: + arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8) + img = PILImage.fromarray(arr, mode="L") + path = os.path.join(tmpdir, "test_gray.png") + img.save(path) + + result = node.load(filename=path) + assert len(result) == 1 + field = result[0] + assert field.data.shape == (48, 64) + assert field.data.dtype == np.float64 + + arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8) + img_rgb = PILImage.fromarray(arr_rgb, mode="RGB") + path_rgb = os.path.join(tmpdir, "test_rgb.png") + img_rgb.save(path_rgb) + + result_rgb = node.load(filename=path_rgb) + assert len(result_rgb) == 1 + assert result_rgb[0].data.shape == (32, 32) + + data_npy = np.random.default_rng(3).standard_normal((50, 60)) + path_npy = os.path.join(tmpdir, "test.npy") + np.save(path_npy, data_npy) + + result_npy = node.load(filename=path_npy) + assert np.allclose(result_npy[0].data, data_npy) + + custom_colormap = { + "mode": "custom", + "stops": [ + {"position": 0.0, "color": "#000000"}, + {"position": 0.5, "color": "#ff0000"}, + {"position": 1.0, "color": "#ffffff"}, + ], + } + result_custom = node.load(filename=path, colormap_map=custom_colormap) + assert isinstance(result_custom[0].colormap, dict) + assert result_custom[0].colormap["mode"] == "custom" + assert len(result_custom[0].colormap["stops"]) == 3 + + result_from_path = node.load(filename="", path=path) + assert len(result_from_path) == 1 + assert result_from_path[0].data.shape == (48, 64) + + +def test_load_file_npz(): + from backend.nodes.image import Image + node = Image() + with tempfile.TemporaryDirectory() as tmpdir: + data = np.random.default_rng(99).standard_normal((30, 40)) + path = os.path.join(tmpdir, "test.npz") + np.savez(path, my_array=data) + + result = node.load(filename=path) + assert len(result) == 1 + assert np.allclose(result[0].data, data) + + +def test_load_file_cache(): + from backend.nodes.image import Image + node = Image() + Image._load_fields_cached.cache_clear() + + with tempfile.TemporaryDirectory() as tmpdir: + data = np.arange(16, dtype=np.float64).reshape(4, 4) + path = os.path.join(tmpdir, "cached.npy") + np.save(path, data) + + with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: + first, = node.load(filename=path) + second, = node.load(filename=path) + assert loader.call_count == 1 + + assert np.allclose(first.data, data) + assert np.allclose(second.data, data) + assert first is not second + first.data[0, 0] = -999.0 + + third, = node.load(filename=path) + assert third.data[0, 0] == data[0, 0] + + Image._load_fields_cached.cache_clear() + + +def test_load_file_not_found(): + from backend.nodes.image import Image + node = Image() + try: + node.load(filename="/nonexistent/path/file.png") + assert False, "Should have raised FileNotFoundError" + except FileNotFoundError: + pass + + +def test_load_file_unsupported(): + from backend.nodes.image import Image + node = Image() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "test.xyz") + with open(path, "w") as f: + f.write("hello") + try: + node.load(filename=path) + assert False, "Should have raised an error for .xyz" + except Exception: + pass + + +def test_load_file_warning(): + from backend.nodes.image import Image as ImageNode + node = ImageNode() + warnings = [] + ImageNode._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) + ImageNode._current_node_id = "test" + + with tempfile.TemporaryDirectory() as tmpdir: + arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8) + img = PILImage.fromarray(arr) + path = os.path.join(tmpdir, "test.png") + img.save(path) + + result = node.load(filename=path) + assert len(result) == 1 + assert len(warnings) == 1 + assert "Uncalibrated" in warnings[0] + + ImageNode._broadcast_warning_fn = None + + +def test_load_file_ibw(): + from backend.nodes.image import Image + node = Image() + ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "demo", "BR_New20012.ibw")) + if not os.path.exists(ibw_path): + return + + result = node.load(filename=ibw_path) + assert len(result) == 4 + + for i, field in enumerate(result): + assert isinstance(field, DataField) + assert field.data.shape == (512, 1024) + assert field.data.dtype == np.float64 + assert field.xreal > 1e-8 + assert field.yreal > 1e-8 + assert field.si_unit_xy == "m" + assert field.si_unit_z == "m" + + assert result[0].xreal == result[1].xreal + assert result[0].yreal == result[1].yreal + assert not np.array_equal(result[0].data, result[1].data) diff --git a/tests/node_tests/test_image_demo.py b/tests/node_tests/test_image_demo.py new file mode 100644 index 0000000..079a853 --- /dev/null +++ b/tests/node_tests/test_image_demo.py @@ -0,0 +1,70 @@ +import os +from unittest.mock import patch + +import numpy as np +from backend.data_types import DataField +from backend.execution import ExecutionEngine +import backend.nodes # noqa: F401 + + +def test_load_demo(): + from backend.nodes.image_demo import ImageDemo + node = ImageDemo() + + result = node.load(name="nanoparticles.npy") + assert len(result) >= 1 + assert isinstance(result[0], DataField) + assert result[0].data.ndim == 2 + + result_ibw = node.load(name="whiskers.ibw") + assert len(result_ibw) == 4 + for field in result_ibw: + assert isinstance(field, DataField) + + try: + node.load(name="nonexistent_file.png") + assert False, "Should have raised FileNotFoundError" + except FileNotFoundError: + pass + + +def test_load_demo_cache(): + from backend.nodes.image import Image + from backend.nodes.image_demo import ImageDemo + + node = ImageDemo() + Image._load_fields_cached.cache_clear() + + with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: + first, = node.load(name="nanoparticles.npy") + second, = node.load(name="nanoparticles.npy") + assert loader.call_count == 1 + + assert np.allclose(first.data, second.data) + assert first is not second + first.data[0, 0] = -999.0 + + third, = node.load(name="nanoparticles.npy") + assert third.data[0, 0] != -999.0 + + Image._load_fields_cached.cache_clear() + + +def test_load_demo_multi_layer_preview_payload(): + previews = [] + prompt = { + "1": { + "class_type": "ImageDemo", + "inputs": {"name": "whiskers.ibw", "colormap": "viridis"}, + }, + } + + ExecutionEngine().execute(prompt, on_preview=lambda node_id, payload: previews.append((node_id, payload))) + + assert len(previews) == 1 + node_id, payload = previews[0] + assert node_id == "1" + assert payload["kind"] == "layer_gallery" + assert len(payload["layers"]) == 4 + assert all(isinstance(layer["name"], str) and layer["name"] for layer in payload["layers"]) + assert all(layer["image"].startswith("data:image/png;base64,") for layer in payload["layers"]) diff --git a/tests/node_tests/test_level_facet.py b/tests/node_tests/test_level_facet.py new file mode 100644 index 0000000..9d8cbe4 --- /dev/null +++ b/tests/node_tests/test_level_facet.py @@ -0,0 +1,68 @@ +import numpy as np +from backend.data_types import DataField +from tests.node_tests._shared import make_field + + +def test_facet_level(): + from backend.node_registry import get_node_info + from backend.nodes.level_facet import FacetLevelField + from backend.nodes.level_plane import PlaneLevelField + + def fit_pixel_plane(data, region): + yy, xx = np.mgrid[0:data.shape[0], 0:data.shape[1]] + A = np.column_stack([np.ones(int(np.count_nonzero(region))), xx[region].astype(np.float64), yy[region].astype(np.float64)]) + coeffs, _, _, _ = np.linalg.lstsq(A, data[region].ravel().astype(np.float64), rcond=None) + return float(coeffs[0]), float(coeffs[1]), float(coeffs[2]) + + node = FacetLevelField() + plane_node = PlaneLevelField() + assert get_node_info("FacetLevelField")["category"] == "Level & Correct" + + N = 96 + yy, xx = np.mgrid[0:N, 0:N] + base = 0.055 * xx + 0.028 * yy + terraces = np.zeros((N, N), dtype=np.float64) + terraces[:, 54:] += 6.0 + terraces[18:70, 68:88] += 3.5 + field = make_field(data=base + terraces) + + plane_leveled, = plane_node.process(field) + facet_leveled, = node.process(field, masking="ignore") + + left_region = xx < 48 + right_region = (xx > 60) & ~((yy >= 18) & (yy < 70) & (xx >= 68) & (xx < 88)) + _, plane_left_bx, plane_left_by = fit_pixel_plane(plane_leveled.data, left_region) + _, plane_right_bx, plane_right_by = fit_pixel_plane(plane_leveled.data, right_region) + _, facet_left_bx, facet_left_by = fit_pixel_plane(facet_leveled.data, left_region) + _, facet_right_bx, facet_right_by = fit_pixel_plane(facet_leveled.data, right_region) + plane_slope = float(max(np.hypot(plane_left_bx, plane_left_by), np.hypot(plane_right_bx, plane_right_by))) + facet_slope = float(max(np.hypot(facet_left_bx, facet_left_by), np.hypot(facet_right_bx, facet_right_by))) + assert facet_slope < plane_slope * 1e-6 + + mask = np.zeros((N, N), dtype=np.uint8) + mask[24:72, 24:72] = 255 + base_only = 0.035 * xx + 0.014 * yy + masked_facet = 5.0 - 0.065 * xx + 0.045 * yy + competing = np.where(mask > 0, masked_facet, base_only) + competing_field = make_field(data=competing) + + excluded, = node.process(competing_field, masking="exclude", mask=mask) + included, = node.process(competing_field, masking="include", mask=mask) + + outer_region = (mask == 0) & (xx > 4) & (xx < N - 4) & (yy > 4) & (yy < N - 4) + inner_region = (mask > 0) & (xx > 28) & (xx < 68) & (yy > 28) & (yy < 68) + _, excl_outer_bx, excl_outer_by = fit_pixel_plane(excluded.data, outer_region) + _, excl_inner_bx, excl_inner_by = fit_pixel_plane(excluded.data, inner_region) + _, incl_outer_bx, incl_outer_by = fit_pixel_plane(included.data, outer_region) + _, incl_inner_bx, incl_inner_by = fit_pixel_plane(included.data, inner_region) + + assert float(np.hypot(excl_outer_bx, excl_outer_by)) < float(np.hypot(incl_outer_bx, incl_outer_by)) * 0.2 + assert float(np.hypot(incl_inner_bx, incl_inner_by)) < float(np.hypot(excl_inner_bx, excl_inner_by)) * 0.2 + + bad_units = DataField(data=np.zeros((16, 16), dtype=np.float64), xreal=1e-6, yreal=1e-6, si_unit_xy="nm", si_unit_z="V") + try: + node.process(bad_units, masking="ignore") + except ValueError as exc: + assert "compatible XY and Z units" in str(exc) + else: + assert False, "Facet level should reject incompatible XY/Z units." diff --git a/tests/node_tests/test_level_plane.py b/tests/node_tests/test_level_plane.py new file mode 100644 index 0000000..7f36aa9 --- /dev/null +++ b/tests/node_tests/test_level_plane.py @@ -0,0 +1,40 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_plane_level(): + from backend.nodes.level_plane import PlaneLevelField + node = PlaneLevelField() + + N = 64 + y, x = np.mgrid[0:N, 0:N] / N + signal = np.sin(2 * np.pi * 5 * x) + data = 100 * x + 50 * y + signal + field = make_field(data=data) + + result, = node.process(field) + assert result.data.shape == field.data.shape + assert abs(result.data.mean()) < 1e-10 + corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1] + assert corr > 0.98 + + yy_px, xx_px = np.mgrid[0:N, 0:N] + + def fit_pixel_plane(data_in, region): + A = np.column_stack([np.ones(int(np.count_nonzero(region))), xx_px[region].astype(np.float64), yy_px[region].astype(np.float64)]) + coeffs, _, _, _ = np.linalg.lstsq(A, data_in[region].ravel().astype(np.float64), rcond=None) + return float(coeffs[0]), float(coeffs[1]), float(coeffs[2]) + + mask = np.zeros((N, N), dtype=np.uint8) + mask[20:44, 22:46] = 255 + feature = np.zeros((N, N), dtype=np.float64) + feature[mask > 0] = 35.0 + masked_field = make_field(data=100 * x + 50 * y + feature) + + unmasked, = node.process(masked_field) + masked, = node.process(masked_field, masking="exclude", mask=mask) + + outside = mask == 0 + _, unmasked_bx, unmasked_by = fit_pixel_plane(unmasked.data, outside) + _, masked_bx, masked_by = fit_pixel_plane(masked.data, outside) + assert np.hypot(masked_bx, masked_by) < np.hypot(unmasked_bx, unmasked_by) * 1e-3 diff --git a/tests/node_tests/test_level_poly.py b/tests/node_tests/test_level_poly.py new file mode 100644 index 0000000..71a9892 --- /dev/null +++ b/tests/node_tests/test_level_poly.py @@ -0,0 +1,24 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_poly_level(): + from backend.nodes.level_poly import PolyLevelField + node = PolyLevelField() + + N = 64 + y, x = np.mgrid[0:N, 0:N] / N + background = 50 * x**2 + 30 * y**2 + 10 * x * y + signal = np.sin(2 * np.pi * 8 * x) + data = background + signal + field = make_field(data=data) + + leveled, bg = node.process(field, degree_x=2, degree_y=2) + assert leveled.data.shape == field.data.shape + assert bg.data.shape == field.data.shape + assert np.allclose(leveled.data + bg.data, field.data, atol=1e-10) + corr = np.corrcoef(leveled.data.ravel(), signal.ravel())[0, 1] + assert corr > 0.95 + + leveled_0, bg_0 = node.process(field, degree_x=0, degree_y=0) + assert abs(leveled_0.data.mean()) < 1e-10 diff --git a/tests/node_tests/test_line_correction.py b/tests/node_tests/test_line_correction.py new file mode 100644 index 0000000..6b2e791 --- /dev/null +++ b/tests/node_tests/test_line_correction.py @@ -0,0 +1,54 @@ +import numpy as np +from backend.data_types import LineData +from backend.node_registry import get_node_info +from tests.node_tests._shared import make_field + + +def test_line_correction(): + from backend.nodes.line_correction import LineCorrection + + node = LineCorrection() + assert get_node_info("LineCorrection")["category"] == "Level & Correct" + + rows = 96 + cols = 128 + y = np.linspace(0.0, 1.0, rows, dtype=np.float64) + x = np.linspace(-1.0, 1.0, cols, dtype=np.float64) + signal = ( + 0.15 * np.sin(8.0 * np.pi * x)[None, :] + + 0.05 * np.cos(4.0 * np.pi * y)[:, None] + ) + row_offsets = 1.5 * np.sin(3.0 * np.pi * y) + 0.25 * np.cos(7.0 * np.pi * y) + field = make_field(data=signal + row_offsets[:, None], xreal=2.5e-6, yreal=1.5e-6) + + corrected, background, shifts = node.process( + field, method="median", direction="horizontal", masking="ignore", + trim_fraction=0.05, polynomial_degree=1, + ) + expected_shifts = row_offsets - row_offsets.mean() + assert corrected.data.shape == field.data.shape + assert background.data.shape == field.data.shape + assert np.allclose(corrected.data + background.data, field.data) + assert isinstance(shifts, LineData) + assert shifts.x_unit == field.si_unit_xy + assert shifts.y_unit == field.si_unit_z + assert np.isclose(shifts.x_axis[0], 0.0) + assert np.isclose(shifts.x_axis[-1], field.yreal) + assert np.corrcoef(shifts.data, expected_shifts)[0, 1] > 0.999 + assert corrected.data.mean(axis=1).std() < field.data.mean(axis=1).std() * 0.03 + + poly_background = ( + row_offsets[:, None] + + (0.35 * y - 0.15)[:, None] * x[None, :] + + (0.10 + 0.05 * y)[:, None] * (x[None, :] ** 2) + ) + poly_signal = 0.08 * np.sin(10.0 * np.pi * x)[None, :] * (1.0 + 0.15 * np.cos(2.0 * np.pi * y)[:, None]) + poly_field = make_field(data=poly_signal + poly_background) + + leveled, poly_bg, poly_shifts = node.process( + poly_field, method="polynomial", direction="horizontal", masking="ignore", + trim_fraction=0.05, polynomial_degree=2, + ) + assert np.allclose(leveled.data + poly_bg.data, poly_field.data) + assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995 + assert len(poly_shifts) == rows diff --git a/tests/node_tests/test_markup.py b/tests/node_tests/test_markup.py new file mode 100644 index 0000000..b4f9a13 --- /dev/null +++ b/tests/node_tests/test_markup.py @@ -0,0 +1,60 @@ +import json +import numpy as np +from backend.data_types import DataField, ImageData, render_datafield_preview +from backend.execution_context import active_node, execution_callbacks +from tests.node_tests._shared import make_field + + +def test_markup(): + from backend.nodes.markup import Markup + from backend.data_types import _preview_markup_stroke_width + + node = Markup() + field = make_field(data=np.linspace(0.0, 1.0, 48 * 48, dtype=np.float64).reshape(48, 48)) + base = render_datafield_preview(field, field.colormap) + required_inputs = Markup.INPUT_TYPES()["required"] + + assert _preview_markup_stroke_width(5, 128, 128) == 5 + assert _preview_markup_stroke_width(5, 2048, 2048) > 5 + assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] + assert required_inputs["shape"][1]["default"] == "arrow" + assert required_inputs["stroke_color"][1]["default"] == "#ff0000" + + overlays = [] + with execution_callbacks(overlay=lambda nid, data: overlays.append(data)), active_node("test"): + plain_field, = node.process(input=field, shape="line", stroke_color="#ffd54f", stroke_width=3, markup_shapes="[]") + assert isinstance(plain_field, DataField) + assert plain_field.overlays[-1]["kind"] == "markup" + plain = render_datafield_preview(plain_field, plain_field.colormap) + assert np.array_equal(plain, base) + assert overlays[-1]["kind"] == "markup" + assert overlays[-1]["shape"] == "line" + assert overlays[-1]["stroke_color"] == "#ffd54f" + assert overlays[-1]["stroke_width"] == 3 + assert overlays[-1]["image"].startswith("data:image/png;base64,") + + shapes = json.dumps([ + {"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 3, "color": "#ff0000"}, + {"kind": "rectangle", "x1": 0.2, "y1": 0.2, "x2": 0.8, "y2": 0.5, "width": 2, "color": "#00ff00"}, + {"kind": "circle", "x1": 0.25, "y1": 0.55, "x2": 0.55, "y2": 0.85, "width": 2, "color": "#4fc3f7"}, + {"kind": "arrow", "x1": 0.15, "y1": 0.85, "x2": 0.85, "y2": 0.2, "width": 4, "color": "#ffffff"}, + ]) + marked_field, = node.process(input=field, shape="arrow", stroke_color="#ffffff", stroke_width=4, markup_shapes=shapes) + marked = render_datafield_preview(marked_field, marked_field.colormap) + assert marked.shape == base.shape + assert not np.array_equal(marked, base) + assert overlays[-1]["shape"] == "arrow" + assert overlays[-1]["stroke_color"] == "#ffffff" + assert overlays[-1]["stroke_width"] == 4 + + viewport_image = ImageData( + np.zeros((48, 48, 3), dtype=np.uint8), + metadata={"annotation_context": {"xreal": 1e-6, "si_unit_xy": "m"}}, + ) + image_markup, = node.process( + input=viewport_image, shape="line", stroke_color="#ff0000", stroke_width=4, + markup_shapes=json.dumps([{"kind": "line", "x1": 0.1, "y1": 0.2, "x2": 0.9, "y2": 0.8, "width": 4, "color": "#ff0000"}]), + ) + assert isinstance(image_markup, ImageData) + assert image_markup.metadata["annotation_context"]["si_unit_xy"] == "m" + assert not np.array_equal(np.asarray(image_markup), np.asarray(viewport_image)) diff --git a/tests/node_tests/test_mask_draw.py b/tests/node_tests/test_mask_draw.py new file mode 100644 index 0000000..0e812f0 --- /dev/null +++ b/tests/node_tests/test_mask_draw.py @@ -0,0 +1,40 @@ +import json +import numpy as np +from tests.node_tests._shared import make_field + + +def test_draw_mask(): + from backend.nodes.mask_draw import DrawMask + node = DrawMask() + + field = make_field(data=np.zeros((32, 32), dtype=np.float64)) + overlays = [] + DrawMask._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + DrawMask._current_node_id = "test" + + mask_paths = [{"size": 5, "points": [{"x": 0.2, "y": 0.5}, {"x": 0.8, "y": 0.5}]}] + + mask, = node.process(field, pen_size=2, invert=False, mask_paths=json.dumps(mask_paths)) + assert mask.dtype == np.uint8 + assert mask.shape == (32, 32) + assert mask[16, 16] == 255 + assert mask[14, 16] == 255 + assert mask[0, 0] == 0 + + assert len(overlays) == 1 + assert overlays[0]["kind"] == "mask_paint" + assert overlays[0]["section_title"] == "Mask" + assert overlays[0]["image"].startswith("data:image/png;base64,") + assert overlays[0]["image_width"] == field.xres + assert overlays[0]["image_height"] == field.yres + assert overlays[0]["invert"] is False + + inverted, = node.process(field, pen_size=2, invert=True, mask_paths=json.dumps(mask_paths)) + assert inverted[16, 16] == 0 + assert inverted[0, 0] == 255 + assert overlays[-1]["invert"] is True + + cleared, = node.process(field, pen_size=12, invert=False, mask_paths="[]") + assert np.count_nonzero(cleared) == 0 + + DrawMask._broadcast_overlay_fn = None diff --git a/tests/node_tests/test_mask_invert.py b/tests/node_tests/test_mask_invert.py new file mode 100644 index 0000000..de83e0f --- /dev/null +++ b/tests/node_tests/test_mask_invert.py @@ -0,0 +1,17 @@ +import numpy as np + + +def test_mask_invert(): + from backend.nodes.mask_invert import MaskInvert + node = MaskInvert() + + mask = np.zeros((64, 64), dtype=np.uint8) + mask[10:20, 10:20] = 255 + + inverted, = node.process(mask) + assert inverted.dtype == np.uint8 + assert np.all(inverted[10:20, 10:20] == 0) + assert np.all(inverted[0:10, 0:10] == 255) + + double, = node.process(inverted) + assert np.array_equal(double, mask) diff --git a/tests/node_tests/test_mask_morphology.py b/tests/node_tests/test_mask_morphology.py new file mode 100644 index 0000000..b7b8b5b --- /dev/null +++ b/tests/node_tests/test_mask_morphology.py @@ -0,0 +1,29 @@ +import numpy as np + + +def test_mask_morphology(): + from backend.nodes.mask_morphology import MaskMorphology + node = MaskMorphology() + + mask = np.zeros((64, 64), dtype=np.uint8) + mask[28:36, 28:36] = 255 + orig_count = np.count_nonzero(mask) + + dilated, = node.process(mask, operation="dilate", radius=1, shape="square") + assert dilated.dtype == np.uint8 + assert np.count_nonzero(dilated) > orig_count + + eroded, = node.process(mask, operation="erode", radius=1, shape="square") + assert np.count_nonzero(eroded) < orig_count + + opened, = node.process(mask, operation="open", radius=1, shape="square") + assert np.count_nonzero(opened) <= orig_count + + mask_hole = mask.copy() + mask_hole[32, 32] = 0 + assert np.count_nonzero(mask_hole) == orig_count - 1 + closed, = node.process(mask_hole, operation="close", radius=1, shape="square") + assert closed[32, 32] == 255 + + dilated_disk, = node.process(mask, operation="dilate", radius=2, shape="disk") + assert np.count_nonzero(dilated_disk) > orig_count diff --git a/tests/node_tests/test_mask_operations.py b/tests/node_tests/test_mask_operations.py new file mode 100644 index 0000000..623537b --- /dev/null +++ b/tests/node_tests/test_mask_operations.py @@ -0,0 +1,44 @@ +import numpy as np + + +def test_mask_operations(): + from backend.nodes.mask_operations import MaskOperations + node = MaskOperations() + + a = np.zeros((64, 64), dtype=np.uint8) + a[10:30, 10:30] = 255 + b = np.zeros((64, 64), dtype=np.uint8) + b[20:40, 20:40] = 255 + + result_and, = node.process(a, b, operation="and") + assert np.all(result_and[20:30, 20:30] == 255) + assert result_and[15, 15] == 0 + assert result_and[35, 35] == 0 + + result_or, = node.process(a, b, operation="or") + assert result_or[15, 15] == 255 + assert result_or[35, 35] == 255 + assert result_or[25, 25] == 255 + assert result_or[5, 5] == 0 + + result_xor, = node.process(a, b, operation="xor") + assert result_xor[15, 15] == 255 + assert result_xor[35, 35] == 255 + assert result_xor[25, 25] == 0 + + result_sub, = node.process(a, b, operation="a_minus_b") + assert result_sub[15, 15] == 255 + assert result_sub[25, 25] == 0 + assert result_sub[35, 35] == 0 + + result_nand, = node.process(a, b, operation="nand") + assert result_nand[15, 15] == 255 + assert result_nand[35, 35] == 255 + assert result_nand[25, 25] == 0 + assert result_nand[5, 5] == 255 + + result_xnor, = node.process(a, b, operation="xnor") + assert result_xnor[25, 25] == 255 + assert result_xnor[5, 5] == 255 + assert result_xnor[15, 15] == 0 + assert result_xnor[35, 35] == 0 diff --git a/tests/node_tests/test_mask_threshold.py b/tests/node_tests/test_mask_threshold.py new file mode 100644 index 0000000..fa6bdca --- /dev/null +++ b/tests/node_tests/test_mask_threshold.py @@ -0,0 +1,36 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_threshold_mask(): + from backend.nodes.mask_threshold import ThresholdMask + node = ThresholdMask() + + data = np.zeros((64, 64)) + data[:, 32:] = 1.0 + field = make_field(data=data) + + previews = [] + ThresholdMask._broadcast_fn = lambda nid, uri: previews.append(uri) + ThresholdMask._current_node_id = "test" + + mask, = node.process(field, method="absolute", threshold=0.5, direction="above") + assert mask.dtype == np.uint8 + assert mask.shape == (64, 64) + assert np.all(mask[:, :32] == 0) + assert np.all(mask[:, 32:] == 255) + + assert len(previews) == 1 + assert previews[0].startswith("data:image/png;base64,") + + mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below") + assert np.all(mask_below[:, :32] == 255) + assert np.all(mask_below[:, 32:] == 0) + + mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above") + assert np.all(mask_rel[:, 32:] == 255) + + mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above") + assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum() + + ThresholdMask._broadcast_fn = None diff --git a/tests/node_tests/test_number.py b/tests/node_tests/test_number.py new file mode 100644 index 0000000..e3b4f91 --- /dev/null +++ b/tests/node_tests/test_number.py @@ -0,0 +1,10 @@ +def test_number(): + from backend.nodes.number import Number + + node = Number() + + result = node.process(value=1.25) + assert result == (1.25,) + + result_neg = node.process(value=-3.5) + assert result_neg == (-3.5,) diff --git a/tests/node_tests/test_preview_image.py b/tests/node_tests/test_preview_image.py new file mode 100644 index 0000000..1baca09 --- /dev/null +++ b/tests/node_tests/test_preview_image.py @@ -0,0 +1,58 @@ +import numpy as np +from backend.data_types import DataField, ImageData +from backend.execution_context import active_node, execution_callbacks +from tests.node_tests._shared import make_field + + +def test_preview_image(): + from backend.nodes.preview_image import PreviewImage + + node = PreviewImage() + preview_input = PreviewImage.INPUT_TYPES()["optional"]["input"] + assert preview_input[0] == "ANNOTATION_SOURCE" + assert preview_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] + + captured = [] + with execution_callbacks(preview=lambda nid, data_uri: captured.append(data_uri)), active_node("test"): + field = make_field() + node.preview(colormap="viridis", input=field) + assert len(captured) == 1 + assert captured[0].startswith("data:image/png;base64,") + + captured.clear() + field_with_overlay = field.replace(overlays=[{"kind": "annotation", "show_scale_bar": True, "show_color_map": False, "text_size": 14.0}]) + node.preview(colormap="viridis", input=field_with_overlay) + assert len(captured) == 1 + assert captured[0].startswith("data:image/png;base64,") + + captured.clear() + custom_colormap = { + "mode": "custom", + "stops": [ + {"position": 0.0, "color": "#000000"}, + {"position": 0.5, "color": "#ff0000"}, + {"position": 1.0, "color": "#ffffff"}, + ], + } + node.preview(colormap="auto", input=field, colormap_map=custom_colormap) + assert len(captured) == 1 + assert captured[0].startswith("data:image/png;base64,") + + captured.clear() + arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8) + node.preview(colormap="gray", input=arr) + assert len(captured) == 1 + + captured.clear() + node.preview(colormap="auto", input=field_with_overlay) + assert len(captured) == 1 + assert captured[0].startswith("data:image/png;base64,") + + captured.clear() + annotated_image = ImageData( + np.zeros((24, 24, 3), dtype=np.uint8), + metadata={"annotation_context": {"xreal": 1e-6, "si_unit_xy": "m"}}, + ) + node.preview(colormap="auto", input=annotated_image) + assert len(captured) == 1 + assert captured[0].startswith("data:image/png;base64,") diff --git a/tests/node_tests/test_print_table.py b/tests/node_tests/test_print_table.py new file mode 100644 index 0000000..e6b1f23 --- /dev/null +++ b/tests/node_tests/test_print_table.py @@ -0,0 +1,18 @@ +def test_print_table(): + from backend.nodes.print_table import PrintTable + node = PrintTable() + + table_spec = PrintTable.INPUT_TYPES()["required"]["table"] + assert table_spec[0] == "RECORD_TABLE" + assert table_spec[1]["accepted_types"] == ["DATA_TABLE"] + + captured = [] + PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows) + PrintTable._current_node_id = "test" + + table = [{"quantity": "test", "value": 42.0, "unit": "m"}] + node.print_table(table=table) + assert len(captured) == 1 + assert captured[0] == table + + PrintTable._broadcast_table_fn = None diff --git a/tests/node_tests/test_psdf.py b/tests/node_tests/test_psdf.py new file mode 100644 index 0000000..332e147 --- /dev/null +++ b/tests/node_tests/test_psdf.py @@ -0,0 +1,35 @@ +import numpy as np +from backend.data_types import DataField + + +def test_psdf_node(): + from backend.nodes.fft_2d import FFT2D + from backend.nodes.psdf import PSDF + + field = DataField( + data=np.random.default_rng(17).standard_normal((64, 64)), + xreal=2.0e-6, yreal=1.0e-6, si_unit_xy="m", si_unit_z="nm", + ) + + fft_node = FFT2D() + psdf_node = PSDF() + + fft_psdf = fft_node.process(field, windowing="hann", level="plane")[3] + psdf, = psdf_node.process(field, windowing="hann", level="plane") + assert np.allclose(psdf.data, fft_psdf.data) + assert psdf.data.shape == field.data.shape + assert psdf.domain == "frequency" + assert psdf.si_unit_xy == "1/m" + assert psdf.si_unit_z == "nm^2 m^2" + assert np.all(psdf.data >= 0.0) + + white = DataField( + data=np.random.default_rng(123).standard_normal((128, 128)), + xreal=1.0e-6, yreal=1.0e-6, si_unit_xy="m", si_unit_z="m", + ) + psdf_white, = psdf_node.process(white, windowing="none", level="none") + variance = float(np.var(white.data)) + dk_x = psdf_white.xreal / psdf_white.xres + dk_y = psdf_white.yreal / psdf_white.yres + integral = float(np.sum(psdf_white.data) * dk_x * dk_y) + assert 0.8 < integral / variance < 1.2 diff --git a/tests/node_tests/test_range_slider.py b/tests/node_tests/test_range_slider.py new file mode 100644 index 0000000..a9aaa96 --- /dev/null +++ b/tests/node_tests/test_range_slider.py @@ -0,0 +1,16 @@ +def test_range_slider(): + from backend.nodes.range_slider import RangeSlider + + node = RangeSlider() + + result = node.process(min_value=0.0, max_value=10.0, value=3.25) + assert result == (3.25,) + + result_high = node.process(min_value=0.0, max_value=10.0, value=12.0) + assert result_high == (10.0,) + + result_reversed = node.process(min_value=5.0, max_value=-1.0, value=4.0) + assert result_reversed == (4.0,) + + result_fixed = node.process(min_value=2.5, max_value=2.5, value=99.0) + assert result_fixed == (2.5,) diff --git a/tests/node_tests/test_rotate.py b/tests/node_tests/test_rotate.py new file mode 100644 index 0000000..871db6b --- /dev/null +++ b/tests/node_tests/test_rotate.py @@ -0,0 +1,60 @@ +import numpy as np +from backend.data_types import DataField + + +def test_rotate_field(): + from backend.nodes.rotate import RotateField + node = RotateField() + + data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) + field = DataField(data=data, xreal=6.0, yreal=4.0, xoff=10.0, yoff=20.0, si_unit_xy="nm", si_unit_z="nm") + + rotated_90, = node.process(field, angle=90.0, interpolation="nearest", expand_canvas=True) + assert np.array_equal(rotated_90.data, np.rot90(data)) + assert rotated_90.data.shape == (3, 2) + assert rotated_90.xreal == 4.0 + assert rotated_90.yreal == 6.0 + assert rotated_90.xoff == 11.0 + assert rotated_90.yoff == 19.0 + assert rotated_90.si_unit_xy == field.si_unit_xy + assert rotated_90.si_unit_z == field.si_unit_z + assert rotated_90.overlays == [] + + rotated_180, = node.process(field, angle=180.0, interpolation="nearest", expand_canvas=False) + assert np.array_equal(rotated_180.data, np.rot90(data, 2)) + assert rotated_180.data.shape == data.shape + assert rotated_180.xreal == field.xreal + assert rotated_180.yreal == field.yreal + assert rotated_180.xoff == field.xoff + assert rotated_180.yoff == field.yoff + + rotated_45, = node.process(field, angle=45.0, interpolation="bilinear", expand_canvas=True) + expected_xreal = abs(field.xreal * np.cos(np.deg2rad(45.0))) + abs(field.yreal * np.sin(np.deg2rad(45.0))) + expected_yreal = abs(field.xreal * np.sin(np.deg2rad(45.0))) + abs(field.yreal * np.cos(np.deg2rad(45.0))) + assert rotated_45.data.shape[0] > field.data.shape[0] + assert rotated_45.data.shape[1] > field.data.shape[1] + assert np.isclose(rotated_45.xreal, expected_xreal) + assert np.isclose(rotated_45.yreal, expected_yreal) + assert np.isclose(rotated_45.xoff + rotated_45.xreal / 2.0, field.xoff + field.xreal / 2.0) + assert np.isclose(rotated_45.yoff + rotated_45.yreal / 2.0, field.yoff + field.yreal / 2.0) + + +def test_rotate_field_overlay_warning(): + from backend.nodes.rotate import RotateField + + node = RotateField() + warnings = [] + RotateField._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) + RotateField._current_node_id = "test" + + field = DataField( + data=np.arange(16, dtype=np.float64).reshape(4, 4), + overlays=[{"kind": "markup", "shapes": [{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 2, "color": "#ffffff"}]}], + ) + + rotated, = node.process(field, angle=30.0, interpolation="bilinear", expand_canvas=True) + assert rotated.overlays == [] + assert len(warnings) == 1 + assert "clears annotation/markup overlays" in warnings[0] + + RotateField._broadcast_warning_fn = None diff --git a/tests/node_tests/test_save.py b/tests/node_tests/test_save.py new file mode 100644 index 0000000..241dd58 --- /dev/null +++ b/tests/node_tests/test_save.py @@ -0,0 +1,138 @@ +import json +import os +import tempfile +from pathlib import Path + +import numpy as np +import tifffile +from PIL import Image as PILImage + +from backend.data_types import DataField, ImageData, LineData, RecordTable, MeshModel, DataTable + + +def test_save_generic(): + from backend.nodes.save import Save + + node = Save() + value_spec = node.INPUT_TYPES()["required"]["value"] + assert value_spec[0] == "DATA_FIELD" + assert value_spec[1]["accepted_types"] == [ + "IMAGE", "ANNOTATION_SOURCE", "LINE", "RECORD_TABLE", "DATA_TABLE", "MESH_MODEL", "FLOAT", + ] + format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"] + assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"] + + with tempfile.TemporaryDirectory() as tmpdir: + node.save(filename="scalar", directory_path=tmpdir, format="TXT", value=3.5) + assert Path(tmpdir, "scalar.txt").read_text(encoding="utf-8").strip() == "3.5" + node.save(filename="scalar_json", directory_path=tmpdir, format="JSON", value=3.5) + assert json.loads(Path(tmpdir, "scalar_json.json").read_text(encoding="utf-8")) == {"value": 3.5} + + line = LineData(data=np.array([1.0, 2.0, 3.0]), x_axis=np.array([0.0, 0.5, 1.0]), x_unit="um", y_unit="nm") + node.save(filename="profile", directory_path=tmpdir, format="CSV", value=line) + csv_text = Path(tmpdir, "profile.csv").read_text(encoding="utf-8") + assert "x,y,x_unit,y_unit" in csv_text + assert "um" in csv_text and "nm" in csv_text + node.save(filename="profile_npz", directory_path=tmpdir, format="NPZ", value=line) + line_npz = np.load(Path(tmpdir, "profile_npz.npz")) + assert np.allclose(line_npz["x"], line.x_axis) + assert np.allclose(line_npz["y"], line.data) + node.save(filename="profile_json", directory_path=tmpdir, format="JSON", value=line) + line_json = json.loads(Path(tmpdir, "profile_json.json").read_text(encoding="utf-8")) + assert line_json["x_unit"] == "um" + assert line_json["y_unit"] == "nm" + assert line_json["x"] == [0.0, 0.5, 1.0] + assert line_json["y"] == [1.0, 2.0, 3.0] + + field = DataField( + data=np.array([[1.0, 2.0], [3.0, 4.5]], dtype=np.float64), + xreal=2e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m", colormap="viridis", + ) + node.save(filename="field_tiff", directory_path=tmpdir, format="TIFF", value=field) + field_tiff = tifffile.imread(Path(tmpdir, "field_tiff.tiff")) + assert field_tiff.shape == field.data.shape + assert field_tiff.dtype == np.float32 + assert np.allclose(field_tiff, field.data.astype(np.float32)) + + node.save(filename="field_png", directory_path=tmpdir, format="PNG", value=field) + field_png = np.asarray(PILImage.open(Path(tmpdir, "field_png.png"))) + assert field_png.shape == (2, 2, 3) + assert field_png.dtype == np.uint8 + + node.save(filename="field_npz", directory_path=tmpdir, format="NPZ", value=field) + field_npz = np.load(Path(tmpdir, "field_npz.npz")) + assert np.allclose(field_npz["field"], field.data) + + image = np.array([[[255, 0, 0], [0, 255, 0]], [[0, 0, 255], [255, 255, 0]]], dtype=np.uint8) + node.save(filename="image_png", directory_path=tmpdir, format="PNG", value=image) + image_png = np.asarray(PILImage.open(Path(tmpdir, "image_png.png"))) + assert image_png.shape == image.shape + assert np.array_equal(image_png, image) + + node.save(filename="image_tiff", directory_path=tmpdir, format="TIFF", value=image) + image_tiff = tifffile.imread(Path(tmpdir, "image_tiff.tiff")) + assert image_tiff.shape == image.shape + assert image_tiff.dtype == np.uint8 + assert np.array_equal(image_tiff, image) + + node.save(filename="image_npz", directory_path=tmpdir, format="NPZ", value=image) + image_npz = np.load(Path(tmpdir, "image_npz.npz")) + assert np.array_equal(image_npz["image"], image) + + annotation_image = ImageData(image, metadata={"annotation_context": {"si_unit_xy": "um", "si_unit_z": "nm"}}) + node.save(filename="annotation_png", directory_path=tmpdir, format="PNG", value=annotation_image) + assert np.array_equal(np.asarray(PILImage.open(Path(tmpdir, "annotation_png.png"))), image) + + node.save(filename="annotation_tiff", directory_path=tmpdir, format="TIFF", value=annotation_image) + assert np.array_equal(tifffile.imread(Path(tmpdir, "annotation_tiff.tiff")), image) + + node.save(filename="annotation_npz", directory_path=tmpdir, format="NPZ", value=annotation_image) + assert np.array_equal(np.load(Path(tmpdir, "annotation_npz.npz"))["image"], image) + + measure_table = RecordTable([ + {"quantity": "Rq", "value": 1.23, "unit": "nm"}, + {"quantity": "Ra", "value": 0.98, "unit": "nm"}, + ]) + node.save(filename="measurements_csv", directory_path=tmpdir, format="CSV", value=measure_table) + measure_csv = Path(tmpdir, "measurements_csv.csv").read_text(encoding="utf-8") + assert "quantity,value,unit" in measure_csv + assert "Rq,1.23,nm" in measure_csv + node.save(filename="measurements_json", directory_path=tmpdir, format="JSON", value=measure_table) + assert json.loads(Path(tmpdir, "measurements_json.json").read_text(encoding="utf-8")) == list(measure_table) + + record_table = DataTable([ + {"label": "particle-1", "height": 12.0, "area": 44.0}, + {"label": "particle-2", "height": 8.0, "area": 21.0}, + ]) + node.save(filename="records_csv", directory_path=tmpdir, format="CSV", value=record_table) + record_csv = Path(tmpdir, "records_csv.csv").read_text(encoding="utf-8") + assert "label,height,area" in record_csv + assert "particle-1,12.0,44.0" in record_csv + node.save(filename="records_json", directory_path=tmpdir, format="JSON", value=record_table) + assert json.loads(Path(tmpdir, "records_json.json").read_text(encoding="utf-8")) == list(record_table) + + mesh = MeshModel( + vertices=np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=np.float32), + faces=np.array([[0, 1, 2]], dtype=np.int32), + ) + node.save(filename="triangle", directory_path=tmpdir, format="OBJ", value=mesh) + obj_text = Path(tmpdir, "triangle.obj").read_text(encoding="utf-8") + assert "v 0.0 0.0 0.0" in obj_text + assert "f 1 2 3" in obj_text + + node.save(filename="triangle", directory_path=tmpdir, format="STL", value=mesh) + stl_text = Path(tmpdir, "triangle.stl").read_text(encoding="utf-8") + assert stl_text.startswith("solid argonode") + assert "facet normal" in stl_text + + try: + node.save(filename="triangle", directory_path=tmpdir, format="PNG", value=mesh) + assert False, "Mesh should only be saveable as OBJ or STL" + except ValueError: + pass + + try: + node.save(filename="field_bad", directory_path=tmpdir, format="CSV", value=field) + assert False, "DATA_FIELD should reject unsupported save formats" + except ValueError: + pass diff --git a/tests/node_tests/test_save_layers.py b/tests/node_tests/test_save_layers.py new file mode 100644 index 0000000..e5130b7 --- /dev/null +++ b/tests/node_tests/test_save_layers.py @@ -0,0 +1,77 @@ +import os +import tempfile + +import numpy as np +import tifffile +from PIL import Image + +from tests.node_tests._shared import make_field + + +def test_save_image(): + from backend.nodes.save_layers import SaveImage + + node = SaveImage() + input_types = SaveImage.INPUT_TYPES() + field_spec = input_types["optional"]["field_0"] + assert field_spec[0] == "DATA_FIELD" + assert field_spec[1]["accepted_types"] == ["IMAGE", "ANNOTATION_SOURCE"] + + field_a = make_field(data=np.random.default_rng(4).random((32, 32))) + field_b = make_field(data=np.random.default_rng(5).random((32, 32))) + annotated = np.zeros((24, 24, 3), dtype=np.uint8) + annotated[..., 0] = 255 + + with tempfile.TemporaryDirectory() as tmpdir: + tiff_path = os.path.join(tmpdir, "out.tiff") + node.save(filename=tiff_path, format="TIFF", field_0=field_a) + assert os.path.exists(tiff_path) + im = Image.open(tiff_path) + assert im.n_frames == 1 + assert np.array(im).shape == (32, 32) + + tiff_path2 = os.path.join(tmpdir, "multi.tiff") + node.save(filename=tiff_path2, format="TIFF", field_0=field_a, field_1=field_b) + im2 = Image.open(tiff_path2) + assert im2.n_frames == 2 + + annotated_tiff = os.path.join(tmpdir, "annotated.tiff") + node.save(filename=annotated_tiff, format="TIFF", field_0=annotated, layer_name_0="annotated overview") + with tifffile.TiffFile(annotated_tiff) as tif: + assert len(tif.pages) == 1 + assert tif.pages[0].description == "annotated overview" + assert tif.pages[0].asarray().shape == annotated.shape + + npz_path = os.path.join(tmpdir, "out.npz") + node.save(filename=npz_path, format="NPZ", field_0=field_a, field_1=annotated, layer_name_0="height map", layer_name_1="annotated-overview") + assert os.path.exists(npz_path) + npz = np.load(npz_path) + assert len(npz.files) == 2 + assert np.allclose(npz["height_map"], field_a.data) + assert np.array_equal(npz["annotated_overview"], annotated) + + wrong_ext = os.path.join(tmpdir, "output.png") + node.save(filename=wrong_ext, format="TIFF", field_0=field_a) + assert os.path.exists(os.path.join(tmpdir, "output.tiff")) + + driven_dir = os.path.join(tmpdir, "nested-output") + node.save(filename="driven_name", directory=driven_dir, format="NPZ", field_0=field_a) + assert os.path.exists(os.path.join(driven_dir, "driven_name.npz")) + + try: + node.save(filename="bad", directory=os.path.join(tmpdir, "looks_like_file.txt"), format="TIFF", field_0=field_a) + assert False, "Should have raised ValueError for file-like directory path" + except ValueError: + pass + + try: + node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF") + assert False, "Should have raised ValueError" + except ValueError: + pass + + try: + node.save(filename="", format="TIFF", field_0=field_a) + assert False, "Should have raised ValueError" + except ValueError: + pass diff --git a/tests/node_tests/test_scar_removal.py b/tests/node_tests/test_scar_removal.py new file mode 100644 index 0000000..be85bd5 --- /dev/null +++ b/tests/node_tests/test_scar_removal.py @@ -0,0 +1,48 @@ +import numpy as np +from backend.node_registry import get_node_info +from tests.node_tests._shared import make_field + + +def test_scar_removal(): + from backend.nodes.scar_removal import ScarRemoval + + node = ScarRemoval() + info = get_node_info("ScarRemoval") + assert info["category"] == "Filter" + assert {entry["category"] for entry in info["menu_categories"]} == {"Filter", "Level & Correct"} + + rows = 96 + cols = 128 + yy, xx = np.mgrid[0:rows, 0:cols] + base = ( + 0.005 * xx + 0.01 * yy + + 0.12 * np.sin(2.0 * np.pi * xx / cols) + + 0.07 * np.cos(2.0 * np.pi * yy / rows) + ) + scarred = base.copy() + scarred[24, 20:92] += 1.8 + scarred[25, 20:92] += 1.6 + scarred[60, 12:116] -= 1.7 + + field = make_field(data=scarred) + corrected, scar_mask = node.process( + field, scar_type="both", threshold_high=0.6, threshold_low=0.2, min_length=12, max_width=4, + ) + + mask_bool = scar_mask > 127 + assert scar_mask.dtype == np.uint8 + assert scar_mask.shape == field.data.shape + assert np.count_nonzero(mask_bool) > 0 + assert np.count_nonzero(mask_bool[24:26, 20:92]) > 0 + assert np.count_nonzero(mask_bool[60:61, 12:116]) > 0 + assert np.allclose(corrected.data[~mask_bool], field.data[~mask_bool]) + + before_rmse = np.sqrt(np.mean((field.data[mask_bool] - base[mask_bool]) ** 2)) + after_rmse = np.sqrt(np.mean((corrected.data[mask_bool] - base[mask_bool]) ** 2)) + assert after_rmse < before_rmse * 0.35 + + clean_corrected, clean_mask = node.process( + make_field(data=base), scar_type="both", threshold_high=0.6, threshold_low=0.2, min_length=12, max_width=4, + ) + assert np.count_nonzero(clean_mask) == 0 + assert np.allclose(clean_corrected.data, base) diff --git a/tests/node_tests/test_statistics.py b/tests/node_tests/test_statistics.py new file mode 100644 index 0000000..f0ed4d5 --- /dev/null +++ b/tests/node_tests/test_statistics.py @@ -0,0 +1,28 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_statistics(): + from backend.nodes.statistics import Statistics + node = Statistics() + + data = np.array([[1, 2], [3, 4]], dtype=np.float64) + field = make_field(data=data) + + table, = node.process(field) + stats = {row["quantity"]: row["value"] for row in table} + + assert stats["min"] == 1.0 + assert stats["max"] == 4.0 + assert stats["mean"] == 2.5 + assert stats["median"] == 2.5 + assert stats["range"] == 3.0 + expected_rms = np.sqrt(np.mean((data - 2.5) ** 2)) + assert abs(stats["RMS"] - expected_rms) < 1e-10 + + const_field = make_field(data=np.ones((4, 4)) * 5.0) + table_const, = node.process(const_field) + const_stats = {row["quantity"]: row["value"] for row in table_const} + assert const_stats["RMS"] == 0.0 + assert const_stats["skewness"] == 0.0 + assert const_stats["kurtosis"] == 0.0 diff --git a/tests/node_tests/test_stats.py b/tests/node_tests/test_stats.py new file mode 100644 index 0000000..f32a291 --- /dev/null +++ b/tests/node_tests/test_stats.py @@ -0,0 +1,68 @@ +import numpy as np +from backend.data_types import DataTable, RecordTable +from tests.node_tests._shared import make_field + + +def test_stats(): + from backend.nodes.stats import Stats + + node = Stats() + input_spec = Stats.INPUT_TYPES()["required"]["input"] + assert input_spec[0] == "DATA_FIELD" + assert input_spec[1]["accepted_types"] == ["IMAGE", "LINE", "DATA_TABLE"] + + captured = [] + Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) + Stats._current_node_id = "test" + + line = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float64) + result, = node.process(line, operation="mean", column="value") + assert np.isclose(result, 2.5) + assert captured[-1] == ("test", {"value": result}) + roughness, = node.process(line, operation="Rq", column="value") + assert np.isclose(roughness, np.sqrt(np.mean((line - line.mean()) ** 2))) + + table = DataTable([ + {"name": "a", "value": 3.0, "unit": "m", "other": 10.0}, + {"name": "b", "value": 7.0, "unit": "m", "other": 20.0}, + ]) + result, = node.process(table, operation="max", column="value") + assert result == 7.0 + assert captured[-1] == ("test", {"value": 7.0, "unit": "m"}) + count, = node.process(table, operation="count", column="other") + assert count == 2.0 + auto_column_range, = node.process(table, operation="range", column="") + assert auto_column_range == 4.0 + + field = make_field(data=np.array([[1.0, 5.0], [2.0, 4.0]], dtype=np.float64)) + result, = node.process(field, operation="range", column="value") + assert result == 4.0 + assert captured[-1] == ("test", {"value": 4.0, "unit": "m"}) + + image = np.array([[0, 10], [20, 30]], dtype=np.uint8) + result, = node.process(image, operation="avg", column="value") + assert np.isclose(result, 15.0) + assert captured[-1] == ("test", {"value": 15.0}) + + try: + node.process(table, operation="Rq", column="value") + raise AssertionError("Expected invalid TABLE operation to raise ValueError") + except ValueError: + pass + + try: + node.process([{"label": "only text"}], operation="max", column="label") + raise AssertionError("Expected non-numeric record-table input to raise ValueError") + except ValueError: + pass + + try: + node.process( + RecordTable([{"quantity": "min", "value": 1.0, "unit": "m"}]), + operation="max", column="value", + ) + raise AssertionError("Expected measurement table input to raise ValueError") + except ValueError: + pass + + Stats._broadcast_value_fn = None diff --git a/tests/node_tests/test_value_io.py b/tests/node_tests/test_value_io.py new file mode 100644 index 0000000..e7faefa --- /dev/null +++ b/tests/node_tests/test_value_io.py @@ -0,0 +1,28 @@ +from backend.data_types import RecordTable + + +def test_value_display(): + from backend.nodes.value_io import ValueIO + + node = ValueIO() + value_spec = ValueIO.INPUT_TYPES()["required"]["value"] + assert value_spec[0] == "FLOAT" + assert value_spec[1]["accepted_types"] == ["RECORD_TABLE"] + + captured = [] + ValueIO._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) + ValueIO._current_node_id = "test" + + result = node.display_value(3.25) + assert result == (3.25,) + assert captured == [("test", {"value": 3.25})] + + measurements = RecordTable([ + {"quantity": "delta X", "value": 1.7e-7, "unit": "m"}, + {"quantity": "delta Y", "value": 463, "unit": "count"}, + ]) + result = node.display_value(measurements, measurement="delta X") + assert result == (1.7e-7,) + assert captured[-1] == ("test", {"value": 1.7e-7, "unit": "m"}) + + ValueIO._broadcast_value_fn = None diff --git a/tests/node_tests/test_view_3d.py b/tests/node_tests/test_view_3d.py new file mode 100644 index 0000000..49bdd08 --- /dev/null +++ b/tests/node_tests/test_view_3d.py @@ -0,0 +1,106 @@ +import base64 +import io + +import numpy as np +from PIL import Image +from backend.data_types import DataField, ImageData, MeshModel +from backend.execution_context import active_node, execution_callbacks +from tests.node_tests._shared import make_field + + +def test_view3d_normalizes_small_physical_extents_for_display(): + from backend.nodes.view_3d import View3D + + data = np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64) + field = DataField(data=data, xreal=1.0e-5, yreal=1.0e-5, si_unit_xy="m", si_unit_z="m") + + node = View3D() + mesh, _ = node.render(field, colormap="auto", z_scale=1.0, resolution=64, make_solid=False) + + vertices = np.asarray(mesh.vertices, dtype=np.float64) + spans = vertices.max(axis=0) - vertices.min(axis=0) + + assert np.isclose(spans[0], 1.0, atol=1e-6) + assert np.isclose(spans[2], 1.0, atol=1e-6) + assert spans[1] > 0.09 + + +def test_view3d(): + from backend.nodes.view_3d import View3D + + node = View3D() + field = make_field() + + captured = [] + mesh_callback = lambda nid, mesh: captured.append(mesh) + + preview_image = Image.new("RGB", (12, 10), (255, 0, 0)) + preview_buffer = io.BytesIO() + preview_image.save(preview_buffer, format="PNG") + viewport_snapshot = "data:image/png;base64," + base64.b64encode(preview_buffer.getvalue()).decode() + + with execution_callbacks(mesh=mesh_callback), active_node("test"): + result = node.render( + field, colormap="viridis", z_scale=2.0, resolution=64, make_solid=False, + camera_target_x=0.1, camera_target_y=-0.2, camera_target_z=0.3, + viewport_snapshot=viewport_snapshot, + ) + assert len(result) == 2 + assert isinstance(result[0], MeshModel) + assert isinstance(result[1], ImageData) + assert result[1].shape == (10, 12, 3) + assert np.all(result[1][0, 0] == np.array([255, 0, 0], dtype=np.uint8)) + assert result[1].metadata["annotation_context"]["si_unit_xy"] == field.si_unit_xy + assert result[1].metadata["viewport_camera"]["target_x"] == 0.1 + assert result[1].metadata["viewport_camera"]["target_y"] == -0.2 + assert result[1].metadata["viewport_camera"]["target_z"] == 0.3 + assert len(captured) == 1 + + mesh = captured[0] + assert "width" in mesh and "height" in mesh and "z_data" in mesh and "colors" in mesh + assert mesh["z_scale"] == 0.2 + assert mesh["width"] <= 64 + assert mesh["height"] <= 64 + assert mesh["camera_target_x"] == 0.1 + assert mesh["z_min"] < mesh["z_max"] + + z_bytes = base64.b64decode(mesh["z_data"]) + assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 + colors_bytes = base64.b64decode(mesh["colors"]) + assert len(colors_bytes) == mesh["width"] * mesh["height"] * 3 + + big_field = make_field(shape=(256, 256)) + captured.clear() + with execution_callbacks(mesh=mesh_callback), active_node("test"): + node.render(big_field, colormap="hot", z_scale=1.0, resolution=64, make_solid=False) + assert captured[0]["width"] <= 64 + assert captured[0]["height"] <= 64 + + mesh_field = make_field(data=np.zeros((64, 64), dtype=np.float64), xreal=2.0, yreal=3.0) + map_field = make_field(data=np.tile(np.linspace(0.0, 1.0, 64, dtype=np.float64), (64, 1)), xreal=2.0, yreal=3.0) + captured.clear() + with execution_callbacks(mesh=mesh_callback), active_node("test"): + mapped_result = node.render(mesh_field, map_field=map_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) + mapped_mesh = captured[0] + assert mapped_mesh["x_range"] == [float(mesh_field.xoff), float(mesh_field.xoff + mesh_field.xreal)] + assert np.isclose(mapped_mesh["surface_extent_x"] / mapped_mesh["surface_extent_y"], mesh_field.xreal / mesh_field.yreal) + mapped_z = np.frombuffer(base64.b64decode(mapped_mesh["z_data"]), dtype=np.float32) + assert np.allclose(mapped_z, 0.0) + mapped_colors = np.frombuffer(base64.b64decode(mapped_mesh["colors"]), dtype=np.uint8) + + captured.clear() + with execution_callbacks(mesh=mesh_callback), active_node("test"): + node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) + mesh_only_colors = np.frombuffer(base64.b64decode(captured[0]["colors"]), dtype=np.uint8) + assert not np.array_equal(mapped_colors, mesh_only_colors) + + captured.clear() + with execution_callbacks(mesh=mesh_callback), active_node("test"): + solid_result = node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=16, make_solid=True) + assert len(solid_result[0].vertices) > 16 * 16 + assert len(solid_result[0].faces) > (15 * 15 * 2) + solid_payload = captured[0] + assert solid_payload["make_solid"] is True + assert "positions" in solid_payload + assert "indices" in solid_payload + assert "vertex_colors" in solid_payload diff --git a/tests/node_tests/test_watershed_segmentation.py b/tests/node_tests/test_watershed_segmentation.py new file mode 100644 index 0000000..9b55dc0 --- /dev/null +++ b/tests/node_tests/test_watershed_segmentation.py @@ -0,0 +1,53 @@ +import numpy as np +from backend.execution_context import active_node, execution_callbacks +from tests.node_tests._shared import make_field + + +def test_watershed_segmentation(): + from scipy.ndimage import label + from backend.nodes.watershed_segmentation import WatershedSegmentation + + node = WatershedSegmentation() + y, x = np.mgrid[-1:1:64j, -1:1:64j] + data = ( + 2.0 * np.exp(-((x + 0.45) ** 2 + y**2) / 0.05) + + 2.0 * np.exp(-((x - 0.45) ** 2 + y**2) / 0.05) + - 0.3 * np.exp(-(x**2 + y**2) / 0.12) + ) + field = make_field(data=data, xreal=2.0e-6, yreal=2.0e-6) + + previews = [] + with execution_callbacks(preview=lambda nid, uri: previews.append(uri)), active_node("test"): + mask, = node.process( + field, + invert_height=False, + locate_steps=10, + locate_threshold=8, + locate_drop_size=0.1, + watershed_steps=20, + watershed_drop_size=0.1, + combine_mode="replace", + ) + assert mask.dtype == np.uint8 + assert mask.shape == field.data.shape + assert len(previews) == 1 + assert previews[0].startswith("data:image/png;base64,") + + _, ngrains = label(mask > 127) + assert ngrains >= 2 + + seed_mask = np.zeros_like(mask) + seed_mask[:, :32] = 255 + intersected, = node.process( + field, + invert_height=False, + locate_steps=10, + locate_threshold=8, + locate_drop_size=0.1, + watershed_steps=20, + watershed_drop_size=0.1, + combine_mode="intersection", + mask=seed_mask, + ) + assert np.count_nonzero(intersected) < np.count_nonzero(mask) + assert np.all(intersected[:, 40:] == 0)