refactor socket types

This commit is contained in:
2026-03-28 13:56:22 -07:00
parent 4368aeb4a0
commit 1b831cda5d
20 changed files with 366 additions and 79 deletions

View File

@@ -85,7 +85,10 @@ class AngleMeasure:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"input": ("ANNOTATION_SOURCE", {"label": "Input"}), "input": ("ANNOTATION_SOURCE", {
"label": "Input",
"accepted_types": ["DATA_FIELD", "IMAGE"],
}),
"color": ("STRING", {"default": ANGLE_DEFAULT_COLOR, "color_picker": True}), "color": ("STRING", {"default": ANGLE_DEFAULT_COLOR, "color_picker": True}),
"stroke_width": ("FLOAT", { "stroke_width": ("FLOAT", {
"default": 1.35, "default": 1.35,

View File

@@ -23,7 +23,10 @@ class Annotations:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"input": ("ANNOTATION_SOURCE", {"label": "Input"}), "input": ("ANNOTATION_SOURCE", {
"label": "Input",
"accepted_types": ["DATA_FIELD", "IMAGE"],
}),
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
"show_scale_bar": ("BOOLEAN", {"default": True}), "show_scale_bar": ("BOOLEAN", {"default": True}),
"show_color_map": ("BOOLEAN", {"default": True}), "show_color_map": ("BOOLEAN", {"default": True}),

View File

@@ -13,7 +13,10 @@ class Cursors:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"line": ("CURSOR_SOURCE", {"label": "input"}), "line": ("LINE", {
"label": "input",
"accepted_types": ["DATA_FIELD"],
}),
"x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), "x1": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), "x2": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),

View File

@@ -21,7 +21,10 @@ class Markup:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"input": ("ANNOTATION_SOURCE", {"label": "Input"}), "input": ("ANNOTATION_SOURCE", {
"label": "Input",
"accepted_types": ["DATA_FIELD", "IMAGE"],
}),
"shape": (["line", "rectangle", "circle", "arrow"], {"default": "arrow"}), "shape": (["line", "rectangle", "circle", "arrow"], {"default": "arrow"}),
"stroke_color": ("STRING", {"default": "#ff0000", "color_picker": True}), "stroke_color": ("STRING", {"default": "#ff0000", "color_picker": True}),
"stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}), "stroke_width": ("INT", {"default": 3, "min": 1, "max": 64, "step": 1}),

View File

@@ -22,7 +22,10 @@ class PreviewImage:
"colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}), "colormap": (["auto"] + list(COLORMAPS), {"hide_when_input_connected": "colormap_map"}),
}, },
"optional": { "optional": {
"input": ("ANNOTATION_SOURCE", {"label": "Input"}), "input": ("ANNOTATION_SOURCE", {
"label": "Input",
"accepted_types": ["DATA_FIELD", "IMAGE"],
}),
"colormap_map": ("COLORMAP", {"label": "colormap"}), "colormap_map": ("COLORMAP", {"label": "colormap"}),
} }
} }

View File

@@ -9,7 +9,9 @@ class PrintTable:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"table": ("ANY_TABLE",), "table": ("MEASURE_TABLE", {
"accepted_types": ["RECORD_TABLE"],
}),
} }
} }

View File

@@ -29,7 +29,18 @@ class Save:
"hide_when_input_connected": "directory", "hide_when_input_connected": "directory",
"top_socket_input": "directory", "top_socket_input": "directory",
}), }),
"value": ("SAVE_VALUE", {"label": "value"}), "value": ("DATA_FIELD", {
"label": "value",
"accepted_types": [
"IMAGE",
"ANNOTATION_SOURCE",
"LINE",
"MEASURE_TABLE",
"RECORD_TABLE",
"MESH_MODEL",
"FLOAT",
],
}),
"format": ("STRING", { "format": ("STRING", {
"default": "TIFF", "default": "TIFF",
"choices_by_source_type": { "choices_by_source_type": {

View File

@@ -17,7 +17,10 @@ class SaveImage:
"directory": ("DIRECTORY", {"label": "directory"}), "directory": ("DIRECTORY", {"label": "directory"}),
} }
for i in range(_MAX_SAVE_FIELDS): for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("SAVE_LAYER", {"label": f"layer {i + 1}"}) optional[f"field_{i}"] = ("DATA_FIELD", {
"label": f"layer {i + 1}",
"accepted_types": ["IMAGE", "ANNOTATION_SOURCE"],
})
optional[f"layer_name_{i}"] = ("STRING", { optional[f"layer_name_{i}"] = ("STRING", {
"default": "", "default": "",
"placeholder": "name", "placeholder": "name",

View File

@@ -26,7 +26,9 @@ class Stats:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"input": ("STATS_SOURCE",), "input": ("DATA_FIELD", {
"accepted_types": ["IMAGE", "LINE", "RECORD_TABLE"],
}),
"column": ("STRING", { "column": ("STRING", {
"default": "value", "default": "value",
"choices_from_table_input": "input", "choices_from_table_input": "input",

View File

@@ -11,7 +11,9 @@ class ValueDisplay:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"value": ("VALUE_SOURCE",), "value": ("FLOAT", {
"accepted_types": ["MEASURE_TABLE"],
}),
"measurement": ("STRING", { "measurement": ("STRING", {
"default": "", "default": "",
"choices_from_measure_input": "value", "choices_from_measure_input": "value",

View File

@@ -38,7 +38,11 @@ import {
import { buildDefaultWidgetValues } from './nodeWidgetDefaults.js'; import { buildDefaultWidgetValues } from './nodeWidgetDefaults.js';
import { import {
DATA_TYPES, SOCKET_COMPATIBILITY, TYPE_COLORS, CAT_COLORS, CANVAS_COLORS, getSpecTypeAndOptions,
socketSpecAcceptsType,
TYPE_COLORS,
CAT_COLORS,
CANVAS_COLORS,
} from './constants'; } from './constants';
const NODE_TYPES = { custom: CustomNode }; const NODE_TYPES = { custom: CustomNode };
@@ -428,10 +432,54 @@ function compareMenuCategories(a, b) {
return String(a?.name || '').localeCompare(String(b?.name || '')); return String(a?.name || '').localeCompare(String(b?.name || ''));
} }
function socketTypesCompatible(sourceType, targetType) { function getResolvedHandleRef(nodeId, handleId) {
if (sourceType === targetType) return true; const proxy = parseGroupProxyHandle(handleId);
const accepted = SOCKET_COMPATIBILITY[targetType]; return {
return !!accepted?.has(sourceType); nodeId: proxy?.nodeId || nodeId,
handleId: proxy?.realHandle || handleId,
type: proxy?.type || getHandleType(handleId),
};
}
function getNodeInputSpecForHandle(node, handleId) {
const definition = node?.data?.definition;
if (!definition?.input) return null;
const inputName = getInputName(handleId);
return definition.input.required?.[inputName]
|| definition.input.optional?.[inputName]
|| null;
}
function socketTypesCompatible(sourceType, targetSpecOrType) {
return socketSpecAcceptsType(sourceType, targetSpecOrType);
}
function outputTypeCanConnectToTarget(outputType, targetSpecOrType) {
if (socketTypesCompatible(outputType, targetSpecOrType)) {
return true;
}
return outputType === 'ANNOTATION_SOURCE'
&& !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)
&& (
socketTypesCompatible('DATA_FIELD', targetSpecOrType)
|| socketTypesCompatible('IMAGE', targetSpecOrType)
);
}
function resolveOutputTypeForTarget(outputType, targetSpecOrType) {
if (outputType !== 'ANNOTATION_SOURCE') {
return outputType;
}
if (socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)) {
return 'ANNOTATION_SOURCE';
}
if (socketTypesCompatible('DATA_FIELD', targetSpecOrType)) {
return 'DATA_FIELD';
}
if (socketTypesCompatible('IMAGE', targetSpecOrType)) {
return 'IMAGE';
}
return 'ANNOTATION_SOURCE';
} }
function getRenderedNodeBounds(nodes) { function getRenderedNodeBounds(nodes) {
@@ -592,7 +640,18 @@ async function captureViewportBlob(viewportEl, options) {
// ── Context menu component ──────────────────────────────────────────── // ── Context menu component ────────────────────────────────────────────
function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection, selectedNodeCount = 0, onCreateGroup = null }) { function ContextMenu({
x,
y,
nodeDefs,
onAdd,
onClose,
filterType,
filterSpec = null,
filterDirection,
selectedNodeCount = 0,
onCreateGroup = null,
}) {
const [openCat, setOpenCat] = useState(null); const [openCat, setOpenCat] = useState(null);
const [search, setSearch] = useState(''); const [search, setSearch] = useState('');
const menuRef = useRef(null); const menuRef = useRef(null);
@@ -611,14 +670,12 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
const opt = def.input.optional || {}; const opt = def.input.optional || {};
const allInputs = { ...req, ...opt }; const allInputs = { ...req, ...opt };
const hasMatch = Object.values(allInputs).some((spec) => { const hasMatch = Object.values(allInputs).some((spec) => {
const [type] = Array.isArray(spec) ? spec : [spec]; return socketTypesCompatible(filterType, spec);
return socketTypesCompatible(filterType, type);
}); });
if (!hasMatch) continue; if (!hasMatch) continue;
} else { } else {
const hasMatch = def.output.some((type) => const hasMatch = def.output.some((type) =>
socketTypesCompatible(type, filterType) outputTypeCanConnectToTarget(type, filterSpec || filterType)
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
); );
if (!hasMatch) continue; if (!hasMatch) continue;
} }
@@ -661,7 +718,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
items: [...category.items].sort(compareMenuNodes), items: [...category.items].sort(compareMenuNodes),
})) }))
.sort(compareMenuCategories); .sort(compareMenuCategories);
}, [nodeDefs, filterType, filterDirection]); }, [nodeDefs, filterDirection, filterSpec, filterType]);
// Flat filtered list for search // Flat filtered list for search
const searchResults = useMemo(() => { const searchResults = useMemo(() => {
@@ -1262,7 +1319,10 @@ function Flow() {
setEdges((prev) => prev.filter((edge) => { setEdges((prev) => prev.filter((edge) => {
if (edge.source !== nodeId) return true; if (edge.source !== nodeId) return true;
return socketTypesCompatible(outputType, getHandleType(edge.targetHandle)); const resolvedTarget = getResolvedHandleRef(edge.target, edge.targetHandle);
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
return socketTypesCompatible(outputType, targetSpec);
})); }));
}, [reactFlow, setEdges, setNodeOutputs]); }, [reactFlow, setEdges, setNodeOutputs]);
@@ -1328,9 +1388,11 @@ function Flow() {
const isValidConnection = useCallback((connection) => { const isValidConnection = useCallback((connection) => {
const srcType = getConnectionHandleType(connection.sourceHandle); const srcType = getConnectionHandleType(connection.sourceHandle);
const tgtType = getConnectionHandleType(connection.targetHandle); const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle);
return socketTypesCompatible(srcType, tgtType); const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
}, []); const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
return socketTypesCompatible(srcType, targetSpec);
}, [reactFlow]);
const onConnect = useCallback((params) => { const onConnect = useCallback((params) => {
const sourceProxy = parseGroupProxyHandle(params.sourceHandle); const sourceProxy = parseGroupProxyHandle(params.sourceHandle);
@@ -1497,17 +1559,23 @@ function Flow() {
const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event; const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event;
const handleType = getConnectionHandleType(fromHandle.id); const handleType = getConnectionHandleType(fromHandle.id);
const resolvedFromHandle = getResolvedHandleRef(fromHandle.nodeId, fromHandle.id);
const fromNode = reactFlow.getNode(resolvedFromHandle.nodeId);
const filterSpec = fromHandle.type === 'target'
? (getNodeInputSpecForHandle(fromNode, resolvedFromHandle.handleId) || handleType)
: handleType;
setContextMenu({ setContextMenu({
x: clientX, x: clientX,
y: clientY, y: clientY,
filterType: handleType, filterType: handleType,
filterSpec,
filterDirection: fromHandle.type, filterDirection: fromHandle.type,
pendingNodeId: fromHandle.nodeId, pendingNodeId: fromHandle.nodeId,
pendingHandleId: fromHandle.id, pendingHandleId: fromHandle.id,
pendingHandleType: fromHandle.type, pendingHandleType: fromHandle.type,
}); });
}, []); }, [reactFlow]);
// ── Widget change callback ────────────────────────────────────────── // ── Widget change callback ──────────────────────────────────────────
@@ -1670,18 +1738,18 @@ function Flow() {
// Auto-connect if this was triggered by dropping a connection on blank space // Auto-connect if this was triggered by dropping a connection on blank space
if (contextMenu.pendingHandleId) { if (contextMenu.pendingHandleId) {
const filterType = contextMenu.filterType; const filterType = contextMenu.filterType;
const filterSpec = contextMenu.filterSpec || filterType;
if (contextMenu.pendingHandleType === 'source') { if (contextMenu.pendingHandleType === 'source') {
// Dragged from an output → connect to the first matching input on the new node // Dragged from an output → connect to the first matching input on the new node
const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) }; const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) };
const inputName = Object.entries(allInputs).find(([, spec]) => { const inputName = Object.entries(allInputs).find(([, spec]) => {
const [type] = Array.isArray(spec) ? spec : [spec]; return socketTypesCompatible(filterType, spec);
return socketTypesCompatible(filterType, type);
})?.[0]; })?.[0];
if (inputName) { if (inputName) {
const targetType = (() => { const targetType = (() => {
const spec = allInputs[inputName]; const spec = allInputs[inputName];
const [type] = Array.isArray(spec) ? spec : [spec]; const [type] = getSpecTypeAndOptions(spec);
return type; return type;
})(); })();
const targetHandle = `input::${inputName}::${targetType}`; const targetHandle = `input::${inputName}::${targetType}`;
@@ -1697,11 +1765,10 @@ function Flow() {
} else { } else {
// Dragged from an input → connect from the first matching output on the new node // Dragged from an input → connect from the first matching output on the new node
const outputIdx = def.output.findIndex((type) => const outputIdx = def.output.findIndex((type) =>
socketTypesCompatible(type, filterType) outputTypeCanConnectToTarget(type, filterSpec)
|| (type === 'ANNOTATION_SOURCE' && (filterType === 'DATA_FIELD' || filterType === 'IMAGE'))
); );
if (outputIdx !== -1) { if (outputIdx !== -1) {
const outputType = def.output[outputIdx] === 'ANNOTATION_SOURCE' ? filterType : def.output[outputIdx]; const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec);
const sourceHandle = `output::${outputIdx}::${outputType}`; const sourceHandle = `output::${outputIdx}::${outputType}`;
const color = TYPE_COLORS[outputType] || 'var(--fallback-type)'; const color = TYPE_COLORS[outputType] || 'var(--fallback-type)';
setEdges((eds) => addEdge({ setEdges((eds) => addEdge({
@@ -2848,6 +2915,7 @@ function Flow() {
onCreateGroup={createGroupFromSelection} onCreateGroup={createGroupFromSelection}
onClose={() => setContextMenu(null)} onClose={() => setContextMenu(null)}
filterType={contextMenu.filterType} filterType={contextMenu.filterType}
filterSpec={contextMenu.filterSpec}
filterDirection={contextMenu.filterDirection} filterDirection={contextMenu.filterDirection}
selectedNodeCount={selectedNodeCount} selectedNodeCount={selectedNodeCount}
/> />

View File

@@ -10,7 +10,7 @@ const MarkupOverlay = lazy(() => import('./MarkupOverlay'));
const AngleMeasureOverlay = lazy(() => import('./AngleMeasureOverlay')); const AngleMeasureOverlay = lazy(() => import('./AngleMeasureOverlay'));
import { import {
DATA_TYPES, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS, getSpecTypeAndOptions, isDataSocketSpec, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS,
} from './constants'; } from './constants';
import { getGroupMinimumSize } from './groupSizing.js'; import { getGroupMinimumSize } from './groupSizing.js';
import { buildCombinedInputNameByWidgetName, formatUiLabel } from './nodeWidgetLayout.js'; import { buildCombinedInputNameByWidgetName, formatUiLabel } from './nodeWidgetLayout.js';
@@ -898,8 +898,8 @@ function CustomNode({ id, data }) {
const hiddenWidgets = new Set(); const hiddenWidgets = new Set();
for (const [name, spec] of Object.entries(required)) { for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; const [type, opts] = getSpecTypeAndOptions(spec);
if (DATA_TYPES.has(type)) { if (isDataSocketSpec(spec)) {
dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) }); dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) });
visibleInputNames.add(name); visibleInputNames.add(name);
} else if (opts?.hidden) { } else if (opts?.hidden) {
@@ -943,8 +943,8 @@ function CustomNode({ id, data }) {
); );
for (const [name, spec] of Object.entries(optional)) { for (const [name, spec] of Object.entries(optional)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; const [type, opts] = getSpecTypeAndOptions(spec);
if (isProgressive && DATA_TYPES.has(type)) { if (isProgressive && isDataSocketSpec(spec)) {
// Progressive: show this slot only if it's the first or the previous is connected // Progressive: show this slot only if it's the first or the previous is connected
const match = name.match(/^field_(\d+)$/); const match = name.match(/^field_(\d+)$/);
if (match) { if (match) {
@@ -958,7 +958,7 @@ function CustomNode({ id, data }) {
} }
if (opts?.hidden) { if (opts?.hidden) {
hiddenWidgets.add(name); hiddenWidgets.add(name);
} else if (DATA_TYPES.has(type)) { } else if (isDataSocketSpec(spec)) {
dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) }); dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) });
visibleInputNames.add(name); visibleInputNames.add(name);
} else { } else {

View File

@@ -1,9 +1,9 @@
// ── Shared type & color constants ───────────────────────────────────── // ── Shared type & color constants ─────────────────────────────────────
export const DATA_TYPES = new Set([ export const DATA_TYPES = new Set([
'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'ANY_TABLE', 'DATA_FIELD', 'IMAGE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE',
'COORD', 'STATS_SOURCE', 'CURSOR_SOURCE', 'VALUE_SOURCE', 'ANNOTATION_SOURCE', 'COLORMAP', 'COORD', 'ANNOTATION_SOURCE', 'COLORMAP',
'SAVE_LAYER', 'SAVE_VALUE', 'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR', 'MESH_MODEL', 'FONT', 'FILE_PATH', 'DIRECTORY', 'COORDPAIR',
]); ]);
export const SOCKET_WIDGET_TYPES = new Set(['FLOAT', 'INT']); export const SOCKET_WIDGET_TYPES = new Set(['FLOAT', 'INT']);
@@ -13,19 +13,13 @@ export const TYPE_COLORS = {
IMAGE: '#00ff08a0', IMAGE: '#00ff08a0',
LINE: '#ffbe5c', LINE: '#ffbe5c',
MEASURE_TABLE: '#35e2fd', MEASURE_TABLE: '#35e2fd',
RECORD_TABLE: '#fbbf24', RECORD_TABLE: '#ff7474',
ANY_TABLE: '#67e8f9',
COORD: '#e91ed1', COORD: '#e91ed1',
COORDPAIR: '#5c7cb8', COORDPAIR: '#5cb861',
FLOAT: '#ab3197', FLOAT: '#ab3197',
INT: '#38bdf8', INT: '#ffffff',
STATS_SOURCE: '#c084fc',
CURSOR_SOURCE: '#a78bfa',
VALUE_SOURCE: '#60a5fa',
ANNOTATION_SOURCE: '#06b6d4', ANNOTATION_SOURCE: '#06b6d4',
COLORMAP: '#f472b6', COLORMAP: '#f472b6',
SAVE_LAYER: '#22c55e',
SAVE_VALUE: '#4ade80',
MESH_MODEL: '#14b8a6', MESH_MODEL: '#14b8a6',
FONT: '#fb7185', FONT: '#fb7185',
FILE_PATH: '#f59e0b', FILE_PATH: '#f59e0b',
@@ -46,18 +40,60 @@ export const CAT_COLORS = {
}; };
export const SOCKET_COMPATIBILITY = { export const SOCKET_COMPATIBILITY = {
STATS_SOURCE: new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'RECORD_TABLE']),
CURSOR_SOURCE: new Set(['DATA_FIELD', 'LINE']),
ANY_TABLE: new Set(['MEASURE_TABLE', 'RECORD_TABLE']),
VALUE_SOURCE: new Set(['FLOAT', 'MEASURE_TABLE']),
ANNOTATION_SOURCE: new Set(['DATA_FIELD', 'IMAGE']),
SAVE_LAYER: new Set(['DATA_FIELD', 'IMAGE']),
SAVE_VALUE: new Set(['DATA_FIELD', 'IMAGE', 'ANNOTATION_SOURCE', 'LINE', 'MEASURE_TABLE', 'RECORD_TABLE', 'MESH_MODEL', 'FLOAT']),
FLOAT: new Set(['INT']), FLOAT: new Set(['INT']),
INT: new Set(['FLOAT']), INT: new Set(['FLOAT']),
LINE: new Set(['COORDPAIR']), LINE: new Set(['COORDPAIR']),
}; };
const EMPTY_SOCKET_TYPE_SET = new Set();
export function getSpecTypeAndOptions(spec) {
if (Array.isArray(spec)) {
return [spec[0], spec[1] || {}];
}
return [spec, {}];
}
export function isDataSocketType(type) {
return typeof type === 'string' && DATA_TYPES.has(type);
}
export function isDataSocketSpec(spec) {
const [type] = getSpecTypeAndOptions(spec);
return isDataSocketType(type);
}
export function getAcceptedSocketTypes(specOrType) {
const [type, opts] = Array.isArray(specOrType)
? getSpecTypeAndOptions(specOrType)
: [specOrType, {}];
if (typeof type !== 'string') {
return EMPTY_SOCKET_TYPE_SET;
}
const accepted = new Set([type]);
const explicitAccepted = Array.isArray(opts?.accepted_types) ? opts.accepted_types : [];
for (const acceptedType of explicitAccepted) {
if (typeof acceptedType === 'string' && acceptedType) {
accepted.add(acceptedType);
}
}
const fallbackAccepted = SOCKET_COMPATIBILITY[type];
if (fallbackAccepted) {
for (const acceptedType of fallbackAccepted) {
accepted.add(acceptedType);
}
}
return accepted;
}
export function socketSpecAcceptsType(sourceType, targetSpecOrType) {
if (typeof sourceType !== 'string' || !sourceType) return false;
return getAcceptedSocketTypes(targetSpecOrType).has(sourceType);
}
// Colors used in Canvas 2D / toBlob contexts where CSS var() is unavailable. // Colors used in Canvas 2D / toBlob contexts where CSS var() is unavailable.
export const CANVAS_COLORS = { export const CANVAS_COLORS = {
bgDeep: '#0f172a', bgDeep: '#0f172a',

View File

@@ -1,4 +1,4 @@
import { DATA_TYPES } from './constants.js'; import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js';
const OMITTED_WIDGET_INPUTS_BY_CLASS = { const OMITTED_WIDGET_INPUTS_BY_CLASS = {
View3D: new Set([ View3D: new Set([
@@ -91,8 +91,8 @@ export function serializeExecutionGraph(nodes, edges, { excludeManualTrigger = f
}; };
for (const [name, spec] of Object.entries(allWidgets)) { for (const [name, spec] of Object.entries(allWidgets)) {
if (omittedInputs?.has(name)) continue; if (omittedInputs?.has(name)) continue;
const [type] = Array.isArray(spec) ? spec : [spec]; const [type] = getSpecTypeAndOptions(spec);
if (DATA_TYPES.has(type)) continue; if (isDataSocketSpec(spec)) continue;
if (type === 'BUTTON') continue; if (type === 'BUTTON') continue;
if (valueBag[name] !== undefined) { if (valueBag[name] !== undefined) {
inputs[name] = valueBag[name]; inputs[name] = valueBag[name];
@@ -125,16 +125,16 @@ export function hasBlockingAutoRunInput(node, edges) {
const required = def.input.required || {}; const required = def.input.required || {};
for (const [name, spec] of Object.entries(required)) { for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; const [type, opts] = getSpecTypeAndOptions(spec);
const hiddenByConnectedInput = (() => { const hiddenByConnectedInput = (() => {
const raw = opts?.hide_when_input_connected; const raw = opts?.hide_when_input_connected;
if (!raw) return false; if (!raw) return false;
const inputs = Array.isArray(raw) ? raw : [raw]; const inputs = Array.isArray(raw) ? raw : [raw];
return inputs.some((inputName) => edges.some( return inputs.some((inputName) => edges.some(
(edge) => { (edge) => {
const resolved = resolveExecutionEdge(edge); const resolved = resolveExecutionEdge(edge);
return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName); return resolved.target === node.id && getInputName(resolved.targetHandle) === String(inputName);
} }
)); ));
})(); })();
@@ -144,7 +144,7 @@ export function hasBlockingAutoRunInput(node, edges) {
if (!node.data.widgetValues?.[name]) return true; if (!node.data.widgetValues?.[name]) return true;
continue; continue;
} }
if (!DATA_TYPES.has(type)) continue; if (!isDataSocketSpec(spec)) continue;
const hasEdge = edges.some( const hasEdge = edges.some(
(edge) => { (edge) => {
const resolved = resolveExecutionEdge(edge); const resolved = resolveExecutionEdge(edge);

View File

@@ -1,8 +1,8 @@
import { DATA_TYPES } from './constants.js'; import { getSpecTypeAndOptions, isDataSocketSpec } from './constants.js';
export function getDefaultWidgetValue(spec) { export function getDefaultWidgetValue(spec) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; const [type, opts] = getSpecTypeAndOptions(spec);
if (DATA_TYPES.has(type)) return undefined; if (isDataSocketSpec(spec)) return undefined;
if (type === 'BUTTON') return undefined; if (type === 'BUTTON') return undefined;
if (Array.isArray(type)) { if (Array.isArray(type)) {
if (typeof opts?.default === 'string' && type.includes(opts.default)) { if (typeof opts?.default === 'string' && type.includes(opts.default)) {

View File

@@ -1,8 +1,31 @@
import test from 'node:test'; import test from 'node:test';
import assert from 'node:assert/strict'; import assert from 'node:assert/strict';
import { SOCKET_COMPATIBILITY } from '../src/constants.js'; import {
DATA_TYPES,
getAcceptedSocketTypes,
isDataSocketSpec,
socketSpecAcceptsType,
} from '../src/constants.js';
test('SAVE_VALUE accepts ANNOTATION_SOURCE inputs', () => { test('intrinsic socket compatibility still allows INT to connect to FLOAT sockets', () => {
assert.equal(SOCKET_COMPATIBILITY.SAVE_VALUE.has('ANNOTATION_SOURCE'), true); assert.equal(socketSpecAcceptsType('INT', 'FLOAT'), true);
assert.equal(socketSpecAcceptsType('FLOAT', 'INT'), true);
});
test('retired save alias types are no longer first-class socket types', () => {
assert.equal(DATA_TYPES.has('SAVE_VALUE'), false);
assert.equal(DATA_TYPES.has('SAVE_LAYER'), false);
});
test('accepted_types extend canonical socket compatibility without reintroducing alias types', () => {
const spec = ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }];
assert.equal(isDataSocketSpec(spec), true);
assert.deepEqual(
Array.from(getAcceptedSocketTypes(spec)).sort(),
['MEASURE_TABLE', 'RECORD_TABLE'],
);
assert.equal(socketSpecAcceptsType('RECORD_TABLE', spec), true);
assert.equal(socketSpecAcceptsType('LINE', spec), false);
}); });

View File

@@ -478,3 +478,89 @@ test('hasBlockingAutoRunInput skips required file widgets when a connected socke
assert.equal(hasBlockingAutoRunInput(node, edges), false); assert.equal(hasBlockingAutoRunInput(node, edges), false);
}); });
test('serializeExecutionGraph treats accepted_types inputs as sockets, not widgets', () => {
const nodes = [
{
id: '1',
data: {
className: 'TableSource',
definition: {
input: { required: {}, optional: {} },
output: ['RECORD_TABLE'],
output_name: ['rows'],
manual_trigger: false,
},
widgetValues: {},
},
},
{
id: '2',
data: {
className: 'PrintTable',
definition: {
input: {
required: {
table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }],
},
optional: {},
},
manual_trigger: false,
},
widgetValues: { table: 'should-not-serialize' },
},
},
];
const edges = [
{
source: '1',
sourceHandle: 'output::0::RECORD_TABLE',
target: '2',
targetHandle: 'input::table::MEASURE_TABLE',
},
];
const prompt = serializeExecutionGraph(nodes, edges);
assert.deepEqual(prompt, {
'1': {
class_type: 'TableSource',
inputs: {},
},
'2': {
class_type: 'PrintTable',
inputs: { table: ['1', 0] },
},
});
});
test('hasBlockingAutoRunInput still blocks unconnected accepted_types sockets', () => {
const node = {
id: '2',
data: {
definition: {
manual_trigger: false,
input: {
required: {
input: ['DATA_FIELD', { accepted_types: ['IMAGE', 'LINE', 'RECORD_TABLE'] }],
},
optional: {},
},
},
widgetValues: {},
},
};
assert.equal(hasBlockingAutoRunInput(node, []), true);
assert.equal(
hasBlockingAutoRunInput(node, [
{
source: '1',
sourceHandle: 'output::0::RECORD_TABLE',
target: '2',
targetHandle: 'input::input::DATA_FIELD',
},
]),
false,
);
});

View File

@@ -58,7 +58,7 @@ test('buildNodeClipboardPayload keeps only selected nodes and internal edges', (
source: '2', source: '2',
sourceHandle: 'output::0::IMAGE', sourceHandle: 'output::0::IMAGE',
target: '3', target: '3',
targetHandle: 'input::value::SAVE_VALUE', targetHandle: 'input::value::DATA_FIELD',
}, },
]; ];
@@ -166,7 +166,7 @@ test('buildNodeClipboardPayloadForIds can include upstream external edges for du
source: '2', source: '2',
sourceHandle: 'output::0::IMAGE', sourceHandle: 'output::0::IMAGE',
target: '3', target: '3',
targetHandle: 'input::value::SAVE_VALUE', targetHandle: 'input::value::DATA_FIELD',
}, },
]; ];

View File

@@ -16,6 +16,7 @@ test('buildDefaultWidgetValues keeps non-data required widget defaults', () => {
input: { input: {
required: { required: {
input: ['ANNOTATION_SOURCE', { label: 'Input' }], input: ['ANNOTATION_SOURCE', { label: 'Input' }],
table: ['MEASURE_TABLE', { accepted_types: ['RECORD_TABLE'] }],
shape: [['line', 'rectangle', 'circle', 'arrow'], { default: 'arrow' }], shape: [['line', 'rectangle', 'circle', 'arrow'], { default: 'arrow' }],
stroke_color: ['STRING', { default: '#ff0000', color_picker: true }], stroke_color: ['STRING', { default: '#ff0000', color_picker: true }],
stroke_width: ['INT', { default: 3 }], stroke_width: ['INT', { default: 3 }],

View File

@@ -881,6 +881,7 @@ def test_angle_measure():
assert {entry["category"] for entry in info["menu_categories"]} == {"Overlay", "Measure"} assert {entry["category"] for entry in info["menu_categories"]} == {"Overlay", "Measure"}
required_inputs = AngleMeasure.INPUT_TYPES()["required"] required_inputs = AngleMeasure.INPUT_TYPES()["required"]
optional_inputs = AngleMeasure.INPUT_TYPES().get("optional", {}) optional_inputs = AngleMeasure.INPUT_TYPES().get("optional", {})
assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
assert required_inputs["color"][1]["default"] == "#ff9800" assert required_inputs["color"][1]["default"] == "#ff9800"
assert required_inputs["stroke_width"][1]["default"] == 1.35 assert required_inputs["stroke_width"][1]["default"] == 1.35
assert optional_inputs["line_thickness"][1]["hidden"] is True assert optional_inputs["line_thickness"][1]["hidden"] is True
@@ -1584,6 +1585,10 @@ def test_save_image():
from backend.nodes.save_image import SaveImage from backend.nodes.save_image import SaveImage
import tifffile import tifffile
node = SaveImage() node = SaveImage()
input_types = SaveImage.INPUT_TYPES()
field_spec = input_types["optional"]["field_0"]
assert field_spec[0] == "DATA_FIELD"
assert field_spec[1]["accepted_types"] == ["IMAGE", "ANNOTATION_SOURCE"]
field_a = make_field(data=np.random.default_rng(4).random((32, 32))) field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
field_b = make_field(data=np.random.default_rng(5).random((32, 32))) field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
@@ -1729,6 +1734,9 @@ def test_preview_image():
from backend.data_types import ImageData from backend.data_types import ImageData
from backend.execution_context import active_node, execution_callbacks from backend.execution_context import active_node, execution_callbacks
node = PreviewImage() node = PreviewImage()
preview_input = PreviewImage.INPUT_TYPES()["optional"]["input"]
assert preview_input[0] == "ANNOTATION_SOURCE"
assert preview_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
# Set up a capture for the broadcast # Set up a capture for the broadcast
captured = [] captured = []
@@ -1794,6 +1802,9 @@ def test_annotations():
node = Annotations() node = Annotations()
font_node = Font() font_node = Font()
annotation_input = Annotations.INPUT_TYPES()["required"]["input"]
assert annotation_input[0] == "ANNOTATION_SOURCE"
assert annotation_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
warnings = [] warnings = []
field = DataField( field = DataField(
data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64), data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64),
@@ -1920,6 +1931,7 @@ def test_markup():
assert _preview_markup_stroke_width(5, 128, 128) == 5 assert _preview_markup_stroke_width(5, 128, 128) == 5
assert _preview_markup_stroke_width(5, 2048, 2048) > 5 assert _preview_markup_stroke_width(5, 2048, 2048) > 5
assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
assert required_inputs["shape"][1]["default"] == "arrow" assert required_inputs["shape"][1]["default"] == "arrow"
assert required_inputs["stroke_color"][1]["default"] == "#ff0000" assert required_inputs["stroke_color"][1]["default"] == "#ff0000"
@@ -1987,6 +1999,10 @@ def test_print_table():
from backend.nodes.print_table import PrintTable from backend.nodes.print_table import PrintTable
node = PrintTable() node = PrintTable()
table_spec = PrintTable.INPUT_TYPES()["required"]["table"]
assert table_spec[0] == "MEASURE_TABLE"
assert table_spec[1]["accepted_types"] == ["RECORD_TABLE"]
captured = [] captured = []
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows) PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
PrintTable._current_node_id = "test" PrintTable._current_node_id = "test"
@@ -2005,6 +2021,10 @@ def test_value_display():
from backend.nodes.value_display import ValueDisplay from backend.nodes.value_display import ValueDisplay
node = ValueDisplay() node = ValueDisplay()
value_spec = ValueDisplay.INPUT_TYPES()["required"]["value"]
assert value_spec[0] == "FLOAT"
assert value_spec[1]["accepted_types"] == ["MEASURE_TABLE"]
captured = [] captured = []
ValueDisplay._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) ValueDisplay._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
ValueDisplay._current_node_id = "test" ValueDisplay._current_node_id = "test"
@@ -2599,6 +2619,9 @@ def test_line_cursors():
from backend.nodes.cursors import Cursors from backend.nodes.cursors import Cursors
node = Cursors() node = Cursors()
line_spec = Cursors.INPUT_TYPES()["required"]["line"]
assert line_spec[0] == "LINE"
assert line_spec[1]["accepted_types"] == ["DATA_FIELD"]
# Create a simple linear ramp # Create a simple linear ramp
line = np.linspace(0, 10, 100).astype(np.float64) line = np.linspace(0, 10, 100).astype(np.float64)
@@ -2814,6 +2837,10 @@ def test_stats():
from backend.nodes.stats import Stats from backend.nodes.stats import Stats
node = Stats() node = Stats()
input_spec = Stats.INPUT_TYPES()["required"]["input"]
assert input_spec[0] == "DATA_FIELD"
assert input_spec[1]["accepted_types"] == ["IMAGE", "LINE", "RECORD_TABLE"]
captured = [] captured = []
Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload)) Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
Stats._current_node_id = "test" Stats._current_node_id = "test"
@@ -2998,6 +3025,17 @@ def test_save_generic():
from PIL import Image as PILImage from PIL import Image as PILImage
node = Save() node = Save()
value_spec = node.INPUT_TYPES()["required"]["value"]
assert value_spec[0] == "DATA_FIELD"
assert value_spec[1]["accepted_types"] == [
"IMAGE",
"ANNOTATION_SOURCE",
"LINE",
"MEASURE_TABLE",
"RECORD_TABLE",
"MESH_MODEL",
"FLOAT",
]
format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"] format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"]
assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"] assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"]