From 1b831cda5d8be115b5e1ecc53e4c2f5b8330c5f5 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Sat, 28 Mar 2026 13:56:22 -0700 Subject: [PATCH] refactor socket types --- backend/nodes/angle_measure.py | 5 +- backend/nodes/annotations.py | 5 +- backend/nodes/cursors.py | 5 +- backend/nodes/markup.py | 5 +- backend/nodes/preview_image.py | 5 +- backend/nodes/print_table.py | 4 +- backend/nodes/save.py | 13 ++- backend/nodes/save_image.py | 5 +- backend/nodes/stats.py | 4 +- backend/nodes/value_display.py | 4 +- frontend/src/App.jsx | 112 +++++++++++++++++---- frontend/src/CustomNode.jsx | 12 +-- frontend/src/constants.js | 74 ++++++++++---- frontend/src/executionGraph.js | 28 +++--- frontend/src/nodeWidgetDefaults.js | 6 +- frontend/tests/constants.test.mjs | 29 +++++- frontend/tests/executionGraph.test.mjs | 86 ++++++++++++++++ frontend/tests/nodeClipboard.test.mjs | 4 +- frontend/tests/nodeWidgetDefaults.test.mjs | 1 + tests/test_nodes.py | 38 +++++++ 20 files changed, 366 insertions(+), 79 deletions(-) diff --git a/backend/nodes/angle_measure.py b/backend/nodes/angle_measure.py index 93df65a..1cccad4 100644 --- a/backend/nodes/angle_measure.py +++ b/backend/nodes/angle_measure.py @@ -85,7 +85,10 @@ class AngleMeasure: def INPUT_TYPES(cls): return { "required": { - "input": ("ANNOTATION_SOURCE", {"label": "Input"}), + "input": ("ANNOTATION_SOURCE", { + "label": "Input", + "accepted_types": ["DATA_FIELD", "IMAGE"], + }), "color": ("STRING", {"default": ANGLE_DEFAULT_COLOR, "color_picker": True}), "stroke_width": ("FLOAT", { "default": 1.35, diff --git a/backend/nodes/annotations.py b/backend/nodes/annotations.py index 4deb199..6930c23 100644 --- a/backend/nodes/annotations.py +++ b/backend/nodes/annotations.py @@ -23,7 +23,10 @@ class Annotations: def INPUT_TYPES(cls): return { "required": { - "input": ("ANNOTATION_SOURCE", {"label": "Input"}), + "input": ("ANNOTATION_SOURCE", { + "label": "Input", + "accepted_types": ["DATA_FIELD", "IMAGE"], + }), "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), "show_scale_bar": ("BOOLEAN", {"default": True}), "show_color_map": ("BOOLEAN", {"default": True}), diff --git a/backend/nodes/cursors.py b/backend/nodes/cursors.py index faa15b7..aaf52dd 100644 --- a/backend/nodes/cursors.py +++ b/backend/nodes/cursors.py @@ -13,7 +13,10 @@ class Cursors: def INPUT_TYPES(cls): return { "required": { - "line": ("CURSOR_SOURCE", {"label": "input"}), + "line": ("LINE", { + "label": "input", + "accepted_types": ["DATA_FIELD"], + }), "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), diff --git a/backend/nodes/markup.py b/backend/nodes/markup.py index 3813571..5cf8a69 100644 --- a/backend/nodes/markup.py +++ b/backend/nodes/markup.py @@ -21,7 +21,10 @@ class Markup: def INPUT_TYPES(cls): return { "required": { - "input": ("ANNOTATION_SOURCE", {"label": "Input"}), + "input": ("ANNOTATION_SOURCE", { + "label": "Input", + "accepted_types": ["DATA_FIELD", "IMAGE"], + }), "shape": (["line", "rectangle", "circle", "arrow"], {"default": "arrow"}), "stroke_color": ("STRING", {"default": "#ff0000", "color_picker": True}), "stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}), diff --git a/backend/nodes/preview_image.py b/backend/nodes/preview_image.py index f860177..d39f207 100644 --- a/backend/nodes/preview_image.py +++ b/backend/nodes/preview_image.py @@ -22,7 +22,10 @@ class PreviewImage: "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), }, "optional": { - "input": ("ANNOTATION_SOURCE", {"label": "Input"}), + "input": ("ANNOTATION_SOURCE", { + "label": "Input", + "accepted_types": ["DATA_FIELD", "IMAGE"], + }), "colormap_map": ("COLORMAP", {"label": "colormap"}), } } diff --git a/backend/nodes/print_table.py b/backend/nodes/print_table.py index 72f1c44..30cc5ce 100644 --- a/backend/nodes/print_table.py +++ b/backend/nodes/print_table.py @@ -9,7 +9,9 @@ class PrintTable: def INPUT_TYPES(cls): return { "required": { - "table": ("ANY_TABLE",), + "table": ("MEASURE_TABLE", { + "accepted_types": ["RECORD_TABLE"], + }), } } diff --git a/backend/nodes/save.py b/backend/nodes/save.py index b7fce1c..ae01f32 100644 --- a/backend/nodes/save.py +++ b/backend/nodes/save.py @@ -29,7 +29,18 @@ class Save: "hide_when_input_connected": "directory", "top_socket_input": "directory", }), - "value": ("SAVE_VALUE", {"label": "value"}), + "value": ("DATA_FIELD", { + "label": "value", + "accepted_types": [ + "IMAGE", + "ANNOTATION_SOURCE", + "LINE", + "MEASURE_TABLE", + "RECORD_TABLE", + "MESH_MODEL", + "FLOAT", + ], + }), "format": ("STRING", { "default": "TIFF", "choices_by_source_type": { diff --git a/backend/nodes/save_image.py b/backend/nodes/save_image.py index f9365c1..072f198 100644 --- a/backend/nodes/save_image.py +++ b/backend/nodes/save_image.py @@ -17,7 +17,10 @@ class SaveImage: "directory": ("DIRECTORY", {"label": "directory"}), } for i in range(_MAX_SAVE_FIELDS): - optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"}) + optional[f"field_{i}"] = ("DATA_FIELD", { + "label": f"layer {i + 1}", + "accepted_types": ["IMAGE", "ANNOTATION_SOURCE"], + }) optional[f"layer_name_{i}"] = ("STRING", { "default": "", "placeholder": "name", diff --git a/backend/nodes/stats.py b/backend/nodes/stats.py index bcbf780..f11fba0 100644 --- a/backend/nodes/stats.py +++ b/backend/nodes/stats.py @@ -26,7 +26,9 @@ class Stats: def INPUT_TYPES(cls): return { "required": { - "input": ("STATS_SOURCE",), + "input": ("DATA_FIELD", { + "accepted_types": ["IMAGE", "LINE", "RECORD_TABLE"], + }), "column": ("STRING", { "default": "value", "choices_from_table_input": "input", diff --git a/backend/nodes/value_display.py b/backend/nodes/value_display.py index e21cabf..eb6cdca 100644 --- a/backend/nodes/value_display.py +++ b/backend/nodes/value_display.py @@ -11,7 +11,9 @@ class ValueDisplay: def INPUT_TYPES(cls): return { "required": { - "value": ("VALUE_SOURCE",), + "value": ("FLOAT", { + "accepted_types": ["MEASURE_TABLE"], + }), "measurement": ("STRING", { "default": "", "choices_from_measure_input": "value", diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 661f932..75f1a48 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -38,7 +38,11 @@ import { import { buildDefaultWidgetValues } from './nodeWidgetDefaults.js'; import { - DATA_TYPES, SOCKET_COMPATIBILITY, TYPE_COLORS, CAT_COLORS, CANVAS_COLORS, + getSpecTypeAndOptions, + socketSpecAcceptsType, + TYPE_COLORS, + CAT_COLORS, + CANVAS_COLORS, } from './constants'; const NODE_TYPES = { custom: CustomNode }; @@ -428,10 +432,54 @@ function compareMenuCategories(a, b) { return String(a?.name || '').localeCompare(String(b?.name || '')); } -function socketTypesCompatible(sourceType, targetType) { - if (sourceType === targetType) return true; - const accepted = SOCKET_COMPATIBILITY[targetType]; - return !!accepted?.has(sourceType); +function getResolvedHandleRef(nodeId, handleId) { + const proxy = parseGroupProxyHandle(handleId); + return { + nodeId: proxy?.nodeId || nodeId, + handleId: proxy?.realHandle || handleId, + type: proxy?.type || getHandleType(handleId), + }; +} + +function getNodeInputSpecForHandle(node, handleId) { + const definition = node?.data?.definition; + if (!definition?.input) return null; + const inputName = getInputName(handleId); + return definition.input.required?.[inputName] + || definition.input.optional?.[inputName] + || null; +} + +function socketTypesCompatible(sourceType, targetSpecOrType) { + return socketSpecAcceptsType(sourceType, targetSpecOrType); +} + +function outputTypeCanConnectToTarget(outputType, targetSpecOrType) { + if (socketTypesCompatible(outputType, targetSpecOrType)) { + return true; + } + return outputType === 'ANNOTATION_SOURCE' + && !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType) + && ( + socketTypesCompatible('DATA_FIELD', targetSpecOrType) + || socketTypesCompatible('IMAGE', targetSpecOrType) + ); +} + +function resolveOutputTypeForTarget(outputType, targetSpecOrType) { + if (outputType !== 'ANNOTATION_SOURCE') { + return outputType; + } + if (socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)) { + return 'ANNOTATION_SOURCE'; + } + if (socketTypesCompatible('DATA_FIELD', targetSpecOrType)) { + return 'DATA_FIELD'; + } + if (socketTypesCompatible('IMAGE', targetSpecOrType)) { + return 'IMAGE'; + } + return 'ANNOTATION_SOURCE'; } function getRenderedNodeBounds(nodes) { @@ -592,7 +640,18 @@ async function captureViewportBlob(viewportEl, options) { // ── Context menu component ──────────────────────────────────────────── -function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection, selectedNodeCount = 0, onCreateGroup = null }) { +function ContextMenu({ + x, + y, + nodeDefs, + onAdd, + onClose, + filterType, + filterSpec = null, + filterDirection, + selectedNodeCount = 0, + onCreateGroup = null, +}) { const [openCat, setOpenCat] = useState(null); const [search, setSearch] = useState(''); const menuRef = useRef(null); @@ -611,14 +670,12 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti const opt = def.input.optional || {}; const allInputs = { ...req, ...opt }; const hasMatch = Object.values(allInputs).some((spec) => { - const [type] = Array.isArray(spec) ? spec : [spec]; - return socketTypesCompatible(filterType, type); + return socketTypesCompatible(filterType, spec); }); if (!hasMatch) continue; } else { const hasMatch = def.output.some((type) => - socketTypesCompatible(type, filterType) - || (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE')) + outputTypeCanConnectToTarget(type, filterSpec || filterType) ); if (!hasMatch) continue; } @@ -661,7 +718,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti items: [...category.items].sort(compareMenuNodes), })) .sort(compareMenuCategories); - }, [nodeDefs, filterType, filterDirection]); + }, [nodeDefs, filterDirection, filterSpec, filterType]); // Flat filtered list for search const searchResults = useMemo(() => { @@ -1262,7 +1319,10 @@ function Flow() { setEdges((prev) => prev.filter((edge) => { if (edge.source !== nodeId) return true; - return socketTypesCompatible(outputType, getHandleType(edge.targetHandle)); + const resolvedTarget = getResolvedHandleRef(edge.target, edge.targetHandle); + const targetNode = reactFlow.getNode(resolvedTarget.nodeId); + const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type; + return socketTypesCompatible(outputType, targetSpec); })); }, [reactFlow, setEdges, setNodeOutputs]); @@ -1328,9 +1388,11 @@ function Flow() { const isValidConnection = useCallback((connection) => { const srcType = getConnectionHandleType(connection.sourceHandle); - const tgtType = getConnectionHandleType(connection.targetHandle); - return socketTypesCompatible(srcType, tgtType); - }, []); + const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle); + const targetNode = reactFlow.getNode(resolvedTarget.nodeId); + const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type; + return socketTypesCompatible(srcType, targetSpec); + }, [reactFlow]); const onConnect = useCallback((params) => { const sourceProxy = parseGroupProxyHandle(params.sourceHandle); @@ -1497,17 +1559,23 @@ function Flow() { const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event; const handleType = getConnectionHandleType(fromHandle.id); + const resolvedFromHandle = getResolvedHandleRef(fromHandle.nodeId, fromHandle.id); + const fromNode = reactFlow.getNode(resolvedFromHandle.nodeId); + const filterSpec = fromHandle.type === 'target' + ? (getNodeInputSpecForHandle(fromNode, resolvedFromHandle.handleId) || handleType) + : handleType; setContextMenu({ x: clientX, y: clientY, filterType: handleType, + filterSpec, filterDirection: fromHandle.type, pendingNodeId: fromHandle.nodeId, pendingHandleId: fromHandle.id, pendingHandleType: fromHandle.type, }); - }, []); + }, [reactFlow]); // ── Widget change callback ────────────────────────────────────────── @@ -1670,18 +1738,18 @@ function Flow() { // Auto-connect if this was triggered by dropping a connection on blank space if (contextMenu.pendingHandleId) { const filterType = contextMenu.filterType; + const filterSpec = contextMenu.filterSpec || filterType; if (contextMenu.pendingHandleType === 'source') { // Dragged from an output → connect to the first matching input on the new node const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) }; const inputName = Object.entries(allInputs).find(([, spec]) => { - const [type] = Array.isArray(spec) ? spec : [spec]; - return socketTypesCompatible(filterType, type); + return socketTypesCompatible(filterType, spec); })?.[0]; if (inputName) { const targetType = (() => { const spec = allInputs[inputName]; - const [type] = Array.isArray(spec) ? spec : [spec]; + const [type] = getSpecTypeAndOptions(spec); return type; })(); const targetHandle = `input::${inputName}::${targetType}`; @@ -1697,11 +1765,10 @@ function Flow() { } else { // Dragged from an input → connect from the first matching output on the new node const outputIdx = def.output.findIndex((type) => - socketTypesCompatible(type, filterType) - || (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE')) + outputTypeCanConnectToTarget(type, filterSpec) ); if (outputIdx !== -1) { - const outputType = def.output[outputIdx] === 'ANNOTATION_SOURCE' ? filterType : def.output[outputIdx]; + const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec); const sourceHandle = `output::${outputIdx}::${outputType}`; const color = TYPE_COLORS[outputType] || 'var(--fallback-type)'; setEdges((eds) => addEdge({ @@ -2848,6 +2915,7 @@ function Flow() { onCreateGroup={createGroupFromSelection} onClose={() => setContextMenu(null)} filterType={contextMenu.filterType} + filterSpec={contextMenu.filterSpec} filterDirection={contextMenu.filterDirection} selectedNodeCount={selectedNodeCount} /> diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 897373a..8e8ecf9 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -10,7 +10,7 @@ const MarkupOverlay = lazy(() => import('./MarkupOverlay')); const AngleMeasureOverlay = lazy(() => import('./AngleMeasureOverlay')); import { - DATA_TYPES, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS, + getSpecTypeAndOptions, isDataSocketSpec, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS, } from './constants'; import { getGroupMinimumSize } from './groupSizing.js'; import { buildCombinedInputNameByWidgetName, formatUiLabel } from './nodeWidgetLayout.js'; @@ -898,8 +898,8 @@ function CustomNode({ id, data }) { const hiddenWidgets = new Set(); for (const [name, spec] of Object.entries(required)) { - const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; - if (DATA_TYPES.has(type)) { + const [type, opts] = getSpecTypeAndOptions(spec); + if (isDataSocketSpec(spec)) { dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) }); visibleInputNames.add(name); } else if (opts?.hidden) { @@ -943,8 +943,8 @@ function CustomNode({ id, data }) { ); for (const [name, spec] of Object.entries(optional)) { - const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; - if (isProgressive && DATA_TYPES.has(type)) { + const [type, opts] = getSpecTypeAndOptions(spec); + if (isProgressive && isDataSocketSpec(spec)) { // Progressive: show this slot only if it's the first or the previous is connected const match = name.match(/^field_(\d+)$/); if (match) { @@ -958,7 +958,7 @@ function CustomNode({ id, data }) { } if (opts?.hidden) { hiddenWidgets.add(name); - } else if (DATA_TYPES.has(type)) { + } else if (isDataSocketSpec(spec)) { dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) }); visibleInputNames.add(name); } else { diff --git a/frontend/src/constants.js b/frontend/src/constants.js index 3ec208a..fc1c0ae 100644 --- a/frontend/src/constants.js +++ b/frontend/src/constants.js @@ -1,9 +1,9 @@ // ── Shared type & color constants ───────────────────────────────────── export const DATA_TYPES = new Set([ - 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE', - 'COORD', 'STATS_SOURCE', 'CURSOR_SOURCE', 'VALUE_SOURCE', 'ANNOTATION_SOURCE', 'COLORMAP', - 'SAVE_LAYER', 'SAVE_VALUE', 'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR', + 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', + 'COORD', 'ANNOTATION_SOURCE', 'COLORMAP', + 'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR', ]); export const SOCKET_WIDGET_TYPES = new Set(['FLOAT', 'INT']); @@ -13,19 +13,13 @@ export const TYPE_COLORS = { IMAGE: '#00ff08a0', LINE: '#ffbe5c', MEASURE_TABLE: '#35e2fd', - RECORD_TABLE: '#fbbf24', - ANY_TABLE: '#67e8f9', + RECORD_TABLE: '#ff7474', COORD: '#e91ed1', - COORDPAIR: '#5c7cb8', + COORDPAIR: '#5cb861', FLOAT: '#ab3197', - INT: '#38bdf8', - STATS_SOURCE: '#c084fc', - CURSOR_SOURCE: '#a78bfa', - VALUE_SOURCE: '#60a5fa', + INT: '#ffffff', ANNOTATION_SOURCE: '#06b6d4', COLORMAP: '#f472b6', - SAVE_LAYER: '#22c55e', - SAVE_VALUE: '#4ade80', MESH_MODEL: '#14b8a6', FONT: '#fb7185', FILE_PATH: '#f59e0b', @@ -46,18 +40,60 @@ export const CAT_COLORS = { }; export const SOCKET_COMPATIBILITY = { - STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'RECORD_TABLE']), - CURSOR_SOURCE: new Set(['DATA_FIELD', 'LINE']), - ANY_TABLE: new Set(['MEASURE_TABLE', 'RECORD_TABLE']), - VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']), - ANNOTATION_SOURCE: new Set(['DATA_FIELD', 'IMAGE']), - SAVE_LAYER: new Set(['DATA_FIELD', 'IMAGE']), - SAVE_VALUE: new Set(['DATA_FIELD', 'IMAGE', 'ANNOTATION_SOURCE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'MESH_MODEL', 'FLOAT']), FLOAT: new Set(['INT']), INT: new Set(['FLOAT']), LINE: new Set(['COORDPAIR']), }; +const EMPTY_SOCKET_TYPE_SET = new Set(); + +export function getSpecTypeAndOptions(spec) { + if (Array.isArray(spec)) { + return [spec[0], spec[1] || {}]; + } + return [spec, {}]; +} + +export function isDataSocketType(type) { + return typeof type === 'string' && DATA_TYPES.has(type); +} + +export function isDataSocketSpec(spec) { + const [type] = getSpecTypeAndOptions(spec); + return isDataSocketType(type); +} + +export function getAcceptedSocketTypes(specOrType) { + const [type, opts] = Array.isArray(specOrType) + ? getSpecTypeAndOptions(specOrType) + : [specOrType, {}]; + if (typeof type !== 'string') { + return EMPTY_SOCKET_TYPE_SET; + } + + const accepted = new Set([type]); + const explicitAccepted = Array.isArray(opts?.accepted_types) ? opts.accepted_types : []; + for (const acceptedType of explicitAccepted) { + if (typeof acceptedType === 'string' && acceptedType) { + accepted.add(acceptedType); + } + } + + const fallbackAccepted = SOCKET_COMPATIBILITY[type]; + if (fallbackAccepted) { + for (const acceptedType of fallbackAccepted) { + accepted.add(acceptedType); + } + } + + return accepted; +} + +export function socketSpecAcceptsType(sourceType, targetSpecOrType) { + if (typeof sourceType !== 'string' || !sourceType) return false; + return getAcceptedSocketTypes(targetSpecOrType).has(sourceType); +} + // Colors used in Canvas 2D / toBlob contexts where CSS var() is unavailable. export const CANVAS_COLORS = { bgDeep: '#0f172a', diff --git a/frontend/src/executionGraph.js b/frontend/src/executionGraph.js index 996ea37..dd795d6 100644 --- a/frontend/src/executionGraph.js +++ b/frontend/src/executionGraph.js @@ -1,4 +1,4 @@ -import { DATA_TYPES } from './constants.js'; +import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js'; const OMITTED_WIDGET_INPUTS_BY_CLASS = { View3D: new Set([ @@ -91,8 +91,8 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f }; for (const [name, spec] of Object.entries(allWidgets)) { if (omittedInputs?.has(name)) continue; - const [type] = Array.isArray(spec) ? spec : [spec]; - if (DATA_TYPES.has(type)) continue; + const [type] = getSpecTypeAndOptions(spec); + if (isDataSocketSpec(spec)) continue; if (type === 'BUTTON') continue; if (valueBag[name] !== undefined) { inputs[name] = valueBag[name]; @@ -125,16 +125,16 @@ export function hasBlockingAutoRunInput(node, edges) { const required = def.input.required || {}; for (const [name, spec] of Object.entries(required)) { - const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; - const hiddenByConnectedInput = (() => { - const raw = opts?.hide_when_input_connected; - if (!raw) return false; - const inputs = Array.isArray(raw) ? raw : [raw]; - return inputs.some((inputName) => edges.some( - (edge) => { - const resolved = resolveExecutionEdge(edge); - return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName); - } + const [type, opts] = getSpecTypeAndOptions(spec); + const hiddenByConnectedInput = (() => { + const raw = opts?.hide_when_input_connected; + if (!raw) return false; + const inputs = Array.isArray(raw) ? raw : [raw]; + return inputs.some((inputName) => edges.some( + (edge) => { + const resolved = resolveExecutionEdge(edge); + return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName); + } )); })(); @@ -144,7 +144,7 @@ export function hasBlockingAutoRunInput(node, edges) { if (!node.data.widgetValues?.[name]) return true; continue; } - if (!DATA_TYPES.has(type)) continue; + if (!isDataSocketSpec(spec)) continue; const hasEdge = edges.some( (edge) => { const resolved = resolveExecutionEdge(edge); diff --git a/frontend/src/nodeWidgetDefaults.js b/frontend/src/nodeWidgetDefaults.js index c06458c..c198130 100644 --- a/frontend/src/nodeWidgetDefaults.js +++ b/frontend/src/nodeWidgetDefaults.js @@ -1,8 +1,8 @@ -import { DATA_TYPES } from './constants.js'; +import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js'; export function getDefaultWidgetValue(spec) { - const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; - if (DATA_TYPES.has(type)) return undefined; + const [type, opts] = getSpecTypeAndOptions(spec); + if (isDataSocketSpec(spec)) return undefined; if (type === 'BUTTON') return undefined; if (Array.isArray(type)) { if (typeof opts?.default === 'string' && type.includes(opts.default)) { diff --git a/frontend/tests/constants.test.mjs b/frontend/tests/constants.test.mjs index 987b9c6..ce4fb6b 100644 --- a/frontend/tests/constants.test.mjs +++ b/frontend/tests/constants.test.mjs @@ -1,8 +1,31 @@ import test from 'node:test'; import assert from 'node:assert/strict'; -import { SOCKET_COMPATIBILITY } from '../src/constants.js'; +import { + DATA_TYPES, + getAcceptedSocketTypes, + isDataSocketSpec, + socketSpecAcceptsType, +} from '../src/constants.js'; -test('SAVE_VALUE accepts ANNOTATION_SOURCE inputs', () => { - assert.equal(SOCKET_COMPATIBILITY.SAVE_VALUE.has('ANNOTATION_SOURCE'), true); +test('intrinsic socket compatibility still allows INT to connect to FLOAT sockets', () => { + assert.equal(socketSpecAcceptsType('INT', 'FLOAT'), true); + assert.equal(socketSpecAcceptsType('FLOAT', 'INT'), true); +}); + +test('retired save alias types are no longer first-class socket types', () => { + assert.equal(DATA_TYPES.has('SAVE_VALUE'), false); + assert.equal(DATA_TYPES.has('SAVE_LAYER'), false); +}); + +test('accepted_types extend canonical socket compatibility without reintroducing alias types', () => { + const spec = ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }]; + + assert.equal(isDataSocketSpec(spec), true); + assert.deepEqual( + Array.from(getAcceptedSocketTypes(spec)).sort(), + ['MEASURE_TABLE', 'RECORD_TABLE'], + ); + assert.equal(socketSpecAcceptsType('RECORD_TABLE', spec), true); + assert.equal(socketSpecAcceptsType('LINE', spec), false); }); diff --git a/frontend/tests/executionGraph.test.mjs b/frontend/tests/executionGraph.test.mjs index ae1061d..258bf9e 100644 --- a/frontend/tests/executionGraph.test.mjs +++ b/frontend/tests/executionGraph.test.mjs @@ -478,3 +478,89 @@ test('hasBlockingAutoRunInput skips required file widgets when a connected socke assert.equal(hasBlockingAutoRunInput(node, edges), false); }); + +test('serializeExecutionGraph treats accepted_types inputs as sockets, not widgets', () => { + const nodes = [ + { + id: '1', + data: { + className: 'TableSource', + definition: { + input: { required: {}, optional: {} }, + output: ['RECORD_TABLE'], + output_name: ['rows'], + manual_trigger: false, + }, + widgetValues: {}, + }, + }, + { + id: '2', + data: { + className: 'PrintTable', + definition: { + input: { + required: { + table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }], + }, + optional: {}, + }, + manual_trigger: false, + }, + widgetValues: { table: 'should-not-serialize' }, + }, + }, + ]; + const edges = [ + { + source: '1', + sourceHandle: 'output::0::RECORD_TABLE', + target: '2', + targetHandle: 'input::table::MEASURE_TABLE', + }, + ]; + + const prompt = serializeExecutionGraph(nodes, edges); + + assert.deepEqual(prompt, { + '1': { + class_type: 'TableSource', + inputs: {}, + }, + '2': { + class_type: 'PrintTable', + inputs: { table: ['1', 0] }, + }, + }); +}); + +test('hasBlockingAutoRunInput still blocks unconnected accepted_types sockets', () => { + const node = { + id: '2', + data: { + definition: { + manual_trigger: false, + input: { + required: { + input: ['DATA_FIELD', { accepted_types: ['IMAGE', 'LINE', 'RECORD_TABLE'] }], + }, + optional: {}, + }, + }, + widgetValues: {}, + }, + }; + + assert.equal(hasBlockingAutoRunInput(node, []), true); + assert.equal( + hasBlockingAutoRunInput(node, [ + { + source: '1', + sourceHandle: 'output::0::RECORD_TABLE', + target: '2', + targetHandle: 'input::input::DATA_FIELD', + }, + ]), + false, + ); +}); diff --git a/frontend/tests/nodeClipboard.test.mjs b/frontend/tests/nodeClipboard.test.mjs index ab961bc..6027f9b 100644 --- a/frontend/tests/nodeClipboard.test.mjs +++ b/frontend/tests/nodeClipboard.test.mjs @@ -58,7 +58,7 @@ test('buildNodeClipboardPayload keeps only selected nodes and internal edges', ( source: '2', sourceHandle: 'output::0::IMAGE', target: '3', - targetHandle: 'input::value::SAVE_VALUE', + targetHandle: 'input::value::DATA_FIELD', }, ]; @@ -166,7 +166,7 @@ test('buildNodeClipboardPayloadForIds can include upstream external edges for du source: '2', sourceHandle: 'output::0::IMAGE', target: '3', - targetHandle: 'input::value::SAVE_VALUE', + targetHandle: 'input::value::DATA_FIELD', }, ]; diff --git a/frontend/tests/nodeWidgetDefaults.test.mjs b/frontend/tests/nodeWidgetDefaults.test.mjs index 5a48117..8ae06ee 100644 --- a/frontend/tests/nodeWidgetDefaults.test.mjs +++ b/frontend/tests/nodeWidgetDefaults.test.mjs @@ -16,6 +16,7 @@ test('buildDefaultWidgetValues keeps non-data required widget defaults', () => { input: { required: { input: ['ANNOTATION_SOURCE', { label: 'Input' }], + table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }], shape: [['line', 'rectangle', 'circle', 'arrow'], { default: 'arrow' }], stroke_color: ['STRING', { default: '#ff0000', color_picker: true }], stroke_width: ['INT', { default: 3 }], diff --git a/tests/test_nodes.py b/tests/test_nodes.py index ca15e78..9366785 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -881,6 +881,7 @@ def test_angle_measure(): assert {entry["category"] for entry in info["menu_categories"]} == {"Overlay", "Measure"} required_inputs = AngleMeasure.INPUT_TYPES()["required"] optional_inputs = AngleMeasure.INPUT_TYPES().get("optional", {}) + assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] assert required_inputs["color"][1]["default"] == "#ff9800" assert required_inputs["stroke_width"][1]["default"] == 1.35 assert optional_inputs["line_thickness"][1]["hidden"] is True @@ -1584,6 +1585,10 @@ def test_save_image(): from backend.nodes.save_image import SaveImage import tifffile node = SaveImage() + input_types = SaveImage.INPUT_TYPES() + field_spec = input_types["optional"]["field_0"] + assert field_spec[0] == "DATA_FIELD" + assert field_spec[1]["accepted_types"] == ["IMAGE", "ANNOTATION_SOURCE"] field_a = make_field(data=np.random.default_rng(4).random((32, 32))) field_b = make_field(data=np.random.default_rng(5).random((32, 32))) @@ -1729,6 +1734,9 @@ def test_preview_image(): from backend.data_types import ImageData from backend.execution_context import active_node, execution_callbacks node = PreviewImage() + preview_input = PreviewImage.INPUT_TYPES()["optional"]["input"] + assert preview_input[0] == "ANNOTATION_SOURCE" + assert preview_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] # Set up a capture for the broadcast captured = [] @@ -1794,6 +1802,9 @@ def test_annotations(): node = Annotations() font_node = Font() + annotation_input = Annotations.INPUT_TYPES()["required"]["input"] + assert annotation_input[0] == "ANNOTATION_SOURCE" + assert annotation_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] warnings = [] field = DataField( data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64), @@ -1920,6 +1931,7 @@ def test_markup(): assert _preview_markup_stroke_width(5, 128, 128) == 5 assert _preview_markup_stroke_width(5, 2048, 2048) > 5 + assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"] assert required_inputs["shape"][1]["default"] == "arrow" assert required_inputs["stroke_color"][1]["default"] == "#ff0000" @@ -1987,6 +1999,10 @@ def test_print_table(): from backend.nodes.print_table import PrintTable node = PrintTable() + table_spec = PrintTable.INPUT_TYPES()["required"]["table"] + assert table_spec[0] == "MEASURE_TABLE" + assert table_spec[1]["accepted_types"] == ["RECORD_TABLE"] + captured = [] PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows) PrintTable._current_node_id = "test" @@ -2005,6 +2021,10 @@ def test_value_display(): from backend.nodes.value_display import ValueDisplay node = ValueDisplay() + value_spec = ValueDisplay.INPUT_TYPES()["required"]["value"] + assert value_spec[0] == "FLOAT" + assert value_spec[1]["accepted_types"] == ["MEASURE_TABLE"] + captured = [] ValueDisplay._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) ValueDisplay._current_node_id = "test" @@ -2599,6 +2619,9 @@ def test_line_cursors(): from backend.nodes.cursors import Cursors node = Cursors() + line_spec = Cursors.INPUT_TYPES()["required"]["line"] + assert line_spec[0] == "LINE" + assert line_spec[1]["accepted_types"] == ["DATA_FIELD"] # Create a simple linear ramp line = np.linspace(0, 10, 100).astype(np.float64) @@ -2814,6 +2837,10 @@ def test_stats(): from backend.nodes.stats import Stats node = Stats() + input_spec = Stats.INPUT_TYPES()["required"]["input"] + assert input_spec[0] == "DATA_FIELD" + assert input_spec[1]["accepted_types"] == ["IMAGE", "LINE", "RECORD_TABLE"] + captured = [] Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) Stats._current_node_id = "test" @@ -2998,6 +3025,17 @@ def test_save_generic(): from PIL import Image as PILImage node = Save() + value_spec = node.INPUT_TYPES()["required"]["value"] + assert value_spec[0] == "DATA_FIELD" + assert value_spec[1]["accepted_types"] == [ + "IMAGE", + "ANNOTATION_SOURCE", + "LINE", + "MEASURE_TABLE", + "RECORD_TABLE", + "MESH_MODEL", + "FLOAT", + ] format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"] assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"]