diff --git a/backend/execution.py b/backend/execution.py index 7aa582b..4d39526 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -182,7 +182,7 @@ class ExecutionEngine: ) -> None: """Wire up broadcast callbacks on display node classes.""" from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay - from backend.nodes.analysis import CrossSection, LineCursors, TableMath + from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram from backend.nodes.modify import CropResizeField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine from backend.nodes.io import SaveImage, LoadFile @@ -196,6 +196,8 @@ class ExecutionEngine: PrintTable._broadcast_table_fn = on_table ValueDisplay._broadcast_value_fn = on_value TableMath._broadcast_value_fn = on_value + Stats._broadcast_value_fn = on_value + HeightHistogram._broadcast_overlay_fn = on_overlay CrossSection._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay CropResizeField._broadcast_overlay_fn = on_overlay @@ -205,11 +207,11 @@ class ExecutionEngine: def _set_node_id_on_display(self, cls: type, node_id: str) -> None: """Inform display nodes of their current node_id for WS tagging.""" from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay - from backend.nodes.analysis import CrossSection, LineCursors, TableMath + from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram from backend.nodes.modify import CropResizeField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine from backend.nodes.io import LoadFile, SaveImage - if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, CrossSection, LineCursors, CropResizeField, + if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile, SaveImage): cls._current_node_id = node_id diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index 05ba976..1835849 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -71,26 +71,91 @@ class HeightHistogram: "field": ("DATA_FIELD",), "n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}), "y_scale": (["linear", "log"],), + "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), } } - RETURN_TYPES = ("LINE", "LINE") - RETURN_NAMES = ("counts", "bin_centers") + RETURN_TYPES = ("TABLE",) + RETURN_NAMES = ("measurements",) FUNCTION = "process" CATEGORY = "analysis" DESCRIPTION = ( "Compute the height distribution histogram (DH). " "Use log scale to reveal small peaks next to a dominant background. " + "Outputs marker measurements while showing the histogram interactively in-node. " "Equivalent to gwy_data_field_dh." ) - def process(self, field: DataField, n_bins: int, y_scale: str = "linear") -> tuple: - counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins)) + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process( + self, + field: DataField, + n_bins: int, + y_scale: str = "linear", + x1: float = 0.25, + y1: float = 0.5, + x2: float = 0.75, + y2: float = 0.5, + ) -> tuple: + raw_counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins)) bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) - counts = counts.astype(np.float64) + counts = raw_counts.astype(np.float64) if y_scale == "log": counts = np.log10(1.0 + counts) - return (counts, bin_centers) + + x1 = float(np.clip(x1, 0.0, 1.0)) + x2 = float(np.clip(x2, 0.0, 0.0 + 1.0)) + + xmin = float(np.min(bin_centers)) if len(bin_centers) else 0.0 + xmax = float(np.max(bin_centers)) if len(bin_centers) else 1.0 + + def x_frac_to_idx(frac): + if len(bin_centers) <= 1: + return 0 + if xmax == xmin: + return 0 + target_x = xmin + frac * (xmax - xmin) + return int(np.argmin(np.abs(bin_centers - target_x))) + + idx_a = x_frac_to_idx(x1) + idx_b = x_frac_to_idx(x2) + xa = float(bin_centers[idx_a]) if len(bin_centers) else 0.0 + xb = float(bin_centers[idx_b]) if len(bin_centers) else 0.0 + ya = float(counts[idx_a]) if len(counts) else 0.0 + yb = float(counts[idx_b]) if len(counts) else 0.0 + count_unit = "count" if y_scale == "linear" else "log10(1+count)" + + if HeightHistogram._broadcast_overlay_fn is not None: + HeightHistogram._broadcast_overlay_fn( + HeightHistogram._current_node_id, + { + "kind": "line_plot", + "section_title": "Histogram", + "line": counts.tolist(), + "x_axis": bin_centers.astype(np.float64).tolist(), + "x1": float(np.clip(x1, 0.0, 1.0)), + "x2": float(np.clip(x2, 0.0, 1.0)), + "y1": float(y1), + "y2": float(y2), + "a_locked": False, + "b_locked": False, + }, + ) + + table = [ + {"quantity": "A position", "value": xa, "unit": field.si_unit_z}, + {"quantity": "A count", "value": ya, "unit": count_unit}, + {"quantity": "B position", "value": xb, "unit": field.si_unit_z}, + {"quantity": "B count", "value": yb, "unit": count_unit}, + {"quantity": "delta X", "value": xb - xa, "unit": field.si_unit_z}, + {"quantity": "delta Y", "value": yb - ya, "unit": count_unit}, + ] + return (table,) # --------------------------------------------------------------------------- @@ -164,6 +229,7 @@ class LineCursors: LineCursors._current_node_id, { "kind": "line_plot", + "section_title": "Line Cursors", "line": y.tolist(), "x_axis": x.tolist(), "x1": x1, @@ -582,6 +648,20 @@ TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = { "count": lambda values: float(len(values)), } +ARRAY_OPS: dict[str, Callable[[np.ndarray], float]] = { + "min": lambda values: float(np.min(values)), + "max": lambda values: float(np.max(values)), + "avg": lambda values: float(np.mean(values)), + "mean": lambda values: float(np.mean(values)), + "median": lambda values: float(np.median(values)), + "sum": lambda values: float(np.sum(values)), + "range": lambda values: float(np.max(values) - np.min(values)), + "std": lambda values: float(np.std(values)), + "variance": lambda values: float(np.var(values)), + "rms": lambda values: float(np.sqrt(np.mean(values * values))), + "count": lambda values: float(values.size), +} + @register_node(display_name="Table Math") class TableMath: @@ -616,8 +696,8 @@ class TableMath: if not isinstance(table, list) or not table: raise ValueError("Table Math requires a non-empty TABLE input.") - column_name = self._resolve_column_name(table, column) - values = self._extract_numeric_values(table, column_name) + column_name = resolve_table_column_name(table, column) + values = extract_numeric_table_values(table, column_name) if not values: raise ValueError(f"Column '{column_name}' has no numeric values.") @@ -630,46 +710,134 @@ class TableMath: TableMath._broadcast_value_fn(TableMath._current_node_id, result) return (result,) - def _resolve_column_name(self, table: list, column: str) -> str: - requested = str(column or "").strip() - if requested: - return requested - if self._extract_numeric_values(table, "value"): - return "value" +def extract_numeric_table_values(table: list, column: str) -> list[float]: + values = [] + for row in table: + if not isinstance(row, dict) or column not in row: + continue + value = row[column] + if isinstance(value, bool): + continue + try: + numeric = float(value) + except (TypeError, ValueError): + continue + if np.isfinite(numeric): + values.append(numeric) + return values - numeric_columns = [] - seen = set() - for row in table: - if not isinstance(row, dict): - continue - for key in row.keys(): - if key in seen: - continue - seen.add(key) - if self._extract_numeric_values(table, key): - numeric_columns.append(key) - if len(numeric_columns) == 1: - return numeric_columns[0] - if not numeric_columns: - raise ValueError("Table Math could not find any numeric columns in the input table.") - raise ValueError( - "Table Math found multiple numeric columns; set the column name explicitly." - ) +def resolve_table_column_name(table: list, column: str) -> str: + requested = str(column or "").strip() + if requested: + return requested - def _extract_numeric_values(self, table: list, column: str) -> list[float]: - values = [] - for row in table: - if not isinstance(row, dict) or column not in row: + if extract_numeric_table_values(table, "value"): + return "value" + + numeric_columns = [] + seen = set() + for row in table: + if not isinstance(row, dict): + continue + for key in row.keys(): + if key in seen: continue - value = row[column] - if isinstance(value, bool): - continue - try: - numeric = float(value) - except (TypeError, ValueError): - continue - if np.isfinite(numeric): - values.append(numeric) - return values + seen.add(key) + if extract_numeric_table_values(table, key): + numeric_columns.append(key) + + if len(numeric_columns) == 1: + return numeric_columns[0] + if not numeric_columns: + raise ValueError("Table Math could not find any numeric columns in the input table.") + raise ValueError( + "Table Math found multiple numeric columns; set the column name explicitly." + ) + + +@register_node(display_name="Stats") +class Stats: + """Polymorphic scalar stats node for LINE, TABLE, DATA_FIELD, or IMAGE inputs.""" + + _broadcast_value_fn = None + _current_node_id: str = "" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input": ("STATS_SOURCE",), + "column": ("STRING", { + "default": "value", + "choices_from_table_input": "input", + "show_when_source_type": { + "input": ["TABLE"], + }, + }), + "operation": ("STRING", { + "default": "mean", + "choices_by_source_type": { + "LINE": list(LINE_OPS.keys()), + "TABLE": list(TABLE_OPS.keys()), + "DATA_FIELD": list(ARRAY_OPS.keys()), + "IMAGE": list(ARRAY_OPS.keys()), + }, + "source_type_input": "input", + }), + } + } + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + CATEGORY = "analysis" + DESCRIPTION = ( + "Compute a contextual scalar statistic from a LINE, TABLE, DATA_FIELD, or IMAGE. " + "The available operations adapt to the connected input type." + ) + + def process(self, input, operation: str, column: str = "value") -> tuple: + source_type, values = self._resolve_input_values(input, column) + + if source_type == "TABLE": + ops = TABLE_OPS + elif source_type == "LINE": + ops = LINE_OPS + else: + ops = ARRAY_OPS + + if operation not in ops: + raise ValueError(f"Operation '{operation}' is not valid for {source_type} input.") + + op_entry = ops[operation] + fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry + result = fn(values) + if Stats._broadcast_value_fn is not None: + Stats._broadcast_value_fn(Stats._current_node_id, result) + return (result,) + + def _resolve_input_values(self, input_value, column: str) -> tuple[str, np.ndarray]: + if isinstance(input_value, DataField): + values = np.asarray(input_value.data, dtype=np.float64) + return ("DATA_FIELD", values.ravel()) + + if isinstance(input_value, list): + if not input_value: + raise ValueError("Stats requires a non-empty TABLE input.") + column_name = resolve_table_column_name(input_value, column) + values = extract_numeric_table_values(input_value, column_name) + if not values: + raise ValueError(f"Column '{column_name}' has no numeric values.") + return ("TABLE", np.asarray(values, dtype=np.float64)) + + if isinstance(input_value, np.ndarray): + values = np.asarray(input_value, dtype=np.float64) + if values.size == 0: + raise ValueError("Stats requires a non-empty input.") + if values.ndim == 1: + return ("LINE", values.ravel()) + return ("IMAGE", values.ravel()) + + raise ValueError(f"Unsupported Stats input type: {type(input_value).__name__}") diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index c0ead17..5914c4a 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -18,7 +18,11 @@ import { serializeWorkflowState } from './workflowSerialization'; // ── Constants ───────────────────────────────────────────────────────── -const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']); +const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD', 'STATS_SOURCE']); + +const SOCKET_COMPATIBILITY = { + STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE']), +}; const TYPE_COLORS = { DATA_FIELD: '#ff002f', @@ -27,6 +31,7 @@ const TYPE_COLORS = { TABLE: '#35e2fd', COORD: '#e91ed1', FLOAT: '#7dd3fc', + STATS_SOURCE:'#c084fc', }; const NODE_TYPES = { custom: CustomNode }; @@ -45,6 +50,12 @@ function getOutputSlot(handleId) { return parseInt(handleId.split('::')[1], 10); } +function socketTypesCompatible(sourceType, targetType) { + if (sourceType === targetType) return true; + const accepted = SOCKET_COMPATIBILITY[targetType]; + return !!accepted?.has(sourceType); +} + async function waitForImageElement(img) { if (img.complete && img.naturalWidth > 0) return; if (typeof img.decode === 'function') { @@ -220,11 +231,11 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti const allInputs = { ...req, ...opt }; const hasMatch = Object.values(allInputs).some((spec) => { const [type] = Array.isArray(spec) ? spec : [spec]; - return type === filterType; + return socketTypesCompatible(filterType, type); }); if (!hasMatch) continue; } else { - if (!def.output.includes(filterType)) continue; + if (!def.output.some((type) => socketTypesCompatible(type, filterType))) continue; } } const cat = def.category || 'uncategorized'; @@ -474,7 +485,7 @@ function Flow() { const isValidConnection = useCallback((connection) => { const srcType = getHandleType(connection.sourceHandle); const tgtType = getHandleType(connection.targetHandle); - return srcType === tgtType; + return socketTypesCompatible(srcType, tgtType); }, []); const onConnect = useCallback((params) => { @@ -667,10 +678,15 @@ function Flow() { const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) }; const inputName = Object.entries(allInputs).find(([, spec]) => { const [type] = Array.isArray(spec) ? spec : [spec]; - return type === filterType; + return socketTypesCompatible(filterType, type); })?.[0]; if (inputName) { - const targetHandle = `input::${inputName}::${filterType}`; + const targetType = (() => { + const spec = allInputs[inputName]; + const [type] = Array.isArray(spec) ? spec : [spec]; + return type; + })(); + const targetHandle = `input::${inputName}::${targetType}`; const color = TYPE_COLORS[filterType] || '#999'; setEdges((eds) => addEdge({ source: contextMenu.pendingNodeId, @@ -682,10 +698,11 @@ function Flow() { } } else { // Dragged from an input → connect from the first matching output on the new node - const outputIdx = def.output.indexOf(filterType); + const outputIdx = def.output.findIndex((type) => socketTypesCompatible(type, filterType)); if (outputIdx !== -1) { - const sourceHandle = `output::${outputIdx}::${filterType}`; - const color = TYPE_COLORS[filterType] || '#999'; + const outputType = def.output[outputIdx]; + const sourceHandle = `output::${outputIdx}::${outputType}`; + const color = TYPE_COLORS[outputType] || '#999'; setEdges((eds) => addEdge({ source: newNodeId, sourceHandle, diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 8deee74..afbb001 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -8,7 +8,7 @@ const CropBoxOverlay = lazy(() => import('./CropBoxOverlay')); // ── Constants ───────────────────────────────────────────────────────── -const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']); +const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD', 'STATS_SOURCE']); const SOCKET_WIDGET_TYPES = new Set(['FLOAT']); const TYPE_COLORS = { @@ -18,6 +18,7 @@ const TYPE_COLORS = { TABLE: '#fdd835', COORD: '#e91e63', FLOAT: '#7dd3fc', + STATS_SOURCE:'#c084fc', }; const CAT_COLORS = { @@ -205,6 +206,30 @@ function formatScalarValue(value) { return numeric.toFixed(abs >= 100 ? 2 : 4).replace(/\.?0+$/, ''); } +function getSourceTypeForInput(store, nodeId, inputName) { + const targetHandle = `input::${inputName}::`; + const edge = store.edges?.find((e) => e.target === nodeId && e.targetHandle?.startsWith(targetHandle)); + if (!edge?.sourceHandle) return null; + const parts = edge.sourceHandle.split('::'); + return parts[2] || null; +} + +function getSourceNodeForInput(store, nodeId, inputName) { + const targetHandle = `input::${inputName}::`; + const edge = store.edges?.find((e) => e.target === nodeId && e.targetHandle?.startsWith(targetHandle)); + if (!edge) return null; + return store.nodeLookup?.get(edge.source) || store.nodes?.find((n) => n.id === edge.source) || null; +} + +function widgetVisibleForSourceType(widget, sourceType) { + const rules = widget?.opts?.show_when_source_type; + if (!rules || typeof rules !== 'object') return true; + const inputName = Object.keys(rules)[0]; + const allowed = Array.isArray(rules[inputName]) ? rules[inputName] : []; + if (allowed.length === 0) return true; + return allowed.includes(sourceType); +} + function NodeTable({ rows }) { const columns = getTableColumns(rows); if (columns.length === 0) return null; @@ -290,6 +315,20 @@ function CustomNode({ id, data }) { ), ); + const connectedSourceTypes = useStore( + useCallback( + (s) => { + const sourceTypes = {}; + const allInputs = { ...required, ...optional }; + for (const name of Object.keys(allInputs)) { + sourceTypes[name] = getSourceTypeForInput(s, id, name); + } + return sourceTypes; + }, + [id, required, optional], + ), + ); + for (const [name, spec] of Object.entries(optional)) { const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; if (isProgressive && DATA_TYPES.has(type)) { @@ -320,6 +359,13 @@ function CustomNode({ id, data }) { const catColor = CAT_COLORS[def.category] || '#333'; const maxIORows = Math.max(dataInputs.length, outputs.length); + const hasInteractiveLineOverlay = data.overlay?.kind === 'line_plot' && hiddenWidgets.has('x1'); + const overlayTitle = data.overlay?.section_title + || (data.overlay?.kind === 'crop_box' + ? 'Crop' + : data.overlay?.kind === 'line_plot' + ? 'Line Plot' + : 'Cross Section'); return (
@@ -380,7 +426,7 @@ function CustomNode({ id, data }) { )} {/* Widget rows */} - {widgets.map((w) => ( + {widgets.filter((w) => widgetVisibleForSourceType(w, connectedSourceTypes?.[w.opts?.source_type_input || w.opts?.choices_from_table_input || Object.keys(w.opts?.show_when_source_type || {})[0]])).map((w) => (
{w.socketType && ( + Loading...
}> {data.overlay.kind === 'line_plot' ? ( { + const inputName = opts?.source_type_input + || opts?.choices_from_table_input + || Object.keys(opts?.show_when_source_type || {})[0]; + if (!inputName) return null; + return getSourceTypeForInput(s, nodeId, inputName); + }, + [nodeId, opts], + ), + ); const dynamicTableColumns = useStore( useCallback( (s) => { const tableInputName = opts?.choices_from_table_input; if (!tableInputName) return []; - const targetHandle = `input::${tableInputName}::TABLE`; - const edge = s.edges?.find((e) => e.target === nodeId && e.targetHandle === targetHandle); - if (!edge) return []; - const sourceNode = s.nodeLookup?.get(edge.source) || s.nodes?.find((n) => n.id === edge.source); + const sourceType = getSourceTypeForInput(s, nodeId, tableInputName); + if (sourceType !== 'TABLE') return []; + const sourceNode = getSourceNodeForInput(s, nodeId, tableInputName); const rows = sourceNode?.data?.tableRows; return Array.isArray(rows) ? getTableColumns(rows) : []; }, [nodeId, opts?.choices_from_table_input], ), ); + const dynamicTypeChoices = (() => { + const byType = opts?.choices_by_source_type; + if (!byType) return []; + if (dynamicSourceType) { + return Array.isArray(byType[dynamicSourceType]) ? byType[dynamicSourceType] : []; + } + const merged = []; + for (const choices of Object.values(byType)) { + if (!Array.isArray(choices)) continue; + for (const choice of choices) { + if (!merged.includes(choice)) merged.push(choice); + } + } + return merged; + })(); useEffect(() => { if (!opts?.choices_from_table_input || dynamicTableColumns.length === 0) return; @@ -528,6 +600,13 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile if (preferred != null) onChange(nodeId, name, preferred); }, [dynamicTableColumns, name, nodeId, onChange, opts?.choices_from_table_input, val]); + useEffect(() => { + if (dynamicTypeChoices.length === 0) return; + const current = String(val ?? ''); + if (dynamicTypeChoices.includes(current)) return; + onChange(nodeId, name, dynamicTypeChoices[0]); + }, [dynamicTypeChoices, name, nodeId, onChange, val]); + // Combo / enum — type itself is the array of options if (Array.isArray(type)) { return ( @@ -546,6 +625,24 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile ); } + if (type === 'STRING' && dynamicTypeChoices.length > 0) { + const selected = dynamicTypeChoices.includes(String(val)) ? String(val) : dynamicTypeChoices[0]; + return ( + <> + + + + ); + } + if (type === 'STRING' && opts?.choices_from_table_input && dynamicTableColumns.length > 0) { const selected = dynamicTableColumns.includes(String(val)) ? String(val) : dynamicTableColumns[0]; return ( diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 8694124..70c614f 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -481,17 +481,42 @@ def test_height_histogram(): data = np.linspace(0, 1, 1000).reshape(25, 40) field = make_field(data=data) - counts, bin_centers = node.process(field, n_bins=10, y_scale="linear") - assert len(counts) == 10 - assert len(bin_centers) == 10 - assert counts.dtype == np.float64 - # Total counts should equal number of pixels - assert counts.sum() == 1000 - # For uniform data, each bin should have ~100 counts - assert np.std(counts) < 10, f"Histogram not flat enough: std={np.std(counts)}" - # Bin centers should span the data range - assert bin_centers[0] > 0.0 - assert bin_centers[-1] < 1.0 + overlays = [] + HeightHistogram._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + HeightHistogram._current_node_id = "test" + + table, = node.process( + field, + n_bins=10, + y_scale="linear", + x1=0.2, + y1=0.5, + x2=0.8, + y2=0.5, + ) + measurements = {row["quantity"]: row for row in table} + assert "A position" in measurements + assert "A count" in measurements + assert "B position" in measurements + assert "B count" in measurements + assert "delta X" in measurements + assert "delta Y" in measurements + assert measurements["A count"]["unit"] == "count" + assert measurements["B count"]["unit"] == "count" + assert measurements["B position"]["value"] > measurements["A position"]["value"] + assert len(overlays) == 1 + assert overlays[0]["kind"] == "line_plot" + assert overlays[0]["section_title"] == "Histogram" + assert len(overlays[0]["line"]) == 10 + assert len(overlays[0]["x_axis"]) == 10 + assert np.isclose(overlays[0]["x1"], 0.2) + assert np.isclose(overlays[0]["x2"], 0.8) + assert np.isclose( + measurements["delta Y"]["value"], + measurements["B count"]["value"] - measurements["A count"]["value"], + ) + + HeightHistogram._broadcast_overlay_fn = None print(" PASS\n") @@ -1380,6 +1405,49 @@ def test_table_math(): print(" PASS\n") +# ========================================================================= +# Analysis — Stats +# ========================================================================= + +def test_stats(): + print("=== Test: Stats ===") + from backend.nodes.analysis import Stats + + node = Stats() + captured = [] + Stats._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value)) + Stats._current_node_id = "test" + + line = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float64) + result, = node.process(line, operation="mean", column="value") + assert np.isclose(result, 2.5) + assert captured[-1] == ("test", result) + + table = [ + {"name": "a", "value": 3.0, "other": 10.0}, + {"name": "b", "value": 7.0, "other": 20.0}, + ] + result, = node.process(table, operation="max", column="value") + assert result == 7.0 + + field = make_field(data=np.array([[1.0, 5.0], [2.0, 4.0]], dtype=np.float64)) + result, = node.process(field, operation="range", column="value") + assert result == 4.0 + + image = np.array([[0, 10], [20, 30]], dtype=np.uint8) + result, = node.process(image, operation="avg", column="value") + assert np.isclose(result, 15.0) + + try: + node.process(table, operation="Rq", column="value") + raise AssertionError("Expected invalid TABLE operation to raise ValueError") + except ValueError: + pass + + Stats._broadcast_value_fn = None + print(" PASS\n") + + # ========================================================================= # Display — View3D # ========================================================================= @@ -1457,6 +1525,7 @@ if __name__ == "__main__": test_fft2d() test_line_math() test_table_math() + test_stats() # Mask test_threshold_mask()