fix table math column picker

This commit is contained in:
2026-03-25 00:01:24 -07:00
parent 44de72d31b
commit a65b7c5642
9 changed files with 174 additions and 9 deletions

View File

@@ -50,6 +50,7 @@ class ExecutionEngine:
on_table: Callable[[str, list], None] | None = None,
on_mesh: Callable[[str, dict], None] | None = None,
on_overlay: Callable[[str, str], None] | None = None,
on_value: Callable[[str, float], None] | None = None,
on_warning: Callable[[str, str], None] | None = None,
) -> dict[str, tuple]:
"""
@@ -73,7 +74,7 @@ class ExecutionEngine:
node_outputs: dict[str, tuple] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_warning)
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
for node_id in order:
node_def = prompt[node_id]
@@ -176,11 +177,12 @@ class ExecutionEngine:
on_table: Callable | None,
on_mesh: Callable | None = None,
on_overlay: Callable | None = None,
on_value: Callable | None = None,
on_warning: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.analysis import CrossSection, LineCursors, TableMath
from backend.nodes.modify import CropResizeField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import SaveImage, LoadFile
@@ -192,6 +194,8 @@ class ExecutionEngine:
MaskCombine._broadcast_fn = on_preview
View3D._broadcast_mesh_fn = on_mesh
PrintTable._broadcast_table_fn = on_table
ValueDisplay._broadcast_value_fn = on_value
TableMath._broadcast_value_fn = on_value
CrossSection._broadcast_overlay_fn = on_overlay
LineCursors._broadcast_overlay_fn = on_overlay
CropResizeField._broadcast_overlay_fn = on_overlay
@@ -200,12 +204,12 @@ class ExecutionEngine:
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.analysis import CrossSection, LineCursors, TableMath
from backend.nodes.modify import CropResizeField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import LoadFile, SaveImage
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors, CropResizeField,
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, CrossSection, LineCursors, CropResizeField,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine,
LoadFile, SaveImage):
cls._current_node_id = node_id

View File

@@ -587,12 +587,18 @@ TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = {
class TableMath:
"""Compute a scalar reduction over one numeric column in a TABLE."""
_broadcast_value_fn = None
_current_node_id: str = ""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"table": ("TABLE",),
"column": ("STRING", {"default": "value"}),
"column": ("STRING", {
"default": "value",
"choices_from_table_input": "table",
}),
"operation": (list(TABLE_OPS.keys()),),
}
}
@@ -618,7 +624,11 @@ class TableMath:
op = TABLE_OPS.get(operation)
if op is None:
raise ValueError(f"Unsupported table operation: {operation}")
return (op(np.asarray(values, dtype=np.float64)),)
result = op(np.asarray(values, dtype=np.float64))
if TableMath._broadcast_value_fn is not None:
TableMath._broadcast_value_fn(TableMath._current_node_id, result)
return (result,)
def _resolve_column_name(self, table: list, column: str) -> str:
requested = str(column or "").strip()

View File

@@ -173,3 +173,29 @@ class PrintTable:
if PrintTable._broadcast_table_fn is not None:
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
return ()
@register_node(display_name="Value Display")
class ValueDisplay:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT",),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "display_value"
CATEGORY = "display"
DESCRIPTION = "Display a FLOAT in the graph and pass the same value through unchanged."
_broadcast_value_fn = None
_current_node_id: str = ""
def display_value(self, value: float) -> tuple:
numeric = float(value)
if ValueDisplay._broadcast_value_fn is not None:
ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, numeric)
return (numeric,)

View File

@@ -16,6 +16,7 @@ WebSocket message types sent to clients
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
{"type": "table", "data": {"node_id": "...", "rows": [...]}}
{"type": "scalar", "data": {"node_id": "...", "value": 1.23}}
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
{"type": "execution_complete", "data": {"prompt_id": "..."}}
"""
@@ -114,6 +115,9 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
def on_overlay(node_id: str, overlay_data) -> None:
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
def on_value(node_id: str, value: float) -> None:
broadcast({"type": "scalar", "data": {"node_id": node_id, "value": value}})
def on_warning(node_id: str, message: str) -> None:
broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}})
@@ -260,6 +264,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
on_table=on_table,
on_mesh=on_mesh,
on_overlay=on_overlay,
on_value=on_value,
on_warning=on_warning,
),
)

View File

@@ -451,6 +451,9 @@ function Flow() {
case 'table':
updateNodeData(msg.data.node_id, { tableRows: msg.data.rows });
break;
case 'scalar':
updateNodeData(msg.data.node_id, { scalarValue: msg.data.value });
break;
case 'mesh3d':
updateNodeData(msg.data.node_id, { meshData: msg.data.mesh });
break;
@@ -628,6 +631,7 @@ function Flow() {
tableRows: null,
meshData: null,
overlay: null,
scalarValue: null,
},
};

View File

@@ -1,4 +1,4 @@
import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react';
import React, { useContext, useRef, useCallback, useState, useEffect, memo, lazy, Suspense } from 'react';
import { Handle, Position, useStore } from '@xyflow/react';
import LinePlotOverlay from './LinePlotOverlay';
@@ -195,6 +195,16 @@ function formatTableCell(value) {
return String(value);
}
function formatScalarValue(value) {
if (value == null || Number.isNaN(Number(value))) return '—';
const numeric = Number(value);
if (!Number.isFinite(numeric)) return String(numeric);
const abs = Math.abs(numeric);
if (abs === 0) return '0';
if ((abs > 0 && abs < 1e-3) || abs >= 1e5) return numeric.toExponential(4);
return numeric.toFixed(abs >= 100 ? 2 : 4).replace(/\.?0+$/, '');
}
function NodeTable({ rows }) {
const columns = getTableColumns(rows);
if (columns.length === 0) return null;
@@ -362,6 +372,13 @@ function CustomNode({ id, data }) {
<div className="node-warning">{data.warning}</div>
)}
{typeof data.scalarValue === 'number' && (
<div className="node-value-display">
<div className="node-value-label">Value</div>
<div className="node-value-box">{formatScalarValue(data.scalarValue)}</div>
</div>
)}
{/* Widget rows */}
{widgets.map((w) => (
<div className={`widget-row${w.socketType ? ' widget-row-socket' : ''}`} key={w.name}>
@@ -487,6 +504,29 @@ function CustomNode({ id, data }) {
function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFileBrowser }) {
const { name, type, opts } = widget;
const val = value ?? opts?.default ?? '';
const dynamicTableColumns = useStore(
useCallback(
(s) => {
const tableInputName = opts?.choices_from_table_input;
if (!tableInputName) return [];
const targetHandle = `input::${tableInputName}::TABLE`;
const edge = s.edges?.find((e) => e.target === nodeId && e.targetHandle === targetHandle);
if (!edge) return [];
const sourceNode = s.nodeLookup?.get(edge.source) || s.nodes?.find((n) => n.id === edge.source);
const rows = sourceNode?.data?.tableRows;
return Array.isArray(rows) ? getTableColumns(rows) : [];
},
[nodeId, opts?.choices_from_table_input],
),
);
useEffect(() => {
if (!opts?.choices_from_table_input || dynamicTableColumns.length === 0) return;
const current = String(val ?? '');
if (dynamicTableColumns.includes(current)) return;
const preferred = dynamicTableColumns.includes('value') ? 'value' : dynamicTableColumns[0];
if (preferred != null) onChange(nodeId, name, preferred);
}, [dynamicTableColumns, name, nodeId, onChange, opts?.choices_from_table_input, val]);
// Combo / enum — type itself is the array of options
if (Array.isArray(type)) {
@@ -506,6 +546,24 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile
);
}
if (type === 'STRING' && opts?.choices_from_table_input && dynamicTableColumns.length > 0) {
const selected = dynamicTableColumns.includes(String(val)) ? String(val) : dynamicTableColumns[0];
return (
<>
<label>{name}</label>
<select
className="nodrag"
value={selected}
onChange={(e) => onChange(nodeId, name, e.target.value)}
>
{dynamicTableColumns.map((column) => (
<option key={column} value={column}>{column}</option>
))}
</select>
</>
);
}
if (type === 'FILE_PICKER') {
return (
<>

View File

@@ -161,6 +161,39 @@ html, body, #root {
border-bottom: 1px solid rgba(251, 191, 36, 0.2);
}
.node-value-display {
padding: 8px 10px 4px;
}
.node-value-label {
font-size: 9px;
font-weight: 700;
letter-spacing: 0.12em;
text-transform: uppercase;
color: #7dd3fc;
margin-bottom: 5px;
}
.node-value-box {
padding: 10px 12px;
border-radius: 8px;
border: 1px solid rgba(125, 211, 252, 0.45);
background:
linear-gradient(180deg, rgba(14, 116, 144, 0.2), rgba(8, 47, 73, 0.45)),
linear-gradient(135deg, rgba(125, 211, 252, 0.08), rgba(56, 189, 248, 0.02));
box-shadow:
inset 0 1px 0 rgba(255, 255, 255, 0.05),
0 8px 20px rgba(2, 132, 199, 0.14);
color: #e0f2fe;
font-size: 22px;
font-weight: 700;
line-height: 1.1;
letter-spacing: 0.02em;
text-align: center;
font-variant-numeric: tabular-nums lining-nums;
overflow-wrap: anywhere;
}
/* ── I/O rows ──────────────────────────────────────────────────────── */
.io-row {
display: flex;

View File

@@ -39,6 +39,7 @@ export function hydrateWorkflowState(data, defs = {}) {
tableRows: null,
meshData: null,
overlay: null,
scalarValue: null,
},
}));

View File

@@ -868,6 +868,23 @@ def test_print_table():
print(" PASS\n")
def test_value_display():
print("=== Test: ValueDisplay ===")
from backend.nodes.display import ValueDisplay
node = ValueDisplay()
captured = []
ValueDisplay._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value))
ValueDisplay._current_node_id = "test"
result = node.display_value(3.25)
assert result == (3.25,)
assert captured == [("test", 3.25)]
ValueDisplay._broadcast_value_fn = None
print(" PASS\n")
# =========================================================================
# I/O — IBW multi-channel loading
# =========================================================================
@@ -1313,6 +1330,9 @@ def test_table_math():
from backend.nodes.analysis import TableMath
node = TableMath()
captured = []
TableMath._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value))
TableMath._current_node_id = "test"
table = [
{"label": "a", "value": 1.0, "other": 10},
{"label": "b", "value": 5.0, "other": 20},
@@ -1322,6 +1342,7 @@ def test_table_math():
result, = node.process(table, column="value", operation="max")
assert result == 5.0
assert captured[-1] == ("test", 5.0)
result, = node.process(table, column="value", operation="min")
assert result == 1.0
@@ -1354,6 +1375,8 @@ def test_table_math():
except ValueError:
pass
TableMath._broadcast_value_fn = None
print(" PASS\n")
@@ -1460,6 +1483,7 @@ if __name__ == "__main__":
# Display
test_preview_image()
test_print_table()
test_value_display()
test_view3d()
print("All tests passed!")