refactor socket types

This commit is contained in:
2026-03-28 13:56:22 -07:00
parent 4368aeb4a0
commit 1b831cda5d
20 changed files with 366 additions and 79 deletions

View File

@@ -85,7 +85,10 @@ class AngleMeasure:
def INPUT_TYPES(cls):
return {
"required": {
"input": ("ANNOTATION_SOURCE", {"label": "Input"}),
"input": ("ANNOTATION_SOURCE", {
"label": "Input",
"accepted_types": ["DATA_FIELD", "IMAGE"],
}),
"color": ("STRING", {"default": ANGLE_DEFAULT_COLOR, "color_picker": True}),
"stroke_width": ("FLOAT", {
"default": 1.35,

View File

@@ -23,7 +23,10 @@ class Annotations:
def INPUT_TYPES(cls):
return {
"required": {
"input": ("ANNOTATION_SOURCE", {"label": "Input"}),
"input": ("ANNOTATION_SOURCE", {
"label": "Input",
"accepted_types": ["DATA_FIELD", "IMAGE"],
}),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"show_scale_bar": ("BOOLEAN", {"default": True}),
"show_color_map": ("BOOLEAN", {"default": True}),

View File

@@ -13,7 +13,10 @@ class Cursors:
def INPUT_TYPES(cls):
return {
"required": {
"line": ("CURSOR_SOURCE", {"label": "input"}),
"line": ("LINE", {
"label": "input",
"accepted_types": ["DATA_FIELD"],
}),
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),

View File

@@ -21,7 +21,10 @@ class Markup:
def INPUT_TYPES(cls):
return {
"required": {
"input": ("ANNOTATION_SOURCE", {"label": "Input"}),
"input": ("ANNOTATION_SOURCE", {
"label": "Input",
"accepted_types": ["DATA_FIELD", "IMAGE"],
}),
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "arrow"}),
"stroke_color": ("STRING", {"default": "#ff0000", "color_picker": True}),
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),

View File

@@ -22,7 +22,10 @@ class PreviewImage:
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
},
"optional": {
"input": ("ANNOTATION_SOURCE", {"label": "Input"}),
"input": ("ANNOTATION_SOURCE", {
"label": "Input",
"accepted_types": ["DATA_FIELD", "IMAGE"],
}),
"colormap_map": ("COLORMAP", {"label": "colormap"}),
}
}

View File

@@ -9,7 +9,9 @@ class PrintTable:
def INPUT_TYPES(cls):
return {
"required": {
"table": ("ANY_TABLE",),
"table": ("MEASURE_TABLE", {
"accepted_types": ["RECORD_TABLE"],
}),
}
}

View File

@@ -29,7 +29,18 @@ class Save:
"hide_when_input_connected": "directory",
"top_socket_input": "directory",
}),
"value": ("SAVE_VALUE", {"label": "value"}),
"value": ("DATA_FIELD", {
"label": "value",
"accepted_types": [
"IMAGE",
"ANNOTATION_SOURCE",
"LINE",
"MEASURE_TABLE",
"RECORD_TABLE",
"MESH_MODEL",
"FLOAT",
],
}),
"format": ("STRING", {
"default": "TIFF",
"choices_by_source_type": {

View File

@@ -17,7 +17,10 @@ class SaveImage:
"directory": ("DIRECTORY", {"label": "directory"}),
}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"})
optional[f"field_{i}"] = ("DATA_FIELD", {
"label": f"layer {i + 1}",
"accepted_types": ["IMAGE", "ANNOTATION_SOURCE"],
})
optional[f"layer_name_{i}"] = ("STRING", {
"default": "",
"placeholder": "name",

View File

@@ -26,7 +26,9 @@ class Stats:
def INPUT_TYPES(cls):
return {
"required": {
"input": ("STATS_SOURCE",),
"input": ("DATA_FIELD", {
"accepted_types": ["IMAGE", "LINE", "RECORD_TABLE"],
}),
"column": ("STRING", {
"default": "value",
"choices_from_table_input": "input",

View File

@@ -11,7 +11,9 @@ class ValueDisplay:
def INPUT_TYPES(cls):
return {
"required": {
"value": ("VALUE_SOURCE",),
"value": ("FLOAT", {
"accepted_types": ["MEASURE_TABLE"],
}),
"measurement": ("STRING", {
"default": "",
"choices_from_measure_input": "value",

View File

@@ -38,7 +38,11 @@ import {
import { buildDefaultWidgetValues } from './nodeWidgetDefaults.js';
import {
DATA_TYPES, SOCKET_COMPATIBILITY, TYPE_COLORS, CAT_COLORS, CANVAS_COLORS,
getSpecTypeAndOptions,
socketSpecAcceptsType,
TYPE_COLORS,
CAT_COLORS,
CANVAS_COLORS,
} from './constants';
const NODE_TYPES = { custom: CustomNode };
@@ -428,10 +432,54 @@ function compareMenuCategories(a, b) {
return String(a?.name || '').localeCompare(String(b?.name || ''));
}
function socketTypesCompatible(sourceType, targetType) {
if (sourceType === targetType) return true;
const accepted = SOCKET_COMPATIBILITY[targetType];
return !!accepted?.has(sourceType);
function getResolvedHandleRef(nodeId, handleId) {
const proxy = parseGroupProxyHandle(handleId);
return {
nodeId: proxy?.nodeId || nodeId,
handleId: proxy?.realHandle || handleId,
type: proxy?.type || getHandleType(handleId),
};
}
function getNodeInputSpecForHandle(node, handleId) {
const definition = node?.data?.definition;
if (!definition?.input) return null;
const inputName = getInputName(handleId);
return definition.input.required?.[inputName]
|| definition.input.optional?.[inputName]
|| null;
}
function socketTypesCompatible(sourceType, targetSpecOrType) {
return socketSpecAcceptsType(sourceType, targetSpecOrType);
}
function outputTypeCanConnectToTarget(outputType, targetSpecOrType) {
if (socketTypesCompatible(outputType, targetSpecOrType)) {
return true;
}
return outputType === 'ANNOTATION_SOURCE'
&& !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)
&& (
socketTypesCompatible('DATA_FIELD', targetSpecOrType)
|| socketTypesCompatible('IMAGE', targetSpecOrType)
);
}
function resolveOutputTypeForTarget(outputType, targetSpecOrType) {
if (outputType !== 'ANNOTATION_SOURCE') {
return outputType;
}
if (socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)) {
return 'ANNOTATION_SOURCE';
}
if (socketTypesCompatible('DATA_FIELD', targetSpecOrType)) {
return 'DATA_FIELD';
}
if (socketTypesCompatible('IMAGE', targetSpecOrType)) {
return 'IMAGE';
}
return 'ANNOTATION_SOURCE';
}
function getRenderedNodeBounds(nodes) {
@@ -592,7 +640,18 @@ async function captureViewportBlob(viewportEl, options) {
// ── Context menu component ────────────────────────────────────────────
function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection, selectedNodeCount = 0, onCreateGroup = null }) {
function ContextMenu({
x,
y,
nodeDefs,
onAdd,
onClose,
filterType,
filterSpec = null,
filterDirection,
selectedNodeCount = 0,
onCreateGroup = null,
}) {
const [openCat, setOpenCat] = useState(null);
const [search, setSearch] = useState('');
const menuRef = useRef(null);
@@ -611,14 +670,12 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
const opt = def.input.optional || {};
const allInputs = { ...req, ...opt };
const hasMatch = Object.values(allInputs).some((spec) => {
const [type] = Array.isArray(spec) ? spec : [spec];
return socketTypesCompatible(filterType, type);
return socketTypesCompatible(filterType, spec);
});
if (!hasMatch) continue;
} else {
const hasMatch = def.output.some((type) =>
socketTypesCompatible(type, filterType)
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
outputTypeCanConnectToTarget(type, filterSpec || filterType)
);
if (!hasMatch) continue;
}
@@ -661,7 +718,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
items: [...category.items].sort(compareMenuNodes),
}))
.sort(compareMenuCategories);
}, [nodeDefs, filterType, filterDirection]);
}, [nodeDefs, filterDirection, filterSpec, filterType]);
// Flat filtered list for search
const searchResults = useMemo(() => {
@@ -1262,7 +1319,10 @@ function Flow() {
setEdges((prev) => prev.filter((edge) => {
if (edge.source !== nodeId) return true;
return socketTypesCompatible(outputType, getHandleType(edge.targetHandle));
const resolvedTarget = getResolvedHandleRef(edge.target, edge.targetHandle);
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
return socketTypesCompatible(outputType, targetSpec);
}));
}, [reactFlow, setEdges, setNodeOutputs]);
@@ -1328,9 +1388,11 @@ function Flow() {
const isValidConnection = useCallback((connection) => {
const srcType = getConnectionHandleType(connection.sourceHandle);
const tgtType = getConnectionHandleType(connection.targetHandle);
return socketTypesCompatible(srcType, tgtType);
}, []);
const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle);
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
return socketTypesCompatible(srcType, targetSpec);
}, [reactFlow]);
const onConnect = useCallback((params) => {
const sourceProxy = parseGroupProxyHandle(params.sourceHandle);
@@ -1497,17 +1559,23 @@ function Flow() {
const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event;
const handleType = getConnectionHandleType(fromHandle.id);
const resolvedFromHandle = getResolvedHandleRef(fromHandle.nodeId, fromHandle.id);
const fromNode = reactFlow.getNode(resolvedFromHandle.nodeId);
const filterSpec = fromHandle.type === 'target'
? (getNodeInputSpecForHandle(fromNode, resolvedFromHandle.handleId) || handleType)
: handleType;
setContextMenu({
x: clientX,
y: clientY,
filterType: handleType,
filterSpec,
filterDirection: fromHandle.type,
pendingNodeId: fromHandle.nodeId,
pendingHandleId: fromHandle.id,
pendingHandleType: fromHandle.type,
});
}, []);
}, [reactFlow]);
// ── Widget change callback ──────────────────────────────────────────
@@ -1670,18 +1738,18 @@ function Flow() {
// Auto-connect if this was triggered by dropping a connection on blank space
if (contextMenu.pendingHandleId) {
const filterType = contextMenu.filterType;
const filterSpec = contextMenu.filterSpec || filterType;
if (contextMenu.pendingHandleType === 'source') {
// Dragged from an output → connect to the first matching input on the new node
const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) };
const inputName = Object.entries(allInputs).find(([, spec]) => {
const [type] = Array.isArray(spec) ? spec : [spec];
return socketTypesCompatible(filterType, type);
return socketTypesCompatible(filterType, spec);
})?.[0];
if (inputName) {
const targetType = (() => {
const spec = allInputs[inputName];
const [type] = Array.isArray(spec) ? spec : [spec];
const [type] = getSpecTypeAndOptions(spec);
return type;
})();
const targetHandle = `input::${inputName}::${targetType}`;
@@ -1697,11 +1765,10 @@ function Flow() {
} else {
// Dragged from an input → connect from the first matching output on the new node
const outputIdx = def.output.findIndex((type) =>
socketTypesCompatible(type, filterType)
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
outputTypeCanConnectToTarget(type, filterSpec)
);
if (outputIdx !== -1) {
const outputType = def.output[outputIdx] === 'ANNOTATION_SOURCE' ? filterType : def.output[outputIdx];
const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec);
const sourceHandle = `output::${outputIdx}::${outputType}`;
const color = TYPE_COLORS[outputType] || 'var(--fallback-type)';
setEdges((eds) => addEdge({
@@ -2848,6 +2915,7 @@ function Flow() {
onCreateGroup={createGroupFromSelection}
onClose={() => setContextMenu(null)}
filterType={contextMenu.filterType}
filterSpec={contextMenu.filterSpec}
filterDirection={contextMenu.filterDirection}
selectedNodeCount={selectedNodeCount}
/>

View File

@@ -10,7 +10,7 @@ const MarkupOverlay = lazy(() => import('./MarkupOverlay'));
const AngleMeasureOverlay = lazy(() => import('./AngleMeasureOverlay'));
import {
DATA_TYPES, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS,
getSpecTypeAndOptions, isDataSocketSpec, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS,
} from './constants';
import { getGroupMinimumSize } from './groupSizing.js';
import { buildCombinedInputNameByWidgetName, formatUiLabel } from './nodeWidgetLayout.js';
@@ -898,8 +898,8 @@ function CustomNode({ id, data }) {
const hiddenWidgets = new Set();
for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (DATA_TYPES.has(type)) {
const [type, opts] = getSpecTypeAndOptions(spec);
if (isDataSocketSpec(spec)) {
dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) });
visibleInputNames.add(name);
} else if (opts?.hidden) {
@@ -943,8 +943,8 @@ function CustomNode({ id, data }) {
);
for (const [name, spec] of Object.entries(optional)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (isProgressive && DATA_TYPES.has(type)) {
const [type, opts] = getSpecTypeAndOptions(spec);
if (isProgressive && isDataSocketSpec(spec)) {
// Progressive: show this slot only if it's the first or the previous is connected
const match = name.match(/^field_(\d+)$/);
if (match) {
@@ -958,7 +958,7 @@ function CustomNode({ id, data }) {
}
if (opts?.hidden) {
hiddenWidgets.add(name);
} else if (DATA_TYPES.has(type)) {
} else if (isDataSocketSpec(spec)) {
dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) });
visibleInputNames.add(name);
} else {

View File

@@ -1,9 +1,9 @@
// ── Shared type & color constants ─────────────────────────────────────
export const DATA_TYPES = new Set([
'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE',
'COORD', 'STATS_SOURCE', 'CURSOR_SOURCE', 'VALUE_SOURCE', 'ANNOTATION_SOURCE', 'COLORMAP',
'SAVE_LAYER', 'SAVE_VALUE', 'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR',
'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE',
'COORD', 'ANNOTATION_SOURCE', 'COLORMAP',
'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR',
]);
export const SOCKET_WIDGET_TYPES = new Set(['FLOAT', 'INT']);
@@ -13,19 +13,13 @@ export const TYPE_COLORS = {
IMAGE: '#00ff08a0',
LINE: '#ffbe5c',
MEASURE_TABLE: '#35e2fd',
RECORD_TABLE: '#fbbf24',
ANY_TABLE: '#67e8f9',
RECORD_TABLE: '#ff7474',
COORD: '#e91ed1',
COORDPAIR: '#5c7cb8',
COORDPAIR: '#5cb861',
FLOAT: '#ab3197',
INT: '#38bdf8',
STATS_SOURCE: '#c084fc',
CURSOR_SOURCE: '#a78bfa',
VALUE_SOURCE: '#60a5fa',
INT: '#ffffff',
ANNOTATION_SOURCE: '#06b6d4',
COLORMAP: '#f472b6',
SAVE_LAYER: '#22c55e',
SAVE_VALUE: '#4ade80',
MESH_MODEL: '#14b8a6',
FONT: '#fb7185',
FILE_PATH: '#f59e0b',
@@ -46,18 +40,60 @@ export const CAT_COLORS = {
};
export const SOCKET_COMPATIBILITY = {
STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'RECORD_TABLE']),
CURSOR_SOURCE: new Set(['DATA_FIELD', 'LINE']),
ANY_TABLE: new Set(['MEASURE_TABLE', 'RECORD_TABLE']),
VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']),
ANNOTATION_SOURCE: new Set(['DATA_FIELD', 'IMAGE']),
SAVE_LAYER: new Set(['DATA_FIELD', 'IMAGE']),
SAVE_VALUE: new Set(['DATA_FIELD', 'IMAGE', 'ANNOTATION_SOURCE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'MESH_MODEL', 'FLOAT']),
FLOAT: new Set(['INT']),
INT: new Set(['FLOAT']),
LINE: new Set(['COORDPAIR']),
};
const EMPTY_SOCKET_TYPE_SET = new Set();
export function getSpecTypeAndOptions(spec) {
if (Array.isArray(spec)) {
return [spec[0], spec[1] || {}];
}
return [spec, {}];
}
export function isDataSocketType(type) {
return typeof type === 'string' && DATA_TYPES.has(type);
}
export function isDataSocketSpec(spec) {
const [type] = getSpecTypeAndOptions(spec);
return isDataSocketType(type);
}
export function getAcceptedSocketTypes(specOrType) {
const [type, opts] = Array.isArray(specOrType)
? getSpecTypeAndOptions(specOrType)
: [specOrType, {}];
if (typeof type !== 'string') {
return EMPTY_SOCKET_TYPE_SET;
}
const accepted = new Set([type]);
const explicitAccepted = Array.isArray(opts?.accepted_types) ? opts.accepted_types : [];
for (const acceptedType of explicitAccepted) {
if (typeof acceptedType === 'string' && acceptedType) {
accepted.add(acceptedType);
}
}
const fallbackAccepted = SOCKET_COMPATIBILITY[type];
if (fallbackAccepted) {
for (const acceptedType of fallbackAccepted) {
accepted.add(acceptedType);
}
}
return accepted;
}
export function socketSpecAcceptsType(sourceType, targetSpecOrType) {
if (typeof sourceType !== 'string' || !sourceType) return false;
return getAcceptedSocketTypes(targetSpecOrType).has(sourceType);
}
// Colors used in Canvas 2D / toBlob contexts where CSS var() is unavailable.
export const CANVAS_COLORS = {
bgDeep: '#0f172a',

View File

@@ -1,4 +1,4 @@
import { DATA_TYPES } from './constants.js';
import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js';
const OMITTED_WIDGET_INPUTS_BY_CLASS = {
View3D: new Set([
@@ -91,8 +91,8 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f
};
for (const [name, spec] of Object.entries(allWidgets)) {
if (omittedInputs?.has(name)) continue;
const [type] = Array.isArray(spec) ? spec : [spec];
if (DATA_TYPES.has(type)) continue;
const [type] = getSpecTypeAndOptions(spec);
if (isDataSocketSpec(spec)) continue;
if (type === 'BUTTON') continue;
if (valueBag[name] !== undefined) {
inputs[name] = valueBag[name];
@@ -125,16 +125,16 @@ export function hasBlockingAutoRunInput(node, edges) {
const required = def.input.required || {};
for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
const hiddenByConnectedInput = (() => {
const raw = opts?.hide_when_input_connected;
if (!raw) return false;
const inputs = Array.isArray(raw) ? raw : [raw];
return inputs.some((inputName) => edges.some(
(edge) => {
const resolved = resolveExecutionEdge(edge);
return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName);
}
const [type, opts] = getSpecTypeAndOptions(spec);
const hiddenByConnectedInput = (() => {
const raw = opts?.hide_when_input_connected;
if (!raw) return false;
const inputs = Array.isArray(raw) ? raw : [raw];
return inputs.some((inputName) => edges.some(
(edge) => {
const resolved = resolveExecutionEdge(edge);
return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName);
}
));
})();
@@ -144,7 +144,7 @@ export function hasBlockingAutoRunInput(node, edges) {
if (!node.data.widgetValues?.[name]) return true;
continue;
}
if (!DATA_TYPES.has(type)) continue;
if (!isDataSocketSpec(spec)) continue;
const hasEdge = edges.some(
(edge) => {
const resolved = resolveExecutionEdge(edge);

View File

@@ -1,8 +1,8 @@
import { DATA_TYPES } from './constants.js';
import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js';
export function getDefaultWidgetValue(spec) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (DATA_TYPES.has(type)) return undefined;
const [type, opts] = getSpecTypeAndOptions(spec);
if (isDataSocketSpec(spec)) return undefined;
if (type === 'BUTTON') return undefined;
if (Array.isArray(type)) {
if (typeof opts?.default === 'string' && type.includes(opts.default)) {

View File

@@ -1,8 +1,31 @@
import test from 'node:test';
import assert from 'node:assert/strict';
import { SOCKET_COMPATIBILITY } from '../src/constants.js';
import {
DATA_TYPES,
getAcceptedSocketTypes,
isDataSocketSpec,
socketSpecAcceptsType,
} from '../src/constants.js';
test('SAVE_VALUE accepts ANNOTATION_SOURCE inputs', () => {
assert.equal(SOCKET_COMPATIBILITY.SAVE_VALUE.has('ANNOTATION_SOURCE'), true);
test('intrinsic socket compatibility still allows INT to connect to FLOAT sockets', () => {
assert.equal(socketSpecAcceptsType('INT', 'FLOAT'), true);
assert.equal(socketSpecAcceptsType('FLOAT', 'INT'), true);
});
test('retired save alias types are no longer first-class socket types', () => {
assert.equal(DATA_TYPES.has('SAVE_VALUE'), false);
assert.equal(DATA_TYPES.has('SAVE_LAYER'), false);
});
test('accepted_types extend canonical socket compatibility without reintroducing alias types', () => {
const spec = ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }];
assert.equal(isDataSocketSpec(spec), true);
assert.deepEqual(
Array.from(getAcceptedSocketTypes(spec)).sort(),
['MEASURE_TABLE', 'RECORD_TABLE'],
);
assert.equal(socketSpecAcceptsType('RECORD_TABLE', spec), true);
assert.equal(socketSpecAcceptsType('LINE', spec), false);
});

View File

@@ -478,3 +478,89 @@ test('hasBlockingAutoRunInput skips required file widgets when a connected socke
assert.equal(hasBlockingAutoRunInput(node, edges), false);
});
test('serializeExecutionGraph treats accepted_types inputs as sockets, not widgets', () => {
const nodes = [
{
id: '1',
data: {
className: 'TableSource',
definition: {
input: { required: {}, optional: {} },
output: ['RECORD_TABLE'],
output_name: ['rows'],
manual_trigger: false,
},
widgetValues: {},
},
},
{
id: '2',
data: {
className: 'PrintTable',
definition: {
input: {
required: {
table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }],
},
optional: {},
},
manual_trigger: false,
},
widgetValues: { table: 'should-not-serialize' },
},
},
];
const edges = [
{
source: '1',
sourceHandle: 'output::0::RECORD_TABLE',
target: '2',
targetHandle: 'input::table::MEASURE_TABLE',
},
];
const prompt = serializeExecutionGraph(nodes, edges);
assert.deepEqual(prompt, {
'1': {
class_type: 'TableSource',
inputs: {},
},
'2': {
class_type: 'PrintTable',
inputs: { table: ['1', 0] },
},
});
});
test('hasBlockingAutoRunInput still blocks unconnected accepted_types sockets', () => {
const node = {
id: '2',
data: {
definition: {
manual_trigger: false,
input: {
required: {
input: ['DATA_FIELD', { accepted_types: ['IMAGE', 'LINE', 'RECORD_TABLE'] }],
},
optional: {},
},
},
widgetValues: {},
},
};
assert.equal(hasBlockingAutoRunInput(node, []), true);
assert.equal(
hasBlockingAutoRunInput(node, [
{
source: '1',
sourceHandle: 'output::0::RECORD_TABLE',
target: '2',
targetHandle: 'input::input::DATA_FIELD',
},
]),
false,
);
});

View File

@@ -58,7 +58,7 @@ test('buildNodeClipboardPayload keeps only selected nodes and internal edges', (
source: '2',
sourceHandle: 'output::0::IMAGE',
target: '3',
targetHandle: 'input::value::SAVE_VALUE',
targetHandle: 'input::value::DATA_FIELD',
},
];
@@ -166,7 +166,7 @@ test('buildNodeClipboardPayloadForIds can include upstream external edges for du
source: '2',
sourceHandle: 'output::0::IMAGE',
target: '3',
targetHandle: 'input::value::SAVE_VALUE',
targetHandle: 'input::value::DATA_FIELD',
},
];

View File

@@ -16,6 +16,7 @@ test('buildDefaultWidgetValues keeps non-data required widget defaults', () => {
input: {
required: {
input: ['ANNOTATION_SOURCE', { label: 'Input' }],
table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }],
shape: [['line', 'rectangle', 'circle', 'arrow'], { default: 'arrow' }],
stroke_color: ['STRING', { default: '#ff0000', color_picker: true }],
stroke_width: ['INT', { default: 3 }],

View File

@@ -881,6 +881,7 @@ def test_angle_measure():
assert {entry["category"] for entry in info["menu_categories"]} == {"Overlay", "Measure"}
required_inputs = AngleMeasure.INPUT_TYPES()["required"]
optional_inputs = AngleMeasure.INPUT_TYPES().get("optional", {})
assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
assert required_inputs["color"][1]["default"] == "#ff9800"
assert required_inputs["stroke_width"][1]["default"] == 1.35
assert optional_inputs["line_thickness"][1]["hidden"] is True
@@ -1584,6 +1585,10 @@ def test_save_image():
from backend.nodes.save_image import SaveImage
import tifffile
node = SaveImage()
input_types = SaveImage.INPUT_TYPES()
field_spec = input_types["optional"]["field_0"]
assert field_spec[0] == "DATA_FIELD"
assert field_spec[1]["accepted_types"] == ["IMAGE", "ANNOTATION_SOURCE"]
field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
@@ -1729,6 +1734,9 @@ def test_preview_image():
from backend.data_types import ImageData
from backend.execution_context import active_node, execution_callbacks
node = PreviewImage()
preview_input = PreviewImage.INPUT_TYPES()["optional"]["input"]
assert preview_input[0] == "ANNOTATION_SOURCE"
assert preview_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
# Set up a capture for the broadcast
captured = []
@@ -1794,6 +1802,9 @@ def test_annotations():
node = Annotations()
font_node = Font()
annotation_input = Annotations.INPUT_TYPES()["required"]["input"]
assert annotation_input[0] == "ANNOTATION_SOURCE"
assert annotation_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
warnings = []
field = DataField(
data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64),
@@ -1920,6 +1931,7 @@ def test_markup():
assert _preview_markup_stroke_width(5, 128, 128) == 5
assert _preview_markup_stroke_width(5, 2048, 2048) > 5
assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
assert required_inputs["shape"][1]["default"] == "arrow"
assert required_inputs["stroke_color"][1]["default"] == "#ff0000"
@@ -1987,6 +1999,10 @@ def test_print_table():
from backend.nodes.print_table import PrintTable
node = PrintTable()
table_spec = PrintTable.INPUT_TYPES()["required"]["table"]
assert table_spec[0] == "MEASURE_TABLE"
assert table_spec[1]["accepted_types"] == ["RECORD_TABLE"]
captured = []
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
PrintTable._current_node_id = "test"
@@ -2005,6 +2021,10 @@ def test_value_display():
from backend.nodes.value_display import ValueDisplay
node = ValueDisplay()
value_spec = ValueDisplay.INPUT_TYPES()["required"]["value"]
assert value_spec[0] == "FLOAT"
assert value_spec[1]["accepted_types"] == ["MEASURE_TABLE"]
captured = []
ValueDisplay._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
ValueDisplay._current_node_id = "test"
@@ -2599,6 +2619,9 @@ def test_line_cursors():
from backend.nodes.cursors import Cursors
node = Cursors()
line_spec = Cursors.INPUT_TYPES()["required"]["line"]
assert line_spec[0] == "LINE"
assert line_spec[1]["accepted_types"] == ["DATA_FIELD"]
# Create a simple linear ramp
line = np.linspace(0, 10, 100).astype(np.float64)
@@ -2814,6 +2837,10 @@ def test_stats():
from backend.nodes.stats import Stats
node = Stats()
input_spec = Stats.INPUT_TYPES()["required"]["input"]
assert input_spec[0] == "DATA_FIELD"
assert input_spec[1]["accepted_types"] == ["IMAGE", "LINE", "RECORD_TABLE"]
captured = []
Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
Stats._current_node_id = "test"
@@ -2998,6 +3025,17 @@ def test_save_generic():
from PIL import Image as PILImage
node = Save()
value_spec = node.INPUT_TYPES()["required"]["value"]
assert value_spec[0] == "DATA_FIELD"
assert value_spec[1]["accepted_types"] == [
"IMAGE",
"ANNOTATION_SOURCE",
"LINE",
"MEASURE_TABLE",
"RECORD_TABLE",
"MESH_MODEL",
"FLOAT",
]
format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"]
assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"]