diff --git a/backend/execution.py b/backend/execution.py index ba79fca..5be496c 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -31,7 +31,7 @@ from threading import RLock from time import perf_counter from typing import Any, Callable -from backend.node_registry import NODE_CLASS_MAPPINGS, get_node_output_types +from backend.node_registry import NODE_CLASS_MAPPINGS, get_node_output_types, get_node_output_accepted_types from backend.execution_context import active_node, execution_callbacks @@ -462,16 +462,19 @@ class ExecutionEngine: return return_types = get_node_output_types(cls) + output_accepted = get_node_output_accepted_types(cls) for slot, type_name in enumerate(return_types): if slot >= len(result): break value = result[slot] + all_types = {type_name} | set(output_accepted[slot] if slot < len(output_accepted) else []) - if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview: + # For polymorphic outputs, check the actual runtime type first. + if isinstance(value, DataField) and ("DATA_FIELD" in all_types) and on_preview: arr = render_datafield_preview(value, value.colormap) on_preview(node_id, encode_preview(arr)) - return # one preview per node is enough + return if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview: arr = image_to_uint8(value) @@ -488,7 +491,7 @@ class ExecutionEngine: on_preview(node_id, encode_preview(arr)) return - if type_name == "LINE" and isinstance(value, (np.ndarray, LineData)) and on_preview: + if "LINE" in all_types and isinstance(value, (np.ndarray, LineData)) and on_preview: preview = self._render_line_preview(cls, slot, result) if preview: on_preview(node_id, preview) diff --git a/backend/node_registry.py b/backend/node_registry.py index 41841af..c32e6a4 100644 --- a/backend/node_registry.py +++ b/backend/node_registry.py @@ -15,28 +15,38 @@ NODE_CLASS_MAPPINGS: dict[str, type] = {} NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {} -def get_node_output_specs(cls: type) -> tuple[tuple[str, str], ...]: +def get_node_output_specs(cls: type) -> tuple[tuple[str, str, dict], ...]: raw_outputs = getattr(cls, "OUTPUTS", None) if raw_outputs is None: raise AttributeError(f"{cls.__name__} must define OUTPUTS.") - specs: list[tuple[str, str]] = [] + specs: list[tuple[str, str, dict]] = [] for index, output in enumerate(raw_outputs): - if not isinstance(output, (list, tuple)) or len(output) != 2: + if not isinstance(output, (list, tuple)) or len(output) not in (2, 3): raise TypeError( - f"{cls.__name__}.OUTPUTS[{index}] must be a 2-item tuple of (type, name)." + f"{cls.__name__}.OUTPUTS[{index}] must be a 2- or 3-item tuple of (type, name[, meta])." ) - type_name, name = output - specs.append((str(type_name), str(name))) + type_name = output[0] + name = output[1] + meta: dict = output[2] if len(output) == 3 else {} + specs.append((str(type_name), str(name), meta)) return tuple(specs) def get_node_output_types(cls: type) -> tuple[str, ...]: - return tuple(type_name for type_name, _ in get_node_output_specs(cls)) + return tuple(type_name for type_name, _, _meta in get_node_output_specs(cls)) def get_node_output_names(cls: type) -> tuple[str, ...]: - return tuple(name for _, name in get_node_output_specs(cls)) + return tuple(name for _, name, _meta in get_node_output_specs(cls)) + + +def get_node_output_accepted_types(cls: type) -> tuple[list[str], ...]: + """Return per-slot accepted_types lists (empty list means only the declared type).""" + return tuple( + list(meta.get("accepted_types", [])) + for _, _, meta in get_node_output_specs(cls) + ) def register_node(display_name: str | None = None): @@ -77,6 +87,7 @@ def get_node_info(class_name: str) -> dict[str, Any]: "input_order": {k: list(v.keys()) for k, v in input_types.items()}, "output": list(get_node_output_types(cls)), "output_name": list(get_node_output_names(cls)), + "output_accepted_types": list(get_node_output_accepted_types(cls)), "output_node": bool(getattr(cls, "OUTPUT_NODE", False)), "manual_trigger": bool(getattr(cls, "MANUAL_TRIGGER", False)), "description": getattr(cls, "DESCRIPTION", ""), diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index b6d7f1a..4660955 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -4,8 +4,7 @@ from backend.nodes import ( colormap, crop_resize, fft_2d_inverse, - filter_fft_1d, - filter_fft_2d, + filter_fft, filter_gaussian, filter_median, flip, diff --git a/backend/nodes/filter_fft.py b/backend/nodes/filter_fft.py new file mode 100644 index 0000000..3e81f3e --- /dev/null +++ b/backend/nodes/filter_fft.py @@ -0,0 +1,89 @@ +from __future__ import annotations +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, LineData +from backend.nodes.helpers import _cached_1d_transfer, _cached_2d_transfer + + +@register_node(display_name="FFT Filter") +class FFTFilter: + """Frequency-domain filtering of a line profile or 2-D data field. + + Accepts either a LINE or DATA_FIELD and returns a filtered output of the + same type. Uses a Butterworth transfer function with configurable order + for a smooth roll-off. Equivalent to Gwyddion fft_filter_1d / fft_filter_2d. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input": ("LINE", { + "label": "input", + "accepted_types": ["DATA_FIELD"], + }), + "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), + "cutoff": ("FLOAT", { + "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "cutoff_high": ("FLOAT", { + "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, + }), + "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), + } + } + + OUTPUTS = ( + ('LINE', 'filtered', {"accepted_types": ["DATA_FIELD"]}), + ) + FUNCTION = "process" + + DESCRIPTION = ( + "Frequency-domain filtering of a line profile or 2-D data field. " + "Connect a LINE for 1-D filtering or a DATA_FIELD for 2-D filtering — " + "the output mirrors the input type. " + "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " + "with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency." + ) + + def process(self, input, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> tuple: + if isinstance(input, DataField): + return self._process_field(input, filter_type, float(cutoff), float(cutoff_high), int(order)) + return self._process_line(input, filter_type, float(cutoff), float(cutoff_high), int(order)) + + def _process_line(self, line, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> tuple: + z = np.asarray(line, dtype=np.float64).ravel() + n = len(z) + + Z = np.fft.rfft(z) + H = _cached_1d_transfer(n, filter_type, cutoff, cutoff_high, order) + Z *= H + filtered = np.fft.irfft(Z, n=n) + + if isinstance(line, LineData): + return ( + LineData( + data=filtered, + x_axis=line.x_axis.copy() if line.x_axis is not None else None, + x_unit=line.x_unit, + y_unit=line.y_unit, + ), + ) + return (filtered,) + + def _process_field(self, field: DataField, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> tuple: + data = field.data + yres, xres = data.shape + + mean_val = float(data.mean()) + centered = data - mean_val + + spectrum = np.fft.rfft2(centered) + transfer = _cached_2d_transfer(yres, xres, filter_type, cutoff, cutoff_high, order) + result = np.fft.irfft2(spectrum * transfer, s=(yres, xres)) + result += mean_val + + return (field.replace(data=result),) diff --git a/backend/nodes/filter_fft_1d.py b/backend/nodes/filter_fft_1d.py deleted file mode 100644 index 1ed0754..0000000 --- a/backend/nodes/filter_fft_1d.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations -import numpy as np -from backend.node_registry import register_node -from backend.data_types import LineData -from backend.nodes.helpers import _cached_1d_transfer - - -@register_node(display_name="FFT Filter 1D") -class FFTFilter1D: - """Bandpass / lowpass / highpass / notch filtering of 1-D line profiles. - - Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth - transfer function with configurable order for a smooth roll-off. - """ - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "line": ("LINE",), - "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), - "cutoff": ("FLOAT", { - "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, - }), - "cutoff_high": ("FLOAT", { - "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, - }), - "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), - } - } - - OUTPUTS = ( - ('LINE', 'filtered'), - ) - FUNCTION = "process" - - DESCRIPTION = ( - "Frequency-domain filtering of a 1-D line profile. " - "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " - "with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. " - "Equivalent to Gwyddion fft_filter_1d." - ) - - def process(self, line, filter_type: str, cutoff: float, - cutoff_high: float, order: int) -> tuple: - z = np.asarray(line, dtype=np.float64).ravel() - n = len(z) - - Z = np.fft.rfft(z) - H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order)) - Z *= H - filtered = np.fft.irfft(Z, n=n) - - if isinstance(line, LineData): - return ( - LineData( - data=filtered, - x_axis=line.x_axis.copy() if line.x_axis is not None else None, - x_unit=line.x_unit, - y_unit=line.y_unit, - ), - ) - return (filtered,) diff --git a/backend/nodes/filter_fft_2d.py b/backend/nodes/filter_fft_2d.py deleted file mode 100644 index 336b2d2..0000000 --- a/backend/nodes/filter_fft_2d.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations -import numpy as np -from backend.node_registry import register_node -from backend.data_types import DataField -from backend.nodes.helpers import _cached_2d_transfer - - -@register_node(display_name="FFT Filter 2D") -class FFTFilter2D: - """Frequency-domain filtering of 2-D data fields (images). - - Equivalent to Gwyddion's fft_filter_2d module. Applies a radial - Butterworth transfer function in the frequency domain to remove or - isolate periodic features. - """ - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "field": ("DATA_FIELD",), - "filter_type": (["lowpass", "highpass", "bandpass", "notch"],), - "cutoff": ("FLOAT", { - "default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001, - }), - "cutoff_high": ("FLOAT", { - "default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001, - }), - "order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), - } - } - - OUTPUTS = ( - ('DATA_FIELD', 'filtered'), - ) - FUNCTION = "process" - - DESCRIPTION = ( - "Frequency-domain filtering of a 2-D data field. " - "Supports lowpass, highpass, bandpass, and notch (band-reject) modes " - "with a radial Butterworth roll-off. Cutoffs are fractions of the " - "Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or " - "bandpass/notch to isolate or remove periodic noise. " - "Equivalent to Gwyddion fft_filter_2d." - ) - - def process(self, field: DataField, filter_type: str, cutoff: float, - cutoff_high: float, order: int) -> tuple: - data = field.data - yres, xres = data.shape - - mean_val = float(data.mean()) - centered = data - mean_val - - spectrum = np.fft.rfft2(centered) - transfer = _cached_2d_transfer( - yres, xres, filter_type, - float(cutoff), float(cutoff_high), int(order), - ) - result = np.fft.irfft2(spectrum * transfer, s=(yres, xres)) - result += mean_val - - return (field.replace(data=result),) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 06a2f87..7d4b27f 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -454,10 +454,15 @@ function socketTypesCompatible(sourceType, targetSpecOrType) { return socketSpecAcceptsType(sourceType, targetSpecOrType); } -function outputTypeCanConnectToTarget(outputType, targetSpecOrType) { +function outputTypeCanConnectToTarget(outputType, targetSpecOrType, outputAcceptedTypes = []) { if (socketTypesCompatible(outputType, targetSpecOrType)) { return true; } + // Polymorphic output: the output socket declares it can also produce the target type + if (outputAcceptedTypes.length > 0) { + const targetType = Array.isArray(targetSpecOrType) ? targetSpecOrType[0] : targetSpecOrType; + if (outputAcceptedTypes.includes(targetType)) return true; + } return outputType === 'ANNOTATION_SOURCE' && !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType) && ( @@ -674,8 +679,8 @@ function ContextMenu({ }); if (!hasMatch) continue; } else { - const hasMatch = def.output.some((type) => - outputTypeCanConnectToTarget(type, filterSpec || filterType) + const hasMatch = def.output.some((type, idx) => + outputTypeCanConnectToTarget(type, filterSpec || filterType, def.output_accepted_types?.[idx] || []) ); if (!hasMatch) continue; } @@ -1392,7 +1397,16 @@ function Flow() { const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle); const targetNode = reactFlow.getNode(resolvedTarget.nodeId); const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type; - return socketTypesCompatible(srcType, targetSpec); + if (socketTypesCompatible(srcType, targetSpec)) return true; + // Polymorphic output: check if the source output declares it can produce the target type + const srcProxy = parseGroupProxyHandle(connection.sourceHandle); + const srcNodeId = srcProxy ? srcProxy.nodeId : connection.source; + const srcHandleId = srcProxy ? srcProxy.realHandle : connection.sourceHandle; + const srcNode = reactFlow.getNode(srcNodeId); + const srcSlot = getOutputSlot(srcHandleId); + const srcAcceptedTypes = srcNode?.data?.definition?.output_accepted_types?.[srcSlot] || []; + const targetType = Array.isArray(targetSpec) ? targetSpec[0] : targetSpec; + return Array.isArray(srcAcceptedTypes) && srcAcceptedTypes.includes(targetType); }, [reactFlow]); const onConnect = useCallback((params) => { @@ -1765,8 +1779,8 @@ function Flow() { } } else { // Dragged from an input → connect from the first matching output on the new node - const outputIdx = def.output.findIndex((type) => - outputTypeCanConnectToTarget(type, filterSpec) + const outputIdx = def.output.findIndex((type, idx) => + outputTypeCanConnectToTarget(type, filterSpec, def.output_accepted_types?.[idx] || []) ); if (outputIdx !== -1) { const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec); diff --git a/tests/node_tests/filter_fft.py b/tests/node_tests/filter_fft.py new file mode 100644 index 0000000..f617b5f --- /dev/null +++ b/tests/node_tests/filter_fft.py @@ -0,0 +1,63 @@ +import numpy as np +from tests.node_tests._shared import make_field + + +def test_fft_filter_line(): + from backend.nodes.filter_fft import FFTFilter + node = FFTFilter() + + n = 256 + t = np.arange(n, dtype=np.float64) / n + low = np.sin(2 * np.pi * 3 * t) + high = np.sin(2 * np.pi * 80 * t) + line = low + high + + filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) + assert len(filtered_lp) == n + corr_low = np.corrcoef(filtered_lp, low)[0, 1] + corr_high = np.corrcoef(filtered_lp, high)[0, 1] + assert corr_low > 0.95 + assert abs(corr_high) < 0.3 + + filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) + assert abs(np.corrcoef(filtered_hp, low)[0, 1]) < 0.3 + assert np.corrcoef(filtered_hp, high)[0, 1] > 0.95 + + filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4) + assert abs(np.corrcoef(filtered_bp, low)[0, 1]) < 0.3 + assert np.corrcoef(filtered_bp, high)[0, 1] > 0.9 + + filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4) + assert np.corrcoef(filtered_notch, low)[0, 1] > 0.95 + assert abs(np.corrcoef(filtered_notch, high)[0, 1]) < 0.3 + + +def test_fft_filter_field(): + from backend.nodes.filter_fft import FFTFilter + from backend.data_types import DataField + node = FFTFilter() + + N = 128 + y, x = np.mgrid[0:N, 0:N] / N + low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y) + high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y) + data = low_2d + high_2d + field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6) + + result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) + assert isinstance(result_lp, DataField) + assert result_lp.data.shape == (N, N) + assert result_lp.xreal == field.xreal + assert result_lp.si_unit_z == field.si_unit_z + corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1] + corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1] + assert corr_low > 0.9 + assert abs(corr_high) < 0.3 + + result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) + assert abs(np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]) < 0.3 + assert np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] > 0.9 + + const = make_field(data=np.ones((32, 32)) * 7.0) + result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2) + assert np.allclose(result_const.data, 7.0, atol=1e-10) diff --git a/tests/node_tests/filter_fft_1d.py b/tests/node_tests/filter_fft_1d.py deleted file mode 100644 index 8b67a2c..0000000 --- a/tests/node_tests/filter_fft_1d.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np - - -def test_fft_filter_1d(): - from backend.nodes.filter_fft_1d import FFTFilter1D - node = FFTFilter1D() - - n = 256 - t = np.arange(n, dtype=np.float64) / n - low = np.sin(2 * np.pi * 3 * t) - high = np.sin(2 * np.pi * 80 * t) - line = low + high - - filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) - assert len(filtered_lp) == n - corr_low = np.corrcoef(filtered_lp, low)[0, 1] - corr_high = np.corrcoef(filtered_lp, high)[0, 1] - assert corr_low > 0.95 - assert abs(corr_high) < 0.3 - - filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) - corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1] - corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1] - assert abs(corr_low_hp) < 0.3 - assert corr_high_hp > 0.95 - - filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4) - assert abs(np.corrcoef(filtered_bp, low)[0, 1]) < 0.3 - assert np.corrcoef(filtered_bp, high)[0, 1] > 0.9 - - filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4) - assert np.corrcoef(filtered_notch, low)[0, 1] > 0.95 - assert abs(np.corrcoef(filtered_notch, high)[0, 1]) < 0.3 diff --git a/tests/node_tests/filter_fft_2d.py b/tests/node_tests/filter_fft_2d.py deleted file mode 100644 index ada9e2d..0000000 --- a/tests/node_tests/filter_fft_2d.py +++ /dev/null @@ -1,31 +0,0 @@ -import numpy as np -from tests.node_tests._shared import make_field - - -def test_fft_filter_2d(): - from backend.nodes.filter_fft_2d import FFTFilter2D - node = FFTFilter2D() - - N = 128 - y, x = np.mgrid[0:N, 0:N] / N - low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y) - high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y) - data = low_2d + high_2d - field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6) - - result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) - assert result_lp.data.shape == (N, N) - assert result_lp.xreal == field.xreal - assert result_lp.si_unit_z == field.si_unit_z - corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1] - corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1] - assert corr_low > 0.9 - assert abs(corr_high) < 0.3 - - result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) - assert abs(np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]) < 0.3 - assert np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] > 0.9 - - const = make_field(data=np.ones((32, 32)) * 7.0) - result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2) - assert np.allclose(result_const.data, 7.0, atol=1e-10) diff --git a/tests/test_grains.py b/tests/test_grains.py index 39c3c75..96ef98d 100644 --- a/tests/test_grains.py +++ b/tests/test_grains.py @@ -36,7 +36,7 @@ def test_threshold_otsu_bimodal(): data[70:100, 80:110] = 10.0 # another bright region field = make_field(data) - mask, = node.process(field, method="otsu", threshold=0.0, direction="above") + mask, table = node.process(field, method="otsu", threshold=0.0, direction="above") bright_pixels = (mask == 255) # Should capture both bright regions assert bright_pixels[40, 40], "Otsu missed bright region 1" @@ -57,7 +57,7 @@ def test_threshold_relative_range(): data[10:20, 10:20] = 8.0 # bright patch, range = [2, 8], midpoint = 5 field = make_field(data) - mask, = node.process(field, method="relative", threshold=0.5, direction="above") + mask, table = node.process(field, method="relative", threshold=0.5, direction="above") # Only the bright patch (value 8 >= 5) should be masked assert np.all(mask[10:20, 10:20] == 255) assert np.all(mask[0:10, :] == 0) @@ -74,7 +74,7 @@ def test_threshold_empty_mask(): data = np.ones((64, 64)) field = make_field(data) - mask, = node.process(field, method="absolute", threshold=999.0, direction="above") + mask, table = node.process(field, method="absolute", threshold=999.0, direction="above") assert mask.sum() == 0, "Mask should be completely empty" print(" PASS\n") @@ -88,7 +88,7 @@ def test_threshold_full_mask(): data = np.ones((64, 64)) * 5.0 field = make_field(data) - mask, = node.process(field, method="absolute", threshold=-1.0, direction="above") + mask, table = node.process(field, method="absolute", threshold=-1.0, direction="above") assert np.all(mask == 255), "Mask should be all white" print(" PASS\n") @@ -345,7 +345,7 @@ def test_pipeline_synthetic(): # Step 1: threshold thresh = ThresholdMask() - mask, = thresh.process(field, method="absolute", threshold=1.0, direction="above") + mask, table = thresh.process(field, method="absolute", threshold=1.0, direction="above") # Grains are well above noise, so mask should capture all 5 assert mask.max() == 255, "No grains detected" @@ -387,7 +387,7 @@ def test_pipeline_demo_image(): # Threshold to find grains (they are raised above background) thresh = ThresholdMask() - mask, = thresh.process(field, method="otsu", threshold=0.0, direction="above") + mask, table = thresh.process(field, method="otsu", threshold=0.0, direction="above") # Should detect grains assert mask.max() == 255, "No grains found in demo image" diff --git a/tests/test_nodes.py b/tests/test_nodes.py deleted file mode 100644 index 26731f7..0000000 --- a/tests/test_nodes.py +++ /dev/null @@ -1,3377 +0,0 @@ -""" -Tests for all argonode backend nodes (excluding FFT2D which has its own test file). - -Run from project root: - .venv/bin/python -m tests.test_nodes -""" -import json -import sys -import os -import tempfile -from pathlib import Path -import numpy as np - -sys.path.insert(0, ".") -from backend.data_types import DataField, LineData, RecordTable, DataTable, datafield_to_uint8, render_datafield_preview - - -def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6): - """Create a DataField, optionally from given data or a random field.""" - if data is None: - data = np.random.default_rng(42).standard_normal(shape) - return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m") - - -# ========================================================================= -# Filters -# ========================================================================= - -def test_gaussian_filter(): - print("=== Test: GaussianFilter ===") - from backend.nodes.filter_gaussian import GaussianFilter - node = GaussianFilter() - field = make_field() - - result, = node.process(field, sigma=2.0) - assert result.data.shape == field.data.shape - assert result.xreal == field.xreal - assert result.si_unit_z == field.si_unit_z - # Gaussian blur should reduce variance - assert result.data.std() < field.data.std() - # With very small sigma, output should be nearly unchanged - result_tiny, = node.process(field, sigma=0.01) - assert np.allclose(result_tiny.data, field.data, atol=1e-6) - print(" PASS\n") - - -def test_median_filter(): - print("=== Test: MedianFilter ===") - from backend.nodes.filter_median import MedianFilter - node = MedianFilter() - - # Median filter should remove salt-and-pepper noise - data = np.zeros((64, 64)) - rng = np.random.default_rng(7) - noise_idx = rng.choice(64 * 64, size=100, replace=False) - data.ravel()[noise_idx] = 1.0 - field = make_field(data=data) - - result, = node.process(field, size=3) - assert result.data.shape == field.data.shape - # Should remove most impulse noise - assert result.data.sum() < field.data.sum() - # Size=1 should be identity - result_1, = node.process(field, size=1) - assert np.array_equal(result_1.data, field.data) - print(" PASS\n") - - -def test_crop_resize_field(): - print("=== Test: CropResizeField ===") - from backend.nodes.crop_resize import CropResizeField - node = CropResizeField() - - data = np.arange(32, dtype=np.float64).reshape(4, 8) - field = DataField( - data=data, - xreal=8.0, - yreal=4.0, - xoff=10.0, - yoff=20.0, - si_unit_xy="nm", - si_unit_z="nm", - overlays=[{"kind": "markup", "shapes": [{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 2, "color": "#ffffff"}]}], - ) - - overlays = [] - CropResizeField._broadcast_overlay_fn = lambda nid, data: overlays.append(data) - CropResizeField._current_node_id = "test" - - cropped, = node.process( - field, - x1=0.25, - y1=0.25, - x2=0.75, - y2=1.0, - target_width=0, - target_height=0, - interpolation="bilinear", - ) - assert cropped.data.shape == (3, 4) - assert np.array_equal(cropped.data, data[1:4, 2:6]) - assert cropped.xreal == 4.0 - assert cropped.yreal == 3.0 - assert cropped.xoff == 12.0 - assert cropped.yoff == 21.0 - assert cropped.si_unit_xy == field.si_unit_xy - assert cropped.si_unit_z == field.si_unit_z - assert cropped.overlays == [] - assert len(overlays) == 1 - assert overlays[0]["kind"] == "crop_box" - assert overlays[0]["image"].startswith("data:image/png;base64,") - assert overlays[0]["a_locked"] is False - assert overlays[0]["b_locked"] is False - - resized, = node.process( - field, - x1=0.0, - y1=0.0, - x2=1.0, - y2=1.0, - target_width=8, - target_height=0, - interpolation="bilinear", - corner_a=(0.25, 0.25), - corner_b=(0.75, 1.0), - ) - assert resized.data.shape == (6, 8) - assert resized.xreal == cropped.xreal - assert resized.yreal == cropped.yreal - assert resized.xoff == cropped.xoff - assert resized.yoff == cropped.yoff - assert resized.domain == field.domain - assert overlays[-1]["a_locked"] is True - assert overlays[-1]["b_locked"] is True - - reversed_crop, = node.process( - field, - x1=0.75, - y1=1.0, - x2=0.25, - y2=0.25, - target_width=0, - target_height=0, - interpolation="nearest", - ) - assert np.array_equal(reversed_crop.data, cropped.data) - - try: - node.process( - field, - x1=0.9, - y1=0.0, - x2=0.9, - y2=1.0, - target_width=0, - target_height=0, - interpolation="nearest", - ) - raise AssertionError("Expected invalid crop bounds to raise ValueError") - except ValueError: - pass - - CropResizeField._broadcast_overlay_fn = None - - print(" PASS\n") - - -def test_rotate_field(): - print("=== Test: RotateField ===") - from backend.nodes.rotate import RotateField - node = RotateField() - - data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) - field = DataField( - data=data, - xreal=6.0, - yreal=4.0, - xoff=10.0, - yoff=20.0, - si_unit_xy="nm", - si_unit_z="nm", - ) - - rotated_90, = node.process( - field, - angle=90.0, - interpolation="nearest", - expand_canvas=True, - ) - assert np.array_equal(rotated_90.data, np.rot90(data)) - assert rotated_90.data.shape == (3, 2) - assert rotated_90.xreal == 4.0 - assert rotated_90.yreal == 6.0 - assert rotated_90.xoff == 11.0 - assert rotated_90.yoff == 19.0 - assert rotated_90.si_unit_xy == field.si_unit_xy - assert rotated_90.si_unit_z == field.si_unit_z - assert rotated_90.overlays == [] - - rotated_180, = node.process( - field, - angle=180.0, - interpolation="nearest", - expand_canvas=False, - ) - assert np.array_equal(rotated_180.data, np.rot90(data, 2)) - assert rotated_180.data.shape == data.shape - assert rotated_180.xreal == field.xreal - assert rotated_180.yreal == field.yreal - assert rotated_180.xoff == field.xoff - assert rotated_180.yoff == field.yoff - - rotated_45, = node.process( - field, - angle=45.0, - interpolation="bilinear", - expand_canvas=True, - ) - expected_xreal = abs(field.xreal * np.cos(np.deg2rad(45.0))) + abs(field.yreal * np.sin(np.deg2rad(45.0))) - expected_yreal = abs(field.xreal * np.sin(np.deg2rad(45.0))) + abs(field.yreal * np.cos(np.deg2rad(45.0))) - assert rotated_45.data.shape[0] > field.data.shape[0] - assert rotated_45.data.shape[1] > field.data.shape[1] - assert np.isclose(rotated_45.xreal, expected_xreal) - assert np.isclose(rotated_45.yreal, expected_yreal) - assert np.isclose(rotated_45.xoff + rotated_45.xreal / 2.0, field.xoff + field.xreal / 2.0) - assert np.isclose(rotated_45.yoff + rotated_45.yreal / 2.0, field.yoff + field.yreal / 2.0) - - print(" PASS\n") - - -def test_rotate_field_overlay_warning(): - print("=== Test: RotateField overlay warning ===") - from backend.nodes.rotate import RotateField - - node = RotateField() - warnings = [] - RotateField._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) - RotateField._current_node_id = "test" - - field = DataField( - data=np.arange(16, dtype=np.float64).reshape(4, 4), - overlays=[{"kind": "markup", "shapes": [{"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 2, "color": "#ffffff"}]}], - ) - - rotated, = node.process( - field, - angle=30.0, - interpolation="bilinear", - expand_canvas=True, - ) - assert rotated.overlays == [] - assert len(warnings) == 1 - assert "clears annotation/markup overlays" in warnings[0] - - RotateField._broadcast_warning_fn = None - print(" PASS\n") - - -def test_flip_field(): - print("=== Test: FlipField ===") - from backend.nodes.flip import FlipField - from backend.node_registry import get_node_info - - node = FlipField() - data = np.arange(1, 10, dtype=np.float64).reshape(3, 3) - markup_overlay = { - "kind": "markup", - "shapes": [ - {"kind": "line", "x1": 0.1, "y1": 0.2, "x2": 0.9, "y2": 0.8, "width": 2, "color": "#ffffff"}, - {"kind": "rectangle", "x1": 0.15, "y1": 0.1, "x2": 0.45, "y2": 0.6, "width": 3, "color": "#ff0000"}, - ], - } - annotation_overlay = { - "kind": "annotation", - "show_scale_bar": True, - "show_color_map": False, - "text_size": 14.0, - } - field = DataField( - data=data, - xreal=3.0, - yreal=4.0, - xoff=10.0, - yoff=20.0, - si_unit_xy="nm", - si_unit_z="nm", - overlays=[markup_overlay, annotation_overlay], - ) - - assert get_node_info("FlipField")["category"] == "Geometry" - - flipped_x, = node.process(field, axis="x") - assert np.array_equal(flipped_x.data, np.flipud(data)) - assert flipped_x.xreal == field.xreal - assert flipped_x.yreal == field.yreal - assert flipped_x.xoff == field.xoff - assert flipped_x.yoff == field.yoff - assert flipped_x.si_unit_xy == field.si_unit_xy - assert flipped_x.si_unit_z == field.si_unit_z - assert np.isclose(flipped_x.overlays[0]["shapes"][0]["x1"], 0.1) - assert np.isclose(flipped_x.overlays[0]["shapes"][0]["y1"], 0.8) - assert np.isclose(flipped_x.overlays[0]["shapes"][0]["x2"], 0.9) - assert np.isclose(flipped_x.overlays[0]["shapes"][0]["y2"], 0.2) - assert np.isclose(flipped_x.overlays[0]["shapes"][1]["x1"], 0.15) - assert np.isclose(flipped_x.overlays[0]["shapes"][1]["y1"], 0.4) - assert np.isclose(flipped_x.overlays[0]["shapes"][1]["x2"], 0.45) - assert np.isclose(flipped_x.overlays[0]["shapes"][1]["y2"], 0.9) - assert flipped_x.overlays[1] == annotation_overlay - - flipped_y, = node.process(field, axis="y") - assert np.array_equal(flipped_y.data, np.fliplr(data)) - assert np.isclose(flipped_y.overlays[0]["shapes"][0]["x1"], 0.9) - assert np.isclose(flipped_y.overlays[0]["shapes"][0]["y1"], 0.2) - assert np.isclose(flipped_y.overlays[0]["shapes"][0]["x2"], 0.1) - assert np.isclose(flipped_y.overlays[0]["shapes"][0]["y2"], 0.8) - assert np.isclose(flipped_y.overlays[0]["shapes"][1]["x1"], 0.55) - assert np.isclose(flipped_y.overlays[0]["shapes"][1]["y1"], 0.1) - assert np.isclose(flipped_y.overlays[0]["shapes"][1]["x2"], 0.85) - assert np.isclose(flipped_y.overlays[0]["shapes"][1]["y2"], 0.6) - assert flipped_y.overlays[1] == annotation_overlay - - assert field.overlays[0]["shapes"][0]["x1"] == markup_overlay["shapes"][0]["x1"] - assert field.overlays[0]["shapes"][0]["y1"] == markup_overlay["shapes"][0]["y1"] - - try: - node.process(field, axis="diagonal") - raise AssertionError("Expected invalid flip axis to raise ValueError") - except ValueError: - pass - - print(" PASS\n") - - -def test_view3d_normalizes_small_physical_extents_for_display(): - print("=== Test: View3D extent normalization ===") - from backend.nodes.view_3d import View3D - - data = np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64) - field = DataField( - data=data, - xreal=1.0e-5, - yreal=1.0e-5, - si_unit_xy="m", - si_unit_z="m", - ) - - node = View3D() - mesh, _ = node.render(field, colormap="auto", z_scale=1.0, resolution=64, make_solid=False) - - vertices = np.asarray(mesh.vertices, dtype=np.float64) - spans = vertices.max(axis=0) - vertices.min(axis=0) - - assert np.isclose(spans[0], 1.0, atol=1e-6) - assert np.isclose(spans[2], 1.0, atol=1e-6) - assert spans[1] > 0.09 - print(" PASS\n") - - -def test_colormap_adjust(): - print("=== Test: ColormapAdjust ===") - from backend.nodes.colormap_adjust import ColormapAdjust - - node = ColormapAdjust() - field = DataField( - data=np.array([[0.0, 0.25, 0.5, 0.75, 1.0]], dtype=np.float64), - xreal=5.0, - yreal=1.0, - colormap="gray", - ) - - adjusted, = node.process(field, offset=0.25, scale=0.5) - assert np.array_equal(adjusted.data, field.data) - assert adjusted.display_offset == 0.25 - assert adjusted.display_scale == 0.5 - assert adjusted.colormap == field.colormap - - rgb = datafield_to_uint8(adjusted, "gray") - intensities = rgb[0, :, 0] - assert intensities[0] == 0 - assert intensities[1] == 0 - assert 110 <= intensities[2] <= 145 - assert intensities[3] == 255 - assert intensities[4] == 255 - - auto_like, = node.process(field, offset=0.0, scale=1.0) - auto_rgb = datafield_to_uint8(auto_like, "gray") - auto_intensities = auto_rgb[0, :, 0] - assert auto_intensities[0] == 0 - assert auto_intensities[-1] == 255 - - try: - node.process(field, offset=0.0, scale=0.0) - raise AssertionError("Expected non-positive scale to raise ValueError") - except ValueError: - pass - - print(" PASS\n") - - -def test_edge_detect(): - print("=== Test: EdgeDetect ===") - from backend.nodes.edge_detect import EdgeDetect - node = EdgeDetect() - - # Create an image with a sharp vertical edge - data = np.zeros((64, 64)) - data[:, 32:] = 1.0 - field = make_field(data=data) - - for method in ["sobel", "prewitt", "laplacian", "log"]: - result, = node.process(field, method=method, sigma=1.0) - assert result.data.shape == field.data.shape - # Edge response should be strongest near column 32 - col_energy = np.abs(result.data).sum(axis=0) - peak_col = np.argmax(col_energy) - assert abs(peak_col - 32) <= 2, f"{method}: peak at col {peak_col}, expected ~32" - - print(" PASS\n") - - -def test_fft_filter_1d(): - print("=== Test: FFTFilter1D ===") - from backend.nodes.filter_fft_1d import FFTFilter1D - node = FFTFilter1D() - - # Signal: low-frequency sine + high-frequency sine - n = 256 - t = np.arange(n, dtype=np.float64) / n - low = np.sin(2 * np.pi * 3 * t) # 3 cycles — low freq - high = np.sin(2 * np.pi * 80 * t) # 80 cycles — high freq - line = low + high - - # Lowpass should keep low, suppress high - filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) - assert len(filtered_lp) == n - corr_low = np.corrcoef(filtered_lp, low)[0, 1] - corr_high = np.corrcoef(filtered_lp, high)[0, 1] - assert corr_low > 0.95, f"Lowpass: correlation with low={corr_low}" - assert abs(corr_high) < 0.3, f"Lowpass: correlation with high={corr_high}" - - # Highpass should keep high, suppress low - filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) - corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1] - corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1] - assert abs(corr_low_hp) < 0.3, f"Highpass: correlation with low={corr_low_hp}" - assert corr_high_hp > 0.95, f"Highpass: correlation with high={corr_high_hp}" - - # Bandpass centred on the high frequency - filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4) - corr_low_bp = np.corrcoef(filtered_bp, low)[0, 1] - corr_high_bp = np.corrcoef(filtered_bp, high)[0, 1] - assert abs(corr_low_bp) < 0.3, f"Bandpass: correlation with low={corr_low_bp}" - assert corr_high_bp > 0.9, f"Bandpass: correlation with high={corr_high_bp}" - - # Notch (band-reject) centred on the high frequency — should remove it - filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4) - corr_low_notch = np.corrcoef(filtered_notch, low)[0, 1] - corr_high_notch = np.corrcoef(filtered_notch, high)[0, 1] - assert corr_low_notch > 0.95, f"Notch: correlation with low={corr_low_notch}" - assert abs(corr_high_notch) < 0.3, f"Notch: correlation with high={corr_high_notch}" - - print(" PASS\n") - - -def test_fft_filter_2d(): - print("=== Test: FFTFilter2D ===") - from backend.nodes.filter_fft_2d import FFTFilter2D - node = FFTFilter2D() - - N = 128 - y, x = np.mgrid[0:N, 0:N] / N - # Low-frequency 2D pattern + high-frequency pattern - low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y) - high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y) - data = low_2d + high_2d - field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6) - - # Lowpass — should preserve low, remove high - result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4) - assert result_lp.data.shape == (N, N) - assert result_lp.xreal == field.xreal - assert result_lp.si_unit_z == field.si_unit_z - corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1] - corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1] - assert corr_low > 0.9, f"2D lowpass: correlation with low={corr_low}" - assert abs(corr_high) < 0.3, f"2D lowpass: correlation with high={corr_high}" - - # Highpass — should preserve high, remove low - result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4) - corr_low_hp = np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1] - corr_high_hp = np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] - assert abs(corr_low_hp) < 0.3, f"2D highpass: correlation with low={corr_low_hp}" - assert corr_high_hp > 0.9, f"2D highpass: correlation with high={corr_high_hp}" - - # Constant field should be unchanged by lowpass (DC preservation) - const = make_field(data=np.ones((32, 32)) * 7.0) - result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2) - assert np.allclose(result_const.data, 7.0, atol=1e-10), "Lowpass should preserve constant field" - - print(" PASS\n") - - -# ========================================================================= -# Level -# ========================================================================= - -def test_plane_level(): - print("=== Test: PlaneLevelField ===") - from backend.nodes.level_plane import PlaneLevelField - node = PlaneLevelField() - - # Create a tilted plane + small signal - N = 64 - y, x = np.mgrid[0:N, 0:N] / N - signal = np.sin(2 * np.pi * 5 * x) - data = 100 * x + 50 * y + signal - field = make_field(data=data) - - result, = node.process(field) - assert result.data.shape == field.data.shape - # After plane leveling, mean should be near zero - assert abs(result.data.mean()) < 1e-10 - # The signal should remain (correlation with original sine) - corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1] - assert corr > 0.98, f"Signal correlation after leveling: {corr}" - - yy_px, xx_px = np.mgrid[0:N, 0:N] - - def fit_pixel_plane(data_in: np.ndarray, region: np.ndarray) -> tuple[float, float, float]: - A = np.column_stack([ - np.ones(int(np.count_nonzero(region)), dtype=np.float64), - xx_px[region].astype(np.float64), - yy_px[region].astype(np.float64), - ]) - coeffs, _, _, _ = np.linalg.lstsq(A, data_in[region].ravel().astype(np.float64), rcond=None) - return float(coeffs[0]), float(coeffs[1]), float(coeffs[2]) - - mask = np.zeros((N, N), dtype=np.uint8) - mask[20:44, 22:46] = 255 - feature = np.zeros((N, N), dtype=np.float64) - feature[mask > 0] = 35.0 - masked_field = make_field(data=100 * x + 50 * y + feature) - - unmasked, = node.process(masked_field) - masked, = node.process(masked_field, masking="exclude", mask=mask) - - outside = mask == 0 - _, unmasked_bx, unmasked_by = fit_pixel_plane(unmasked.data, outside) - _, masked_bx, masked_by = fit_pixel_plane(masked.data, outside) - assert np.hypot(masked_bx, masked_by) < np.hypot(unmasked_bx, unmasked_by) * 1e-3 - print(" PASS\n") - - -def test_facet_level(): - print("=== Test: FacetLevelField ===") - from backend.node_registry import get_node_info - from backend.nodes.level_facet import FacetLevelField - from backend.nodes.level_plane import PlaneLevelField - - def fit_pixel_plane(data: np.ndarray, region: np.ndarray) -> tuple[float, float, float]: - yy, xx = np.mgrid[0:data.shape[0], 0:data.shape[1]] - A = np.column_stack([ - np.ones(int(np.count_nonzero(region)), dtype=np.float64), - xx[region].astype(np.float64), - yy[region].astype(np.float64), - ]) - coeffs, _, _, _ = np.linalg.lstsq(A, data[region].ravel().astype(np.float64), rcond=None) - return float(coeffs[0]), float(coeffs[1]), float(coeffs[2]) - - node = FacetLevelField() - plane_node = PlaneLevelField() - assert get_node_info("FacetLevelField")["category"] == "Level & Correct" - - N = 96 - yy, xx = np.mgrid[0:N, 0:N] - base = 0.055 * xx + 0.028 * yy - terraces = np.zeros((N, N), dtype=np.float64) - terraces[:, 54:] += 6.0 - terraces[18:70, 68:88] += 3.5 - field = make_field(data=base + terraces) - - plane_leveled, = plane_node.process(field) - facet_leveled, = node.process(field, masking="ignore") - - left_region = xx < 48 - right_region = (xx > 60) & ~((yy >= 18) & (yy < 70) & (xx >= 68) & (xx < 88)) - _, plane_left_bx, plane_left_by = fit_pixel_plane(plane_leveled.data, left_region) - _, plane_right_bx, plane_right_by = fit_pixel_plane(plane_leveled.data, right_region) - _, facet_left_bx, facet_left_by = fit_pixel_plane(facet_leveled.data, left_region) - _, facet_right_bx, facet_right_by = fit_pixel_plane(facet_leveled.data, right_region) - plane_slope = float(max(np.hypot(plane_left_bx, plane_left_by), np.hypot(plane_right_bx, plane_right_by))) - facet_slope = float(max(np.hypot(facet_left_bx, facet_left_by), np.hypot(facet_right_bx, facet_right_by))) - assert facet_slope < plane_slope * 1e-6 - - mask = np.zeros((N, N), dtype=np.uint8) - mask[24:72, 24:72] = 255 - base_only = 0.035 * xx + 0.014 * yy - masked_facet = 5.0 - 0.065 * xx + 0.045 * yy - competing = np.where(mask > 0, masked_facet, base_only) - competing_field = make_field(data=competing) - - excluded, = node.process(competing_field, masking="exclude", mask=mask) - included, = node.process(competing_field, masking="include", mask=mask) - - outer_region = (mask == 0) & (xx > 4) & (xx < N - 4) & (yy > 4) & (yy < N - 4) - inner_region = (mask > 0) & (xx > 28) & (xx < 68) & (yy > 28) & (yy < 68) - _, excl_outer_bx, excl_outer_by = fit_pixel_plane(excluded.data, outer_region) - _, excl_inner_bx, excl_inner_by = fit_pixel_plane(excluded.data, inner_region) - _, incl_outer_bx, incl_outer_by = fit_pixel_plane(included.data, outer_region) - _, incl_inner_bx, incl_inner_by = fit_pixel_plane(included.data, inner_region) - - excl_outer_slope = float(np.hypot(excl_outer_bx, excl_outer_by)) - excl_inner_slope = float(np.hypot(excl_inner_bx, excl_inner_by)) - incl_outer_slope = float(np.hypot(incl_outer_bx, incl_outer_by)) - incl_inner_slope = float(np.hypot(incl_inner_bx, incl_inner_by)) - assert excl_outer_slope < incl_outer_slope * 0.2 - assert incl_inner_slope < excl_inner_slope * 0.2 - - bad_units = DataField(data=np.zeros((16, 16), dtype=np.float64), xreal=1e-6, yreal=1e-6, si_unit_xy="nm", si_unit_z="V") - try: - node.process(bad_units, masking="ignore") - except ValueError as exc: - assert "compatible XY and Z units" in str(exc) - else: - assert False, "Facet level should reject incompatible XY/Z units." - print(" PASS\n") - - -def test_poly_level(): - print("=== Test: PolyLevelField ===") - from backend.nodes.level_poly import PolyLevelField - node = PolyLevelField() - - N = 64 - y, x = np.mgrid[0:N, 0:N] / N - # Quadratic background + signal - background = 50 * x**2 + 30 * y**2 + 10 * x * y - signal = np.sin(2 * np.pi * 8 * x) - data = background + signal - field = make_field(data=data) - - leveled, bg = node.process(field, degree_x=2, degree_y=2) - assert leveled.data.shape == field.data.shape - assert bg.data.shape == field.data.shape - # leveled + bg should reconstruct original - assert np.allclose(leveled.data + bg.data, field.data, atol=1e-10) - # Signal should be preserved after leveling - corr = np.corrcoef(leveled.data.ravel(), signal.ravel())[0, 1] - assert corr > 0.95, f"Signal correlation after poly leveling: {corr}" - # Degree 0 should just subtract the mean - leveled_0, bg_0 = node.process(field, degree_x=0, degree_y=0) - assert abs(leveled_0.data.mean()) < 1e-10 - print(" PASS\n") - - -def test_fix_zero(): - print("=== Test: FixZero ===") - from backend.nodes.fix_zero import FixZero - node = FixZero() - field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64)) - - result_min, = node.process(field, method="min") - assert result_min.data.min() == 0.0 - assert result_min.data.max() == 30.0 - - result_mean, = node.process(field, method="mean") - assert abs(result_mean.data.mean()) < 1e-10 - - result_median, = node.process(field, method="median") - assert abs(np.median(result_median.data)) < 1e-10 - print(" PASS\n") - - -def test_curvature(): - print("=== Test: Curvature ===") - from backend.node_registry import get_node_info - from backend.execution_context import active_node, execution_callbacks - from backend.nodes.curvature import Curvature - - node = Curvature() - assert get_node_info("Curvature")["category"] == "Measure" - - xres, yres = 121, 101 - xreal, yreal = 8.0e-6, 6.0e-6 - xoff, yoff = 1.0e-6, -0.5e-6 - x = np.linspace(xoff, xoff + xreal, xres, dtype=np.float64) - y = np.linspace(yoff, yoff + yreal, yres, dtype=np.float64) - yy, xx = np.meshgrid(y, x, indexing="ij") - - x0 = xoff + 0.45 * xreal - y0 = yoff + 0.60 * yreal - rx = 1.2e-6 - ry = 2.4e-6 - z0 = 3.0e-9 - data = z0 + (xx - x0) ** 2 / (2.0 * rx) + (yy - y0) ** 2 / (2.0 * ry) - field = DataField(data=data, xreal=xreal, yreal=yreal, xoff=xoff, yoff=yoff, si_unit_xy="m", si_unit_z="m") - - previews = [] - tables = [] - with execution_callbacks(preview=lambda nid, uri: previews.append(uri), table=lambda nid, rows: tables.append(rows)), active_node("test"): - output, table, profile1, profile2 = node.process(field, masking="ignore") - - rows = {row["quantity"]: row for row in table} - recovered_radii = sorted([rows["Curvature radius 1"]["value"], rows["Curvature radius 2"]["value"]]) - expected_radii = sorted([rx, ry]) - assert len(previews) == 1 - assert isinstance(previews[0], dict) and previews[0].get("kind") == "panels" - assert len(tables) == 1 - assert abs(rows["Center x position"]["value"] - x0) < xreal * 0.02 - assert abs(rows["Center y position"]["value"] - y0) < yreal * 0.02 - assert abs(rows["Center value"]["value"] - z0) < 5e-11 - assert np.allclose(recovered_radii, expected_radii, rtol=0.08, atol=5e-8) - assert output.overlays[-1]["kind"] == "markup" - assert len(output.overlays[-1]["shapes"]) == 3 - assert isinstance(profile1, LineData) - assert isinstance(profile2, LineData) - assert profile1.x_unit == field.si_unit_xy - assert profile1.y_unit == field.si_unit_z - assert profile2.x_unit == field.si_unit_xy - assert profile2.y_unit == field.si_unit_z - assert len(profile1) > 10 - assert len(profile2) > 10 - - mask = np.zeros((yres, xres), dtype=np.uint8) - mask[:, :xres // 2] = 255 - left = 1.0e-9 + (xx - (xoff + 0.25 * xreal)) ** 2 / (2.0 * 0.9e-6) + (yy - (yoff + 0.5 * yreal)) ** 2 / (2.0 * 1.8e-6) - right = 2.0e-9 + (xx - (xoff + 0.75 * xreal)) ** 2 / (2.0 * 1.6e-6) + (yy - (yoff + 0.5 * yreal)) ** 2 / (2.0 * 3.2e-6) - split_field = DataField(data=np.where(mask > 0, left, right), xreal=xreal, yreal=yreal, xoff=xoff, yoff=yoff, si_unit_xy="m", si_unit_z="m") - _, include_table, _, _ = node.process(split_field, masking="include", mask=mask) - _, exclude_table, _, _ = node.process(split_field, masking="exclude", mask=mask) - include_radii = sorted([row["value"] for row in include_table if row["quantity"].startswith("Curvature radius")]) - exclude_radii = sorted([row["value"] for row in exclude_table if row["quantity"].startswith("Curvature radius")]) - assert np.allclose(include_radii, [0.9e-6, 1.8e-6], rtol=0.12, atol=5e-8) - assert np.allclose(exclude_radii, [1.6e-6, 3.2e-6], rtol=0.12, atol=5e-8) - - bad_units = DataField(data=np.zeros((16, 16), dtype=np.float64), xreal=1e-6, yreal=1e-6, si_unit_xy="nm", si_unit_z="V") - try: - node.process(bad_units, masking="ignore") - except ValueError as exc: - assert "compatible XY and Z units" in str(exc) - else: - assert False, "Curvature should reject incompatible XY/Z units." - print(" PASS\n") - - -def test_curvature_flat_surface(): - """A perfectly flat surface has zero curvature — both radii must be float('inf').""" - print("=== Test: Curvature (flat surface → inf radii) ===") - from backend.execution_context import active_node, execution_callbacks - from backend.nodes.curvature import Curvature - - node = Curvature() - data = np.zeros((64, 64), dtype=np.float64) - field = DataField(data=data, xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m") - - warnings = [] - tables = [] - with execution_callbacks( - preview=lambda nid, v: None, - table=lambda nid, rows: tables.append(rows), - warning=lambda nid, msg: warnings.append(msg), - ), active_node("test"): - _, table, _, _ = node.process(field, masking="ignore") - - rows = {row["quantity"]: row for row in table} - assert rows["Curvature radius 1"]["value"] == float("inf") - assert rows["Curvature radius 2"]["value"] == float("inf") - # No warnings expected for a valid (flat) surface - assert len(warnings) == 0 - print(" PASS\n") - - -def test_curvature_cylindrical(): - """A cylindrical surface is curved in one direction only — one radius finite, one inf.""" - print("=== Test: Curvature (cylindrical → one inf radius) ===") - from backend.execution_context import active_node, execution_callbacks - from backend.nodes.curvature import Curvature - - node = Curvature() - N = 64 - xreal = yreal = 1e-6 - x = np.linspace(-xreal / 2, xreal / 2, N, dtype=np.float64) - xx = np.broadcast_to(x, (N, N)) - r_x = 0.8e-6 - # Curved parabolically in x, flat in y - data = xx**2 / (2.0 * r_x) - field = DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m") - - tables = [] - with execution_callbacks( - preview=lambda nid, v: None, - table=lambda nid, rows: tables.append(rows), - ), active_node("test"): - _, table, _, _ = node.process(field, masking="ignore") - - rows = {row["quantity"]: row for row in table} - radii = sorted([rows["Curvature radius 1"]["value"], rows["Curvature radius 2"]["value"]]) - # One radius should be finite (≈ r_x), the other infinite - finite = [r for r in radii if np.isfinite(r)] - infinite = [r for r in radii if not np.isfinite(r)] - assert len(finite) == 1, f"Expected 1 finite radius, got {radii}" - assert len(infinite) == 1, f"Expected 1 inf radius, got {radii}" - assert abs(finite[0] - r_x) < r_x * 0.1, f"Finite radius {finite[0]} far from expected {r_x}" - print(" PASS\n") - - -def test_curvature_too_few_pixels(): - """Curvature with fewer than 6 valid pixels emits a warning and returns an empty table.""" - print("=== Test: Curvature (too few valid pixels) ===") - from backend.execution_context import active_node, execution_callbacks - from backend.nodes.curvature import Curvature - - node = Curvature() - N = 16 - data = np.random.default_rng(0).standard_normal((N, N)) - field = DataField(data=data, xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m") - - # Mask with only 4 'include' pixels — below the 6-pixel minimum - mask = np.zeros((N, N), dtype=np.uint8) - mask[N // 2, N // 2:N // 2 + 4] = 255 - - warnings = [] - tables = [] - with execution_callbacks( - preview=lambda nid, v: None, - table=lambda nid, rows: tables.append(rows), - warning=lambda nid, msg: warnings.append(msg), - ), active_node("test"): - _, table, profile1, profile2 = node.process(field, masking="include", mask=mask) - - assert len(warnings) == 1 - assert "six" in warnings[0].lower() or "6" in warnings[0] - assert len(list(table)) == 0 - # Empty profiles are returned - assert len(profile1.data) == 0 - assert len(profile2.data) == 0 - print(" PASS\n") - - -def test_curvature_inf_json_safe(): - """inf radii from curvature must not produce invalid JSON when sent over the wire.""" - print("=== Test: Curvature (inf radii → valid JSON via server sanitizer) ===") - import json - from backend.server import _sanitize_non_finite, _dumps - - # Simulate a table row as produced by the curvature node for a flat surface - rows = [ - {"quantity": "Curvature radius 1", "value": float("inf"), "unit": "m"}, - {"quantity": "Curvature radius 2", "value": float("-inf"), "unit": "m"}, - {"quantity": "Center value", "value": float("nan"), "unit": "m"}, - {"quantity": "Center x position", "value": 1.5e-7, "unit": "m"}, - ] - - sanitized = _sanitize_non_finite(rows) - assert sanitized[0]["value"] == "∞" - assert sanitized[1]["value"] == "-∞" - assert sanitized[2]["value"] == "NaN" - assert sanitized[3]["value"] == 1.5e-7 # finite float unchanged - - # Must not raise and must produce parseable JSON - payload = _dumps({"type": "table", "data": {"node_id": "n1", "rows": sanitized}}) - decoded = json.loads(payload) - assert decoded["data"]["rows"][0]["value"] == "∞" - print(" PASS\n") - - -def test_line_correction(): - print("=== Test: LineCorrection ===") - from backend.node_registry import get_node_info - from backend.nodes.line_correction import LineCorrection - - node = LineCorrection() - assert get_node_info("LineCorrection")["category"] == "Level & Correct" - - rows = 96 - cols = 128 - y = np.linspace(0.0, 1.0, rows, dtype=np.float64) - x = np.linspace(-1.0, 1.0, cols, dtype=np.float64) - signal = ( - 0.15 * np.sin(8.0 * np.pi * x)[None, :] - + 0.05 * np.cos(4.0 * np.pi * y)[:, None] - ) - row_offsets = 1.5 * np.sin(3.0 * np.pi * y) + 0.25 * np.cos(7.0 * np.pi * y) - field = make_field( - data=signal + row_offsets[:, None], - xreal=2.5e-6, - yreal=1.5e-6, - ) - - corrected, background, shifts = node.process( - field, - method="median", - direction="horizontal", - masking="ignore", - trim_fraction=0.05, - polynomial_degree=1, - ) - expected_shifts = row_offsets - row_offsets.mean() - assert corrected.data.shape == field.data.shape - assert background.data.shape == field.data.shape - assert np.allclose(corrected.data + background.data, field.data) - assert isinstance(shifts, LineData) - assert shifts.x_unit == field.si_unit_xy - assert shifts.y_unit == field.si_unit_z - assert np.isclose(shifts.x_axis[0], 0.0) - assert np.isclose(shifts.x_axis[-1], field.yreal) - assert np.corrcoef(shifts.data, expected_shifts)[0, 1] > 0.999 - assert corrected.data.mean(axis=1).std() < field.data.mean(axis=1).std() * 0.03 - - poly_background = ( - row_offsets[:, None] - + (0.35 * y - 0.15)[:, None] * x[None, :] - + (0.10 + 0.05 * y)[:, None] * (x[None, :] ** 2) - ) - poly_signal = 0.08 * np.sin(10.0 * np.pi * x)[None, :] * (1.0 + 0.15 * np.cos(2.0 * np.pi * y)[:, None]) - poly_field = make_field(data=poly_signal + poly_background) - - leveled, poly_bg, poly_shifts = node.process( - poly_field, - method="polynomial", - direction="horizontal", - masking="ignore", - trim_fraction=0.05, - polynomial_degree=2, - ) - assert np.allclose(leveled.data + poly_bg.data, poly_field.data) - assert np.corrcoef(leveled.data.ravel(), poly_signal.ravel())[0, 1] > 0.995 - assert len(poly_shifts) == rows - - print(" PASS\n") - - -def test_scar_removal(): - print("=== Test: ScarRemoval ===") - from backend.node_registry import get_node_info - from backend.nodes.scar_removal import ScarRemoval - - node = ScarRemoval() - info = get_node_info("ScarRemoval") - assert info["category"] == "Filter" - assert {entry["category"] for entry in info["menu_categories"]} == {"Filter", "Level & Correct"} - - rows = 96 - cols = 128 - yy, xx = np.mgrid[0:rows, 0:cols] - base = ( - 0.005 * xx - + 0.01 * yy - + 0.12 * np.sin(2.0 * np.pi * xx / cols) - + 0.07 * np.cos(2.0 * np.pi * yy / rows) - ) - scarred = base.copy() - scarred[24, 20:92] += 1.8 - scarred[25, 20:92] += 1.6 - scarred[60, 12:116] -= 1.7 - - field = make_field(data=scarred) - corrected, scar_mask = node.process( - field, - scar_type="both", - threshold_high=0.6, - threshold_low=0.2, - min_length=12, - max_width=4, - ) - - mask_bool = scar_mask > 127 - assert scar_mask.dtype == np.uint8 - assert scar_mask.shape == field.data.shape - assert np.count_nonzero(mask_bool) > 0 - assert np.count_nonzero(mask_bool[24:26, 20:92]) > 0 - assert np.count_nonzero(mask_bool[60:61, 12:116]) > 0 - assert np.allclose(corrected.data[~mask_bool], field.data[~mask_bool]) - - before_rmse = np.sqrt(np.mean((field.data[mask_bool] - base[mask_bool]) ** 2)) - after_rmse = np.sqrt(np.mean((corrected.data[mask_bool] - base[mask_bool]) ** 2)) - assert after_rmse < before_rmse * 0.35 - - clean_corrected, clean_mask = node.process( - make_field(data=base), - scar_type="both", - threshold_high=0.6, - threshold_low=0.2, - min_length=12, - max_width=4, - ) - assert np.count_nonzero(clean_mask) == 0 - assert np.allclose(clean_corrected.data, base) - - print(" PASS\n") - - -def test_angle_measure(): - print("=== Test: AngleMeasure ===") - from backend.node_registry import get_node_info - from backend.nodes.angle_measure import AngleMeasure - from backend.data_types import ImageData - - node = AngleMeasure() - info = get_node_info("AngleMeasure") - assert info["category"] == "Overlay" - assert {entry["category"] for entry in info["menu_categories"]} == {"Overlay", "Measure"} - required_inputs = AngleMeasure.INPUT_TYPES()["required"] - optional_inputs = AngleMeasure.INPUT_TYPES().get("optional", {}) - assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] - assert required_inputs["color"][1]["default"] == "#ff9800" - assert required_inputs["stroke_width"][1]["default"] == 1.35 - assert optional_inputs["line_thickness"][1]["hidden"] is True - assert optional_inputs["line_thickness_input"][1]["hidden"] is True - - field = make_field( - data=np.zeros((32, 64), dtype=np.float64), - xreal=4.0, - yreal=2.0, - ) - output, table = node.process( - field, - color="#c62828", - stroke_width=1.8, - x1=0.2, - y1=0.5, - xm=0.5, - ym=0.5, - x2=0.5, - y2=0.2, - label_dx=0.0, - label_dy=0.0, - ) - rows = {row["quantity"]: row for row in table} - assert isinstance(output, DataField) - assert output is not field - assert len(output.overlays) == len(field.overlays) + 1 - assert output.overlays[-1]["kind"] == "angle_measure" - assert output.overlays[-1]["color"] == "#c62828" - assert np.isclose(output.overlays[-1]["stroke_width"], 1.8) - assert np.isclose(rows["Arm A length"]["value"], 1.2) - assert np.isclose(rows["Arm B length"]["value"], 0.6) - assert np.isclose(rows["Angle"]["value"], 90.0) - assert rows["Angle"]["unit"] == "deg" - assert rows["Vertex x"]["unit"] == field.si_unit_xy - - sanitized_output, _ = node.process( - field, - color="not-a-color", - stroke_width=-0.7, - x1=0.2, - y1=0.5, - xm=0.5, - ym=0.5, - x2=0.5, - y2=0.2, - label_dx=0.0, - label_dy=0.0, - ) - assert sanitized_output.overlays[-1]["color"] == "#ff9800" - assert np.isclose(sanitized_output.overlays[-1]["stroke_width"], 0.35) - - image = np.zeros((50, 100, 3), dtype=np.uint8) - image_output, image_table = node.process( - image, - color="#ff9800", - stroke_width=1.25, - x1=0.25, - y1=0.5, - xm=0.5, - ym=0.5, - x2=0.5, - y2=0.25, - label_dx=0.0, - label_dy=0.0, - ) - image_rows = {row["quantity"]: row for row in image_table} - assert isinstance(image_output, ImageData) - assert image_output.shape == image.shape - assert np.count_nonzero(np.asarray(image_output)) > 0 - assert np.isclose(image_rows["Arm A length"]["value"], 24.75) - assert np.isclose(image_rows["Arm B length"]["value"], 12.25) - assert np.isclose(image_rows["Angle"]["value"], 90.0) - assert image_rows["Arm A length"]["unit"] == "px" - - print(" PASS\n") - - -# ========================================================================= -# Analysis (non-FFT) -# ========================================================================= - -def test_statistics(): - print("=== Test: Statistics ===") - from backend.nodes.statistics import Statistics - node = Statistics() - - data = np.array([[1, 2], [3, 4]], dtype=np.float64) - field = make_field(data=data) - - table, = node.process(field) - stats = {row["quantity"]: row["value"] for row in table} - - assert stats["min"] == 1.0 - assert stats["max"] == 4.0 - assert stats["mean"] == 2.5 - assert stats["median"] == 2.5 - assert stats["range"] == 3.0 - # RMS = sqrt(mean((x - mean)^2)) - expected_rms = np.sqrt(np.mean((data - 2.5) ** 2)) - assert abs(stats["RMS"] - expected_rms) < 1e-10 - - # Constant data should have RMS=0, skewness=0, kurtosis=0 - const_field = make_field(data=np.ones((4, 4)) * 5.0) - table_const, = node.process(const_field) - const_stats = {row["quantity"]: row["value"] for row in table_const} - assert const_stats["RMS"] == 0.0 - assert const_stats["skewness"] == 0.0 - assert const_stats["kurtosis"] == 0.0 - print(" PASS\n") - - -def test_height_histogram(): - print("=== Test: Histogram ===") - from backend.nodes.histogram import Histogram - node = Histogram() - - # Uniform data should give a roughly flat histogram - data = np.linspace(0, 1, 1000).reshape(25, 40) - field = make_field(data=data) - - overlays = [] - Histogram._broadcast_overlay_fn = lambda nid, data: overlays.append(data) - Histogram._current_node_id = "test" - - table, coord_pair = node.process( - field, - n_bins=10, - y_scale="linear", - x1=0.2, - y1=0.5, - x2=0.8, - y2=0.5, - ) - assert isinstance(coord_pair, tuple) and len(coord_pair) == 2 - measurements = {row["quantity"]: row for row in table} - assert "A position" in measurements - assert "A count" in measurements - assert "B position" in measurements - assert "B count" in measurements - assert "delta X" in measurements - assert "delta Y" in measurements - assert measurements["A count"]["unit"] == "count" - assert measurements["B count"]["unit"] == "count" - assert measurements["B position"]["value"] > measurements["A position"]["value"] - assert len(overlays) == 1 - assert overlays[0]["kind"] == "line_plot" - assert overlays[0]["section_title"] == "Histogram" - assert len(overlays[0]["line"]) == 10 - assert len(overlays[0]["x_axis"]) == 10 - assert np.isclose(overlays[0]["x1"], 0.2) - assert np.isclose(overlays[0]["x2"], 0.8) - assert np.isclose( - measurements["delta Y"]["value"], - measurements["B count"]["value"] - measurements["A count"]["value"], - ) - - Histogram._broadcast_overlay_fn = None - print(" PASS\n") - - -def test_fractal_dimension(): - print("=== Test: FractalDimension ===") - from backend.node_registry import get_node_info - from backend.execution_context import active_node, execution_callbacks - from backend.nodes.fractal_dimension import FractalDimension - - node = FractalDimension() - assert get_node_info("FractalDimension")["category"] == "Measure" - - N = 129 - yy, xx = np.mgrid[0:N, 0:N] / (N - 1) - data = 0.25 * xx + 0.12 * yy + 0.03 * np.sin(6.0 * np.pi * xx) + 0.02 * np.cos(4.0 * np.pi * yy) - field = make_field(data=data, xreal=4.0e-6, yreal=4.0e-6) - - overlays = [] - tables = [] - with execution_callbacks(overlay=lambda nid, payload: overlays.append(payload), table=lambda nid, rows: tables.append(rows)), active_node("test"): - dimension, curve, table = node.process( - field, - method="partitioning", - interpolation="linear", - x1=0.0, - y1=0.5, - x2=1.0, - y2=0.5, - ) - - assert np.isfinite(dimension) - assert 1.5 < dimension < 2.5 - assert isinstance(curve, LineData) - assert len(curve) > 3 - assert curve.x_axis is not None - assert np.all(np.diff(curve.x_axis) > 0.0) - assert len(overlays) == 1 - assert overlays[0]["kind"] == "line_plot" - assert len(tables) == 1 - assert table[0]["quantity"] == "Dimension" - - methods = ["partitioning", "cube_counting", "triangulation", "psdf", "hhcf"] - for method in methods: - dim, line, measurements = node.process( - field, - method=method, - interpolation="linear", - x1=0.0, - y1=0.5, - x2=1.0, - y2=0.5, - ) - assert np.isfinite(dim), f"{method} should produce a finite fractal dimension" - if method == "psdf": - assert -1.0 < dim < 3.2 - else: - assert 1.2 < dim < 3.2 - assert isinstance(line, LineData) - assert len(line) >= 2 - assert measurements[0]["quantity"] == "Dimension" - - narrowed_dim, _, narrowed_table = node.process( - field, - method="partitioning", - interpolation="linear", - x1=0.15, - y1=0.5, - x2=0.55, - y2=0.5, - ) - assert np.isfinite(narrowed_dim) - fit_from = next(row["value"] for row in narrowed_table if row["quantity"] == "Fit from") - fit_to = next(row["value"] for row in narrowed_table if row["quantity"] == "Fit to") - assert fit_to > fit_from - print(" PASS\n") - - -def test_cross_section(): - print("=== Test: CrossSection ===") - from backend.nodes.cross_section import CrossSection - node = CrossSection() - - # Create a field with a known horizontal gradient - N = 100 - y, x = np.mgrid[0:N, 0:N] / N - data = x * 10.0 # value = 10 * x_fraction - field = make_field(data=data, xreal=1e-6, yreal=1e-6) - - # Horizontal cross section at y=0.5 - profile, marker_pair = node.process( - field, x1=0.0, y1=0.5, x2=1.0, y2=0.5, - extend="none", n_samples=100, - ) - assert isinstance(marker_pair, tuple) and len(marker_pair) == 2 - assert isinstance(profile, LineData) - assert len(profile) == 100 - assert profile.x_unit == field.si_unit_xy - assert profile.y_unit == field.si_unit_z - assert np.isclose(profile.x_axis[0], 0.0) - assert np.isclose(profile.x_axis[-1], field.xreal) - # Profile should be a linear ramp from ~0 to ~10 - assert profile[0] < 0.5, f"Start of profile: {profile[0]}" - assert profile[-1] > 9.5, f"End of profile: {profile[-1]}" - - # n_samples=0 should auto-calculate - profile_auto, _ = node.process( - field, x1=0.0, y1=0.5, x2=1.0, y2=0.5, - extend="none", n_samples=0, - ) - assert len(profile_auto) >= 2 - - # Test extend to edges — a short segment should be extended - profile_ext, _ = node.process( - field, x1=0.3, y1=0.5, x2=0.7, y2=0.5, - extend="to_edges", n_samples=100, - ) - # Extended profile should start near 0 and end near 10 - assert profile_ext[0] < 0.5 - assert profile_ext[-1] > 9.5 - - # Diagonal cross section - profile_diag, _ = node.process( - field, x1=0.0, y1=0.0, x2=1.0, y2=1.0, - extend="none", n_samples=50, - ) - assert len(profile_diag) == 50 - - from backend.nodes.cursors import Cursors - from backend.nodes.stats import Stats - - cursors = Cursors() - table, _ = cursors.process(profile, x1=0.25, y1=0.5, x2=0.75, y2=0.5) - rows = {row["quantity"]: row for row in table} - assert rows["dx"]["unit"] == field.si_unit_xy - assert rows["dy"]["unit"] == field.si_unit_z - - captured = [] - Stats._broadcast_value_fn = lambda nid, payload: captured.append(payload) - Stats._current_node_id = "test" - stats = Stats() - mean_value, = stats.process(profile, operation="mean", column="value") - assert mean_value > 0 - assert captured[-1]["unit"] == field.si_unit_z - Stats._broadcast_value_fn = None - - print(" PASS\n") - - -# ========================================================================= -# Grains -# ========================================================================= - -def test_threshold_mask(): - print("=== Test: ThresholdMask ===") - from backend.nodes.mask_threshold import ThresholdMask - node = ThresholdMask() - - # Clear bimodal data: left half = 0, right half = 1 - data = np.zeros((64, 64)) - data[:, 32:] = 1.0 - field = make_field(data=data) - - # Capture overlay preview - previews = [] - ThresholdMask._broadcast_fn = lambda nid, uri: previews.append(uri) - ThresholdMask._current_node_id = "test" - - # Absolute threshold at 0.5 - mask, = node.process(field, method="absolute", threshold=0.5, direction="above") - assert mask.dtype == np.uint8 - assert mask.shape == (64, 64) - assert np.all(mask[:, :32] == 0) - assert np.all(mask[:, 32:] == 255) - - # Verify overlay preview was broadcast - assert len(previews) == 1 - assert previews[0].startswith("data:image/png;base64,") - - # Direction "below" - mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below") - assert np.all(mask_below[:, :32] == 255) - assert np.all(mask_below[:, 32:] == 0) - - # Relative threshold at 0.5 (midpoint of range) - mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above") - assert np.all(mask_rel[:, 32:] == 255) - - # Otsu should find the bimodal threshold - mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above") - assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum() - - ThresholdMask._broadcast_fn = None - print(" PASS\n") - - -def test_mask_morphology(): - print("=== Test: MaskMorphology ===") - from backend.nodes.mask_morphology import MaskMorphology - node = MaskMorphology() - - # Small square blob in the centre - mask = np.zeros((64, 64), dtype=np.uint8) - mask[28:36, 28:36] = 255 # 8x8 block - orig_count = np.count_nonzero(mask) - - # Dilate should grow the region - dilated, = node.process(mask, operation="dilate", radius=1, shape="square") - assert dilated.dtype == np.uint8 - assert np.count_nonzero(dilated) > orig_count - - # Erode should shrink it - eroded, = node.process(mask, operation="erode", radius=1, shape="square") - assert np.count_nonzero(eroded) < orig_count - - # Open on a clean block should give back roughly the same block - opened, = node.process(mask, operation="open", radius=1, shape="square") - assert np.count_nonzero(opened) <= orig_count - - # Close on a mask with a 1-pixel hole should fill the hole - mask_hole = mask.copy() - mask_hole[32, 32] = 0 # poke a hole - assert np.count_nonzero(mask_hole) == orig_count - 1 - closed, = node.process(mask_hole, operation="close", radius=1, shape="square") - assert closed[32, 32] == 255, "Close should fill the 1-pixel hole" - - # Disk structuring element should also work - dilated_disk, = node.process(mask, operation="dilate", radius=2, shape="disk") - assert np.count_nonzero(dilated_disk) > orig_count - - print(" PASS\n") - - -def test_mask_invert(): - print("=== Test: MaskInvert ===") - from backend.nodes.mask_invert import MaskInvert - node = MaskInvert() - - mask = np.zeros((64, 64), dtype=np.uint8) - mask[10:20, 10:20] = 255 - - inverted, = node.process(mask) - assert inverted.dtype == np.uint8 - assert np.all(inverted[10:20, 10:20] == 0) - assert np.all(inverted[0:10, 0:10] == 255) - # Double-invert should return to original - double, = node.process(inverted) - assert np.array_equal(double, mask) - - print(" PASS\n") - - -def test_mask_operations(): - print("=== Test: MaskOperations ===") - from backend.nodes.mask_operations import MaskOperations - node = MaskOperations() - - # Two overlapping squares - a = np.zeros((64, 64), dtype=np.uint8) - a[10:30, 10:30] = 255 # 20x20 - b = np.zeros((64, 64), dtype=np.uint8) - b[20:40, 20:40] = 255 # 20x20, overlaps 10x10 - - # AND — only the overlap - result_and, = node.process(a, b, operation="and") - assert np.all(result_and[20:30, 20:30] == 255) - assert result_and[15, 15] == 0 # a-only region - assert result_and[35, 35] == 0 # b-only region - - # OR — union - result_or, = node.process(a, b, operation="or") - assert result_or[15, 15] == 255 - assert result_or[35, 35] == 255 - assert result_or[25, 25] == 255 - assert result_or[5, 5] == 0 - - # XOR — symmetric difference - result_xor, = node.process(a, b, operation="xor") - assert result_xor[15, 15] == 255 # a-only - assert result_xor[35, 35] == 255 # b-only - assert result_xor[25, 25] == 0 # overlap excluded - - # A minus B - result_sub, = node.process(a, b, operation="a_minus_b") - assert result_sub[15, 15] == 255 # a-only kept - assert result_sub[25, 25] == 0 # overlap removed - assert result_sub[35, 35] == 0 # b-only not included - - # NAND — everything except overlap - result_nand, = node.process(a, b, operation="nand") - assert result_nand[15, 15] == 255 - assert result_nand[35, 35] == 255 - assert result_nand[25, 25] == 0 - assert result_nand[5, 5] == 255 - - # XNOR — overlap plus shared background - result_xnor, = node.process(a, b, operation="xnor") - assert result_xnor[25, 25] == 255 - assert result_xnor[5, 5] == 255 - assert result_xnor[15, 15] == 0 - assert result_xnor[35, 35] == 0 - - print(" PASS\n") - - -def test_draw_mask(): - print("=== Test: DrawMask ===") - from backend.nodes.mask_draw import DrawMask - node = DrawMask() - - field = make_field(data=np.zeros((32, 32), dtype=np.float64)) - overlays = [] - DrawMask._broadcast_overlay_fn = lambda nid, data: overlays.append(data) - DrawMask._current_node_id = "test" - - mask_paths = [ - { - "size": 5, - "points": [ - {"x": 0.2, "y": 0.5}, - {"x": 0.8, "y": 0.5}, - ], - } - ] - - mask, = node.process(field, pen_size=2, invert=False, mask_paths=json.dumps(mask_paths)) - assert mask.dtype == np.uint8 - assert mask.shape == (32, 32) - assert mask[16, 16] == 255 - assert mask[14, 16] == 255 - assert mask[0, 0] == 0 - - assert len(overlays) == 1 - assert overlays[0]["kind"] == "mask_paint" - assert overlays[0]["section_title"] == "Mask" - assert overlays[0]["image"].startswith("data:image/png;base64,") - assert overlays[0]["image_width"] == field.xres - assert overlays[0]["image_height"] == field.yres - assert overlays[0]["invert"] is False - - inverted, = node.process(field, pen_size=2, invert=True, mask_paths=json.dumps(mask_paths)) - assert inverted[16, 16] == 0 - assert inverted[0, 0] == 255 - assert overlays[-1]["invert"] is True - - cleared, = node.process(field, pen_size=12, invert=False, mask_paths="[]") - assert np.count_nonzero(cleared) == 0 - - DrawMask._broadcast_overlay_fn = None - print(" PASS\n") - - -def test_grain_analysis(): - print("=== Test: GrainAnalysis ===") - from backend.nodes.grain_analysis import GrainAnalysis - node = GrainAnalysis() - - # Create a field with two distinct grains - N = 64 - data = np.zeros((N, N)) - # Grain 1: 10x10 block at top-left with height 5 - data[5:15, 5:15] = 5.0 - # Grain 2: 8x8 block at bottom-right with height 3 - data[45:53, 45:53] = 3.0 - field = make_field(data=data, xreal=1e-6, yreal=1e-6) - - # Create matching mask - mask = np.zeros((N, N), dtype=np.uint8) - mask[5:15, 5:15] = 255 - mask[45:53, 45:53] = 255 - - table, = node.process(field, mask=mask, min_size=10) - assert len(table) == 2, f"Expected 2 grains, got {len(table)}" - - # Sort by area descending - table.sort(key=lambda r: r["area_px"], reverse=True) - assert table[0]["area_px"] == 100 # 10x10 - assert table[1]["area_px"] == 64 # 8x8 - assert abs(table[0]["mean_height"] - 5.0) < 1e-10 - assert abs(table[1]["mean_height"] - 3.0) < 1e-10 - assert table[0]["area_px_unit"] == "px^2" - assert table[0]["area_m2_unit"] == "m^2" - assert table[0]["equiv_diam_m_unit"] == "m" - assert table[0]["mean_height_unit"] == "m" - assert table[0]["max_height_unit"] == "m" - - # min_size filtering: only keep grains >= 80 px - table_filtered, = node.process(field, mask=mask, min_size=80) - assert len(table_filtered) == 1 - assert table_filtered[0]["area_px"] == 100 - print(" PASS\n") - - -def test_grain_distance_transform(): - print("=== Test: GrainDistanceTransform ===") - from backend.nodes.grain_distance_transform import GrainDistanceTransform - - node = GrainDistanceTransform() - field = make_field(data=np.zeros((7, 7), dtype=np.float64), xreal=7.0, yreal=7.0) - mask = np.zeros((7, 7), dtype=np.uint8) - mask[2:5, 2:5] = 255 - - interior, = node.process(field, mask, distance_type="euclidean", output_type="interior", from_border=True) - assert interior.data.shape == field.data.shape - assert interior.si_unit_z == field.si_unit_xy - assert np.isclose(interior.data[3, 3], 2.0) - assert np.isclose(interior.data[2, 2], 1.0) - assert np.isclose(interior.data[0, 0], 0.0) - - exterior, = node.process(field, mask, distance_type="cityblock", output_type="exterior", from_border=True) - assert np.isclose(exterior.data[1, 1], 2.0) - assert np.isclose(exterior.data[2, 1], 1.0) - assert np.isclose(exterior.data[3, 3], 0.0) - - signed, = node.process(field, mask, distance_type="chess", output_type="signed", from_border=True) - assert signed.data[3, 3] > 0.0 - assert signed.data[0, 0] < 0.0 - - edge_field = make_field(data=np.zeros((5, 5), dtype=np.float64), xreal=5.0, yreal=5.0) - edge_mask = np.zeros((5, 5), dtype=np.uint8) - edge_mask[:, :2] = 255 - from_edge, = node.process(edge_field, edge_mask, distance_type="euclidean", output_type="interior", from_border=True) - not_from_edge, = node.process(edge_field, edge_mask, distance_type="euclidean", output_type="interior", from_border=False) - assert not_from_edge.data[2, 0] > from_edge.data[2, 0] - print(" PASS\n") - - -def test_watershed_segmentation(): - print("=== Test: WatershedSegmentation ===") - from scipy.ndimage import label - from backend.execution_context import active_node, execution_callbacks - from backend.nodes.watershed_segmentation import WatershedSegmentation - - node = WatershedSegmentation() - y, x = np.mgrid[-1:1:64j, -1:1:64j] - data = ( - 2.0 * np.exp(-((x + 0.45) ** 2 + y**2) / 0.05) - + 2.0 * np.exp(-((x - 0.45) ** 2 + y**2) / 0.05) - - 0.3 * np.exp(-(x**2 + y**2) / 0.12) - ) - field = make_field(data=data, xreal=2.0e-6, yreal=2.0e-6) - - previews = [] - with execution_callbacks(preview=lambda nid, uri: previews.append(uri)), active_node("test"): - mask, = node.process( - field, - invert_height=False, - locate_steps=10, - locate_threshold=8, - locate_drop_size=0.1, - watershed_steps=20, - watershed_drop_size=0.1, - combine_mode="replace", - ) - assert mask.dtype == np.uint8 - assert mask.shape == field.data.shape - assert len(previews) == 1 - assert previews[0].startswith("data:image/png;base64,") - - _, ngrains = label(mask > 127) - assert ngrains >= 2 - - seed_mask = np.zeros_like(mask) - seed_mask[:, :32] = 255 - intersected, = node.process( - field, - invert_height=False, - locate_steps=10, - locate_threshold=8, - locate_drop_size=0.1, - watershed_steps=20, - watershed_drop_size=0.1, - combine_mode="intersection", - mask=seed_mask, - ) - assert np.count_nonzero(intersected) < np.count_nonzero(mask) - assert np.all(intersected[:, 40:] == 0) - print(" PASS\n") - - -# ========================================================================= -# I/O -# ========================================================================= - -def test_load_file(): - print("=== Test: Image ===") - from backend.nodes.image import Image as ImageNode - from PIL import Image as PILImage - node = ImageNode() - - with tempfile.TemporaryDirectory() as tmpdir: - # Test loading a grayscale PNG → single DataField output - arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8) - img = PILImage.fromarray(arr, mode="L") - path = os.path.join(tmpdir, "test_gray.png") - img.save(path) - - result = node.load(filename=path) - assert len(result) == 1 - field = result[0] - assert field.data.shape == (48, 64) - assert field.data.dtype == np.float64 - - # Test loading an RGB PNG (should average to grayscale) - arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8) - img_rgb = PILImage.fromarray(arr_rgb, mode="RGB") - path_rgb = os.path.join(tmpdir, "test_rgb.png") - img_rgb.save(path_rgb) - - result_rgb = node.load(filename=path_rgb) - assert len(result_rgb) == 1 - assert result_rgb[0].data.shape == (32, 32) - - # Test loading a .npy file - data_npy = np.random.default_rng(3).standard_normal((50, 60)) - path_npy = os.path.join(tmpdir, "test.npy") - np.save(path_npy, data_npy) - - result_npy = node.load(filename=path_npy) - assert np.allclose(result_npy[0].data, data_npy) - - custom_colormap = { - "mode": "custom", - "stops": [ - {"position": 0.0, "color": "#000000"}, - {"position": 0.5, "color": "#ff0000"}, - {"position": 1.0, "color": "#ffffff"}, - ], - } - result_custom = node.load(filename=path, colormap_map=custom_colormap) - assert isinstance(result_custom[0].colormap, dict) - assert result_custom[0].colormap["mode"] == "custom" - assert len(result_custom[0].colormap["stops"]) == 3 - - result_from_path = node.load(filename="", path=path) - assert len(result_from_path) == 1 - assert result_from_path[0].data.shape == (48, 64) - - print(" PASS\n") - - -def test_save_image(): - print("=== Test: SaveImage (Save Layers) ===") - from backend.nodes.save_layers import SaveImage - import tifffile - node = SaveImage() - input_types = SaveImage.INPUT_TYPES() - field_spec = input_types["optional"]["field_0"] - assert field_spec[0] == "DATA_FIELD" - assert field_spec[1]["accepted_types"] == ["IMAGE", "ANNOTATION_SOURCE"] - - field_a = make_field(data=np.random.default_rng(4).random((32, 32))) - field_b = make_field(data=np.random.default_rng(5).random((32, 32))) - annotated = np.zeros((24, 24, 3), dtype=np.uint8) - annotated[..., 0] = 255 - - with tempfile.TemporaryDirectory() as tmpdir: - # Save single layer as TIFF - tiff_path = os.path.join(tmpdir, "out.tiff") - node.save(filename=tiff_path, format="TIFF", field_0=field_a) - assert os.path.exists(tiff_path), "TIFF file not created" - from PIL import Image - im = Image.open(tiff_path) - assert im.n_frames == 1 - arr_back = np.array(im) - assert arr_back.shape == (32, 32) - - # Save multi-layer as TIFF - tiff_path2 = os.path.join(tmpdir, "multi.tiff") - node.save(filename=tiff_path2, format="TIFF", field_0=field_a, field_1=field_b) - im2 = Image.open(tiff_path2) - assert im2.n_frames == 2 - - # Save annotated image as TIFF with layer name - annotated_tiff = os.path.join(tmpdir, "annotated.tiff") - node.save( - filename=annotated_tiff, - format="TIFF", - field_0=annotated, - layer_name_0="annotated overview", - ) - with tifffile.TiffFile(annotated_tiff) as tif: - assert len(tif.pages) == 1 - assert tif.pages[0].description == "annotated overview" - assert tif.pages[0].asarray().shape == annotated.shape - - # Save as NPZ with layer names - npz_path = os.path.join(tmpdir, "out.npz") - node.save( - filename=npz_path, - format="NPZ", - field_0=field_a, - field_1=annotated, - layer_name_0="height map", - layer_name_1="annotated-overview", - ) - assert os.path.exists(npz_path) - npz = np.load(npz_path) - assert len(npz.files) == 2 - assert np.allclose(npz["height_map"], field_a.data) - assert np.array_equal(npz["annotated_overview"], annotated) - - # Extension is forced to match format - wrong_ext = os.path.join(tmpdir, "output.png") - node.save(filename=wrong_ext, format="TIFF", field_0=field_a) - assert os.path.exists(os.path.join(tmpdir, "output.tiff")) - - # Directory input can drive the destination folder while filename supplies the basename - driven_dir = os.path.join(tmpdir, "nested-output") - node.save(filename="driven_name", directory=driven_dir, format="NPZ", field_0=field_a) - assert os.path.exists(os.path.join(driven_dir, "driven_name.npz")) - - # Directory input rejects file paths - try: - node.save( - filename="bad", - directory=os.path.join(tmpdir, "looks_like_file.txt"), - format="TIFF", - field_0=field_a, - ) - assert False, "Should have raised ValueError for file-like directory path" - except ValueError: - pass - - # No fields connected → error - try: - node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF") - assert False, "Should have raised ValueError" - except ValueError: - pass - - # No filename → error - try: - node.save(filename="", format="TIFF", field_0=field_a) - assert False, "Should have raised ValueError" - except ValueError: - pass - - print(" PASS\n") - - -# ========================================================================= -# Display (limited testing — these are output nodes with WS callbacks) -# ========================================================================= - -def test_color_map_node(): - print("=== Test: ColorMap ===") - from backend.nodes.colormap import ColorMap - - node = ColorMap() - - preset, = node.build(mode="preset", preset="magma", stops_json="[]") - assert preset["mode"] == "preset" - assert preset["preset"] == "magma" - - custom, = node.build( - mode="custom", - preset="viridis", - stops_json=json.dumps([ - {"position": 0.0, "color": "#000000"}, - {"position": 0.4, "color": "#00ff00"}, - {"position": 1.0, "color": "#ffffff"}, - ]), - ) - assert custom["mode"] == "custom" - assert custom["stops"][0]["position"] == 0.0 - assert custom["stops"][-1]["position"] == 1.0 - assert len(custom["stops"]) == 3 - print(" PASS\n") - - -def test_font_node(): - print("=== Test: Font ===") - from backend.nodes.font import Font - from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT - - node = Font() - - system_default, = node.build(SYSTEM_DEFAULT_FONT) - assert system_default is None - - named, = node.build("Arial") - assert named == {"family": "Arial", "path": ""} - - custom, = node.build(CUSTOM_FILE_FONT, "/tmp/example-font.ttf") - assert custom == {"family": "", "path": "/tmp/example-font.ttf"} - print(" PASS\n") - - -def test_preview_image(): - print("=== Test: PreviewImage ===") - from backend.nodes.preview_image import PreviewImage - from backend.data_types import ImageData - from backend.execution_context import active_node, execution_callbacks - node = PreviewImage() - preview_input = PreviewImage.INPUT_TYPES()["optional"]["input"] - assert preview_input[0] == "ANNOTATION_SOURCE" - assert preview_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] - - # Set up a capture for the broadcast - captured = [] - with execution_callbacks(preview=lambda nid, data_uri: captured.append(data_uri)), active_node("test"): - # Preview with a DataField - field = make_field() - node.preview(colormap="viridis", input=field) - assert len(captured) == 1 - assert captured[0].startswith("data:image/png;base64,") - - # Preview with field overlay metadata - captured.clear() - field_with_overlay = field.replace(overlays=[{"kind": "annotation", "show_scale_bar": True, "show_color_map": False, "text_size": 14.0}]) - node.preview(colormap="viridis", input=field_with_overlay) - assert len(captured) == 1 - assert captured[0].startswith("data:image/png;base64,") - - # Preview with a custom colormap input - captured.clear() - custom_colormap = { - "mode": "custom", - "stops": [ - {"position": 0.0, "color": "#000000"}, - {"position": 0.5, "color": "#ff0000"}, - {"position": 1.0, "color": "#ffffff"}, - ], - } - node.preview(colormap="auto", input=field, colormap_map=custom_colormap) - assert len(captured) == 1 - assert captured[0].startswith("data:image/png;base64,") - - # Preview with an IMAGE array - captured.clear() - arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8) - node.preview(colormap="gray", input=arr) - assert len(captured) == 1 - - # Preview with an ANNOTATION_SOURCE carrying a DataField - captured.clear() - node.preview(colormap="auto", input=field_with_overlay) - assert len(captured) == 1 - assert captured[0].startswith("data:image/png;base64,") - - # Preview with an ANNOTATION_SOURCE carrying an ImageData - captured.clear() - annotated_image = ImageData( - np.zeros((24, 24, 3), dtype=np.uint8), - metadata={"annotation_context": {"xreal": 1e-6, "si_unit_xy": "m"}}, - ) - node.preview(colormap="auto", input=annotated_image) - assert len(captured) == 1 - assert captured[0].startswith("data:image/png;base64,") - - print(" PASS\n") - - -def test_annotations(): - print("=== Test: Annotations ===") - from backend.nodes.annotations import Annotations - from backend.nodes.font import Font - from backend.data_types import ImageData - from backend.execution_context import active_node, execution_callbacks - - node = Annotations() - font_node = Font() - annotation_input = Annotations.INPUT_TYPES()["required"]["input"] - assert annotation_input[0] == "ANNOTATION_SOURCE" - assert annotation_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] - warnings = [] - field = DataField( - data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64), - xreal=1e-6, - yreal=1e-6, - si_unit_xy="m", - si_unit_z="V", - colormap="viridis", - ) - - base = datafield_to_uint8(field, "viridis") - plain_preview = render_datafield_preview(field, "viridis") - assert np.array_equal(plain_preview, base) - - with execution_callbacks(warning=lambda nid, msg: warnings.append(msg)), active_node("test"): - plain_field, = node.render(input=field, colormap="auto", show_scale_bar=False, show_color_map=False) - assert isinstance(plain_field, DataField) - assert np.array_equal(plain_field.data, field.data) - assert plain_field.colormap == "viridis" - assert plain_field.overlays[-1]["kind"] == "annotation" - plain = render_datafield_preview(plain_field, plain_field.colormap) - assert plain.shape == base.shape - assert np.array_equal(plain, base) - - with_scale_field, = node.render(input=field, colormap="auto", show_scale_bar=True, show_color_map=False) - with_scale = render_datafield_preview(with_scale_field, with_scale_field.colormap) - assert with_scale.shape == base.shape - assert not np.array_equal(with_scale, base) - - with_legend_field, = node.render(input=field, colormap="auto", show_scale_bar=False, show_color_map=True) - with_legend = render_datafield_preview(with_legend_field, with_legend_field.colormap) - assert with_legend.shape[0] == base.shape[0] - assert with_legend.shape[1] > base.shape[1] - assert with_legend.shape[2] == 3 - - larger_legend_field, = node.render( - input=field, - colormap="auto", - show_scale_bar=False, - show_color_map=True, - text_size=28.0, - ) - larger_legend_text = render_datafield_preview(larger_legend_field, larger_legend_field.colormap) - assert larger_legend_text.shape[0] == with_legend.shape[0] - assert larger_legend_text.shape[1] > with_legend.shape[1] - assert larger_legend_text.shape[2] == with_legend.shape[2] - assert not np.array_equal(larger_legend_text, with_legend) - - annotation_font, = font_node.build("Arial") - with_font_field, = node.render( - input=field, - colormap="auto", - show_scale_bar=False, - show_color_map=True, - text_size=28.0, - font=annotation_font, - ) - assert with_font_field.overlays[-1]["font"] == {"family": "Arial", "path": ""} - with_font = render_datafield_preview(with_font_field, with_font_field.colormap) - assert with_font.shape[0] == with_legend.shape[0] - assert with_font.shape[1] > with_legend.shape[1] - assert with_font.shape[2] == with_legend.shape[2] - - with_both_field, = node.render(input=field, colormap="auto", show_scale_bar=True, show_color_map=True) - with_both = render_datafield_preview(with_both_field, with_both_field.colormap) - assert with_both.shape == with_legend.shape - assert not np.array_equal(with_both[:, :base.shape[1]], base) - - viewport_image = ImageData( - np.zeros((48, 64, 3), dtype=np.uint8), - metadata={ - "annotation_context": { - "xreal": 2e-6, - "si_unit_xy": "m", - "legend_min": -1.5, - "legend_mid": 0.0, - "legend_max": 1.5, - "legend_unit": "V", - "colormap": "viridis", - }, - }, - ) - annotated_image, = node.render( - input=viewport_image, - colormap="auto", - show_scale_bar=True, - show_color_map=True, - text_size=18.0, - ) - assert isinstance(annotated_image, ImageData) - assert annotated_image.shape[0] == viewport_image.shape[0] - assert annotated_image.shape[1] > viewport_image.shape[1] - assert annotated_image.metadata["annotation_context"]["legend_unit"] == "V" - assert not np.array_equal(np.asarray(annotated_image)[:, :viewport_image.shape[1]], np.asarray(viewport_image)) - assert warnings == [] - - plain_image = ImageData(np.zeros((32, 40, 3), dtype=np.uint8)) - passthrough_image, = node.render( - input=plain_image, - colormap="auto", - show_scale_bar=True, - show_color_map=True, - text_size=18.0, - ) - assert isinstance(passthrough_image, ImageData) - assert passthrough_image.shape == plain_image.shape - assert np.array_equal(np.asarray(passthrough_image), np.asarray(plain_image)) - assert len(warnings) == 1 - assert "no scale metadata" in warnings[0] - - print(" PASS\n") - - -def test_markup(): - print("=== Test: Markup ===") - from backend.nodes.markup import Markup - from backend.data_types import ImageData, _preview_markup_stroke_width - from backend.execution_context import active_node, execution_callbacks - - node = Markup() - field = make_field(data=np.linspace(0.0, 1.0, 48 * 48, dtype=np.float64).reshape(48, 48)) - base = render_datafield_preview(field, field.colormap) - required_inputs = Markup.INPUT_TYPES()["required"] - - assert _preview_markup_stroke_width(5, 128, 128) == 5 - assert _preview_markup_stroke_width(5, 2048, 2048) > 5 - assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] - assert required_inputs["shape"][1]["default"] == "arrow" - assert required_inputs["stroke_color"][1]["default"] == "#ff0000" - - overlays = [] - with execution_callbacks(overlay=lambda nid, data: overlays.append(data)), active_node("test"): - plain_field, = node.process( - input=field, - shape="line", - stroke_color="#ffd54f", - stroke_width=3, - markup_shapes="[]", - ) - assert isinstance(plain_field, DataField) - assert plain_field.overlays[-1]["kind"] == "markup" - plain = render_datafield_preview(plain_field, plain_field.colormap) - assert np.array_equal(plain, base) - assert overlays[-1]["kind"] == "markup" - assert overlays[-1]["shape"] == "line" - assert overlays[-1]["stroke_color"] == "#ffd54f" - assert overlays[-1]["stroke_width"] == 3 - assert overlays[-1]["image"].startswith("data:image/png;base64,") - - shapes = json.dumps([ - {"kind": "line", "x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9, "width": 3, "color": "#ff0000"}, - {"kind": "rectangle", "x1": 0.2, "y1": 0.2, "x2": 0.8, "y2": 0.5, "width": 2, "color": "#00ff00"}, - {"kind": "circle", "x1": 0.25, "y1": 0.55, "x2": 0.55, "y2": 0.85, "width": 2, "color": "#4fc3f7"}, - {"kind": "arrow", "x1": 0.15, "y1": 0.85, "x2": 0.85, "y2": 0.2, "width": 4, "color": "#ffffff"}, - ]) - marked_field, = node.process( - input=field, - shape="arrow", - stroke_color="#ffffff", - stroke_width=4, - markup_shapes=shapes, - ) - marked = render_datafield_preview(marked_field, marked_field.colormap) - assert marked.shape == base.shape - assert not np.array_equal(marked, base) - assert overlays[-1]["shape"] == "arrow" - assert overlays[-1]["stroke_color"] == "#ffffff" - assert overlays[-1]["stroke_width"] == 4 - - viewport_image = ImageData( - np.zeros((48, 48, 3), dtype=np.uint8), - metadata={"annotation_context": {"xreal": 1e-6, "si_unit_xy": "m"}}, - ) - image_markup, = node.process( - input=viewport_image, - shape="line", - stroke_color="#ff0000", - stroke_width=4, - markup_shapes=json.dumps([ - {"kind": "line", "x1": 0.1, "y1": 0.2, "x2": 0.9, "y2": 0.8, "width": 4, "color": "#ff0000"}, - ]), - ) - assert isinstance(image_markup, ImageData) - assert image_markup.metadata["annotation_context"]["si_unit_xy"] == "m" - assert not np.array_equal(np.asarray(image_markup), np.asarray(viewport_image)) - - print(" PASS\n") - - -def test_print_table(): - print("=== Test: PrintTable ===") - from backend.nodes.print_table import PrintTable - node = PrintTable() - - table_spec = PrintTable.INPUT_TYPES()["required"]["table"] - assert table_spec[0] == "RECORD_TABLE" - assert table_spec[1]["accepted_types"] == ["DATA_TABLE"] - - captured = [] - PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows) - PrintTable._current_node_id = "test" - - table = [{"quantity": "test", "value": 42.0, "unit": "m"}] - node.print_table(table=table) - assert len(captured) == 1 - assert captured[0] == table - - PrintTable._broadcast_table_fn = None - print(" PASS\n") - - -def test_value_display(): - print("=== Test: ValueIO ===") - from backend.nodes.value_io import ValueIO - - node = ValueIO() - value_spec = ValueIO.INPUT_TYPES()["required"]["value"] - assert value_spec[0] == "FLOAT" - assert value_spec[1]["accepted_types"] == ["RECORD_TABLE"] - - captured = [] - ValueIO._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) - ValueIO._current_node_id = "test" - - result = node.display_value(3.25) - assert result == (3.25,) - assert captured == [("test", {"value": 3.25})] - - measurements = RecordTable([ - {"quantity": "delta X", "value": 1.7e-7, "unit": "m"}, - {"quantity": "delta Y", "value": 463, "unit": "count"}, - ]) - result = node.display_value(measurements, measurement="delta X") - assert result == (1.7e-7,) - assert captured[-1] == ("test", {"value": 1.7e-7, "unit": "m"}) - - ValueIO._broadcast_value_fn = None - print(" PASS\n") - - -# ========================================================================= -# I/O — IBW multi-channel loading -# ========================================================================= - -def test_load_file_ibw(): - print("=== Test: Image IBW multi-channel ===") - from backend.nodes.image import Image - - node = Image() - ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw") - ibw_path = os.path.abspath(ibw_path) - if not os.path.exists(ibw_path): - print(" SKIP (demo IBW file not found)\n") - return - - result = node.load(filename=ibw_path) - - # BR_New20012.ibw has 4 channels - assert len(result) == 4, f"Expected 4 channels, got {len(result)}" - - for i, field in enumerate(result): - assert isinstance(field, DataField), f"Channel {i} is not a DataField" - assert field.data.shape == (512, 1024), f"Channel {i} shape: {field.data.shape}" - assert field.data.dtype == np.float64 - # Physical dimensions should be populated (not default 1e-6) - assert field.xreal > 1e-8, f"Channel {i} xreal too small: {field.xreal}" - assert field.yreal > 1e-8, f"Channel {i} yreal too small: {field.yreal}" - assert field.si_unit_xy == "m" - assert field.si_unit_z == "m" - - # All channels should share the same physical dimensions - assert result[0].xreal == result[1].xreal - assert result[0].yreal == result[1].yreal - - # Different channels should have different data - assert not np.array_equal(result[0].data, result[1].data) - - print(" PASS\n") - - -def test_load_file_npz(): - print("=== Test: Image .npz ===") - from backend.nodes.image import Image - - node = Image() - with tempfile.TemporaryDirectory() as tmpdir: - data = np.random.default_rng(99).standard_normal((30, 40)) - path = os.path.join(tmpdir, "test.npz") - np.savez(path, my_array=data) - - result = node.load(filename=path) - assert len(result) == 1 - assert np.allclose(result[0].data, data) - - print(" PASS\n") - - -def test_load_file_cache(): - print("=== Test: Image cache ===") - from unittest.mock import patch - from backend.nodes.image import Image - - node = Image() - Image._load_fields_cached.cache_clear() - - with tempfile.TemporaryDirectory() as tmpdir: - data = np.arange(16, dtype=np.float64).reshape(4, 4) - path = os.path.join(tmpdir, "cached.npy") - np.save(path, data) - - with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: - first, = node.load(filename=path) - second, = node.load(filename=path) - assert loader.call_count == 1 - - assert np.allclose(first.data, data) - assert np.allclose(second.data, data) - assert first is not second - first.data[0, 0] = -999.0 - - third, = node.load(filename=path) - assert third.data[0, 0] == data[0, 0] - - Image._load_fields_cached.cache_clear() - print(" PASS\n") - - -def test_load_file_not_found(): - print("=== Test: Image not found ===") - from backend.nodes.image import Image - - node = Image() - try: - node.load(filename="/nonexistent/path/file.png") - assert False, "Should have raised FileNotFoundError" - except FileNotFoundError: - pass - - print(" PASS\n") - - -def test_load_file_unsupported(): - print("=== Test: Image unsupported format ===") - from backend.nodes.image import Image - - node = Image() - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "test.xyz") - with open(path, "w") as f: - f.write("hello") - try: - node.load(filename=path) - assert False, "Should have raised an error for .xyz" - except Exception: - pass - - print(" PASS\n") - - -def test_load_file_warning(): - print("=== Test: Image warning for uncalibrated data ===") - from backend.nodes.image import Image as ImageNode - from PIL import Image as PILImage - - node = ImageNode() - warnings = [] - ImageNode._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) - ImageNode._current_node_id = "test" - - with tempfile.TemporaryDirectory() as tmpdir: - arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8) - img = PILImage.fromarray(arr) - path = os.path.join(tmpdir, "test.png") - img.save(path) - - result = node.load(filename=path) - assert len(result) == 1 - assert len(warnings) == 1 - assert "Uncalibrated" in warnings[0] - - ImageNode._broadcast_warning_fn = None - print(" PASS\n") - - -# ========================================================================= -# I/O — list_channels helper -# ========================================================================= - -def test_list_channels(): - print("=== Test: list_channels ===") - from backend.nodes.helpers import list_channels, list_folder_paths - from backend.nodes.folder import Folder - from PIL import Image - - # Non-existent file → default - ch = list_channels("/nonexistent/file.ibw") - assert len(ch) == 1 - assert ch[0]["name"] == "field" - - # IBW with channels - ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")) - if os.path.exists(ibw_path): - ch = list_channels(ibw_path) - assert len(ch) == 4 - names = [c["name"] for c in ch] - assert "HeightRetrace" in names - assert "AmplitudeRetrace" in names - assert all(c["type"] == "DATA_FIELD" for c in ch) - - # Plain image → single default channel - with tempfile.TemporaryDirectory() as tmpdir: - img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8)) - path = os.path.join(tmpdir, "test.png") - img.save(path) - - ch = list_channels(path) - assert len(ch) == 1 - assert ch[0]["name"] == "field" - - # .npy → single default channel - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "test.npy") - np.save(path, np.zeros((4, 4))) - - ch = list_channels(path) - assert len(ch) == 1 - - with tempfile.TemporaryDirectory() as tmpdir: - img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8)) - png_path = os.path.join(tmpdir, "a.png") - npy_path = os.path.join(tmpdir, "b.npy") - gwy_path = os.path.join(tmpdir, "c.gwy") - sxm_path = os.path.join(tmpdir, "d.sxm") - ibw_path = os.path.join(tmpdir, "e.ibw") - txt_path = os.path.join(tmpdir, "notes.txt") - img.save(png_path) - np.save(npy_path, np.zeros((4, 4))) - Path(gwy_path).write_bytes(b"gwy") - Path(sxm_path).write_bytes(b"sxm") - Path(ibw_path).write_bytes(b"ibw") - with open(txt_path, "w", encoding="utf-8") as fh: - fh.write("ignore me") - - paths = list_folder_paths(tmpdir) - assert [entry["name"] for entry in paths] == ["directory", "a.png", "b.npy", "c.gwy", "d.sxm", "e.ibw"] - assert Path(paths[0]["path"]).resolve() == Path(tmpdir).resolve() - assert paths[0]["type"] == "DIRECTORY" - assert all(entry["type"] == "FILE_PATH" for entry in paths[1:]) - - folder_node = Folder() - folder_result = folder_node.list_files(tmpdir) - assert folder_result == tuple(entry["path"] for entry in paths) - - print(" PASS\n") - - -# ========================================================================= -# I/O — ImageDemo -# ========================================================================= - -def test_load_demo(): - print("=== Test: ImageDemo ===") - from backend.nodes.image_demo import ImageDemo - - node = ImageDemo() - - # Should be able to load a demo file by name - result = node.load(name="nanoparticles.npy") - assert len(result) >= 1 - assert isinstance(result[0], DataField) - assert result[0].data.ndim == 2 - - # IBW demo should return multiple channels - result_ibw = node.load(name="whiskers.ibw") - assert len(result_ibw) == 4 - for field in result_ibw: - assert isinstance(field, DataField) - - # Non-existent demo should raise - try: - node.load(name="nonexistent_file.png") - assert False, "Should have raised FileNotFoundError" - except FileNotFoundError: - pass - - print(" PASS\n") - - -def test_load_demo_cache(): - print("=== Test: ImageDemo cache ===") - from unittest.mock import patch - from backend.nodes.image import Image - from backend.nodes.image_demo import ImageDemo - - node = ImageDemo() - Image._load_fields_cached.cache_clear() - - with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: - first, = node.load(name="nanoparticles.npy") - second, = node.load(name="nanoparticles.npy") - assert loader.call_count == 1 - - assert np.allclose(first.data, second.data) - assert first is not second - first.data[0, 0] = -999.0 - - third, = node.load(name="nanoparticles.npy") - assert third.data[0, 0] != -999.0 - - Image._load_fields_cached.cache_clear() - print(" PASS\n") - - -def test_load_demo_multi_layer_preview_payload(): - print("=== Test: ImageDemo multi-layer preview payload ===") - from backend.execution import ExecutionEngine - import backend.nodes # noqa: F401 - - previews = [] - prompt = { - "1": { - "class_type": "ImageDemo", - "inputs": { - "name": "whiskers.ibw", - "colormap": "viridis", - }, - }, - } - - ExecutionEngine().execute(prompt, on_preview=lambda node_id, payload: previews.append((node_id, payload))) - - assert len(previews) == 1 - node_id, payload = previews[0] - assert node_id == "1" - assert payload["kind"] == "layer_gallery" - assert len(payload["layers"]) == 4 - assert all(isinstance(layer["name"], str) and layer["name"] for layer in payload["layers"]) - assert all(layer["image"].startswith("data:image/png;base64,") for layer in payload["layers"]) - - print(" PASS\n") - - -# ========================================================================= -# I/O — Coordinate -# ========================================================================= - -def test_coordinate(): - print("=== Test: Coordinate ===") - from backend.nodes.coordinate import Coordinate - - node = Coordinate() - - result = node.process(x=0.3, y=0.7) - assert len(result) == 1 - assert result[0] == (0.3, 0.7) - - # Edge values - result_zero = node.process(x=0.0, y=0.0) - assert result_zero[0] == (0.0, 0.0) - - result_one = node.process(x=1.0, y=1.0) - assert result_one[0] == (1.0, 1.0) - - print(" PASS\n") - - -# ========================================================================= -# I/O — Number -# ========================================================================= - -def test_number(): - print("=== Test: Number ===") - from backend.nodes.number import Number - - node = Number() - - result = node.process(value=1.25) - assert result == (1.25,) - - result_neg = node.process(value=-3.5) - assert result_neg == (-3.5,) - - print(" PASS\n") - - -def test_range_slider(): - print("=== Test: RangeSlider ===") - from backend.nodes.range_slider import RangeSlider - - node = RangeSlider() - - result = node.process(min_value=0.0, max_value=10.0, value=3.25) - assert result == (3.25,) - - # Clamp above max - result_high = node.process(min_value=0.0, max_value=10.0, value=12.0) - assert result_high == (10.0,) - - # Reversed bounds should still work - result_reversed = node.process(min_value=5.0, max_value=-1.0, value=4.0) - assert result_reversed == (4.0,) - - # Equal bounds collapse to a fixed value - result_fixed = node.process(min_value=2.5, max_value=2.5, value=99.0) - assert result_fixed == (2.5,) - - print(" PASS\n") - - -def test_execution_engine_numeric_socket_coercion(): - print("=== Test: ExecutionEngine numeric socket coercion ===") - from backend.execution import ExecutionEngine - from backend.node_registry import register_node - - @register_node(display_name="Test Echo Int") - class TestEchoInt: - @classmethod - def INPUT_TYPES(cls): - return {"required": {"value": ("INT",)}} - - OUTPUTS = ( - ('INT', 'value'), - ) - FUNCTION = "process" - CATEGORY = "tests" - - def process(self, value): - return (value,) - - @register_node(display_name="Test Echo Float") - class TestEchoFloat: - @classmethod - def INPUT_TYPES(cls): - return {"required": {"value": ("FLOAT",)}} - - OUTPUTS = ( - ('FLOAT', 'value'), - ) - FUNCTION = "process" - CATEGORY = "tests" - - def process(self, value): - return (value,) - - engine = ExecutionEngine() - prompt = { - "1": { - "class_type": "Number", - "inputs": {"value": 3.6}, - }, - "2": { - "class_type": "TestEchoInt", - "inputs": {"value": ["1", 0]}, - }, - "3": { - "class_type": "TestEchoFloat", - "inputs": {"value": ["1", 0]}, - }, - } - - outputs = engine.execute(prompt) - assert outputs["2"] == (4,) - assert outputs["3"] == (3.6,) - - print(" PASS\n") - - -def test_execution_engine_caches_unchanged_nodes(): - print("=== Test: ExecutionEngine caches unchanged nodes ===") - from backend.execution import ExecutionEngine - from backend.node_registry import register_node - - @register_node(display_name="Test Cache Source") - class TestCacheSource: - calls = 0 - - @classmethod - def INPUT_TYPES(cls): - return {"required": {"value": ("FLOAT",)}} - - OUTPUTS = ( - ('FLOAT', 'value'), - ) - FUNCTION = "process" - CATEGORY = "tests" - - def process(self, value): - TestCacheSource.calls += 1 - return (float(value),) - - @register_node(display_name="Test Cache Downstream") - class TestCacheDownstream: - calls = 0 - - @classmethod - def INPUT_TYPES(cls): - return {"required": {"value": ("FLOAT",)}} - - OUTPUTS = ( - ('FLOAT', 'value'), - ) - FUNCTION = "process" - CATEGORY = "tests" - - def process(self, value): - TestCacheDownstream.calls += 1 - return (float(value) * 2.0,) - - TestCacheSource.calls = 0 - TestCacheDownstream.calls = 0 - - engine = ExecutionEngine() - prompt = { - "1": { - "class_type": "TestCacheSource", - "inputs": {"value": 2.5}, - }, - "2": { - "class_type": "TestCacheDownstream", - "inputs": {"value": ["1", 0]}, - }, - } - - first_timings = [] - first_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: first_timings.append((node_id, elapsed_ms))) - second_timings = [] - second_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: second_timings.append((node_id, elapsed_ms))) - - assert first_outputs["2"] == (5.0,) - assert second_outputs["2"] == (5.0,) - assert TestCacheSource.calls == 1 - assert TestCacheDownstream.calls == 1 - assert {node_id for node_id, _ in second_timings} == {"1", "2"} - assert all(elapsed_ms == 0.0 for _, elapsed_ms in second_timings) - - print(" PASS\n") - - -def test_execution_engine_only_propagates_real_output_changes(): - print("=== Test: ExecutionEngine propagates only real upstream output changes ===") - from backend.execution import ExecutionEngine - from backend.node_registry import register_node - - @register_node(display_name="Test Quantized Source") - class TestQuantizedSource: - calls = 0 - - @classmethod - def INPUT_TYPES(cls): - return {"required": {"value": ("FLOAT",)}} - - OUTPUTS = ( - ('INT', 'value'), - ) - FUNCTION = "process" - CATEGORY = "tests" - - def process(self, value): - TestQuantizedSource.calls += 1 - return (int(round(float(value))),) - - @register_node(display_name="Test Quantized Downstream") - class TestQuantizedDownstream: - calls = 0 - - @classmethod - def INPUT_TYPES(cls): - return {"required": {"value": ("INT",)}} - - OUTPUTS = ( - ('FLOAT', 'value'), - ) - FUNCTION = "process" - CATEGORY = "tests" - - def process(self, value): - TestQuantizedDownstream.calls += 1 - return (float(value) + 0.5,) - - TestQuantizedSource.calls = 0 - TestQuantizedDownstream.calls = 0 - - engine = ExecutionEngine() - prompt = { - "1": { - "class_type": "TestQuantizedSource", - "inputs": {"value": 1.2}, - }, - "2": { - "class_type": "TestQuantizedDownstream", - "inputs": {"value": ["1", 0]}, - }, - } - - outputs_first = engine.execute(prompt) - assert outputs_first["2"] == (1.5,) - - prompt["1"]["inputs"]["value"] = 1.3 - outputs_second = engine.execute(prompt) - assert outputs_second["2"] == (1.5,) - - prompt["1"]["inputs"]["value"] = 2.2 - outputs_third = engine.execute(prompt) - assert outputs_third["2"] == (2.5,) - - assert TestQuantizedSource.calls == 3 - assert TestQuantizedDownstream.calls == 2 - - print(" PASS\n") - - -# ========================================================================= -# Analysis — Cursors -# ========================================================================= - -def test_line_cursors(): - print("=== Test: Cursors ===") - from backend.nodes.cursors import Cursors - - node = Cursors() - line_spec = Cursors.INPUT_TYPES()["required"]["line"] - assert line_spec[0] == "LINE" - assert line_spec[1]["accepted_types"] == ["DATA_FIELD"] - - # Create a simple linear ramp - line = np.linspace(0, 10, 100).astype(np.float64) - - # Capture overlay - overlays = [] - Cursors._broadcast_overlay_fn = lambda nid, data: overlays.append(data) - Cursors._current_node_id = "test" - - table, coord_pair = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5) - assert isinstance(coord_pair, tuple) and len(coord_pair) == 2 - - # Should produce a 6-row table - assert len(table) == 6 - quantities = {row["quantity"] for row in table} - assert "A x" in quantities - assert "B x" in quantities - assert "dx" in quantities - assert "dy" in quantities - - # B should be at a later position than A - a_pos = next(r["value"] for r in table if r["quantity"] == "A x") - b_pos = next(r["value"] for r in table if r["quantity"] == "B x") - assert b_pos > a_pos - - # Delta Y should reflect the height difference along the ramp - dy = next(r["value"] for r in table if r["quantity"] == "dy") - assert dy > 0 # ramp goes upward - - # Overlay should have been broadcast - assert len(overlays) == 1 - assert overlays[0]["kind"] == "line_plot" - assert len(overlays[0]["line"]) == len(line) - assert len(overlays[0]["x_axis"]) == len(line) - assert 0.0 <= overlays[0]["x1"] <= 1.0 - assert 0.0 <= overlays[0]["x2"] <= 1.0 - - # With LineData input (which carries its own x_axis) - line_data = LineData(data=line, x_axis=np.linspace(0, 1, 100)) - table2, _ = node.process(line_data, x1=0.25, y1=0.5, x2=0.75, y2=0.5) - assert len(table2) == 6 - - # Field input should report dx/dy/dz and broadcast an image overlay - field = DataField( - data=np.arange(100, dtype=np.float64).reshape(10, 10), - xreal=2.0, - yreal=4.0, - si_unit_xy="um", - si_unit_z="nm", - ) - overlays.clear() - table3, _ = node.process(field, x1=0.2, y1=0.25, x2=0.7, y2=0.75) - assert len(table3) == 9 - field_rows = {row["quantity"]: row for row in table3} - assert field_rows["dx"]["unit"] == "um" - assert field_rows["dy"]["unit"] == "um" - assert field_rows["dz"]["unit"] == "nm" - assert np.isclose(field_rows["dx"]["value"], 1.0) - assert np.isclose(field_rows["dy"]["value"], 2.0) - assert len(overlays) == 1 - assert overlays[0]["kind"] == "cursor_points" - assert overlays[0]["image"].startswith("data:image/png;base64,") - - Cursors._broadcast_overlay_fn = None - print(" PASS\n") - - -# ========================================================================= -# Analysis — FFT2D / ACF / PSDF -# ========================================================================= - -def test_fft2d(): - print("=== Test: FFT2D ===") - from backend.nodes.fft_2d import FFT2D - - node = FFT2D() - - # Pure single-frequency signal: peak should appear at the right location - N = 64 - y, x = np.mgrid[0:N, 0:N] / N - freq = 5 - data = np.sin(2 * np.pi * freq * x) - field = make_field(data=data, xreal=1e-6, yreal=1e-6) - - # log_magnitude - spectrum, spec_mag, spec_phase, spec_psdf = node.process(field, windowing="none", level="none") - assert spectrum.data.shape == (N, N) - assert spectrum.domain == "frequency" - assert spectrum.si_unit_xy == "1/m" - # Peak should be symmetric about centre - centre = N // 2 - row = spectrum.data[centre, :] - peak_idx = np.argmax(row[centre + 1:]) + centre + 1 - assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}" - - # magnitude output - _, spec_mag, _, _ = node.process(field, windowing="hann", level="mean") - assert spec_mag.data.shape == (N, N) - assert np.all(spec_mag.data >= 0) - - # phase output - _, _, spec_phase, _ = node.process(field, windowing="none", level="none") - assert spec_phase.data.shape == (N, N) - assert spec_phase.data.min() >= -np.pi - 0.01 - assert spec_phase.data.max() <= np.pi + 0.01 - - # psdf output — units should reflect PSDF calibration - _, _, _, spec_psdf = node.process(field, windowing="hamming", level="plane") - assert spec_psdf.data.shape == (N, N) - assert np.all(spec_psdf.data >= 0) - assert "^2" in spec_psdf.si_unit_z - - # Constant field should have all energy at DC - const_field = make_field(data=np.ones((32, 32)) * 3.0) - _, spec_const, _, _ = node.process(const_field, windowing="none", level="none") - centre32 = 16 - dc_val = spec_const.data[centre32, centre32] - assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field" - - # Blackman windowing should also work without error - spec_bk, _, _, _ = node.process(field, windowing="blackman", level="none") - assert spec_bk.data.shape == (N, N) - - print(" PASS\n") - - -def test_acf(): - print("=== Test: ACF ===") - from backend.nodes.acf import ACF - - node = ACF() - data = np.array([ - [1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [2.0, 1.0, 0.0, -1.0], - [0.0, 1.0, 2.0, 3.0], - ], dtype=np.float64) - field = DataField(data=data, xreal=8.0, yreal=4.0, si_unit_xy="nm", si_unit_z="V") - - acf, = node.process(field, level="none") - assert acf.data.shape == (3, 3) - assert acf.domain == "spatial" - assert acf.si_unit_xy == "nm" - assert acf.si_unit_z == "V^2" - assert np.isclose(acf.xreal, 6.0) - assert np.isclose(acf.yreal, 3.0) - assert np.isclose(acf.xoff, -3.0) - assert np.isclose(acf.yoff, -1.5) - - expected = np.zeros((3, 3), dtype=np.float64) - for iy, dy in enumerate(range(-1, 2)): - for ix, dx in enumerate(range(-1, 2)): - y0a = max(0, dy) - y1a = min(data.shape[0], data.shape[0] + dy) - x0a = max(0, dx) - x1a = min(data.shape[1], data.shape[1] + dx) - lhs = data[y0a:y1a, x0a:x1a] - rhs = data[y0a - dy:y1a - dy, x0a - dx:x1a - dx] - expected[iy, ix] = float(np.mean(lhs * rhs)) - - assert np.allclose(acf.data, expected) - assert np.allclose(acf.data, acf.data[::-1, ::-1]) - print(" PASS\n") - - -def test_psdf_node(): - print("=== Test: PSDF ===") - from backend.nodes.fft_2d import FFT2D - from backend.nodes.psdf import PSDF - - field = DataField( - data=np.random.default_rng(17).standard_normal((64, 64)), - xreal=2.0e-6, - yreal=1.0e-6, - si_unit_xy="m", - si_unit_z="nm", - ) - - fft_node = FFT2D() - psdf_node = PSDF() - - fft_psdf = fft_node.process(field, windowing="hann", level="plane")[3] - psdf, = psdf_node.process(field, windowing="hann", level="plane") - assert np.allclose(psdf.data, fft_psdf.data) - assert psdf.data.shape == field.data.shape - assert psdf.domain == "frequency" - assert psdf.si_unit_xy == "1/m" - assert psdf.si_unit_z == "nm^2 m^2" - assert np.all(psdf.data >= 0.0) - - white = DataField( - data=np.random.default_rng(123).standard_normal((128, 128)), - xreal=1.0e-6, - yreal=1.0e-6, - si_unit_xy="m", - si_unit_z="m", - ) - psdf_white, = psdf_node.process(white, windowing="none", level="none") - variance = float(np.var(white.data)) - dk_x = psdf_white.xreal / psdf_white.xres - dk_y = psdf_white.yreal / psdf_white.yres - integral = float(np.sum(psdf_white.data) * dk_x * dk_y) - assert 0.8 < integral / variance < 1.2 - print(" PASS\n") - - -# ========================================================================= -# Analysis — Stats -# ========================================================================= - -def test_stats(): - print("=== Test: Stats ===") - from backend.nodes.stats import Stats - - node = Stats() - input_spec = Stats.INPUT_TYPES()["required"]["input"] - assert input_spec[0] == "DATA_FIELD" - assert input_spec[1]["accepted_types"] == ["IMAGE", "LINE", "DATA_TABLE"] - - captured = [] - Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) - Stats._current_node_id = "test" - - line = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float64) - result, = node.process(line, operation="mean", column="value") - assert np.isclose(result, 2.5) - assert captured[-1] == ("test", {"value": result}) - roughness, = node.process(line, operation="Rq", column="value") - assert np.isclose(roughness, np.sqrt(np.mean((line - line.mean()) ** 2))) - - table = DataTable([ - {"name": "a", "value": 3.0, "unit": "m", "other": 10.0}, - {"name": "b", "value": 7.0, "unit": "m", "other": 20.0}, - ]) - result, = node.process(table, operation="max", column="value") - assert result == 7.0 - assert captured[-1] == ("test", {"value": 7.0, "unit": "m"}) - count, = node.process(table, operation="count", column="other") - assert count == 2.0 - auto_column_range, = node.process(table, operation="range", column="") - assert auto_column_range == 4.0 - - field = make_field(data=np.array([[1.0, 5.0], [2.0, 4.0]], dtype=np.float64)) - result, = node.process(field, operation="range", column="value") - assert result == 4.0 - assert captured[-1] == ("test", {"value": 4.0, "unit": "m"}) - - image = np.array([[0, 10], [20, 30]], dtype=np.uint8) - result, = node.process(image, operation="avg", column="value") - assert np.isclose(result, 15.0) - assert captured[-1] == ("test", {"value": 15.0}) - - try: - node.process(table, operation="Rq", column="value") - raise AssertionError("Expected invalid TABLE operation to raise ValueError") - except ValueError: - pass - - try: - node.process([{"label": "only text"}], operation="max", column="label") - raise AssertionError("Expected non-numeric record-table input to raise ValueError") - except ValueError: - pass - - try: - node.process( - RecordTable([{"quantity": "min", "value": 1.0, "unit": "m"}]), - operation="max", - column="value", - ) - raise AssertionError("Expected measurement table input to raise ValueError") - except ValueError: - pass - - Stats._broadcast_value_fn = None - print(" PASS\n") - - -# ========================================================================= -# Display — View3D -# ========================================================================= - -def test_view3d(): - print("=== Test: View3D ===") - from backend.nodes.view_3d import View3D - from backend.data_types import ImageData, MeshModel - from backend.execution_context import active_node, execution_callbacks - import base64 - import io - from PIL import Image - - node = View3D() - field = make_field() - - captured = [] - mesh_callback = lambda nid, mesh: captured.append(mesh) - - preview_image = Image.new("RGB", (12, 10), (255, 0, 0)) - preview_buffer = io.BytesIO() - preview_image.save(preview_buffer, format="PNG") - viewport_snapshot = "data:image/png;base64," + base64.b64encode(preview_buffer.getvalue()).decode() - - with execution_callbacks(mesh=mesh_callback), active_node("test"): - result = node.render( - field, - colormap="viridis", - z_scale=2.0, - resolution=64, - make_solid=False, - camera_target_x=0.1, - camera_target_y=-0.2, - camera_target_z=0.3, - viewport_snapshot=viewport_snapshot, - ) - assert len(result) == 2 - assert isinstance(result[0], MeshModel) - assert isinstance(result[1], ImageData) - assert result[1].shape == (10, 12, 3) - assert np.all(result[1][0, 0] == np.array([255, 0, 0], dtype=np.uint8)) - assert result[1].metadata["annotation_context"]["si_unit_xy"] == field.si_unit_xy - assert result[1].metadata["viewport_camera"]["target_x"] == 0.1 - assert result[1].metadata["viewport_camera"]["target_y"] == -0.2 - assert result[1].metadata["viewport_camera"]["target_z"] == 0.3 - assert len(captured) == 1 - - mesh = captured[0] - assert "width" in mesh - assert "height" in mesh - assert "z_data" in mesh - assert "colors" in mesh - assert mesh["z_scale"] == 0.2 - assert mesh["width"] <= 64 - assert mesh["height"] <= 64 - assert mesh["camera_target_x"] == 0.1 - assert mesh["camera_target_y"] == -0.2 - assert mesh["camera_target_z"] == 0.3 - # z_min < z_max for non-constant data - assert mesh["z_min"] < mesh["z_max"] - - # Verify base64 data can be decoded - z_bytes = base64.b64decode(mesh["z_data"]) - assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 # float32 - - colors_bytes = base64.b64decode(mesh["colors"]) - assert len(colors_bytes) == mesh["width"] * mesh["height"] * 3 # uint8 RGB - - # High-res input should be downsampled - big_field = make_field(shape=(256, 256)) - captured.clear() - with execution_callbacks(mesh=mesh_callback), active_node("test"): - node.render(big_field, colormap="hot", z_scale=1.0, resolution=64, make_solid=False) - assert captured[0]["width"] <= 64 - assert captured[0]["height"] <= 64 - - # Separate map input should affect colors without changing mesh geometry - mesh_field = make_field(data=np.zeros((64, 64), dtype=np.float64), xreal=2.0, yreal=3.0) - map_field = make_field(data=np.tile(np.linspace(0.0, 1.0, 64, dtype=np.float64), (64, 1)), xreal=2.0, yreal=3.0) - captured.clear() - with execution_callbacks(mesh=mesh_callback), active_node("test"): - mapped_result = node.render(mesh_field, map_field=map_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) - mapped_mesh = captured[0] - assert mapped_mesh["x_range"] == [float(mesh_field.xoff), float(mesh_field.xoff + mesh_field.xreal)] - assert mapped_mesh["y_range"] == [float(mesh_field.yoff), float(mesh_field.yoff + mesh_field.yreal)] - assert np.isclose(mapped_mesh["surface_extent_x"] / mapped_mesh["surface_extent_y"], mesh_field.xreal / mesh_field.yreal) - mapped_z = np.frombuffer(base64.b64decode(mapped_mesh["z_data"]), dtype=np.float32) - assert np.allclose(mapped_z, 0.0) - mapped_colors = np.frombuffer(base64.b64decode(mapped_mesh["colors"]), dtype=np.uint8) - top_vertices = np.asarray(mapped_result[0].vertices, dtype=np.float32) - x_span = float(top_vertices[:, 0].max() - top_vertices[:, 0].min()) - y_span = float(top_vertices[:, 2].max() - top_vertices[:, 2].min()) - assert np.isclose(x_span / y_span, mesh_field.xreal / mesh_field.yreal) - - captured.clear() - with execution_callbacks(mesh=mesh_callback), active_node("test"): - node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=32, make_solid=False) - mesh_only = captured[0] - mesh_only_colors = np.frombuffer(base64.b64decode(mesh_only["colors"]), dtype=np.uint8) - assert not np.array_equal(mapped_colors, mesh_only_colors) - - # make_solid should add extra geometry beyond the top surface grid - solid_mesh = mapped_result[0] - assert isinstance(solid_mesh, MeshModel) - captured.clear() - with execution_callbacks(mesh=mesh_callback), active_node("test"): - solid_result = node.render(mesh_field, colormap="viridis", z_scale=1.0, resolution=16, make_solid=True) - assert len(solid_result[0].vertices) > 16 * 16 - assert len(solid_result[0].faces) > (15 * 15 * 2) - solid_payload = captured[0] - assert solid_payload["make_solid"] is True - assert "positions" in solid_payload - assert "indices" in solid_payload - assert "vertex_colors" in solid_payload - print(" PASS\n") - - -def test_save_generic(): - print("=== Test: Save ===") - from backend.nodes.save import Save - from backend.data_types import DataField, ImageData, LineData, RecordTable, MeshModel, DataTable - import tifffile - from PIL import Image as PILImage - - node = Save() - value_spec = node.INPUT_TYPES()["required"]["value"] - assert value_spec[0] == "DATA_FIELD" - assert value_spec[1]["accepted_types"] == [ - "IMAGE", - "ANNOTATION_SOURCE", - "LINE", - "RECORD_TABLE", - "DATA_TABLE", - "MESH_MODEL", - "FLOAT", - ] - format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"] - assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"] - - with tempfile.TemporaryDirectory() as tmpdir: - # Save scalar as TXT and JSON - node.save(filename="scalar", directory_path=tmpdir, format="TXT", value=3.5) - assert Path(tmpdir, "scalar.txt").read_text(encoding="utf-8").strip() == "3.5" - node.save(filename="scalar_json", directory_path=tmpdir, format="JSON", value=3.5) - assert json.loads(Path(tmpdir, "scalar_json.json").read_text(encoding="utf-8")) == {"value": 3.5} - - # Save line as CSV, NPZ, and JSON - line = LineData(data=np.array([1.0, 2.0, 3.0]), x_axis=np.array([0.0, 0.5, 1.0]), x_unit="um", y_unit="nm") - node.save(filename="profile", directory_path=tmpdir, format="CSV", value=line) - csv_text = Path(tmpdir, "profile.csv").read_text(encoding="utf-8") - assert "x,y,x_unit,y_unit" in csv_text - assert "um" in csv_text and "nm" in csv_text - node.save(filename="profile_npz", directory_path=tmpdir, format="NPZ", value=line) - line_npz = np.load(Path(tmpdir, "profile_npz.npz")) - assert np.allclose(line_npz["x"], line.x_axis) - assert np.allclose(line_npz["y"], line.data) - node.save(filename="profile_json", directory_path=tmpdir, format="JSON", value=line) - line_json = json.loads(Path(tmpdir, "profile_json.json").read_text(encoding="utf-8")) - assert line_json["x_unit"] == "um" - assert line_json["y_unit"] == "nm" - assert line_json["x"] == [0.0, 0.5, 1.0] - assert line_json["y"] == [1.0, 2.0, 3.0] - - # Save DATA_FIELD as TIFF, PNG, and NPZ - field = DataField( - data=np.array([[1.0, 2.0], [3.0, 4.5]], dtype=np.float64), - xreal=2e-6, - yreal=1e-6, - si_unit_xy="m", - si_unit_z="m", - colormap="viridis", - ) - node.save(filename="field_tiff", directory_path=tmpdir, format="TIFF", value=field) - field_tiff = tifffile.imread(Path(tmpdir, "field_tiff.tiff")) - assert field_tiff.shape == field.data.shape - assert field_tiff.dtype == np.float32 - assert np.allclose(field_tiff, field.data.astype(np.float32)) - - node.save(filename="field_png", directory_path=tmpdir, format="PNG", value=field) - field_png = np.asarray(PILImage.open(Path(tmpdir, "field_png.png"))) - assert field_png.shape == (2, 2, 3) - assert field_png.dtype == np.uint8 - - node.save(filename="field_npz", directory_path=tmpdir, format="NPZ", value=field) - field_npz = np.load(Path(tmpdir, "field_npz.npz")) - assert np.allclose(field_npz["field"], field.data) - - # Save IMAGE as PNG, TIFF, and NPZ - image = np.array( - [ - [[255, 0, 0], [0, 255, 0]], - [[0, 0, 255], [255, 255, 0]], - ], - dtype=np.uint8, - ) - node.save(filename="image_png", directory_path=tmpdir, format="PNG", value=image) - image_png = np.asarray(PILImage.open(Path(tmpdir, "image_png.png"))) - assert image_png.shape == image.shape - assert np.array_equal(image_png, image) - - node.save(filename="image_tiff", directory_path=tmpdir, format="TIFF", value=image) - image_tiff = tifffile.imread(Path(tmpdir, "image_tiff.tiff")) - assert image_tiff.shape == image.shape - assert image_tiff.dtype == np.uint8 - assert np.array_equal(image_tiff, image) - - node.save(filename="image_npz", directory_path=tmpdir, format="NPZ", value=image) - image_npz = np.load(Path(tmpdir, "image_npz.npz")) - assert np.array_equal(image_npz["image"], image) - - # Save ANNOTATION_SOURCE as PNG, TIFF, and NPZ - annotation_image = ImageData( - image, - metadata={"annotation_context": {"si_unit_xy": "um", "si_unit_z": "nm"}}, - ) - node.save(filename="annotation_png", directory_path=tmpdir, format="PNG", value=annotation_image) - annotation_png = np.asarray(PILImage.open(Path(tmpdir, "annotation_png.png"))) - assert annotation_png.shape == image.shape - assert np.array_equal(annotation_png, image) - - node.save(filename="annotation_tiff", directory_path=tmpdir, format="TIFF", value=annotation_image) - annotation_tiff = tifffile.imread(Path(tmpdir, "annotation_tiff.tiff")) - assert annotation_tiff.shape == image.shape - assert annotation_tiff.dtype == np.uint8 - assert np.array_equal(annotation_tiff, image) - - node.save(filename="annotation_npz", directory_path=tmpdir, format="NPZ", value=annotation_image) - annotation_npz = np.load(Path(tmpdir, "annotation_npz.npz")) - assert np.array_equal(annotation_npz["image"], image) - - # Save tables as CSV and JSON - measure_table = RecordTable([ - {"quantity": "Rq", "value": 1.23, "unit": "nm"}, - {"quantity": "Ra", "value": 0.98, "unit": "nm"}, - ]) - node.save(filename="measurements_csv", directory_path=tmpdir, format="CSV", value=measure_table) - measure_csv = Path(tmpdir, "measurements_csv.csv").read_text(encoding="utf-8") - assert "quantity,value,unit" in measure_csv - assert "Rq,1.23,nm" in measure_csv - node.save(filename="measurements_json", directory_path=tmpdir, format="JSON", value=measure_table) - assert json.loads(Path(tmpdir, "measurements_json.json").read_text(encoding="utf-8")) == list(measure_table) - - record_table = DataTable([ - {"label": "particle-1", "height": 12.0, "area": 44.0}, - {"label": "particle-2", "height": 8.0, "area": 21.0}, - ]) - node.save(filename="records_csv", directory_path=tmpdir, format="CSV", value=record_table) - record_csv = Path(tmpdir, "records_csv.csv").read_text(encoding="utf-8") - assert "label,height,area" in record_csv - assert "particle-1,12.0,44.0" in record_csv - node.save(filename="records_json", directory_path=tmpdir, format="JSON", value=record_table) - assert json.loads(Path(tmpdir, "records_json.json").read_text(encoding="utf-8")) == list(record_table) - - # Save mesh as OBJ and STL - mesh = MeshModel( - vertices=np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=np.float32), - faces=np.array([[0, 1, 2]], dtype=np.int32), - ) - node.save(filename="triangle", directory_path=tmpdir, format="OBJ", value=mesh) - obj_text = Path(tmpdir, "triangle.obj").read_text(encoding="utf-8") - assert "v 0.0 0.0 0.0" in obj_text - assert "f 1 2 3" in obj_text - - node.save(filename="triangle", directory_path=tmpdir, format="STL", value=mesh) - stl_text = Path(tmpdir, "triangle.stl").read_text(encoding="utf-8") - assert stl_text.startswith("solid argonode") - assert "facet normal" in stl_text - - try: - node.save(filename="triangle", directory_path=tmpdir, format="PNG", value=mesh) - assert False, "Mesh should only be saveable as OBJ or STL" - except ValueError: - pass - - try: - node.save(filename="field_bad", directory_path=tmpdir, format="CSV", value=field) - assert False, "DATA_FIELD should reject unsupported save formats" - except ValueError: - pass - - print(" PASS\n") - - -# ========================================================================= -# Run all tests -# ========================================================================= - -if __name__ == "__main__": - # Filters - test_gaussian_filter() - test_median_filter() - test_crop_resize_field() - test_rotate_field() - test_flip_field() - test_colormap_adjust() - test_edge_detect() - test_fft_filter_1d() - test_fft_filter_2d() - - # Level - test_plane_level() - test_poly_level() - test_fix_zero() - test_curvature() - test_line_correction() - test_scar_removal() - test_angle_measure() - - # Analysis - test_statistics() - test_height_histogram() - test_fractal_dimension() - test_cross_section() - test_line_cursors() - test_fft2d() - test_stats() - - # Mask - test_threshold_mask() - test_mask_morphology() - test_mask_invert() - test_mask_operations() - test_draw_mask() - - # Grains - test_grain_analysis() - - # I/O - test_load_file() - test_load_file_ibw() - test_load_file_npz() - test_load_file_not_found() - test_load_file_unsupported() - test_load_file_warning() - test_list_channels() - test_load_demo() - test_coordinate() - test_range_slider() - test_save_generic() - test_save_image() - - # Display - test_preview_image() - test_print_table() - test_value_display() - test_view3d() - - print("All tests passed!")