From a65b7c56425de45e1c3351dd3e8f9fdbbc84f0c4 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Wed, 25 Mar 2026 00:01:24 -0700 Subject: [PATCH] fix table math column picker --- backend/execution.py | 16 +++++---- backend/nodes/analysis.py | 14 ++++++-- backend/nodes/display.py | 26 ++++++++++++++ backend/server.py | 5 +++ frontend/src/App.jsx | 4 +++ frontend/src/CustomNode.jsx | 60 ++++++++++++++++++++++++++++++- frontend/src/styles.css | 33 +++++++++++++++++ frontend/src/workflowHydration.js | 1 + tests/test_nodes.py | 24 +++++++++++++ 9 files changed, 174 insertions(+), 9 deletions(-) diff --git a/backend/execution.py b/backend/execution.py index a0e01b6..7aa582b 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -50,6 +50,7 @@ class ExecutionEngine: on_table: Callable[[str, list], None] | None = None, on_mesh: Callable[[str, dict], None] | None = None, on_overlay: Callable[[str, str], None] | None = None, + on_value: Callable[[str, float], None] | None = None, on_warning: Callable[[str, str], None] | None = None, ) -> dict[str, tuple]: """ @@ -73,7 +74,7 @@ class ExecutionEngine: node_outputs: dict[str, tuple] = {} # Inject display callbacks before execution - self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_warning) + self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning) for node_id in order: node_def = prompt[node_id] @@ -176,11 +177,12 @@ class ExecutionEngine: on_table: Callable | None, on_mesh: Callable | None = None, on_overlay: Callable | None = None, + on_value: Callable | None = None, on_warning: Callable | None = None, ) -> None: """Wire up broadcast callbacks on display node classes.""" - from backend.nodes.display import PreviewImage, PrintTable, View3D - from backend.nodes.analysis import CrossSection, LineCursors + from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay + from backend.nodes.analysis import CrossSection, LineCursors, TableMath from backend.nodes.modify import CropResizeField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine from backend.nodes.io import SaveImage, LoadFile @@ -192,6 +194,8 @@ class ExecutionEngine: MaskCombine._broadcast_fn = on_preview View3D._broadcast_mesh_fn = on_mesh PrintTable._broadcast_table_fn = on_table + ValueDisplay._broadcast_value_fn = on_value + TableMath._broadcast_value_fn = on_value CrossSection._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay CropResizeField._broadcast_overlay_fn = on_overlay @@ -200,12 +204,12 @@ class ExecutionEngine: def _set_node_id_on_display(self, cls: type, node_id: str) -> None: """Inform display nodes of their current node_id for WS tagging.""" - from backend.nodes.display import PreviewImage, PrintTable, View3D - from backend.nodes.analysis import CrossSection, LineCursors + from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay + from backend.nodes.analysis import CrossSection, LineCursors, TableMath from backend.nodes.modify import CropResizeField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine from backend.nodes.io import LoadFile, SaveImage - if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors, CropResizeField, + if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, CrossSection, LineCursors, CropResizeField, ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile, SaveImage): cls._current_node_id = node_id diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index 37f2dfc..05ba976 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -587,12 +587,18 @@ TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = { class TableMath: """Compute a scalar reduction over one numeric column in a TABLE.""" + _broadcast_value_fn = None + _current_node_id: str = "" + @classmethod def INPUT_TYPES(cls): return { "required": { "table": ("TABLE",), - "column": ("STRING", {"default": "value"}), + "column": ("STRING", { + "default": "value", + "choices_from_table_input": "table", + }), "operation": (list(TABLE_OPS.keys()),), } } @@ -618,7 +624,11 @@ class TableMath: op = TABLE_OPS.get(operation) if op is None: raise ValueError(f"Unsupported table operation: {operation}") - return (op(np.asarray(values, dtype=np.float64)),) + + result = op(np.asarray(values, dtype=np.float64)) + if TableMath._broadcast_value_fn is not None: + TableMath._broadcast_value_fn(TableMath._current_node_id, result) + return (result,) def _resolve_column_name(self, table: list, column: str) -> str: requested = str(column or "").strip() diff --git a/backend/nodes/display.py b/backend/nodes/display.py index e5e06f7..e3e4776 100644 --- a/backend/nodes/display.py +++ b/backend/nodes/display.py @@ -173,3 +173,29 @@ class PrintTable: if PrintTable._broadcast_table_fn is not None: PrintTable._broadcast_table_fn(PrintTable._current_node_id, table) return () + + +@register_node(display_name="Value Display") +class ValueDisplay: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT",), + } + } + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "display_value" + CATEGORY = "display" + DESCRIPTION = "Display a FLOAT in the graph and pass the same value through unchanged." + + _broadcast_value_fn = None + _current_node_id: str = "" + + def display_value(self, value: float) -> tuple: + numeric = float(value) + if ValueDisplay._broadcast_value_fn is not None: + ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, numeric) + return (numeric,) diff --git a/backend/server.py b/backend/server.py index cc453f4..dee1192 100644 --- a/backend/server.py +++ b/backend/server.py @@ -16,6 +16,7 @@ WebSocket message types sent to clients {"type": "executing", "data": {"node": "...", "prompt_id": "..."}} {"type": "preview", "data": {"node_id": "...", "image": "data:..."}} {"type": "table", "data": {"node_id": "...", "rows": [...]}} +{"type": "scalar", "data": {"node_id": "...", "value": 1.23}} {"type": "execution_error", "data": {"node_id": "...", "message": "..."}} {"type": "execution_complete", "data": {"prompt_id": "..."}} """ @@ -114,6 +115,9 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: def on_overlay(node_id: str, overlay_data) -> None: broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}}) + def on_value(node_id: str, value: float) -> None: + broadcast({"type": "scalar", "data": {"node_id": node_id, "value": value}}) + def on_warning(node_id: str, message: str) -> None: broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}}) @@ -260,6 +264,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: on_table=on_table, on_mesh=on_mesh, on_overlay=on_overlay, + on_value=on_value, on_warning=on_warning, ), ) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index c959a17..c0ead17 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -451,6 +451,9 @@ function Flow() { case 'table': updateNodeData(msg.data.node_id, { tableRows: msg.data.rows }); break; + case 'scalar': + updateNodeData(msg.data.node_id, { scalarValue: msg.data.value }); + break; case 'mesh3d': updateNodeData(msg.data.node_id, { meshData: msg.data.mesh }); break; @@ -628,6 +631,7 @@ function Flow() { tableRows: null, meshData: null, overlay: null, + scalarValue: null, }, }; diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 316f5e9..8deee74 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -1,4 +1,4 @@ -import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react'; +import React, { useContext, useRef, useCallback, useState, useEffect, memo, lazy, Suspense } from 'react'; import { Handle, Position, useStore } from '@xyflow/react'; import LinePlotOverlay from './LinePlotOverlay'; @@ -195,6 +195,16 @@ function formatTableCell(value) { return String(value); } +function formatScalarValue(value) { + if (value == null || Number.isNaN(Number(value))) return '—'; + const numeric = Number(value); + if (!Number.isFinite(numeric)) return String(numeric); + const abs = Math.abs(numeric); + if (abs === 0) return '0'; + if ((abs > 0 && abs < 1e-3) || abs >= 1e5) return numeric.toExponential(4); + return numeric.toFixed(abs >= 100 ? 2 : 4).replace(/\.?0+$/, ''); +} + function NodeTable({ rows }) { const columns = getTableColumns(rows); if (columns.length === 0) return null; @@ -362,6 +372,13 @@ function CustomNode({ id, data }) {
{data.warning}
)} + {typeof data.scalarValue === 'number' && ( +
+
Value
+
{formatScalarValue(data.scalarValue)}
+
+ )} + {/* Widget rows */} {widgets.map((w) => (
@@ -487,6 +504,29 @@ function CustomNode({ id, data }) { function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFileBrowser }) { const { name, type, opts } = widget; const val = value ?? opts?.default ?? ''; + const dynamicTableColumns = useStore( + useCallback( + (s) => { + const tableInputName = opts?.choices_from_table_input; + if (!tableInputName) return []; + const targetHandle = `input::${tableInputName}::TABLE`; + const edge = s.edges?.find((e) => e.target === nodeId && e.targetHandle === targetHandle); + if (!edge) return []; + const sourceNode = s.nodeLookup?.get(edge.source) || s.nodes?.find((n) => n.id === edge.source); + const rows = sourceNode?.data?.tableRows; + return Array.isArray(rows) ? getTableColumns(rows) : []; + }, + [nodeId, opts?.choices_from_table_input], + ), + ); + + useEffect(() => { + if (!opts?.choices_from_table_input || dynamicTableColumns.length === 0) return; + const current = String(val ?? ''); + if (dynamicTableColumns.includes(current)) return; + const preferred = dynamicTableColumns.includes('value') ? 'value' : dynamicTableColumns[0]; + if (preferred != null) onChange(nodeId, name, preferred); + }, [dynamicTableColumns, name, nodeId, onChange, opts?.choices_from_table_input, val]); // Combo / enum — type itself is the array of options if (Array.isArray(type)) { @@ -506,6 +546,24 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile ); } + if (type === 'STRING' && opts?.choices_from_table_input && dynamicTableColumns.length > 0) { + const selected = dynamicTableColumns.includes(String(val)) ? String(val) : dynamicTableColumns[0]; + return ( + <> + + + + ); + } + if (type === 'FILE_PICKER') { return ( <> diff --git a/frontend/src/styles.css b/frontend/src/styles.css index ffe36e7..7171d8d 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -161,6 +161,39 @@ html, body, #root { border-bottom: 1px solid rgba(251, 191, 36, 0.2); } +.node-value-display { + padding: 8px 10px 4px; +} + +.node-value-label { + font-size: 9px; + font-weight: 700; + letter-spacing: 0.12em; + text-transform: uppercase; + color: #7dd3fc; + margin-bottom: 5px; +} + +.node-value-box { + padding: 10px 12px; + border-radius: 8px; + border: 1px solid rgba(125, 211, 252, 0.45); + background: + linear-gradient(180deg, rgba(14, 116, 144, 0.2), rgba(8, 47, 73, 0.45)), + linear-gradient(135deg, rgba(125, 211, 252, 0.08), rgba(56, 189, 248, 0.02)); + box-shadow: + inset 0 1px 0 rgba(255, 255, 255, 0.05), + 0 8px 20px rgba(2, 132, 199, 0.14); + color: #e0f2fe; + font-size: 22px; + font-weight: 700; + line-height: 1.1; + letter-spacing: 0.02em; + text-align: center; + font-variant-numeric: tabular-nums lining-nums; + overflow-wrap: anywhere; +} + /* ── I/O rows ──────────────────────────────────────────────────────── */ .io-row { display: flex; diff --git a/frontend/src/workflowHydration.js b/frontend/src/workflowHydration.js index 342686e..92bdecb 100644 --- a/frontend/src/workflowHydration.js +++ b/frontend/src/workflowHydration.js @@ -39,6 +39,7 @@ export function hydrateWorkflowState(data, defs = {}) { tableRows: null, meshData: null, overlay: null, + scalarValue: null, }, })); diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 24bcab0..8694124 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -868,6 +868,23 @@ def test_print_table(): print(" PASS\n") +def test_value_display(): + print("=== Test: ValueDisplay ===") + from backend.nodes.display import ValueDisplay + + node = ValueDisplay() + captured = [] + ValueDisplay._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value)) + ValueDisplay._current_node_id = "test" + + result = node.display_value(3.25) + assert result == (3.25,) + assert captured == [("test", 3.25)] + + ValueDisplay._broadcast_value_fn = None + print(" PASS\n") + + # ========================================================================= # I/O — IBW multi-channel loading # ========================================================================= @@ -1313,6 +1330,9 @@ def test_table_math(): from backend.nodes.analysis import TableMath node = TableMath() + captured = [] + TableMath._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value)) + TableMath._current_node_id = "test" table = [ {"label": "a", "value": 1.0, "other": 10}, {"label": "b", "value": 5.0, "other": 20}, @@ -1322,6 +1342,7 @@ def test_table_math(): result, = node.process(table, column="value", operation="max") assert result == 5.0 + assert captured[-1] == ("test", 5.0) result, = node.process(table, column="value", operation="min") assert result == 1.0 @@ -1354,6 +1375,8 @@ def test_table_math(): except ValueError: pass + TableMath._broadcast_value_fn = None + print(" PASS\n") @@ -1460,6 +1483,7 @@ if __name__ == "__main__": # Display test_preview_image() test_print_table() + test_value_display() test_view3d() print("All tests passed!")