fix table math column picker
This commit is contained in:
@@ -50,6 +50,7 @@ class ExecutionEngine:
|
||||
on_table: Callable[[str, list], None] | None = None,
|
||||
on_mesh: Callable[[str, dict], None] | None = None,
|
||||
on_overlay: Callable[[str, str], None] | None = None,
|
||||
on_value: Callable[[str, float], None] | None = None,
|
||||
on_warning: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, tuple]:
|
||||
"""
|
||||
@@ -73,7 +74,7 @@ class ExecutionEngine:
|
||||
node_outputs: dict[str, tuple] = {}
|
||||
|
||||
# Inject display callbacks before execution
|
||||
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_warning)
|
||||
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
|
||||
|
||||
for node_id in order:
|
||||
node_def = prompt[node_id]
|
||||
@@ -176,11 +177,12 @@ class ExecutionEngine:
|
||||
on_table: Callable | None,
|
||||
on_mesh: Callable | None = None,
|
||||
on_overlay: Callable | None = None,
|
||||
on_value: Callable | None = None,
|
||||
on_warning: Callable | None = None,
|
||||
) -> None:
|
||||
"""Wire up broadcast callbacks on display node classes."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection, LineCursors
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
|
||||
from backend.nodes.analysis import CrossSection, LineCursors, TableMath
|
||||
from backend.nodes.modify import CropResizeField
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||
from backend.nodes.io import SaveImage, LoadFile
|
||||
@@ -192,6 +194,8 @@ class ExecutionEngine:
|
||||
MaskCombine._broadcast_fn = on_preview
|
||||
View3D._broadcast_mesh_fn = on_mesh
|
||||
PrintTable._broadcast_table_fn = on_table
|
||||
ValueDisplay._broadcast_value_fn = on_value
|
||||
TableMath._broadcast_value_fn = on_value
|
||||
CrossSection._broadcast_overlay_fn = on_overlay
|
||||
LineCursors._broadcast_overlay_fn = on_overlay
|
||||
CropResizeField._broadcast_overlay_fn = on_overlay
|
||||
@@ -200,12 +204,12 @@ class ExecutionEngine:
|
||||
|
||||
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||
"""Inform display nodes of their current node_id for WS tagging."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection, LineCursors
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
|
||||
from backend.nodes.analysis import CrossSection, LineCursors, TableMath
|
||||
from backend.nodes.modify import CropResizeField
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||
from backend.nodes.io import LoadFile, SaveImage
|
||||
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors, CropResizeField,
|
||||
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, CrossSection, LineCursors, CropResizeField,
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine,
|
||||
LoadFile, SaveImage):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
@@ -587,12 +587,18 @@ TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = {
|
||||
class TableMath:
|
||||
"""Compute a scalar reduction over one numeric column in a TABLE."""
|
||||
|
||||
_broadcast_value_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"table": ("TABLE",),
|
||||
"column": ("STRING", {"default": "value"}),
|
||||
"column": ("STRING", {
|
||||
"default": "value",
|
||||
"choices_from_table_input": "table",
|
||||
}),
|
||||
"operation": (list(TABLE_OPS.keys()),),
|
||||
}
|
||||
}
|
||||
@@ -618,7 +624,11 @@ class TableMath:
|
||||
op = TABLE_OPS.get(operation)
|
||||
if op is None:
|
||||
raise ValueError(f"Unsupported table operation: {operation}")
|
||||
return (op(np.asarray(values, dtype=np.float64)),)
|
||||
|
||||
result = op(np.asarray(values, dtype=np.float64))
|
||||
if TableMath._broadcast_value_fn is not None:
|
||||
TableMath._broadcast_value_fn(TableMath._current_node_id, result)
|
||||
return (result,)
|
||||
|
||||
def _resolve_column_name(self, table: list, column: str) -> str:
|
||||
requested = str(column or "").strip()
|
||||
|
||||
@@ -173,3 +173,29 @@ class PrintTable:
|
||||
if PrintTable._broadcast_table_fn is not None:
|
||||
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
|
||||
return ()
|
||||
|
||||
|
||||
@register_node(display_name="Value Display")
|
||||
class ValueDisplay:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "display_value"
|
||||
CATEGORY = "display"
|
||||
DESCRIPTION = "Display a FLOAT in the graph and pass the same value through unchanged."
|
||||
|
||||
_broadcast_value_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def display_value(self, value: float) -> tuple:
|
||||
numeric = float(value)
|
||||
if ValueDisplay._broadcast_value_fn is not None:
|
||||
ValueDisplay._broadcast_value_fn(ValueDisplay._current_node_id, numeric)
|
||||
return (numeric,)
|
||||
|
||||
@@ -16,6 +16,7 @@ WebSocket message types sent to clients
|
||||
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
|
||||
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
|
||||
{"type": "table", "data": {"node_id": "...", "rows": [...]}}
|
||||
{"type": "scalar", "data": {"node_id": "...", "value": 1.23}}
|
||||
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
|
||||
{"type": "execution_complete", "data": {"prompt_id": "..."}}
|
||||
"""
|
||||
@@ -114,6 +115,9 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
def on_overlay(node_id: str, overlay_data) -> None:
|
||||
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
|
||||
|
||||
def on_value(node_id: str, value: float) -> None:
|
||||
broadcast({"type": "scalar", "data": {"node_id": node_id, "value": value}})
|
||||
|
||||
def on_warning(node_id: str, message: str) -> None:
|
||||
broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}})
|
||||
|
||||
@@ -260,6 +264,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
on_table=on_table,
|
||||
on_mesh=on_mesh,
|
||||
on_overlay=on_overlay,
|
||||
on_value=on_value,
|
||||
on_warning=on_warning,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -451,6 +451,9 @@ function Flow() {
|
||||
case 'table':
|
||||
updateNodeData(msg.data.node_id, { tableRows: msg.data.rows });
|
||||
break;
|
||||
case 'scalar':
|
||||
updateNodeData(msg.data.node_id, { scalarValue: msg.data.value });
|
||||
break;
|
||||
case 'mesh3d':
|
||||
updateNodeData(msg.data.node_id, { meshData: msg.data.mesh });
|
||||
break;
|
||||
@@ -628,6 +631,7 @@ function Flow() {
|
||||
tableRows: null,
|
||||
meshData: null,
|
||||
overlay: null,
|
||||
scalarValue: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react';
|
||||
import React, { useContext, useRef, useCallback, useState, useEffect, memo, lazy, Suspense } from 'react';
|
||||
import { Handle, Position, useStore } from '@xyflow/react';
|
||||
import LinePlotOverlay from './LinePlotOverlay';
|
||||
|
||||
@@ -195,6 +195,16 @@ function formatTableCell(value) {
|
||||
return String(value);
|
||||
}
|
||||
|
||||
function formatScalarValue(value) {
|
||||
if (value == null || Number.isNaN(Number(value))) return '—';
|
||||
const numeric = Number(value);
|
||||
if (!Number.isFinite(numeric)) return String(numeric);
|
||||
const abs = Math.abs(numeric);
|
||||
if (abs === 0) return '0';
|
||||
if ((abs > 0 && abs < 1e-3) || abs >= 1e5) return numeric.toExponential(4);
|
||||
return numeric.toFixed(abs >= 100 ? 2 : 4).replace(/\.?0+$/, '');
|
||||
}
|
||||
|
||||
function NodeTable({ rows }) {
|
||||
const columns = getTableColumns(rows);
|
||||
if (columns.length === 0) return null;
|
||||
@@ -362,6 +372,13 @@ function CustomNode({ id, data }) {
|
||||
<div className="node-warning">{data.warning}</div>
|
||||
)}
|
||||
|
||||
{typeof data.scalarValue === 'number' && (
|
||||
<div className="node-value-display">
|
||||
<div className="node-value-label">Value</div>
|
||||
<div className="node-value-box">{formatScalarValue(data.scalarValue)}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Widget rows */}
|
||||
{widgets.map((w) => (
|
||||
<div className={`widget-row${w.socketType ? ' widget-row-socket' : ''}`} key={w.name}>
|
||||
@@ -487,6 +504,29 @@ function CustomNode({ id, data }) {
|
||||
function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFileBrowser }) {
|
||||
const { name, type, opts } = widget;
|
||||
const val = value ?? opts?.default ?? '';
|
||||
const dynamicTableColumns = useStore(
|
||||
useCallback(
|
||||
(s) => {
|
||||
const tableInputName = opts?.choices_from_table_input;
|
||||
if (!tableInputName) return [];
|
||||
const targetHandle = `input::${tableInputName}::TABLE`;
|
||||
const edge = s.edges?.find((e) => e.target === nodeId && e.targetHandle === targetHandle);
|
||||
if (!edge) return [];
|
||||
const sourceNode = s.nodeLookup?.get(edge.source) || s.nodes?.find((n) => n.id === edge.source);
|
||||
const rows = sourceNode?.data?.tableRows;
|
||||
return Array.isArray(rows) ? getTableColumns(rows) : [];
|
||||
},
|
||||
[nodeId, opts?.choices_from_table_input],
|
||||
),
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!opts?.choices_from_table_input || dynamicTableColumns.length === 0) return;
|
||||
const current = String(val ?? '');
|
||||
if (dynamicTableColumns.includes(current)) return;
|
||||
const preferred = dynamicTableColumns.includes('value') ? 'value' : dynamicTableColumns[0];
|
||||
if (preferred != null) onChange(nodeId, name, preferred);
|
||||
}, [dynamicTableColumns, name, nodeId, onChange, opts?.choices_from_table_input, val]);
|
||||
|
||||
// Combo / enum — type itself is the array of options
|
||||
if (Array.isArray(type)) {
|
||||
@@ -506,6 +546,24 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'STRING' && opts?.choices_from_table_input && dynamicTableColumns.length > 0) {
|
||||
const selected = dynamicTableColumns.includes(String(val)) ? String(val) : dynamicTableColumns[0];
|
||||
return (
|
||||
<>
|
||||
<label>{name}</label>
|
||||
<select
|
||||
className="nodrag"
|
||||
value={selected}
|
||||
onChange={(e) => onChange(nodeId, name, e.target.value)}
|
||||
>
|
||||
{dynamicTableColumns.map((column) => (
|
||||
<option key={column} value={column}>{column}</option>
|
||||
))}
|
||||
</select>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'FILE_PICKER') {
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -161,6 +161,39 @@ html, body, #root {
|
||||
border-bottom: 1px solid rgba(251, 191, 36, 0.2);
|
||||
}
|
||||
|
||||
.node-value-display {
|
||||
padding: 8px 10px 4px;
|
||||
}
|
||||
|
||||
.node-value-label {
|
||||
font-size: 9px;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0.12em;
|
||||
text-transform: uppercase;
|
||||
color: #7dd3fc;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
.node-value-box {
|
||||
padding: 10px 12px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid rgba(125, 211, 252, 0.45);
|
||||
background:
|
||||
linear-gradient(180deg, rgba(14, 116, 144, 0.2), rgba(8, 47, 73, 0.45)),
|
||||
linear-gradient(135deg, rgba(125, 211, 252, 0.08), rgba(56, 189, 248, 0.02));
|
||||
box-shadow:
|
||||
inset 0 1px 0 rgba(255, 255, 255, 0.05),
|
||||
0 8px 20px rgba(2, 132, 199, 0.14);
|
||||
color: #e0f2fe;
|
||||
font-size: 22px;
|
||||
font-weight: 700;
|
||||
line-height: 1.1;
|
||||
letter-spacing: 0.02em;
|
||||
text-align: center;
|
||||
font-variant-numeric: tabular-nums lining-nums;
|
||||
overflow-wrap: anywhere;
|
||||
}
|
||||
|
||||
/* ── I/O rows ──────────────────────────────────────────────────────── */
|
||||
.io-row {
|
||||
display: flex;
|
||||
|
||||
@@ -39,6 +39,7 @@ export function hydrateWorkflowState(data, defs = {}) {
|
||||
tableRows: null,
|
||||
meshData: null,
|
||||
overlay: null,
|
||||
scalarValue: null,
|
||||
},
|
||||
}));
|
||||
|
||||
|
||||
@@ -868,6 +868,23 @@ def test_print_table():
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_value_display():
|
||||
print("=== Test: ValueDisplay ===")
|
||||
from backend.nodes.display import ValueDisplay
|
||||
|
||||
node = ValueDisplay()
|
||||
captured = []
|
||||
ValueDisplay._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value))
|
||||
ValueDisplay._current_node_id = "test"
|
||||
|
||||
result = node.display_value(3.25)
|
||||
assert result == (3.25,)
|
||||
assert captured == [("test", 3.25)]
|
||||
|
||||
ValueDisplay._broadcast_value_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — IBW multi-channel loading
|
||||
# =========================================================================
|
||||
@@ -1313,6 +1330,9 @@ def test_table_math():
|
||||
from backend.nodes.analysis import TableMath
|
||||
|
||||
node = TableMath()
|
||||
captured = []
|
||||
TableMath._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value))
|
||||
TableMath._current_node_id = "test"
|
||||
table = [
|
||||
{"label": "a", "value": 1.0, "other": 10},
|
||||
{"label": "b", "value": 5.0, "other": 20},
|
||||
@@ -1322,6 +1342,7 @@ def test_table_math():
|
||||
|
||||
result, = node.process(table, column="value", operation="max")
|
||||
assert result == 5.0
|
||||
assert captured[-1] == ("test", 5.0)
|
||||
|
||||
result, = node.process(table, column="value", operation="min")
|
||||
assert result == 1.0
|
||||
@@ -1354,6 +1375,8 @@ def test_table_math():
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
TableMath._broadcast_value_fn = None
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
@@ -1460,6 +1483,7 @@ if __name__ == "__main__":
|
||||
# Display
|
||||
test_preview_image()
|
||||
test_print_table()
|
||||
test_value_display()
|
||||
test_view3d()
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
Reference in New Issue
Block a user