refactor socket types
This commit is contained in:
@@ -85,7 +85,10 @@ class AngleMeasure:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input": ("ANNOTATION_SOURCE", {"label": "Input"}),
|
||||
"input": ("ANNOTATION_SOURCE", {
|
||||
"label": "Input",
|
||||
"accepted_types": ["DATA_FIELD", "IMAGE"],
|
||||
}),
|
||||
"color": ("STRING", {"default": ANGLE_DEFAULT_COLOR, "color_picker": True}),
|
||||
"stroke_width": ("FLOAT", {
|
||||
"default": 1.35,
|
||||
|
||||
@@ -23,7 +23,10 @@ class Annotations:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input": ("ANNOTATION_SOURCE", {"label": "Input"}),
|
||||
"input": ("ANNOTATION_SOURCE", {
|
||||
"label": "Input",
|
||||
"accepted_types": ["DATA_FIELD", "IMAGE"],
|
||||
}),
|
||||
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
|
||||
"show_scale_bar": ("BOOLEAN", {"default": True}),
|
||||
"show_color_map": ("BOOLEAN", {"default": True}),
|
||||
|
||||
@@ -13,7 +13,10 @@ class Cursors:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"line": ("CURSOR_SOURCE", {"label": "input"}),
|
||||
"line": ("LINE", {
|
||||
"label": "input",
|
||||
"accepted_types": ["DATA_FIELD"],
|
||||
}),
|
||||
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
|
||||
@@ -21,7 +21,10 @@ class Markup:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input": ("ANNOTATION_SOURCE", {"label": "Input"}),
|
||||
"input": ("ANNOTATION_SOURCE", {
|
||||
"label": "Input",
|
||||
"accepted_types": ["DATA_FIELD", "IMAGE"],
|
||||
}),
|
||||
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "arrow"}),
|
||||
"stroke_color": ("STRING", {"default": "#ff0000", "color_picker": True}),
|
||||
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),
|
||||
|
||||
@@ -22,7 +22,10 @@ class PreviewImage:
|
||||
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
|
||||
},
|
||||
"optional": {
|
||||
"input": ("ANNOTATION_SOURCE", {"label": "Input"}),
|
||||
"input": ("ANNOTATION_SOURCE", {
|
||||
"label": "Input",
|
||||
"accepted_types": ["DATA_FIELD", "IMAGE"],
|
||||
}),
|
||||
"colormap_map": ("COLORMAP", {"label": "colormap"}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,9 @@ class PrintTable:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"table": ("ANY_TABLE",),
|
||||
"table": ("MEASURE_TABLE", {
|
||||
"accepted_types": ["RECORD_TABLE"],
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,18 @@ class Save:
|
||||
"hide_when_input_connected": "directory",
|
||||
"top_socket_input": "directory",
|
||||
}),
|
||||
"value": ("SAVE_VALUE", {"label": "value"}),
|
||||
"value": ("DATA_FIELD", {
|
||||
"label": "value",
|
||||
"accepted_types": [
|
||||
"IMAGE",
|
||||
"ANNOTATION_SOURCE",
|
||||
"LINE",
|
||||
"MEASURE_TABLE",
|
||||
"RECORD_TABLE",
|
||||
"MESH_MODEL",
|
||||
"FLOAT",
|
||||
],
|
||||
}),
|
||||
"format": ("STRING", {
|
||||
"default": "TIFF",
|
||||
"choices_by_source_type": {
|
||||
|
||||
@@ -17,7 +17,10 @@ class SaveImage:
|
||||
"directory": ("DIRECTORY", {"label": "directory"}),
|
||||
}
|
||||
for i in range(_MAX_SAVE_FIELDS):
|
||||
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"})
|
||||
optional[f"field_{i}"] = ("DATA_FIELD", {
|
||||
"label": f"layer {i + 1}",
|
||||
"accepted_types": ["IMAGE", "ANNOTATION_SOURCE"],
|
||||
})
|
||||
optional[f"layer_name_{i}"] = ("STRING", {
|
||||
"default": "",
|
||||
"placeholder": "name",
|
||||
|
||||
@@ -26,7 +26,9 @@ class Stats:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input": ("STATS_SOURCE",),
|
||||
"input": ("DATA_FIELD", {
|
||||
"accepted_types": ["IMAGE", "LINE", "RECORD_TABLE"],
|
||||
}),
|
||||
"column": ("STRING", {
|
||||
"default": "value",
|
||||
"choices_from_table_input": "input",
|
||||
|
||||
@@ -11,7 +11,9 @@ class ValueDisplay:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("VALUE_SOURCE",),
|
||||
"value": ("FLOAT", {
|
||||
"accepted_types": ["MEASURE_TABLE"],
|
||||
}),
|
||||
"measurement": ("STRING", {
|
||||
"default": "",
|
||||
"choices_from_measure_input": "value",
|
||||
|
||||
@@ -38,7 +38,11 @@ import {
|
||||
import { buildDefaultWidgetValues } from './nodeWidgetDefaults.js';
|
||||
|
||||
import {
|
||||
DATA_TYPES, SOCKET_COMPATIBILITY, TYPE_COLORS, CAT_COLORS, CANVAS_COLORS,
|
||||
getSpecTypeAndOptions,
|
||||
socketSpecAcceptsType,
|
||||
TYPE_COLORS,
|
||||
CAT_COLORS,
|
||||
CANVAS_COLORS,
|
||||
} from './constants';
|
||||
|
||||
const NODE_TYPES = { custom: CustomNode };
|
||||
@@ -428,10 +432,54 @@ function compareMenuCategories(a, b) {
|
||||
return String(a?.name || '').localeCompare(String(b?.name || ''));
|
||||
}
|
||||
|
||||
function socketTypesCompatible(sourceType, targetType) {
|
||||
if (sourceType === targetType) return true;
|
||||
const accepted = SOCKET_COMPATIBILITY[targetType];
|
||||
return !!accepted?.has(sourceType);
|
||||
function getResolvedHandleRef(nodeId, handleId) {
|
||||
const proxy = parseGroupProxyHandle(handleId);
|
||||
return {
|
||||
nodeId: proxy?.nodeId || nodeId,
|
||||
handleId: proxy?.realHandle || handleId,
|
||||
type: proxy?.type || getHandleType(handleId),
|
||||
};
|
||||
}
|
||||
|
||||
function getNodeInputSpecForHandle(node, handleId) {
|
||||
const definition = node?.data?.definition;
|
||||
if (!definition?.input) return null;
|
||||
const inputName = getInputName(handleId);
|
||||
return definition.input.required?.[inputName]
|
||||
|| definition.input.optional?.[inputName]
|
||||
|| null;
|
||||
}
|
||||
|
||||
function socketTypesCompatible(sourceType, targetSpecOrType) {
|
||||
return socketSpecAcceptsType(sourceType, targetSpecOrType);
|
||||
}
|
||||
|
||||
function outputTypeCanConnectToTarget(outputType, targetSpecOrType) {
|
||||
if (socketTypesCompatible(outputType, targetSpecOrType)) {
|
||||
return true;
|
||||
}
|
||||
return outputType === 'ANNOTATION_SOURCE'
|
||||
&& !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)
|
||||
&& (
|
||||
socketTypesCompatible('DATA_FIELD', targetSpecOrType)
|
||||
|| socketTypesCompatible('IMAGE', targetSpecOrType)
|
||||
);
|
||||
}
|
||||
|
||||
function resolveOutputTypeForTarget(outputType, targetSpecOrType) {
|
||||
if (outputType !== 'ANNOTATION_SOURCE') {
|
||||
return outputType;
|
||||
}
|
||||
if (socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)) {
|
||||
return 'ANNOTATION_SOURCE';
|
||||
}
|
||||
if (socketTypesCompatible('DATA_FIELD', targetSpecOrType)) {
|
||||
return 'DATA_FIELD';
|
||||
}
|
||||
if (socketTypesCompatible('IMAGE', targetSpecOrType)) {
|
||||
return 'IMAGE';
|
||||
}
|
||||
return 'ANNOTATION_SOURCE';
|
||||
}
|
||||
|
||||
function getRenderedNodeBounds(nodes) {
|
||||
@@ -592,7 +640,18 @@ async function captureViewportBlob(viewportEl, options) {
|
||||
|
||||
// ── Context menu component ────────────────────────────────────────────
|
||||
|
||||
function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection, selectedNodeCount = 0, onCreateGroup = null }) {
|
||||
function ContextMenu({
|
||||
x,
|
||||
y,
|
||||
nodeDefs,
|
||||
onAdd,
|
||||
onClose,
|
||||
filterType,
|
||||
filterSpec = null,
|
||||
filterDirection,
|
||||
selectedNodeCount = 0,
|
||||
onCreateGroup = null,
|
||||
}) {
|
||||
const [openCat, setOpenCat] = useState(null);
|
||||
const [search, setSearch] = useState('');
|
||||
const menuRef = useRef(null);
|
||||
@@ -611,14 +670,12 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
|
||||
const opt = def.input.optional || {};
|
||||
const allInputs = { ...req, ...opt };
|
||||
const hasMatch = Object.values(allInputs).some((spec) => {
|
||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||
return socketTypesCompatible(filterType, type);
|
||||
return socketTypesCompatible(filterType, spec);
|
||||
});
|
||||
if (!hasMatch) continue;
|
||||
} else {
|
||||
const hasMatch = def.output.some((type) =>
|
||||
socketTypesCompatible(type, filterType)
|
||||
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
|
||||
outputTypeCanConnectToTarget(type, filterSpec || filterType)
|
||||
);
|
||||
if (!hasMatch) continue;
|
||||
}
|
||||
@@ -661,7 +718,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
|
||||
items: [...category.items].sort(compareMenuNodes),
|
||||
}))
|
||||
.sort(compareMenuCategories);
|
||||
}, [nodeDefs, filterType, filterDirection]);
|
||||
}, [nodeDefs, filterDirection, filterSpec, filterType]);
|
||||
|
||||
// Flat filtered list for search
|
||||
const searchResults = useMemo(() => {
|
||||
@@ -1262,7 +1319,10 @@ function Flow() {
|
||||
|
||||
setEdges((prev) => prev.filter((edge) => {
|
||||
if (edge.source !== nodeId) return true;
|
||||
return socketTypesCompatible(outputType, getHandleType(edge.targetHandle));
|
||||
const resolvedTarget = getResolvedHandleRef(edge.target, edge.targetHandle);
|
||||
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
|
||||
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
|
||||
return socketTypesCompatible(outputType, targetSpec);
|
||||
}));
|
||||
}, [reactFlow, setEdges, setNodeOutputs]);
|
||||
|
||||
@@ -1328,9 +1388,11 @@ function Flow() {
|
||||
|
||||
const isValidConnection = useCallback((connection) => {
|
||||
const srcType = getConnectionHandleType(connection.sourceHandle);
|
||||
const tgtType = getConnectionHandleType(connection.targetHandle);
|
||||
return socketTypesCompatible(srcType, tgtType);
|
||||
}, []);
|
||||
const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle);
|
||||
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
|
||||
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
|
||||
return socketTypesCompatible(srcType, targetSpec);
|
||||
}, [reactFlow]);
|
||||
|
||||
const onConnect = useCallback((params) => {
|
||||
const sourceProxy = parseGroupProxyHandle(params.sourceHandle);
|
||||
@@ -1497,17 +1559,23 @@ function Flow() {
|
||||
|
||||
const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event;
|
||||
const handleType = getConnectionHandleType(fromHandle.id);
|
||||
const resolvedFromHandle = getResolvedHandleRef(fromHandle.nodeId, fromHandle.id);
|
||||
const fromNode = reactFlow.getNode(resolvedFromHandle.nodeId);
|
||||
const filterSpec = fromHandle.type === 'target'
|
||||
? (getNodeInputSpecForHandle(fromNode, resolvedFromHandle.handleId) || handleType)
|
||||
: handleType;
|
||||
|
||||
setContextMenu({
|
||||
x: clientX,
|
||||
y: clientY,
|
||||
filterType: handleType,
|
||||
filterSpec,
|
||||
filterDirection: fromHandle.type,
|
||||
pendingNodeId: fromHandle.nodeId,
|
||||
pendingHandleId: fromHandle.id,
|
||||
pendingHandleType: fromHandle.type,
|
||||
});
|
||||
}, []);
|
||||
}, [reactFlow]);
|
||||
|
||||
// ── Widget change callback ──────────────────────────────────────────
|
||||
|
||||
@@ -1670,18 +1738,18 @@ function Flow() {
|
||||
// Auto-connect if this was triggered by dropping a connection on blank space
|
||||
if (contextMenu.pendingHandleId) {
|
||||
const filterType = contextMenu.filterType;
|
||||
const filterSpec = contextMenu.filterSpec || filterType;
|
||||
|
||||
if (contextMenu.pendingHandleType === 'source') {
|
||||
// Dragged from an output → connect to the first matching input on the new node
|
||||
const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) };
|
||||
const inputName = Object.entries(allInputs).find(([, spec]) => {
|
||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||
return socketTypesCompatible(filterType, type);
|
||||
return socketTypesCompatible(filterType, spec);
|
||||
})?.[0];
|
||||
if (inputName) {
|
||||
const targetType = (() => {
|
||||
const spec = allInputs[inputName];
|
||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||
const [type] = getSpecTypeAndOptions(spec);
|
||||
return type;
|
||||
})();
|
||||
const targetHandle = `input::${inputName}::${targetType}`;
|
||||
@@ -1697,11 +1765,10 @@ function Flow() {
|
||||
} else {
|
||||
// Dragged from an input → connect from the first matching output on the new node
|
||||
const outputIdx = def.output.findIndex((type) =>
|
||||
socketTypesCompatible(type, filterType)
|
||||
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
|
||||
outputTypeCanConnectToTarget(type, filterSpec)
|
||||
);
|
||||
if (outputIdx !== -1) {
|
||||
const outputType = def.output[outputIdx] === 'ANNOTATION_SOURCE' ? filterType : def.output[outputIdx];
|
||||
const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec);
|
||||
const sourceHandle = `output::${outputIdx}::${outputType}`;
|
||||
const color = TYPE_COLORS[outputType] || 'var(--fallback-type)';
|
||||
setEdges((eds) => addEdge({
|
||||
@@ -2848,6 +2915,7 @@ function Flow() {
|
||||
onCreateGroup={createGroupFromSelection}
|
||||
onClose={() => setContextMenu(null)}
|
||||
filterType={contextMenu.filterType}
|
||||
filterSpec={contextMenu.filterSpec}
|
||||
filterDirection={contextMenu.filterDirection}
|
||||
selectedNodeCount={selectedNodeCount}
|
||||
/>
|
||||
|
||||
@@ -10,7 +10,7 @@ const MarkupOverlay = lazy(() => import('./MarkupOverlay'));
|
||||
const AngleMeasureOverlay = lazy(() => import('./AngleMeasureOverlay'));
|
||||
|
||||
import {
|
||||
DATA_TYPES, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS,
|
||||
getSpecTypeAndOptions, isDataSocketSpec, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS,
|
||||
} from './constants';
|
||||
import { getGroupMinimumSize } from './groupSizing.js';
|
||||
import { buildCombinedInputNameByWidgetName, formatUiLabel } from './nodeWidgetLayout.js';
|
||||
@@ -898,8 +898,8 @@ function CustomNode({ id, data }) {
|
||||
const hiddenWidgets = new Set();
|
||||
|
||||
for (const [name, spec] of Object.entries(required)) {
|
||||
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
|
||||
if (DATA_TYPES.has(type)) {
|
||||
const [type, opts] = getSpecTypeAndOptions(spec);
|
||||
if (isDataSocketSpec(spec)) {
|
||||
dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) });
|
||||
visibleInputNames.add(name);
|
||||
} else if (opts?.hidden) {
|
||||
@@ -943,8 +943,8 @@ function CustomNode({ id, data }) {
|
||||
);
|
||||
|
||||
for (const [name, spec] of Object.entries(optional)) {
|
||||
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
|
||||
if (isProgressive && DATA_TYPES.has(type)) {
|
||||
const [type, opts] = getSpecTypeAndOptions(spec);
|
||||
if (isProgressive && isDataSocketSpec(spec)) {
|
||||
// Progressive: show this slot only if it's the first or the previous is connected
|
||||
const match = name.match(/^field_(\d+)$/);
|
||||
if (match) {
|
||||
@@ -958,7 +958,7 @@ function CustomNode({ id, data }) {
|
||||
}
|
||||
if (opts?.hidden) {
|
||||
hiddenWidgets.add(name);
|
||||
} else if (DATA_TYPES.has(type)) {
|
||||
} else if (isDataSocketSpec(spec)) {
|
||||
dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) });
|
||||
visibleInputNames.add(name);
|
||||
} else {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// ── Shared type & color constants ─────────────────────────────────────
|
||||
|
||||
export const DATA_TYPES = new Set([
|
||||
'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE',
|
||||
'COORD', 'STATS_SOURCE', 'CURSOR_SOURCE', 'VALUE_SOURCE', 'ANNOTATION_SOURCE', 'COLORMAP',
|
||||
'SAVE_LAYER', 'SAVE_VALUE', 'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR',
|
||||
'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE',
|
||||
'COORD', 'ANNOTATION_SOURCE', 'COLORMAP',
|
||||
'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR',
|
||||
]);
|
||||
|
||||
export const SOCKET_WIDGET_TYPES = new Set(['FLOAT', 'INT']);
|
||||
@@ -13,19 +13,13 @@ export const TYPE_COLORS = {
|
||||
IMAGE: '#00ff08a0',
|
||||
LINE: '#ffbe5c',
|
||||
MEASURE_TABLE: '#35e2fd',
|
||||
RECORD_TABLE: '#fbbf24',
|
||||
ANY_TABLE: '#67e8f9',
|
||||
RECORD_TABLE: '#ff7474',
|
||||
COORD: '#e91ed1',
|
||||
COORDPAIR: '#5c7cb8',
|
||||
COORDPAIR: '#5cb861',
|
||||
FLOAT: '#ab3197',
|
||||
INT: '#38bdf8',
|
||||
STATS_SOURCE: '#c084fc',
|
||||
CURSOR_SOURCE: '#a78bfa',
|
||||
VALUE_SOURCE: '#60a5fa',
|
||||
INT: '#ffffff',
|
||||
ANNOTATION_SOURCE: '#06b6d4',
|
||||
COLORMAP: '#f472b6',
|
||||
SAVE_LAYER: '#22c55e',
|
||||
SAVE_VALUE: '#4ade80',
|
||||
MESH_MODEL: '#14b8a6',
|
||||
FONT: '#fb7185',
|
||||
FILE_PATH: '#f59e0b',
|
||||
@@ -46,18 +40,60 @@ export const CAT_COLORS = {
|
||||
};
|
||||
|
||||
export const SOCKET_COMPATIBILITY = {
|
||||
STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'RECORD_TABLE']),
|
||||
CURSOR_SOURCE: new Set(['DATA_FIELD', 'LINE']),
|
||||
ANY_TABLE: new Set(['MEASURE_TABLE', 'RECORD_TABLE']),
|
||||
VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']),
|
||||
ANNOTATION_SOURCE: new Set(['DATA_FIELD', 'IMAGE']),
|
||||
SAVE_LAYER: new Set(['DATA_FIELD', 'IMAGE']),
|
||||
SAVE_VALUE: new Set(['DATA_FIELD', 'IMAGE', 'ANNOTATION_SOURCE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'MESH_MODEL', 'FLOAT']),
|
||||
FLOAT: new Set(['INT']),
|
||||
INT: new Set(['FLOAT']),
|
||||
LINE: new Set(['COORDPAIR']),
|
||||
};
|
||||
|
||||
const EMPTY_SOCKET_TYPE_SET = new Set();
|
||||
|
||||
export function getSpecTypeAndOptions(spec) {
|
||||
if (Array.isArray(spec)) {
|
||||
return [spec[0], spec[1] || {}];
|
||||
}
|
||||
return [spec, {}];
|
||||
}
|
||||
|
||||
export function isDataSocketType(type) {
|
||||
return typeof type === 'string' && DATA_TYPES.has(type);
|
||||
}
|
||||
|
||||
export function isDataSocketSpec(spec) {
|
||||
const [type] = getSpecTypeAndOptions(spec);
|
||||
return isDataSocketType(type);
|
||||
}
|
||||
|
||||
export function getAcceptedSocketTypes(specOrType) {
|
||||
const [type, opts] = Array.isArray(specOrType)
|
||||
? getSpecTypeAndOptions(specOrType)
|
||||
: [specOrType, {}];
|
||||
if (typeof type !== 'string') {
|
||||
return EMPTY_SOCKET_TYPE_SET;
|
||||
}
|
||||
|
||||
const accepted = new Set([type]);
|
||||
const explicitAccepted = Array.isArray(opts?.accepted_types) ? opts.accepted_types : [];
|
||||
for (const acceptedType of explicitAccepted) {
|
||||
if (typeof acceptedType === 'string' && acceptedType) {
|
||||
accepted.add(acceptedType);
|
||||
}
|
||||
}
|
||||
|
||||
const fallbackAccepted = SOCKET_COMPATIBILITY[type];
|
||||
if (fallbackAccepted) {
|
||||
for (const acceptedType of fallbackAccepted) {
|
||||
accepted.add(acceptedType);
|
||||
}
|
||||
}
|
||||
|
||||
return accepted;
|
||||
}
|
||||
|
||||
export function socketSpecAcceptsType(sourceType, targetSpecOrType) {
|
||||
if (typeof sourceType !== 'string' || !sourceType) return false;
|
||||
return getAcceptedSocketTypes(targetSpecOrType).has(sourceType);
|
||||
}
|
||||
|
||||
// Colors used in Canvas 2D / toBlob contexts where CSS var() is unavailable.
|
||||
export const CANVAS_COLORS = {
|
||||
bgDeep: '#0f172a',
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { DATA_TYPES } from './constants.js';
|
||||
import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js';
|
||||
|
||||
const OMITTED_WIDGET_INPUTS_BY_CLASS = {
|
||||
View3D: new Set([
|
||||
@@ -91,8 +91,8 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f
|
||||
};
|
||||
for (const [name, spec] of Object.entries(allWidgets)) {
|
||||
if (omittedInputs?.has(name)) continue;
|
||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||
if (DATA_TYPES.has(type)) continue;
|
||||
const [type] = getSpecTypeAndOptions(spec);
|
||||
if (isDataSocketSpec(spec)) continue;
|
||||
if (type === 'BUTTON') continue;
|
||||
if (valueBag[name] !== undefined) {
|
||||
inputs[name] = valueBag[name];
|
||||
@@ -125,16 +125,16 @@ export function hasBlockingAutoRunInput(node, edges) {
|
||||
|
||||
const required = def.input.required || {};
|
||||
for (const [name, spec] of Object.entries(required)) {
|
||||
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
|
||||
const hiddenByConnectedInput = (() => {
|
||||
const raw = opts?.hide_when_input_connected;
|
||||
if (!raw) return false;
|
||||
const inputs = Array.isArray(raw) ? raw : [raw];
|
||||
return inputs.some((inputName) => edges.some(
|
||||
(edge) => {
|
||||
const resolved = resolveExecutionEdge(edge);
|
||||
return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName);
|
||||
}
|
||||
const [type, opts] = getSpecTypeAndOptions(spec);
|
||||
const hiddenByConnectedInput = (() => {
|
||||
const raw = opts?.hide_when_input_connected;
|
||||
if (!raw) return false;
|
||||
const inputs = Array.isArray(raw) ? raw : [raw];
|
||||
return inputs.some((inputName) => edges.some(
|
||||
(edge) => {
|
||||
const resolved = resolveExecutionEdge(edge);
|
||||
return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName);
|
||||
}
|
||||
));
|
||||
})();
|
||||
|
||||
@@ -144,7 +144,7 @@ export function hasBlockingAutoRunInput(node, edges) {
|
||||
if (!node.data.widgetValues?.[name]) return true;
|
||||
continue;
|
||||
}
|
||||
if (!DATA_TYPES.has(type)) continue;
|
||||
if (!isDataSocketSpec(spec)) continue;
|
||||
const hasEdge = edges.some(
|
||||
(edge) => {
|
||||
const resolved = resolveExecutionEdge(edge);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { DATA_TYPES } from './constants.js';
|
||||
import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js';
|
||||
|
||||
export function getDefaultWidgetValue(spec) {
|
||||
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
|
||||
if (DATA_TYPES.has(type)) return undefined;
|
||||
const [type, opts] = getSpecTypeAndOptions(spec);
|
||||
if (isDataSocketSpec(spec)) return undefined;
|
||||
if (type === 'BUTTON') return undefined;
|
||||
if (Array.isArray(type)) {
|
||||
if (typeof opts?.default === 'string' && type.includes(opts.default)) {
|
||||
|
||||
@@ -1,8 +1,31 @@
|
||||
import test from 'node:test';
|
||||
import assert from 'node:assert/strict';
|
||||
|
||||
import { SOCKET_COMPATIBILITY } from '../src/constants.js';
|
||||
import {
|
||||
DATA_TYPES,
|
||||
getAcceptedSocketTypes,
|
||||
isDataSocketSpec,
|
||||
socketSpecAcceptsType,
|
||||
} from '../src/constants.js';
|
||||
|
||||
test('SAVE_VALUE accepts ANNOTATION_SOURCE inputs', () => {
|
||||
assert.equal(SOCKET_COMPATIBILITY.SAVE_VALUE.has('ANNOTATION_SOURCE'), true);
|
||||
test('intrinsic socket compatibility still allows INT to connect to FLOAT sockets', () => {
|
||||
assert.equal(socketSpecAcceptsType('INT', 'FLOAT'), true);
|
||||
assert.equal(socketSpecAcceptsType('FLOAT', 'INT'), true);
|
||||
});
|
||||
|
||||
test('retired save alias types are no longer first-class socket types', () => {
|
||||
assert.equal(DATA_TYPES.has('SAVE_VALUE'), false);
|
||||
assert.equal(DATA_TYPES.has('SAVE_LAYER'), false);
|
||||
});
|
||||
|
||||
test('accepted_types extend canonical socket compatibility without reintroducing alias types', () => {
|
||||
const spec = ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }];
|
||||
|
||||
assert.equal(isDataSocketSpec(spec), true);
|
||||
assert.deepEqual(
|
||||
Array.from(getAcceptedSocketTypes(spec)).sort(),
|
||||
['MEASURE_TABLE', 'RECORD_TABLE'],
|
||||
);
|
||||
assert.equal(socketSpecAcceptsType('RECORD_TABLE', spec), true);
|
||||
assert.equal(socketSpecAcceptsType('LINE', spec), false);
|
||||
});
|
||||
|
||||
@@ -478,3 +478,89 @@ test('hasBlockingAutoRunInput skips required file widgets when a connected socke
|
||||
|
||||
assert.equal(hasBlockingAutoRunInput(node, edges), false);
|
||||
});
|
||||
|
||||
test('serializeExecutionGraph treats accepted_types inputs as sockets, not widgets', () => {
|
||||
const nodes = [
|
||||
{
|
||||
id: '1',
|
||||
data: {
|
||||
className: 'TableSource',
|
||||
definition: {
|
||||
input: { required: {}, optional: {} },
|
||||
output: ['RECORD_TABLE'],
|
||||
output_name: ['rows'],
|
||||
manual_trigger: false,
|
||||
},
|
||||
widgetValues: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
data: {
|
||||
className: 'PrintTable',
|
||||
definition: {
|
||||
input: {
|
||||
required: {
|
||||
table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }],
|
||||
},
|
||||
optional: {},
|
||||
},
|
||||
manual_trigger: false,
|
||||
},
|
||||
widgetValues: { table: 'should-not-serialize' },
|
||||
},
|
||||
},
|
||||
];
|
||||
const edges = [
|
||||
{
|
||||
source: '1',
|
||||
sourceHandle: 'output::0::RECORD_TABLE',
|
||||
target: '2',
|
||||
targetHandle: 'input::table::MEASURE_TABLE',
|
||||
},
|
||||
];
|
||||
|
||||
const prompt = serializeExecutionGraph(nodes, edges);
|
||||
|
||||
assert.deepEqual(prompt, {
|
||||
'1': {
|
||||
class_type: 'TableSource',
|
||||
inputs: {},
|
||||
},
|
||||
'2': {
|
||||
class_type: 'PrintTable',
|
||||
inputs: { table: ['1', 0] },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('hasBlockingAutoRunInput still blocks unconnected accepted_types sockets', () => {
|
||||
const node = {
|
||||
id: '2',
|
||||
data: {
|
||||
definition: {
|
||||
manual_trigger: false,
|
||||
input: {
|
||||
required: {
|
||||
input: ['DATA_FIELD', { accepted_types: ['IMAGE', 'LINE', 'RECORD_TABLE'] }],
|
||||
},
|
||||
optional: {},
|
||||
},
|
||||
},
|
||||
widgetValues: {},
|
||||
},
|
||||
};
|
||||
|
||||
assert.equal(hasBlockingAutoRunInput(node, []), true);
|
||||
assert.equal(
|
||||
hasBlockingAutoRunInput(node, [
|
||||
{
|
||||
source: '1',
|
||||
sourceHandle: 'output::0::RECORD_TABLE',
|
||||
target: '2',
|
||||
targetHandle: 'input::input::DATA_FIELD',
|
||||
},
|
||||
]),
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -58,7 +58,7 @@ test('buildNodeClipboardPayload keeps only selected nodes and internal edges', (
|
||||
source: '2',
|
||||
sourceHandle: 'output::0::IMAGE',
|
||||
target: '3',
|
||||
targetHandle: 'input::value::SAVE_VALUE',
|
||||
targetHandle: 'input::value::DATA_FIELD',
|
||||
},
|
||||
];
|
||||
|
||||
@@ -166,7 +166,7 @@ test('buildNodeClipboardPayloadForIds can include upstream external edges for du
|
||||
source: '2',
|
||||
sourceHandle: 'output::0::IMAGE',
|
||||
target: '3',
|
||||
targetHandle: 'input::value::SAVE_VALUE',
|
||||
targetHandle: 'input::value::DATA_FIELD',
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ test('buildDefaultWidgetValues keeps non-data required widget defaults', () => {
|
||||
input: {
|
||||
required: {
|
||||
input: ['ANNOTATION_SOURCE', { label: 'Input' }],
|
||||
table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }],
|
||||
shape: [['line', 'rectangle', 'circle', 'arrow'], { default: 'arrow' }],
|
||||
stroke_color: ['STRING', { default: '#ff0000', color_picker: true }],
|
||||
stroke_width: ['INT', { default: 3 }],
|
||||
|
||||
@@ -881,6 +881,7 @@ def test_angle_measure():
|
||||
assert {entry["category"] for entry in info["menu_categories"]} == {"Overlay", "Measure"}
|
||||
required_inputs = AngleMeasure.INPUT_TYPES()["required"]
|
||||
optional_inputs = AngleMeasure.INPUT_TYPES().get("optional", {})
|
||||
assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
|
||||
assert required_inputs["color"][1]["default"] == "#ff9800"
|
||||
assert required_inputs["stroke_width"][1]["default"] == 1.35
|
||||
assert optional_inputs["line_thickness"][1]["hidden"] is True
|
||||
@@ -1584,6 +1585,10 @@ def test_save_image():
|
||||
from backend.nodes.save_image import SaveImage
|
||||
import tifffile
|
||||
node = SaveImage()
|
||||
input_types = SaveImage.INPUT_TYPES()
|
||||
field_spec = input_types["optional"]["field_0"]
|
||||
assert field_spec[0] == "DATA_FIELD"
|
||||
assert field_spec[1]["accepted_types"] == ["IMAGE", "ANNOTATION_SOURCE"]
|
||||
|
||||
field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
|
||||
field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
|
||||
@@ -1729,6 +1734,9 @@ def test_preview_image():
|
||||
from backend.data_types import ImageData
|
||||
from backend.execution_context import active_node, execution_callbacks
|
||||
node = PreviewImage()
|
||||
preview_input = PreviewImage.INPUT_TYPES()["optional"]["input"]
|
||||
assert preview_input[0] == "ANNOTATION_SOURCE"
|
||||
assert preview_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
|
||||
|
||||
# Set up a capture for the broadcast
|
||||
captured = []
|
||||
@@ -1794,6 +1802,9 @@ def test_annotations():
|
||||
|
||||
node = Annotations()
|
||||
font_node = Font()
|
||||
annotation_input = Annotations.INPUT_TYPES()["required"]["input"]
|
||||
assert annotation_input[0] == "ANNOTATION_SOURCE"
|
||||
assert annotation_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
|
||||
warnings = []
|
||||
field = DataField(
|
||||
data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64),
|
||||
@@ -1920,6 +1931,7 @@ def test_markup():
|
||||
|
||||
assert _preview_markup_stroke_width(5, 128, 128) == 5
|
||||
assert _preview_markup_stroke_width(5, 2048, 2048) > 5
|
||||
assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
|
||||
assert required_inputs["shape"][1]["default"] == "arrow"
|
||||
assert required_inputs["stroke_color"][1]["default"] == "#ff0000"
|
||||
|
||||
@@ -1987,6 +1999,10 @@ def test_print_table():
|
||||
from backend.nodes.print_table import PrintTable
|
||||
node = PrintTable()
|
||||
|
||||
table_spec = PrintTable.INPUT_TYPES()["required"]["table"]
|
||||
assert table_spec[0] == "MEASURE_TABLE"
|
||||
assert table_spec[1]["accepted_types"] == ["RECORD_TABLE"]
|
||||
|
||||
captured = []
|
||||
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
|
||||
PrintTable._current_node_id = "test"
|
||||
@@ -2005,6 +2021,10 @@ def test_value_display():
|
||||
from backend.nodes.value_display import ValueDisplay
|
||||
|
||||
node = ValueDisplay()
|
||||
value_spec = ValueDisplay.INPUT_TYPES()["required"]["value"]
|
||||
assert value_spec[0] == "FLOAT"
|
||||
assert value_spec[1]["accepted_types"] == ["MEASURE_TABLE"]
|
||||
|
||||
captured = []
|
||||
ValueDisplay._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
|
||||
ValueDisplay._current_node_id = "test"
|
||||
@@ -2599,6 +2619,9 @@ def test_line_cursors():
|
||||
from backend.nodes.cursors import Cursors
|
||||
|
||||
node = Cursors()
|
||||
line_spec = Cursors.INPUT_TYPES()["required"]["line"]
|
||||
assert line_spec[0] == "LINE"
|
||||
assert line_spec[1]["accepted_types"] == ["DATA_FIELD"]
|
||||
|
||||
# Create a simple linear ramp
|
||||
line = np.linspace(0, 10, 100).astype(np.float64)
|
||||
@@ -2814,6 +2837,10 @@ def test_stats():
|
||||
from backend.nodes.stats import Stats
|
||||
|
||||
node = Stats()
|
||||
input_spec = Stats.INPUT_TYPES()["required"]["input"]
|
||||
assert input_spec[0] == "DATA_FIELD"
|
||||
assert input_spec[1]["accepted_types"] == ["IMAGE", "LINE", "RECORD_TABLE"]
|
||||
|
||||
captured = []
|
||||
Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
|
||||
Stats._current_node_id = "test"
|
||||
@@ -2998,6 +3025,17 @@ def test_save_generic():
|
||||
from PIL import Image as PILImage
|
||||
|
||||
node = Save()
|
||||
value_spec = node.INPUT_TYPES()["required"]["value"]
|
||||
assert value_spec[0] == "DATA_FIELD"
|
||||
assert value_spec[1]["accepted_types"] == [
|
||||
"IMAGE",
|
||||
"ANNOTATION_SOURCE",
|
||||
"LINE",
|
||||
"MEASURE_TABLE",
|
||||
"RECORD_TABLE",
|
||||
"MESH_MODEL",
|
||||
"FLOAT",
|
||||
]
|
||||
format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"]
|
||||
assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user