refactor socket types
This commit is contained in:
@@ -38,7 +38,11 @@ import {
|
||||
import { buildDefaultWidgetValues } from './nodeWidgetDefaults.js';
|
||||
|
||||
import {
|
||||
DATA_TYPES, SOCKET_COMPATIBILITY, TYPE_COLORS, CAT_COLORS, CANVAS_COLORS,
|
||||
getSpecTypeAndOptions,
|
||||
socketSpecAcceptsType,
|
||||
TYPE_COLORS,
|
||||
CAT_COLORS,
|
||||
CANVAS_COLORS,
|
||||
} from './constants';
|
||||
|
||||
const NODE_TYPES = { custom: CustomNode };
|
||||
@@ -428,10 +432,54 @@ function compareMenuCategories(a, b) {
|
||||
return String(a?.name || '').localeCompare(String(b?.name || ''));
|
||||
}
|
||||
|
||||
function socketTypesCompatible(sourceType, targetType) {
|
||||
if (sourceType === targetType) return true;
|
||||
const accepted = SOCKET_COMPATIBILITY[targetType];
|
||||
return !!accepted?.has(sourceType);
|
||||
function getResolvedHandleRef(nodeId, handleId) {
|
||||
const proxy = parseGroupProxyHandle(handleId);
|
||||
return {
|
||||
nodeId: proxy?.nodeId || nodeId,
|
||||
handleId: proxy?.realHandle || handleId,
|
||||
type: proxy?.type || getHandleType(handleId),
|
||||
};
|
||||
}
|
||||
|
||||
function getNodeInputSpecForHandle(node, handleId) {
|
||||
const definition = node?.data?.definition;
|
||||
if (!definition?.input) return null;
|
||||
const inputName = getInputName(handleId);
|
||||
return definition.input.required?.[inputName]
|
||||
|| definition.input.optional?.[inputName]
|
||||
|| null;
|
||||
}
|
||||
|
||||
function socketTypesCompatible(sourceType, targetSpecOrType) {
|
||||
return socketSpecAcceptsType(sourceType, targetSpecOrType);
|
||||
}
|
||||
|
||||
function outputTypeCanConnectToTarget(outputType, targetSpecOrType) {
|
||||
if (socketTypesCompatible(outputType, targetSpecOrType)) {
|
||||
return true;
|
||||
}
|
||||
return outputType === 'ANNOTATION_SOURCE'
|
||||
&& !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)
|
||||
&& (
|
||||
socketTypesCompatible('DATA_FIELD', targetSpecOrType)
|
||||
|| socketTypesCompatible('IMAGE', targetSpecOrType)
|
||||
);
|
||||
}
|
||||
|
||||
function resolveOutputTypeForTarget(outputType, targetSpecOrType) {
|
||||
if (outputType !== 'ANNOTATION_SOURCE') {
|
||||
return outputType;
|
||||
}
|
||||
if (socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)) {
|
||||
return 'ANNOTATION_SOURCE';
|
||||
}
|
||||
if (socketTypesCompatible('DATA_FIELD', targetSpecOrType)) {
|
||||
return 'DATA_FIELD';
|
||||
}
|
||||
if (socketTypesCompatible('IMAGE', targetSpecOrType)) {
|
||||
return 'IMAGE';
|
||||
}
|
||||
return 'ANNOTATION_SOURCE';
|
||||
}
|
||||
|
||||
function getRenderedNodeBounds(nodes) {
|
||||
@@ -592,7 +640,18 @@ async function captureViewportBlob(viewportEl, options) {
|
||||
|
||||
// ── Context menu component ────────────────────────────────────────────
|
||||
|
||||
function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection, selectedNodeCount = 0, onCreateGroup = null }) {
|
||||
function ContextMenu({
|
||||
x,
|
||||
y,
|
||||
nodeDefs,
|
||||
onAdd,
|
||||
onClose,
|
||||
filterType,
|
||||
filterSpec = null,
|
||||
filterDirection,
|
||||
selectedNodeCount = 0,
|
||||
onCreateGroup = null,
|
||||
}) {
|
||||
const [openCat, setOpenCat] = useState(null);
|
||||
const [search, setSearch] = useState('');
|
||||
const menuRef = useRef(null);
|
||||
@@ -611,14 +670,12 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
|
||||
const opt = def.input.optional || {};
|
||||
const allInputs = { ...req, ...opt };
|
||||
const hasMatch = Object.values(allInputs).some((spec) => {
|
||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||
return socketTypesCompatible(filterType, type);
|
||||
return socketTypesCompatible(filterType, spec);
|
||||
});
|
||||
if (!hasMatch) continue;
|
||||
} else {
|
||||
const hasMatch = def.output.some((type) =>
|
||||
socketTypesCompatible(type, filterType)
|
||||
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
|
||||
outputTypeCanConnectToTarget(type, filterSpec || filterType)
|
||||
);
|
||||
if (!hasMatch) continue;
|
||||
}
|
||||
@@ -661,7 +718,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
|
||||
items: [...category.items].sort(compareMenuNodes),
|
||||
}))
|
||||
.sort(compareMenuCategories);
|
||||
}, [nodeDefs, filterType, filterDirection]);
|
||||
}, [nodeDefs, filterDirection, filterSpec, filterType]);
|
||||
|
||||
// Flat filtered list for search
|
||||
const searchResults = useMemo(() => {
|
||||
@@ -1262,7 +1319,10 @@ function Flow() {
|
||||
|
||||
setEdges((prev) => prev.filter((edge) => {
|
||||
if (edge.source !== nodeId) return true;
|
||||
return socketTypesCompatible(outputType, getHandleType(edge.targetHandle));
|
||||
const resolvedTarget = getResolvedHandleRef(edge.target, edge.targetHandle);
|
||||
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
|
||||
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
|
||||
return socketTypesCompatible(outputType, targetSpec);
|
||||
}));
|
||||
}, [reactFlow, setEdges, setNodeOutputs]);
|
||||
|
||||
@@ -1328,9 +1388,11 @@ function Flow() {
|
||||
|
||||
const isValidConnection = useCallback((connection) => {
|
||||
const srcType = getConnectionHandleType(connection.sourceHandle);
|
||||
const tgtType = getConnectionHandleType(connection.targetHandle);
|
||||
return socketTypesCompatible(srcType, tgtType);
|
||||
}, []);
|
||||
const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle);
|
||||
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
|
||||
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
|
||||
return socketTypesCompatible(srcType, targetSpec);
|
||||
}, [reactFlow]);
|
||||
|
||||
const onConnect = useCallback((params) => {
|
||||
const sourceProxy = parseGroupProxyHandle(params.sourceHandle);
|
||||
@@ -1497,17 +1559,23 @@ function Flow() {
|
||||
|
||||
const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event;
|
||||
const handleType = getConnectionHandleType(fromHandle.id);
|
||||
const resolvedFromHandle = getResolvedHandleRef(fromHandle.nodeId, fromHandle.id);
|
||||
const fromNode = reactFlow.getNode(resolvedFromHandle.nodeId);
|
||||
const filterSpec = fromHandle.type === 'target'
|
||||
? (getNodeInputSpecForHandle(fromNode, resolvedFromHandle.handleId) || handleType)
|
||||
: handleType;
|
||||
|
||||
setContextMenu({
|
||||
x: clientX,
|
||||
y: clientY,
|
||||
filterType: handleType,
|
||||
filterSpec,
|
||||
filterDirection: fromHandle.type,
|
||||
pendingNodeId: fromHandle.nodeId,
|
||||
pendingHandleId: fromHandle.id,
|
||||
pendingHandleType: fromHandle.type,
|
||||
});
|
||||
}, []);
|
||||
}, [reactFlow]);
|
||||
|
||||
// ── Widget change callback ──────────────────────────────────────────
|
||||
|
||||
@@ -1670,18 +1738,18 @@ function Flow() {
|
||||
// Auto-connect if this was triggered by dropping a connection on blank space
|
||||
if (contextMenu.pendingHandleId) {
|
||||
const filterType = contextMenu.filterType;
|
||||
const filterSpec = contextMenu.filterSpec || filterType;
|
||||
|
||||
if (contextMenu.pendingHandleType === 'source') {
|
||||
// Dragged from an output → connect to the first matching input on the new node
|
||||
const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) };
|
||||
const inputName = Object.entries(allInputs).find(([, spec]) => {
|
||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||
return socketTypesCompatible(filterType, type);
|
||||
return socketTypesCompatible(filterType, spec);
|
||||
})?.[0];
|
||||
if (inputName) {
|
||||
const targetType = (() => {
|
||||
const spec = allInputs[inputName];
|
||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||
const [type] = getSpecTypeAndOptions(spec);
|
||||
return type;
|
||||
})();
|
||||
const targetHandle = `input::${inputName}::${targetType}`;
|
||||
@@ -1697,11 +1765,10 @@ function Flow() {
|
||||
} else {
|
||||
// Dragged from an input → connect from the first matching output on the new node
|
||||
const outputIdx = def.output.findIndex((type) =>
|
||||
socketTypesCompatible(type, filterType)
|
||||
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
|
||||
outputTypeCanConnectToTarget(type, filterSpec)
|
||||
);
|
||||
if (outputIdx !== -1) {
|
||||
const outputType = def.output[outputIdx] === 'ANNOTATION_SOURCE' ? filterType : def.output[outputIdx];
|
||||
const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec);
|
||||
const sourceHandle = `output::${outputIdx}::${outputType}`;
|
||||
const color = TYPE_COLORS[outputType] || 'var(--fallback-type)';
|
||||
setEdges((eds) => addEdge({
|
||||
@@ -2848,6 +2915,7 @@ function Flow() {
|
||||
onCreateGroup={createGroupFromSelection}
|
||||
onClose={() => setContextMenu(null)}
|
||||
filterType={contextMenu.filterType}
|
||||
filterSpec={contextMenu.filterSpec}
|
||||
filterDirection={contextMenu.filterDirection}
|
||||
selectedNodeCount={selectedNodeCount}
|
||||
/>
|
||||
|
||||
Reference in New Issue
Block a user