refactor socket types

This commit is contained in:
2026-03-28 13:56:22 -07:00
parent 4368aeb4a0
commit 1b831cda5d
20 changed files with 366 additions and 79 deletions

View File

@@ -38,7 +38,11 @@ import {
import { buildDefaultWidgetValues } from './nodeWidgetDefaults.js';
import {
DATA_TYPES, SOCKET_COMPATIBILITY, TYPE_COLORS, CAT_COLORS, CANVAS_COLORS,
getSpecTypeAndOptions,
socketSpecAcceptsType,
TYPE_COLORS,
CAT_COLORS,
CANVAS_COLORS,
} from './constants';
const NODE_TYPES = { custom: CustomNode };
@@ -428,10 +432,54 @@ function compareMenuCategories(a, b) {
return String(a?.name || '').localeCompare(String(b?.name || ''));
}
function socketTypesCompatible(sourceType, targetType) {
if (sourceType === targetType) return true;
const accepted = SOCKET_COMPATIBILITY[targetType];
return !!accepted?.has(sourceType);
function getResolvedHandleRef(nodeId, handleId) {
const proxy = parseGroupProxyHandle(handleId);
return {
nodeId: proxy?.nodeId || nodeId,
handleId: proxy?.realHandle || handleId,
type: proxy?.type || getHandleType(handleId),
};
}
function getNodeInputSpecForHandle(node, handleId) {
const definition = node?.data?.definition;
if (!definition?.input) return null;
const inputName = getInputName(handleId);
return definition.input.required?.[inputName]
|| definition.input.optional?.[inputName]
|| null;
}
function socketTypesCompatible(sourceType, targetSpecOrType) {
return socketSpecAcceptsType(sourceType, targetSpecOrType);
}
function outputTypeCanConnectToTarget(outputType, targetSpecOrType) {
if (socketTypesCompatible(outputType, targetSpecOrType)) {
return true;
}
return outputType === 'ANNOTATION_SOURCE'
&& !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)
&& (
socketTypesCompatible('DATA_FIELD', targetSpecOrType)
|| socketTypesCompatible('IMAGE', targetSpecOrType)
);
}
function resolveOutputTypeForTarget(outputType, targetSpecOrType) {
if (outputType !== 'ANNOTATION_SOURCE') {
return outputType;
}
if (socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)) {
return 'ANNOTATION_SOURCE';
}
if (socketTypesCompatible('DATA_FIELD', targetSpecOrType)) {
return 'DATA_FIELD';
}
if (socketTypesCompatible('IMAGE', targetSpecOrType)) {
return 'IMAGE';
}
return 'ANNOTATION_SOURCE';
}
function getRenderedNodeBounds(nodes) {
@@ -592,7 +640,18 @@ async function captureViewportBlob(viewportEl, options) {
// ── Context menu component ────────────────────────────────────────────
function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection, selectedNodeCount = 0, onCreateGroup = null }) {
function ContextMenu({
x,
y,
nodeDefs,
onAdd,
onClose,
filterType,
filterSpec = null,
filterDirection,
selectedNodeCount = 0,
onCreateGroup = null,
}) {
const [openCat, setOpenCat] = useState(null);
const [search, setSearch] = useState('');
const menuRef = useRef(null);
@@ -611,14 +670,12 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
const opt = def.input.optional || {};
const allInputs = { ...req, ...opt };
const hasMatch = Object.values(allInputs).some((spec) => {
const [type] = Array.isArray(spec) ? spec : [spec];
return socketTypesCompatible(filterType, type);
return socketTypesCompatible(filterType, spec);
});
if (!hasMatch) continue;
} else {
const hasMatch = def.output.some((type) =>
socketTypesCompatible(type, filterType)
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
outputTypeCanConnectToTarget(type, filterSpec || filterType)
);
if (!hasMatch) continue;
}
@@ -661,7 +718,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
items: [...category.items].sort(compareMenuNodes),
}))
.sort(compareMenuCategories);
}, [nodeDefs, filterType, filterDirection]);
}, [nodeDefs, filterDirection, filterSpec, filterType]);
// Flat filtered list for search
const searchResults = useMemo(() => {
@@ -1262,7 +1319,10 @@ function Flow() {
setEdges((prev) => prev.filter((edge) => {
if (edge.source !== nodeId) return true;
return socketTypesCompatible(outputType, getHandleType(edge.targetHandle));
const resolvedTarget = getResolvedHandleRef(edge.target, edge.targetHandle);
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
return socketTypesCompatible(outputType, targetSpec);
}));
}, [reactFlow, setEdges, setNodeOutputs]);
@@ -1328,9 +1388,11 @@ function Flow() {
const isValidConnection = useCallback((connection) => {
const srcType = getConnectionHandleType(connection.sourceHandle);
const tgtType = getConnectionHandleType(connection.targetHandle);
return socketTypesCompatible(srcType, tgtType);
}, []);
const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle);
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
return socketTypesCompatible(srcType, targetSpec);
}, [reactFlow]);
const onConnect = useCallback((params) => {
const sourceProxy = parseGroupProxyHandle(params.sourceHandle);
@@ -1497,17 +1559,23 @@ function Flow() {
const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event;
const handleType = getConnectionHandleType(fromHandle.id);
const resolvedFromHandle = getResolvedHandleRef(fromHandle.nodeId, fromHandle.id);
const fromNode = reactFlow.getNode(resolvedFromHandle.nodeId);
const filterSpec = fromHandle.type === 'target'
? (getNodeInputSpecForHandle(fromNode, resolvedFromHandle.handleId) || handleType)
: handleType;
setContextMenu({
x: clientX,
y: clientY,
filterType: handleType,
filterSpec,
filterDirection: fromHandle.type,
pendingNodeId: fromHandle.nodeId,
pendingHandleId: fromHandle.id,
pendingHandleType: fromHandle.type,
});
}, []);
}, [reactFlow]);
// ── Widget change callback ──────────────────────────────────────────
@@ -1670,18 +1738,18 @@ function Flow() {
// Auto-connect if this was triggered by dropping a connection on blank space
if (contextMenu.pendingHandleId) {
const filterType = contextMenu.filterType;
const filterSpec = contextMenu.filterSpec || filterType;
if (contextMenu.pendingHandleType === 'source') {
// Dragged from an output → connect to the first matching input on the new node
const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) };
const inputName = Object.entries(allInputs).find(([, spec]) => {
const [type] = Array.isArray(spec) ? spec : [spec];
return socketTypesCompatible(filterType, type);
return socketTypesCompatible(filterType, spec);
})?.[0];
if (inputName) {
const targetType = (() => {
const spec = allInputs[inputName];
const [type] = Array.isArray(spec) ? spec : [spec];
const [type] = getSpecTypeAndOptions(spec);
return type;
})();
const targetHandle = `input::${inputName}::${targetType}`;
@@ -1697,11 +1765,10 @@ function Flow() {
} else {
// Dragged from an input → connect from the first matching output on the new node
const outputIdx = def.output.findIndex((type) =>
socketTypesCompatible(type, filterType)
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
outputTypeCanConnectToTarget(type, filterSpec)
);
if (outputIdx !== -1) {
const outputType = def.output[outputIdx] === 'ANNOTATION_SOURCE' ? filterType : def.output[outputIdx];
const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec);
const sourceHandle = `output::${outputIdx}::${outputType}`;
const color = TYPE_COLORS[outputType] || 'var(--fallback-type)';
setEdges((eds) => addEdge({
@@ -2848,6 +2915,7 @@ function Flow() {
onCreateGroup={createGroupFromSelection}
onClose={() => setContextMenu(null)}
filterType={contextMenu.filterType}
filterSpec={contextMenu.filterSpec}
filterDirection={contextMenu.filterDirection}
selectedNodeCount={selectedNodeCount}
/>