refactor socket types

This commit is contained in:
2026-03-28 13:56:22 -07:00
parent 4368aeb4a0
commit 1b831cda5d
20 changed files with 366 additions and 79 deletions

View File

@@ -38,7 +38,11 @@ import {
import { buildDefaultWidgetValues } from './nodeWidgetDefaults.js';
import {
DATA_TYPES, SOCKET_COMPATIBILITY, TYPE_COLORS, CAT_COLORS, CANVAS_COLORS,
getSpecTypeAndOptions,
socketSpecAcceptsType,
TYPE_COLORS,
CAT_COLORS,
CANVAS_COLORS,
} from './constants';
const NODE_TYPES = { custom: CustomNode };
@@ -428,10 +432,54 @@ function compareMenuCategories(a, b) {
return String(a?.name || '').localeCompare(String(b?.name || ''));
}
function socketTypesCompatible(sourceType, targetType) {
if (sourceType === targetType) return true;
const accepted = SOCKET_COMPATIBILITY[targetType];
return !!accepted?.has(sourceType);
function getResolvedHandleRef(nodeId, handleId) {
const proxy = parseGroupProxyHandle(handleId);
return {
nodeId: proxy?.nodeId || nodeId,
handleId: proxy?.realHandle || handleId,
type: proxy?.type || getHandleType(handleId),
};
}
function getNodeInputSpecForHandle(node, handleId) {
const definition = node?.data?.definition;
if (!definition?.input) return null;
const inputName = getInputName(handleId);
return definition.input.required?.[inputName]
|| definition.input.optional?.[inputName]
|| null;
}
function socketTypesCompatible(sourceType, targetSpecOrType) {
return socketSpecAcceptsType(sourceType, targetSpecOrType);
}
function outputTypeCanConnectToTarget(outputType, targetSpecOrType) {
if (socketTypesCompatible(outputType, targetSpecOrType)) {
return true;
}
return outputType === 'ANNOTATION_SOURCE'
&& !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)
&& (
socketTypesCompatible('DATA_FIELD', targetSpecOrType)
|| socketTypesCompatible('IMAGE', targetSpecOrType)
);
}
function resolveOutputTypeForTarget(outputType, targetSpecOrType) {
if (outputType !== 'ANNOTATION_SOURCE') {
return outputType;
}
if (socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)) {
return 'ANNOTATION_SOURCE';
}
if (socketTypesCompatible('DATA_FIELD', targetSpecOrType)) {
return 'DATA_FIELD';
}
if (socketTypesCompatible('IMAGE', targetSpecOrType)) {
return 'IMAGE';
}
return 'ANNOTATION_SOURCE';
}
function getRenderedNodeBounds(nodes) {
@@ -592,7 +640,18 @@ async function captureViewportBlob(viewportEl, options) {
// ── Context menu component ────────────────────────────────────────────
function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection, selectedNodeCount = 0, onCreateGroup = null }) {
function ContextMenu({
x,
y,
nodeDefs,
onAdd,
onClose,
filterType,
filterSpec = null,
filterDirection,
selectedNodeCount = 0,
onCreateGroup = null,
}) {
const [openCat, setOpenCat] = useState(null);
const [search, setSearch] = useState('');
const menuRef = useRef(null);
@@ -611,14 +670,12 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
const opt = def.input.optional || {};
const allInputs = { ...req, ...opt };
const hasMatch = Object.values(allInputs).some((spec) => {
const [type] = Array.isArray(spec) ? spec : [spec];
return socketTypesCompatible(filterType, type);
return socketTypesCompatible(filterType, spec);
});
if (!hasMatch) continue;
} else {
const hasMatch = def.output.some((type) =>
socketTypesCompatible(type, filterType)
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
outputTypeCanConnectToTarget(type, filterSpec || filterType)
);
if (!hasMatch) continue;
}
@@ -661,7 +718,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
items: [...category.items].sort(compareMenuNodes),
}))
.sort(compareMenuCategories);
}, [nodeDefs, filterType, filterDirection]);
}, [nodeDefs, filterDirection, filterSpec, filterType]);
// Flat filtered list for search
const searchResults = useMemo(() => {
@@ -1262,7 +1319,10 @@ function Flow() {
setEdges((prev) => prev.filter((edge) => {
if (edge.source !== nodeId) return true;
return socketTypesCompatible(outputType, getHandleType(edge.targetHandle));
const resolvedTarget = getResolvedHandleRef(edge.target, edge.targetHandle);
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
return socketTypesCompatible(outputType, targetSpec);
}));
}, [reactFlow, setEdges, setNodeOutputs]);
@@ -1328,9 +1388,11 @@ function Flow() {
const isValidConnection = useCallback((connection) => {
const srcType = getConnectionHandleType(connection.sourceHandle);
const tgtType = getConnectionHandleType(connection.targetHandle);
return socketTypesCompatible(srcType, tgtType);
}, []);
const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle);
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
return socketTypesCompatible(srcType, targetSpec);
}, [reactFlow]);
const onConnect = useCallback((params) => {
const sourceProxy = parseGroupProxyHandle(params.sourceHandle);
@@ -1497,17 +1559,23 @@ function Flow() {
const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event;
const handleType = getConnectionHandleType(fromHandle.id);
const resolvedFromHandle = getResolvedHandleRef(fromHandle.nodeId, fromHandle.id);
const fromNode = reactFlow.getNode(resolvedFromHandle.nodeId);
const filterSpec = fromHandle.type === 'target'
? (getNodeInputSpecForHandle(fromNode, resolvedFromHandle.handleId) || handleType)
: handleType;
setContextMenu({
x: clientX,
y: clientY,
filterType: handleType,
filterSpec,
filterDirection: fromHandle.type,
pendingNodeId: fromHandle.nodeId,
pendingHandleId: fromHandle.id,
pendingHandleType: fromHandle.type,
});
}, []);
}, [reactFlow]);
// ── Widget change callback ──────────────────────────────────────────
@@ -1670,18 +1738,18 @@ function Flow() {
// Auto-connect if this was triggered by dropping a connection on blank space
if (contextMenu.pendingHandleId) {
const filterType = contextMenu.filterType;
const filterSpec = contextMenu.filterSpec || filterType;
if (contextMenu.pendingHandleType === 'source') {
// Dragged from an output → connect to the first matching input on the new node
const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) };
const inputName = Object.entries(allInputs).find(([, spec]) => {
const [type] = Array.isArray(spec) ? spec : [spec];
return socketTypesCompatible(filterType, type);
return socketTypesCompatible(filterType, spec);
})?.[0];
if (inputName) {
const targetType = (() => {
const spec = allInputs[inputName];
const [type] = Array.isArray(spec) ? spec : [spec];
const [type] = getSpecTypeAndOptions(spec);
return type;
})();
const targetHandle = `input::${inputName}::${targetType}`;
@@ -1697,11 +1765,10 @@ function Flow() {
} else {
// Dragged from an input → connect from the first matching output on the new node
const outputIdx = def.output.findIndex((type) =>
socketTypesCompatible(type, filterType)
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
outputTypeCanConnectToTarget(type, filterSpec)
);
if (outputIdx !== -1) {
const outputType = def.output[outputIdx] === 'ANNOTATION_SOURCE' ? filterType : def.output[outputIdx];
const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec);
const sourceHandle = `output::${outputIdx}::${outputType}`;
const color = TYPE_COLORS[outputType] || 'var(--fallback-type)';
setEdges((eds) => addEdge({
@@ -2848,6 +2915,7 @@ function Flow() {
onCreateGroup={createGroupFromSelection}
onClose={() => setContextMenu(null)}
filterType={contextMenu.filterType}
filterSpec={contextMenu.filterSpec}
filterDirection={contextMenu.filterDirection}
selectedNodeCount={selectedNodeCount}
/>

View File

@@ -10,7 +10,7 @@ const MarkupOverlay = lazy(() => import('./MarkupOverlay'));
const AngleMeasureOverlay = lazy(() => import('./AngleMeasureOverlay'));
import {
DATA_TYPES, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS,
getSpecTypeAndOptions, isDataSocketSpec, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS,
} from './constants';
import { getGroupMinimumSize } from './groupSizing.js';
import { buildCombinedInputNameByWidgetName, formatUiLabel } from './nodeWidgetLayout.js';
@@ -898,8 +898,8 @@ function CustomNode({ id, data }) {
const hiddenWidgets = new Set();
for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (DATA_TYPES.has(type)) {
const [type, opts] = getSpecTypeAndOptions(spec);
if (isDataSocketSpec(spec)) {
dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) });
visibleInputNames.add(name);
} else if (opts?.hidden) {
@@ -943,8 +943,8 @@ function CustomNode({ id, data }) {
);
for (const [name, spec] of Object.entries(optional)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (isProgressive && DATA_TYPES.has(type)) {
const [type, opts] = getSpecTypeAndOptions(spec);
if (isProgressive && isDataSocketSpec(spec)) {
// Progressive: show this slot only if it's the first or the previous is connected
const match = name.match(/^field_(\d+)$/);
if (match) {
@@ -958,7 +958,7 @@ function CustomNode({ id, data }) {
}
if (opts?.hidden) {
hiddenWidgets.add(name);
} else if (DATA_TYPES.has(type)) {
} else if (isDataSocketSpec(spec)) {
dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) });
visibleInputNames.add(name);
} else {

View File

@@ -1,9 +1,9 @@
// ── Shared type & color constants ─────────────────────────────────────
export const DATA_TYPES = new Set([
'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE',
'COORD', 'STATS_SOURCE', 'CURSOR_SOURCE', 'VALUE_SOURCE', 'ANNOTATION_SOURCE', 'COLORMAP',
'SAVE_LAYER', 'SAVE_VALUE', 'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR',
'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE',
'COORD', 'ANNOTATION_SOURCE', 'COLORMAP',
'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR',
]);
export const SOCKET_WIDGET_TYPES = new Set(['FLOAT', 'INT']);
@@ -13,19 +13,13 @@ export const TYPE_COLORS = {
IMAGE: '#00ff08a0',
LINE: '#ffbe5c',
MEASURE_TABLE: '#35e2fd',
RECORD_TABLE: '#fbbf24',
ANY_TABLE: '#67e8f9',
RECORD_TABLE: '#ff7474',
COORD: '#e91ed1',
COORDPAIR: '#5c7cb8',
COORDPAIR: '#5cb861',
FLOAT: '#ab3197',
INT: '#38bdf8',
STATS_SOURCE: '#c084fc',
CURSOR_SOURCE: '#a78bfa',
VALUE_SOURCE: '#60a5fa',
INT: '#ffffff',
ANNOTATION_SOURCE: '#06b6d4',
COLORMAP: '#f472b6',
SAVE_LAYER: '#22c55e',
SAVE_VALUE: '#4ade80',
MESH_MODEL: '#14b8a6',
FONT: '#fb7185',
FILE_PATH: '#f59e0b',
@@ -46,18 +40,60 @@ export const CAT_COLORS = {
};
export const SOCKET_COMPATIBILITY = {
STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'RECORD_TABLE']),
CURSOR_SOURCE: new Set(['DATA_FIELD', 'LINE']),
ANY_TABLE: new Set(['MEASURE_TABLE', 'RECORD_TABLE']),
VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']),
ANNOTATION_SOURCE: new Set(['DATA_FIELD', 'IMAGE']),
SAVE_LAYER: new Set(['DATA_FIELD', 'IMAGE']),
SAVE_VALUE: new Set(['DATA_FIELD', 'IMAGE', 'ANNOTATION_SOURCE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'MESH_MODEL', 'FLOAT']),
FLOAT: new Set(['INT']),
INT: new Set(['FLOAT']),
LINE: new Set(['COORDPAIR']),
};
const EMPTY_SOCKET_TYPE_SET = new Set();
export function getSpecTypeAndOptions(spec) {
if (Array.isArray(spec)) {
return [spec[0], spec[1] || {}];
}
return [spec, {}];
}
export function isDataSocketType(type) {
return typeof type === 'string' && DATA_TYPES.has(type);
}
export function isDataSocketSpec(spec) {
const [type] = getSpecTypeAndOptions(spec);
return isDataSocketType(type);
}
export function getAcceptedSocketTypes(specOrType) {
const [type, opts] = Array.isArray(specOrType)
? getSpecTypeAndOptions(specOrType)
: [specOrType, {}];
if (typeof type !== 'string') {
return EMPTY_SOCKET_TYPE_SET;
}
const accepted = new Set([type]);
const explicitAccepted = Array.isArray(opts?.accepted_types) ? opts.accepted_types : [];
for (const acceptedType of explicitAccepted) {
if (typeof acceptedType === 'string' && acceptedType) {
accepted.add(acceptedType);
}
}
const fallbackAccepted = SOCKET_COMPATIBILITY[type];
if (fallbackAccepted) {
for (const acceptedType of fallbackAccepted) {
accepted.add(acceptedType);
}
}
return accepted;
}
export function socketSpecAcceptsType(sourceType, targetSpecOrType) {
if (typeof sourceType !== 'string' || !sourceType) return false;
return getAcceptedSocketTypes(targetSpecOrType).has(sourceType);
}
// Colors used in Canvas 2D / toBlob contexts where CSS var() is unavailable.
export const CANVAS_COLORS = {
bgDeep: '#0f172a',

View File

@@ -1,4 +1,4 @@
import { DATA_TYPES } from './constants.js';
import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js';
const OMITTED_WIDGET_INPUTS_BY_CLASS = {
View3D: new Set([
@@ -91,8 +91,8 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f
};
for (const [name, spec] of Object.entries(allWidgets)) {
if (omittedInputs?.has(name)) continue;
const [type] = Array.isArray(spec) ? spec : [spec];
if (DATA_TYPES.has(type)) continue;
const [type] = getSpecTypeAndOptions(spec);
if (isDataSocketSpec(spec)) continue;
if (type === 'BUTTON') continue;
if (valueBag[name] !== undefined) {
inputs[name] = valueBag[name];
@@ -125,16 +125,16 @@ export function hasBlockingAutoRunInput(node, edges) {
const required = def.input.required || {};
for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
const hiddenByConnectedInput = (() => {
const raw = opts?.hide_when_input_connected;
if (!raw) return false;
const inputs = Array.isArray(raw) ? raw : [raw];
return inputs.some((inputName) => edges.some(
(edge) => {
const resolved = resolveExecutionEdge(edge);
return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName);
}
const [type, opts] = getSpecTypeAndOptions(spec);
const hiddenByConnectedInput = (() => {
const raw = opts?.hide_when_input_connected;
if (!raw) return false;
const inputs = Array.isArray(raw) ? raw : [raw];
return inputs.some((inputName) => edges.some(
(edge) => {
const resolved = resolveExecutionEdge(edge);
return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName);
}
));
})();
@@ -144,7 +144,7 @@ export function hasBlockingAutoRunInput(node, edges) {
if (!node.data.widgetValues?.[name]) return true;
continue;
}
if (!DATA_TYPES.has(type)) continue;
if (!isDataSocketSpec(spec)) continue;
const hasEdge = edges.some(
(edge) => {
const resolved = resolveExecutionEdge(edge);

View File

@@ -1,8 +1,8 @@
import { DATA_TYPES } from './constants.js';
import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js';
export function getDefaultWidgetValue(spec) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (DATA_TYPES.has(type)) return undefined;
const [type, opts] = getSpecTypeAndOptions(spec);
if (isDataSocketSpec(spec)) return undefined;
if (type === 'BUTTON') return undefined;
if (Array.isArray(type)) {
if (typeof opts?.default === 'string' && type.includes(opts.default)) {

View File

@@ -1,8 +1,31 @@
import test from 'node:test';
import assert from 'node:assert/strict';
import { SOCKET_COMPATIBILITY } from '../src/constants.js';
import {
DATA_TYPES,
getAcceptedSocketTypes,
isDataSocketSpec,
socketSpecAcceptsType,
} from '../src/constants.js';
test('SAVE_VALUE accepts ANNOTATION_SOURCE inputs', () => {
assert.equal(SOCKET_COMPATIBILITY.SAVE_VALUE.has('ANNOTATION_SOURCE'), true);
test('intrinsic socket compatibility still allows INT to connect to FLOAT sockets', () => {
assert.equal(socketSpecAcceptsType('INT', 'FLOAT'), true);
assert.equal(socketSpecAcceptsType('FLOAT', 'INT'), true);
});
test('retired save alias types are no longer first-class socket types', () => {
assert.equal(DATA_TYPES.has('SAVE_VALUE'), false);
assert.equal(DATA_TYPES.has('SAVE_LAYER'), false);
});
test('accepted_types extend canonical socket compatibility without reintroducing alias types', () => {
const spec = ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }];
assert.equal(isDataSocketSpec(spec), true);
assert.deepEqual(
Array.from(getAcceptedSocketTypes(spec)).sort(),
['MEASURE_TABLE', 'RECORD_TABLE'],
);
assert.equal(socketSpecAcceptsType('RECORD_TABLE', spec), true);
assert.equal(socketSpecAcceptsType('LINE', spec), false);
});

View File

@@ -478,3 +478,89 @@ test('hasBlockingAutoRunInput skips required file widgets when a connected socke
assert.equal(hasBlockingAutoRunInput(node, edges), false);
});
test('serializeExecutionGraph treats accepted_types inputs as sockets, not widgets', () => {
const nodes = [
{
id: '1',
data: {
className: 'TableSource',
definition: {
input: { required: {}, optional: {} },
output: ['RECORD_TABLE'],
output_name: ['rows'],
manual_trigger: false,
},
widgetValues: {},
},
},
{
id: '2',
data: {
className: 'PrintTable',
definition: {
input: {
required: {
table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }],
},
optional: {},
},
manual_trigger: false,
},
widgetValues: { table: 'should-not-serialize' },
},
},
];
const edges = [
{
source: '1',
sourceHandle: 'output::0::RECORD_TABLE',
target: '2',
targetHandle: 'input::table::MEASURE_TABLE',
},
];
const prompt = serializeExecutionGraph(nodes, edges);
assert.deepEqual(prompt, {
'1': {
class_type: 'TableSource',
inputs: {},
},
'2': {
class_type: 'PrintTable',
inputs: { table: ['1', 0] },
},
});
});
test('hasBlockingAutoRunInput still blocks unconnected accepted_types sockets', () => {
const node = {
id: '2',
data: {
definition: {
manual_trigger: false,
input: {
required: {
input: ['DATA_FIELD', { accepted_types: ['IMAGE', 'LINE', 'RECORD_TABLE'] }],
},
optional: {},
},
},
widgetValues: {},
},
};
assert.equal(hasBlockingAutoRunInput(node, []), true);
assert.equal(
hasBlockingAutoRunInput(node, [
{
source: '1',
sourceHandle: 'output::0::RECORD_TABLE',
target: '2',
targetHandle: 'input::input::DATA_FIELD',
},
]),
false,
);
});

View File

@@ -58,7 +58,7 @@ test('buildNodeClipboardPayload keeps only selected nodes and internal edges', (
source: '2',
sourceHandle: 'output::0::IMAGE',
target: '3',
targetHandle: 'input::value::SAVE_VALUE',
targetHandle: 'input::value::DATA_FIELD',
},
];
@@ -166,7 +166,7 @@ test('buildNodeClipboardPayloadForIds can include upstream external edges for du
source: '2',
sourceHandle: 'output::0::IMAGE',
target: '3',
targetHandle: 'input::value::SAVE_VALUE',
targetHandle: 'input::value::DATA_FIELD',
},
];

View File

@@ -16,6 +16,7 @@ test('buildDefaultWidgetValues keeps non-data required widget defaults', () => {
input: {
required: {
input: ['ANNOTATION_SOURCE', { label: 'Input' }],
table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }],
shape: [['line', 'rectangle', 'circle', 'arrow'], { default: 'arrow' }],
stroke_color: ['STRING', { default: '#ff0000', color_picker: true }],
stroke_width: ['INT', { default: 3 }],