From e749d24cfe592961b1ae595707d154cd66477ce1 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Wed, 25 Mar 2026 01:18:32 -0700 Subject: [PATCH] split table into measurements and records, add units to value display --- backend/data_types.py | 8 ++ backend/execution.py | 9 +- backend/nodes/analysis.py | 135 ++++++++++++++---- backend/nodes/display.py | 74 ++++++++-- backend/nodes/particle.py | 6 +- backend/server.py | 16 ++- frontend/src/App.jsx | 68 ++++++++- frontend/src/CustomNode.jsx | 224 ++++++++++++++++++++++++++++-- frontend/src/styles.css | 30 ++++ frontend/src/workflowHydration.js | 47 ++++++- tests/test_nodes.py | 53 +++++-- 11 files changed, 590 insertions(+), 80 deletions(-) diff --git a/backend/data_types.py b/backend/data_types.py index a24fb2c..ec29b3f 100644 --- a/backend/data_types.py +++ b/backend/data_types.py @@ -19,6 +19,14 @@ import numpy as np COLORMAPS = ("viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain", "cividis", "magma", "copper", "afmhot") + +class RecordTable(list): + """Tabular rows with a shared schema, e.g. particle statistics.""" + + +class MeasureTable(list): + """Named scalar measurements, typically rows of quantity/value/unit.""" + @dataclass class DataField: data: np.ndarray # shape (yres, xres), dtype float64 diff --git a/backend/execution.py b/backend/execution.py index 4d39526..718fc40 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -50,7 +50,7 @@ class ExecutionEngine: on_table: Callable[[str, list], None] | None = None, on_mesh: Callable[[str, dict], None] | None = None, on_overlay: Callable[[str, str], None] | None = None, - on_value: Callable[[str, float], None] | None = None, + on_value: Callable[[str, Any], None] | None = None, on_warning: Callable[[str, str], None] | None = None, ) -> dict[str, tuple]: """ @@ -64,6 +64,7 @@ class ExecutionEngine: on_preview : called with (node_id, data_uri) when a display node runs on_table : called with (node_id, table_list) when PrintTable runs on_overlay : called with (node_id, data_uri) for interactive overlays + on_value : called with (node_id, scalar-payload) for scalar displays on_warning : called with (node_id, message) for node warnings Returns @@ -104,7 +105,7 @@ class ExecutionEngine: node_outputs[node_id] = result # Auto-preview: broadcast a thumbnail for any DATA_FIELD, - # IMAGE, or TABLE output so every node shows its result. + # IMAGE, or table-like output so every node shows its result. if on_preview or on_table: self._auto_preview(cls, node_id, result, on_preview, on_table) @@ -226,7 +227,7 @@ class ExecutionEngine: ) -> None: """ After every node executes, inspect its outputs and broadcast - a preview for the first DATA_FIELD, IMAGE, or TABLE found. + a preview for the first DATA_FIELD, IMAGE, or table-like output found. Skip nodes that broadcast their own custom preview. """ import numpy as np @@ -260,7 +261,7 @@ class ExecutionEngine: on_preview(node_id, preview) return - if type_name == "TABLE" and isinstance(value, list) and on_table: + if type_name in ("TABLE", "MEASURE_TABLE", "RECORD_TABLE") and isinstance(value, list) and on_table: on_table(node_id, value) return diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index 1835849..f0e9721 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -12,7 +12,7 @@ from __future__ import annotations import numpy as np from typing import Callable from backend.node_registry import register_node -from backend.data_types import DataField, datafield_to_uint8, encode_preview +from backend.data_types import DataField, MeasureTable, RecordTable, datafield_to_uint8, encode_preview # --------------------------------------------------------------------------- @@ -29,7 +29,7 @@ class StatisticsNode: } } - RETURN_TYPES = ("TABLE",) + RETURN_TYPES = ("MEASURE_TABLE",) RETURN_NAMES = ("stats",) FUNCTION = "process" CATEGORY = "analysis" @@ -45,7 +45,7 @@ class StatisticsNode: skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0 kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0 - table = [ + table = MeasureTable([ {"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z}, {"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z}, {"quantity": "mean", "value": mean, "unit": field.si_unit_z}, @@ -54,7 +54,7 @@ class StatisticsNode: {"quantity": "skewness", "value": skewness, "unit": ""}, {"quantity": "kurtosis", "value": kurtosis, "unit": ""}, {"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z}, - ] + ]) return (table,) @@ -78,7 +78,7 @@ class HeightHistogram: } } - RETURN_TYPES = ("TABLE",) + RETURN_TYPES = ("MEASURE_TABLE",) RETURN_NAMES = ("measurements",) FUNCTION = "process" CATEGORY = "analysis" @@ -147,14 +147,14 @@ class HeightHistogram: }, ) - table = [ + table = MeasureTable([ {"quantity": "A position", "value": xa, "unit": field.si_unit_z}, {"quantity": "A count", "value": ya, "unit": count_unit}, {"quantity": "B position", "value": xb, "unit": field.si_unit_z}, {"quantity": "B count", "value": yb, "unit": count_unit}, {"quantity": "delta X", "value": xb - xa, "unit": field.si_unit_z}, {"quantity": "delta Y", "value": yb - ya, "unit": count_unit}, - ] + ]) return (table,) @@ -181,7 +181,7 @@ class LineCursors: }, } - RETURN_TYPES = ("TABLE",) + RETURN_TYPES = ("MEASURE_TABLE",) RETURN_NAMES = ("measurement",) FUNCTION = "process" CATEGORY = "analysis" @@ -242,14 +242,14 @@ class LineCursors: ) # --- Output table --- - table = [ + table = MeasureTable([ {"quantity": "A position", "value": xa, "unit": ""}, {"quantity": "A value", "value": ya, "unit": ""}, {"quantity": "B position", "value": xb, "unit": ""}, {"quantity": "B value", "value": yb, "unit": ""}, {"quantity": "delta X", "value": xb - xa, "unit": ""}, {"quantity": "delta Y", "value": yb - ya, "unit": ""}, - ] + ]) return (table,) @@ -614,7 +614,7 @@ class LineMath: } } - RETURN_TYPES = ("TABLE",) + RETURN_TYPES = ("MEASURE_TABLE",) RETURN_NAMES = ("result",) FUNCTION = "process" CATEGORY = "analysis" @@ -627,12 +627,12 @@ class LineMath: z = np.asarray(line, dtype=np.float64).ravel() fn, unit = LINE_OPS[operation] value = fn(z) - table = [{"quantity": operation, "value": value, "unit": unit}] + table = MeasureTable([{"quantity": operation, "value": value, "unit": unit}]) return (table,) # --------------------------------------------------------------------------- -# TableMath — scalar measurement from a numeric TABLE column +# TableMath — scalar measurement from a numeric record-table column # --------------------------------------------------------------------------- TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = { @@ -663,9 +663,62 @@ ARRAY_OPS: dict[str, Callable[[np.ndarray], float]] = { } +def _square_unit(unit: str) -> str: + unit = str(unit or "").strip() + if not unit: + return "" + if any(token in unit for token in ("^", "(", ")", "/", "*", " ")): + return f"({unit})^2" + return f"{unit}^2" + + +def _apply_scalar_unit(base_unit: str, operation: str) -> str: + unit = str(base_unit or "").strip() + if operation == "count": + return "count" + if not unit: + return "" + if operation == "variance": + return _square_unit(unit) + return unit + + +def _common_table_unit(table: list, column: str) -> str: + candidates = [] + seen = set() + unit_key = f"{column}_unit" + + for row in table: + if not isinstance(row, dict): + continue + unit = None + if unit_key in row and isinstance(row.get(unit_key), str): + unit = row.get(unit_key) + elif column == "value" and isinstance(row.get("unit"), str): + unit = row.get("unit") + if unit is None: + continue + unit = unit.strip() + if not unit or unit in seen: + continue + seen.add(unit) + candidates.append(unit) + + if len(candidates) == 1: + return candidates[0] + return "" + + +def _scalar_payload(value: float, unit: str = "") -> dict: + payload = {"value": float(value)} + if isinstance(unit, str) and unit.strip(): + payload["unit"] = unit.strip() + return payload + + @register_node(display_name="Table Math") class TableMath: - """Compute a scalar reduction over one numeric column in a TABLE.""" + """Compute a scalar reduction over one numeric column in a record table.""" _broadcast_value_fn = None _current_node_id: str = "" @@ -674,7 +727,7 @@ class TableMath: def INPUT_TYPES(cls): return { "required": { - "table": ("TABLE",), + "table": ("RECORD_TABLE",), "column": ("STRING", { "default": "value", "choices_from_table_input": "table", @@ -688,13 +741,15 @@ class TableMath: FUNCTION = "process" CATEGORY = "analysis" DESCRIPTION = ( - "Compute a scalar reduction over one numeric TABLE column. " + "Compute a scalar reduction over one numeric record-table column. " "Useful for max, min, avg, median, sum, range, std, variance, and count." ) def process(self, table: list, column: str, operation: str) -> tuple: + if isinstance(table, MeasureTable): + raise ValueError("Table Math only accepts record tables, not measurement tables.") if not isinstance(table, list) or not table: - raise ValueError("Table Math requires a non-empty TABLE input.") + raise ValueError("Table Math requires a non-empty record table input.") column_name = resolve_table_column_name(table, column) values = extract_numeric_table_values(table, column_name) @@ -759,7 +814,7 @@ def resolve_table_column_name(table: list, column: str) -> str: @register_node(display_name="Stats") class Stats: - """Polymorphic scalar stats node for LINE, TABLE, DATA_FIELD, or IMAGE inputs.""" + """Polymorphic scalar stats node for LINE, RECORD_TABLE, DATA_FIELD, or IMAGE inputs.""" _broadcast_value_fn = None _current_node_id: str = "" @@ -773,14 +828,14 @@ class Stats: "default": "value", "choices_from_table_input": "input", "show_when_source_type": { - "input": ["TABLE"], + "input": ["RECORD_TABLE"], }, }), "operation": ("STRING", { "default": "mean", "choices_by_source_type": { "LINE": list(LINE_OPS.keys()), - "TABLE": list(TABLE_OPS.keys()), + "RECORD_TABLE": list(TABLE_OPS.keys()), "DATA_FIELD": list(ARRAY_OPS.keys()), "IMAGE": list(ARRAY_OPS.keys()), }, @@ -794,14 +849,14 @@ class Stats: FUNCTION = "process" CATEGORY = "analysis" DESCRIPTION = ( - "Compute a contextual scalar statistic from a LINE, TABLE, DATA_FIELD, or IMAGE. " + "Compute a contextual scalar statistic from a LINE, record table, DATA_FIELD, or IMAGE. " "The available operations adapt to the connected input type." ) def process(self, input, operation: str, column: str = "value") -> tuple: - source_type, values = self._resolve_input_values(input, column) + source_type, values, resolved_column = self._resolve_input_values(input, column) - if source_type == "TABLE": + if source_type == "RECORD_TABLE": ops = TABLE_OPS elif source_type == "LINE": ops = LINE_OPS @@ -815,29 +870,49 @@ class Stats: fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry result = fn(values) if Stats._broadcast_value_fn is not None: - Stats._broadcast_value_fn(Stats._current_node_id, result) + Stats._broadcast_value_fn( + Stats._current_node_id, + _scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)), + ) return (result,) - def _resolve_input_values(self, input_value, column: str) -> tuple[str, np.ndarray]: + def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str: + if source_type == "DATA_FIELD" and isinstance(input_value, DataField): + return _apply_scalar_unit(input_value.si_unit_z, operation) + + if source_type == "LINE": + line_entry = LINE_OPS.get(operation) + explicit_unit = line_entry[1] if isinstance(line_entry, tuple) and len(line_entry) > 1 else "" + return _apply_scalar_unit(explicit_unit, operation) + + if source_type == "RECORD_TABLE" and isinstance(input_value, list) and column: + return _apply_scalar_unit(_common_table_unit(input_value, column), operation) + + return "" + + def _resolve_input_values(self, input_value, column: str) -> tuple[str, np.ndarray, str | None]: if isinstance(input_value, DataField): values = np.asarray(input_value.data, dtype=np.float64) - return ("DATA_FIELD", values.ravel()) + return ("DATA_FIELD", values.ravel(), None) + + if isinstance(input_value, MeasureTable): + raise ValueError("Stats only accepts record tables, not measurement tables.") if isinstance(input_value, list): if not input_value: - raise ValueError("Stats requires a non-empty TABLE input.") + raise ValueError("Stats requires a non-empty record table input.") column_name = resolve_table_column_name(input_value, column) values = extract_numeric_table_values(input_value, column_name) if not values: raise ValueError(f"Column '{column_name}' has no numeric values.") - return ("TABLE", np.asarray(values, dtype=np.float64)) + return ("RECORD_TABLE", np.asarray(values, dtype=np.float64), column_name) if isinstance(input_value, np.ndarray): values = np.asarray(input_value, dtype=np.float64) if values.size == 0: raise ValueError("Stats requires a non-empty input.") if values.ndim == 1: - return ("LINE", values.ravel()) - return ("IMAGE", values.ravel()) + return ("LINE", values.ravel(), None) + return ("IMAGE", values.ravel(), None) raise ValueError(f"Unsupported Stats input type: {type(input_value).__name__}") diff --git a/backend/nodes/display.py b/backend/nodes/display.py index e3e4776..8d7f7bf 100644 --- a/backend/nodes/display.py +++ b/backend/nodes/display.py @@ -10,10 +10,55 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node from backend.data_types import ( - DataField, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview, normalize_for_colormap, + DataField, MeasureTable, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview, normalize_for_colormap, ) +def _measurement_names(table: list) -> list[str]: + names = [] + for row in table: + if not isinstance(row, dict): + continue + quantity = row.get("quantity") + if isinstance(quantity, str) and quantity and quantity not in names: + names.append(quantity) + return names + + +def _measurement_entry(table: list, selection: str) -> dict: + names = _measurement_names(table) + if not names: + raise ValueError("Measurement table has no selectable rows.") + + target = selection if selection in names else names[0] + for row in table: + if isinstance(row, dict) and row.get("quantity") == target: + return row + + raise ValueError(f"Measurement '{target}' was not found.") + + +def _measurement_value(table: list, selection: str) -> float: + row = _measurement_entry(table, selection) + value = row.get("value") + if isinstance(value, bool): + raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") + try: + numeric = float(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") from exc + if np.isfinite(numeric): + return numeric + raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") + + +def _scalar_payload(value: float, unit: str = "") -> dict: + payload = {"value": float(value)} + if isinstance(unit, str) and unit.strip(): + payload["unit"] = unit.strip() + return payload + + @register_node(display_name="Preview") class PreviewImage: @classmethod @@ -156,7 +201,7 @@ class PrintTable: def INPUT_TYPES(cls): return { "required": { - "table": ("TABLE",), + "table": ("ANY_TABLE",), } } @@ -164,7 +209,7 @@ class PrintTable: FUNCTION = "print_table" CATEGORY = "display" OUTPUT_NODE = True - DESCRIPTION = "Send a TABLE to the browser as a WebSocket message for display." + DESCRIPTION = "Send a measurement or record table to the browser as a WebSocket message for display." _broadcast_table_fn = None _current_node_id: str = "" @@ -181,7 +226,14 @@ class ValueDisplay: def INPUT_TYPES(cls): return { "required": { - "value": ("FLOAT",), + "value": ("VALUE_SOURCE",), + "measurement": ("STRING", { + "default": "", + "choices_from_measure_input": "value", + "show_when_source_type": { + "value": ["MEASURE_TABLE"], + }, + }), } } @@ -189,13 +241,19 @@ class ValueDisplay: RETURN_NAMES = ("value",) FUNCTION = "display_value" CATEGORY = "display" - DESCRIPTION = "Display a FLOAT in the graph and pass the same value through unchanged." + DESCRIPTION = "Display a FLOAT, or a selected numeric row from a measurement table, and pass the value through unchanged." _broadcast_value_fn = None _current_node_id: str = "" - def display_value(self, value: float) -> tuple: - numeric = float(value) + def display_value(self, value, measurement: str = "") -> tuple: + unit = "" + if isinstance(value, MeasureTable): + row = _measurement_entry(value, measurement) + numeric = _measurement_value(value, measurement) + unit = row.get("unit", "") if isinstance(row.get("unit"), str) else "" + else: + numeric = float(value) if ValueDisplay._broadcast_value_fn is not None: - ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, numeric) + ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, _scalar_payload(numeric, unit)) return (numeric,) diff --git a/backend/nodes/particle.py b/backend/nodes/particle.py index a346565..48fe1ad 100644 --- a/backend/nodes/particle.py +++ b/backend/nodes/particle.py @@ -8,7 +8,7 @@ Gwyddion equivalents: from __future__ import annotations import numpy as np from backend.node_registry import register_node -from backend.data_types import DataField +from backend.data_types import DataField, RecordTable # --------------------------------------------------------------------------- @@ -27,7 +27,7 @@ class ParticleAnalysis: } } - RETURN_TYPES = ("TABLE",) + RETURN_TYPES = ("RECORD_TABLE",) RETURN_NAMES = ("particle_stats",) FUNCTION = "process" CATEGORY = "particles" @@ -45,7 +45,7 @@ class ParticleAnalysis: pixel_area = field.dx * field.dy # m^2 per pixel - rows = [] + rows = RecordTable() for pid in range(1, n_particles + 1): particle_pixels = labeled == pid area_px = int(particle_pixels.sum()) diff --git a/backend/server.py b/backend/server.py index dee1192..1c69c7f 100644 --- a/backend/server.py +++ b/backend/server.py @@ -16,7 +16,7 @@ WebSocket message types sent to clients {"type": "executing", "data": {"node": "...", "prompt_id": "..."}} {"type": "preview", "data": {"node_id": "...", "image": "data:..."}} {"type": "table", "data": {"node_id": "...", "rows": [...]}} -{"type": "scalar", "data": {"node_id": "...", "value": 1.23}} +{"type": "scalar", "data": {"node_id": "...", "value": 1.23, "unit": "nm"}} {"type": "execution_error", "data": {"node_id": "...", "message": "..."}} {"type": "execution_complete", "data": {"prompt_id": "..."}} """ @@ -115,8 +115,18 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: def on_overlay(node_id: str, overlay_data) -> None: broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) - def on_value(node_id: str, value: float) -> None: - broadcast({"type": "scalar", "data": {"node_id": node_id, "value": value}}) + def on_value(node_id: str, payload) -> None: + if isinstance(payload, dict): + value = payload.get("value") + unit = payload.get("unit", "") + else: + value = payload + unit = "" + + data = {"node_id": node_id, "value": value} + if isinstance(unit, str) and unit.strip(): + data["unit"] = unit.strip() + broadcast({"type": "scalar", "data": data}) def on_warning(node_id: str, message: str) -> None: broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}}) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 5914c4a..dd7033f 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -4,7 +4,7 @@ import React, { import { ReactFlow, Background, Controls, MiniMap, useNodesState, useEdgesState, addEdge, useReactFlow, - ReactFlowProvider, getNodesBounds, getViewportForBounds, + ReactFlowProvider, getViewportForBounds, } from '@xyflow/react'; import '@xyflow/react/dist/style.css'; @@ -18,20 +18,28 @@ import { serializeWorkflowState } from './workflowSerialization'; // ── Constants ───────────────────────────────────────────────────────── -const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD', 'STATS_SOURCE']); +const DATA_TYPES = new Set([ + 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE', + 'COORD', 'STATS_SOURCE', 'VALUE_SOURCE', +]); const SOCKET_COMPATIBILITY = { - STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE']), + STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'RECORD_TABLE']), + ANY_TABLE: new Set(['MEASURE_TABLE', 'RECORD_TABLE']), + VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']), }; const TYPE_COLORS = { DATA_FIELD: '#ff002f', IMAGE: '#00ff08a0', LINE: '#ffbe5c', - TABLE: '#35e2fd', + MEASURE_TABLE:'#35e2fd', + RECORD_TABLE:'#fbbf24', + ANY_TABLE: '#67e8f9', COORD: '#e91ed1', FLOAT: '#7dd3fc', STATS_SOURCE:'#c084fc', + VALUE_SOURCE:'#60a5fa', }; const NODE_TYPES = { custom: CustomNode }; @@ -56,6 +64,46 @@ function socketTypesCompatible(sourceType, targetType) { return !!accepted?.has(sourceType); } +function getRenderedNodeBounds(nodes) { + let minX = Infinity; + let minY = Infinity; + let maxX = -Infinity; + let maxY = -Infinity; + let found = false; + + for (const node of nodes) { + const selectorId = typeof CSS !== 'undefined' && typeof CSS.escape === 'function' + ? CSS.escape(String(node.id)) + : String(node.id); + const el = document.querySelector(`.react-flow__node[data-id="${selectorId}"]`); + const width = el?.offsetWidth || node.measured?.width || node.width || 0; + const height = el?.offsetHeight || node.measured?.height || node.height || 0; + const x = node.positionAbsolute?.x ?? node.position?.x ?? 0; + const y = node.positionAbsolute?.y ?? node.position?.y ?? 0; + + if (!Number.isFinite(width) || !Number.isFinite(height) || width <= 0 || height <= 0) { + continue; + } + + minX = Math.min(minX, x); + minY = Math.min(minY, y); + maxX = Math.max(maxX, x + width); + maxY = Math.max(maxY, y + height); + found = true; + } + + if (!found) { + return null; + } + + return { + x: minX, + y: minY, + width: Math.max(1, maxX - minX), + height: Math.max(1, maxY - minY), + }; +} + async function waitForImageElement(img) { if (img.complete && img.naturalWidth > 0) return; if (typeof img.decode === 'function') { @@ -463,7 +511,12 @@ function Flow() { updateNodeData(msg.data.node_id, { tableRows: msg.data.rows }); break; case 'scalar': - updateNodeData(msg.data.node_id, { scalarValue: msg.data.value }); + updateNodeData(msg.data.node_id, { + scalarValue: { + value: msg.data.value, + unit: typeof msg.data.unit === 'string' ? msg.data.unit : '', + }, + }); break; case 'mesh3d': updateNodeData(msg.data.node_id, { meshData: msg.data.mesh }); @@ -797,7 +850,10 @@ function Flow() { const allNodes = reactFlow.getNodes(); if (allNodes.length === 0) throw new Error('No nodes to capture'); - const bounds = getNodesBounds(allNodes); + const bounds = getRenderedNodeBounds(allNodes); + if (!bounds) { + throw new Error('Could not determine rendered node bounds'); + } const pad = 0.1; // 10% margin on each side const imageWidth = Math.ceil(bounds.width * (1 + pad * 2)); const imageHeight = Math.ceil(bounds.height * (1 + pad * 2)); diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index afbb001..d4038bb 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -8,17 +8,23 @@ const CropBoxOverlay = lazy(() => import('./CropBoxOverlay')); // ── Constants ───────────────────────────────────────────────────────── -const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD', 'STATS_SOURCE']); +const DATA_TYPES = new Set([ + 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE', + 'COORD', 'STATS_SOURCE', 'VALUE_SOURCE', +]); const SOCKET_WIDGET_TYPES = new Set(['FLOAT']); const TYPE_COLORS = { DATA_FIELD: '#3a7abf', IMAGE: '#4caf50', LINE: '#ff9800', - TABLE: '#fdd835', + MEASURE_TABLE:'#35e2fd', + RECORD_TABLE:'#fbbf24', + ANY_TABLE: '#67e8f9', COORD: '#e91e63', FLOAT: '#7dd3fc', STATS_SOURCE:'#c084fc', + VALUE_SOURCE:'#60a5fa', }; const CAT_COLORS = { @@ -183,7 +189,42 @@ function getTableColumns(rows) { return columns; } -function formatTableCell(value) { +function getMeasurementChoices(rows) { + const names = []; + for (const row of rows || []) { + const quantity = row?.quantity; + if (typeof quantity === 'string' && quantity && !names.includes(quantity)) { + names.push(quantity); + } + } + return names; +} + +const SI_PREFIXES = [ + { exp: -24, prefix: 'y' }, + { exp: -21, prefix: 'z' }, + { exp: -18, prefix: 'a' }, + { exp: -15, prefix: 'f' }, + { exp: -12, prefix: 'p' }, + { exp: -9, prefix: 'n' }, + { exp: -6, prefix: 'u' }, + { exp: -3, prefix: 'm' }, + { exp: 0, prefix: '' }, + { exp: 3, prefix: 'k' }, + { exp: 6, prefix: 'M' }, + { exp: 9, prefix: 'G' }, + { exp: 12, prefix: 'T' }, + { exp: 15, prefix: 'P' }, + { exp: 18, prefix: 'E' }, + { exp: 21, prefix: 'Z' }, + { exp: 24, prefix: 'Y' }, +]; + +const PREFIXABLE_UNITS = new Set([ + 'm', 's', 'A', 'V', 'W', 'Hz', 'F', 'C', 'J', 'N', 'Pa', 'T', 'H', 'S', 'g', 'K', 'Ohm', 'ohm', 'Ω', +]); + +function formatNumericCell(value) { if (value == null) return ''; if (typeof value === 'number') { if (!Number.isFinite(value)) return String(value); @@ -196,6 +237,48 @@ function formatTableCell(value) { return String(value); } +function applySIPrefix(value, unit) { + if (typeof value !== 'number' || !Number.isFinite(value)) { + return { valueText: formatNumericCell(value), unitText: unit }; + } + if (typeof unit !== 'string' || !PREFIXABLE_UNITS.has(unit)) { + return { valueText: formatNumericCell(value), unitText: unit }; + } + if (value === 0) { + return { valueText: '0', unitText: unit }; + } + + const abs = Math.abs(value); + let exp = Math.floor(Math.log10(abs) / 3) * 3; + exp = Math.max(-24, Math.min(24, exp)); + + let scaled = value / (10 ** exp); + if (Math.abs(scaled) >= 999.5 && exp < 24) { + exp += 3; + scaled = value / (10 ** exp); + } + + const prefix = SI_PREFIXES.find((entry) => entry.exp === exp)?.prefix ?? ''; + return { + valueText: formatNumericCell(scaled), + unitText: `${prefix}${unit}`, + }; +} + +function formatTableCell(value) { + return formatNumericCell(value); +} + +function formatTableRowCell(row, column) { + if (column === 'value' && typeof row?.unit === 'string') { + return applySIPrefix(row?.value, row.unit).valueText; + } + if (column === 'unit' && typeof row?.unit === 'string') { + return applySIPrefix(row?.value, row.unit).unitText; + } + return formatTableCell(row?.[column]); +} + function formatScalarValue(value) { if (value == null || Number.isNaN(Number(value))) return '—'; const numeric = Number(value); @@ -206,6 +289,43 @@ function formatScalarValue(value) { return numeric.toFixed(abs >= 100 ? 2 : 4).replace(/\.?0+$/, ''); } +function getScalarPayload(scalarValue) { + if (typeof scalarValue === 'number') { + return Number.isFinite(scalarValue) ? { value: scalarValue, unit: '' } : null; + } + if (!scalarValue || typeof scalarValue !== 'object') return null; + const numeric = Number(scalarValue.value); + if (!Number.isFinite(numeric)) return null; + return { + value: numeric, + unit: typeof scalarValue.unit === 'string' ? scalarValue.unit : '', + }; +} + +function formatScalarDisplay(scalarValue) { + const payload = getScalarPayload(scalarValue); + if (!payload) return null; + + if (payload.unit) { + if (PREFIXABLE_UNITS.has(payload.unit)) { + const prefixed = applySIPrefix(payload.value, payload.unit); + return { + valueText: prefixed.valueText, + unitText: prefixed.unitText, + }; + } + return { + valueText: formatScalarValue(payload.value), + unitText: payload.unit, + }; + } + + return { + valueText: formatScalarValue(payload.value), + unitText: '', + }; +} + function getSourceTypeForInput(store, nodeId, inputName) { const targetHandle = `input::${inputName}::`; const edge = store.edges?.find((e) => e.target === nodeId && e.targetHandle?.startsWith(targetHandle)); @@ -221,6 +341,13 @@ function getSourceNodeForInput(store, nodeId, inputName) { return store.nodeLookup?.get(edge.source) || store.nodes?.find((n) => n.id === edge.source) || null; } +function getWidgetSourceInputName(opts) { + return opts?.source_type_input + || opts?.choices_from_table_input + || opts?.choices_from_measure_input + || Object.keys(opts?.show_when_source_type || {})[0]; +} + function widgetVisibleForSourceType(widget, sourceType) { const rules = widget?.opts?.show_when_source_type; if (!rules || typeof rules !== 'object') return true; @@ -233,15 +360,37 @@ function widgetVisibleForSourceType(widget, sourceType) { function NodeTable({ rows }) { const columns = getTableColumns(rows); if (columns.length === 0) return null; + const lowerColumns = columns.map((column) => String(column).toLowerCase()); + const hasMeasurementLayout = ( + lowerColumns.length === 3 + && lowerColumns[0] === 'quantity' + && lowerColumns[1] === 'value' + && lowerColumns[2] === 'unit' + ); + + const getColumnClass = (column) => { + const lower = String(column).toLowerCase(); + if (lower === 'value') return 'node-table-col-value'; + if (lower === 'unit') return 'node-table-col-unit'; + if (lower === 'quantity') return 'node-table-col-quantity'; + return ''; + }; return (
+ {hasMeasurementLayout && ( + + + + + + )} {columns.map((column) => ( - + ))} @@ -250,13 +399,17 @@ function NodeTable({ rows }) { {columns.map((column) => { const value = row?.[column]; + const displayValue = formatTableRowCell(row, column); return ( ); })} @@ -274,6 +427,7 @@ function NodeTable({ rows }) { function CustomNode({ id, data }) { const ctx = useContext(NodeContext); const def = data.definition; + const scalarDisplay = formatScalarDisplay(data.scalarValue); // Parse inputs into data handles and widgets const required = def.input.required || {}; @@ -418,15 +572,20 @@ function CustomNode({ id, data }) {
{data.warning}
)} - {typeof data.scalarValue === 'number' && ( + {scalarDisplay && (
Value
-
{formatScalarValue(data.scalarValue)}
+
+ {scalarDisplay.valueText} + {scalarDisplay.unitText && ( + {scalarDisplay.unitText} + )} +
)} {/* Widget rows */} - {widgets.filter((w) => widgetVisibleForSourceType(w, connectedSourceTypes?.[w.opts?.source_type_input || w.opts?.choices_from_table_input || Object.keys(w.opts?.show_when_source_type || {})[0]])).map((w) => ( + {widgets.filter((w) => widgetVisibleForSourceType(w, connectedSourceTypes?.[getWidgetSourceInputName(w.opts)])).map((w) => (
{w.socketType && ( { - const inputName = opts?.source_type_input - || opts?.choices_from_table_input - || Object.keys(opts?.show_when_source_type || {})[0]; + const inputName = getWidgetSourceInputName(opts); if (!inputName) return null; return getSourceTypeForInput(s, nodeId, inputName); }, @@ -568,7 +725,7 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile const tableInputName = opts?.choices_from_table_input; if (!tableInputName) return []; const sourceType = getSourceTypeForInput(s, nodeId, tableInputName); - if (sourceType !== 'TABLE') return []; + if (sourceType !== 'RECORD_TABLE') return []; const sourceNode = getSourceNodeForInput(s, nodeId, tableInputName); const rows = sourceNode?.data?.tableRows; return Array.isArray(rows) ? getTableColumns(rows) : []; @@ -576,6 +733,20 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile [nodeId, opts?.choices_from_table_input], ), ); + const dynamicMeasurementChoices = useStore( + useCallback( + (s) => { + const measurementInputName = opts?.choices_from_measure_input; + if (!measurementInputName) return []; + const sourceType = getSourceTypeForInput(s, nodeId, measurementInputName); + if (sourceType !== 'MEASURE_TABLE') return []; + const sourceNode = getSourceNodeForInput(s, nodeId, measurementInputName); + const rows = sourceNode?.data?.tableRows; + return Array.isArray(rows) ? getMeasurementChoices(rows) : []; + }, + [nodeId, opts?.choices_from_measure_input], + ), + ); const dynamicTypeChoices = (() => { const byType = opts?.choices_by_source_type; if (!byType) return []; @@ -600,6 +771,13 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile if (preferred != null) onChange(nodeId, name, preferred); }, [dynamicTableColumns, name, nodeId, onChange, opts?.choices_from_table_input, val]); + useEffect(() => { + if (!opts?.choices_from_measure_input || dynamicMeasurementChoices.length === 0) return; + const current = String(val ?? ''); + if (dynamicMeasurementChoices.includes(current)) return; + if (dynamicMeasurementChoices[0] != null) onChange(nodeId, name, dynamicMeasurementChoices[0]); + }, [dynamicMeasurementChoices, name, nodeId, onChange, opts?.choices_from_measure_input, val]); + useEffect(() => { if (dynamicTypeChoices.length === 0) return; const current = String(val ?? ''); @@ -661,6 +839,24 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile ); } + if (type === 'STRING' && opts?.choices_from_measure_input && dynamicMeasurementChoices.length > 0) { + const selected = dynamicMeasurementChoices.includes(String(val)) ? String(val) : dynamicMeasurementChoices[0]; + return ( + <> + + + + ); + } + if (type === 'FILE_PICKER') { return ( <> diff --git a/frontend/src/styles.css b/frontend/src/styles.css index 7171d8d..64c42aa 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -194,6 +194,20 @@ html, body, #root { overflow-wrap: anywhere; } +.node-value-box-number { + display: inline-block; +} + +.node-value-box-unit { + display: inline-block; + margin-left: 0.35em; + font-size: 0.58em; + font-weight: 600; + letter-spacing: 0.03em; + color: rgba(224, 242, 254, 0.82); + vertical-align: baseline; +} + /* ── I/O rows ──────────────────────────────────────────────────────── */ .io-row { display: flex; @@ -564,6 +578,8 @@ html, body, #root { font-family: "SF Mono", "Fira Code", monospace; font-size: 10px; color: #cbd5e1; + table-layout: auto; + font-variant-numeric: tabular-nums lining-nums; } .node-table-grid th, @@ -594,6 +610,20 @@ html, body, #root { border-bottom: none; } +.node-table-col-quantity { + width: 46%; +} + +.node-table-col-value { + width: 32%; + text-align: right !important; +} + +.node-table-col-unit { + width: 22%; + text-align: left; +} + .node-table-num { text-align: right !important; } diff --git a/frontend/src/workflowHydration.js b/frontend/src/workflowHydration.js index 92bdecb..46039de 100644 --- a/frontend/src/workflowHydration.js +++ b/frontend/src/workflowHydration.js @@ -22,6 +22,43 @@ function mergeDefinition(nodeData, defs) { }; } +function getSocketType(inputDef) { + if (!inputDef) return null; + const [type] = Array.isArray(inputDef) ? inputDef : [inputDef]; + return Array.isArray(type) ? type[0] : type; +} + +function getInputType(definition, inputName) { + const required = definition?.input?.required || {}; + const optional = definition?.input?.optional || {}; + return getSocketType(required[inputName] ?? optional[inputName]); +} + +function remapLegacyHandle(handleId, kind, nodeData) { + if (typeof handleId !== 'string') return handleId; + + const parts = handleId.split('::'); + if (parts.length !== 3 || parts[2] !== 'TABLE') return handleId; + + if (kind === 'source' && parts[0] === 'output') { + const outputSlot = Number.parseInt(parts[1], 10); + const outputType = nodeData?.definition?.output?.[outputSlot]; + if (typeof outputType === 'string' && outputType !== 'TABLE') { + return `output::${outputSlot}::${outputType}`; + } + return handleId; + } + + if (kind === 'target' && parts[0] === 'input') { + const inputType = getInputType(nodeData?.definition, parts[1]); + if (typeof inputType === 'string' && inputType !== 'TABLE') { + return `input::${parts[1]}::${inputType}`; + } + } + + return handleId; +} + export function hydrateWorkflowState(data, defs = {}) { const loadedNodes = Array.isArray(data?.nodes) ? data.nodes : []; const loadedEdges = Array.isArray(data?.edges) ? data.edges : []; @@ -43,11 +80,19 @@ export function hydrateWorkflowState(data, defs = {}) { }, })); + const nodeById = new Map(nodes.map((node) => [String(node.id), node.data])); + + const edges = loadedEdges.map((edge) => ({ + ...edge, + sourceHandle: remapLegacyHandle(edge.sourceHandle, 'source', nodeById.get(String(edge.source))), + targetHandle: remapLegacyHandle(edge.targetHandle, 'target', nodeById.get(String(edge.target))), + })); + const nextNodeId = Math.max(0, ...loadedNodes.map((node) => parseInt(node.id, 10) || 0)) + 1; return { nodes, - edges: loadedEdges, + edges, nextNodeId, }; } diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 70c614f..f85ca45 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -10,7 +10,7 @@ import tempfile import numpy as np sys.path.insert(0, ".") -from backend.data_types import DataField, datafield_to_uint8 +from backend.data_types import DataField, MeasureTable, RecordTable, datafield_to_uint8 def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6): @@ -899,12 +899,20 @@ def test_value_display(): node = ValueDisplay() captured = [] - ValueDisplay._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value)) + ValueDisplay._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) ValueDisplay._current_node_id = "test" result = node.display_value(3.25) assert result == (3.25,) - assert captured == [("test", 3.25)] + assert captured == [("test", {"value": 3.25})] + + measurements = MeasureTable([ + {"quantity": "delta X", "value": 1.7e-7, "unit": "m"}, + {"quantity": "delta Y", "value": 463, "unit": "count"}, + ]) + result = node.display_value(measurements, measurement="delta X") + assert result == (1.7e-7,) + assert captured[-1] == ("test", {"value": 1.7e-7, "unit": "m"}) ValueDisplay._broadcast_value_fn = None print(" PASS\n") @@ -1358,12 +1366,12 @@ def test_table_math(): captured = [] TableMath._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value)) TableMath._current_node_id = "test" - table = [ + table = RecordTable([ {"label": "a", "value": 1.0, "other": 10}, {"label": "b", "value": 5.0, "other": 20}, {"label": "c", "value": "3.0", "other": 30}, {"label": "d", "value": "bad", "other": 40}, - ] + ]) result, = node.process(table, column="value", operation="max") assert result == 5.0 @@ -1400,6 +1408,16 @@ def test_table_math(): except ValueError: pass + try: + node.process( + MeasureTable([{"quantity": "A position", "value": 1.0, "unit": "m"}]), + column="value", + operation="max", + ) + raise AssertionError("Expected measurement table input to raise ValueError") + except ValueError: + pass + TableMath._broadcast_value_fn = None print(" PASS\n") @@ -1415,28 +1433,31 @@ def test_stats(): node = Stats() captured = [] - Stats._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value)) + Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) Stats._current_node_id = "test" line = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float64) result, = node.process(line, operation="mean", column="value") assert np.isclose(result, 2.5) - assert captured[-1] == ("test", result) + assert captured[-1] == ("test", {"value": result}) - table = [ - {"name": "a", "value": 3.0, "other": 10.0}, - {"name": "b", "value": 7.0, "other": 20.0}, - ] + table = RecordTable([ + {"name": "a", "value": 3.0, "unit": "m", "other": 10.0}, + {"name": "b", "value": 7.0, "unit": "m", "other": 20.0}, + ]) result, = node.process(table, operation="max", column="value") assert result == 7.0 + assert captured[-1] == ("test", {"value": 7.0, "unit": "m"}) field = make_field(data=np.array([[1.0, 5.0], [2.0, 4.0]], dtype=np.float64)) result, = node.process(field, operation="range", column="value") assert result == 4.0 + assert captured[-1] == ("test", {"value": 4.0, "unit": "m"}) image = np.array([[0, 10], [20, 30]], dtype=np.uint8) result, = node.process(image, operation="avg", column="value") assert np.isclose(result, 15.0) + assert captured[-1] == ("test", {"value": 15.0}) try: node.process(table, operation="Rq", column="value") @@ -1444,6 +1465,16 @@ def test_stats(): except ValueError: pass + try: + node.process( + MeasureTable([{"quantity": "min", "value": 1.0, "unit": "m"}]), + operation="max", + column="value", + ) + raise AssertionError("Expected measurement table input to raise ValueError") + except ValueError: + pass + Stats._broadcast_value_fn = None print(" PASS\n")
{column}{column}
- {formatTableCell(value)} + {displayValue}