refactor socket types

This commit is contained in:
2026-03-28 13:56:22 -07:00
parent 4368aeb4a0
commit 1b831cda5d
20 changed files with 366 additions and 79 deletions

View File

@@ -881,6 +881,7 @@ def test_angle_measure():
assert {entry["category"] for entry in info["menu_categories"]} == {"Overlay", "Measure"}
required_inputs = AngleMeasure.INPUT_TYPES()["required"]
optional_inputs = AngleMeasure.INPUT_TYPES().get("optional", {})
assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
assert required_inputs["color"][1]["default"] == "#ff9800"
assert required_inputs["stroke_width"][1]["default"] == 1.35
assert optional_inputs["line_thickness"][1]["hidden"] is True
@@ -1584,6 +1585,10 @@ def test_save_image():
from backend.nodes.save_image import SaveImage
import tifffile
node = SaveImage()
input_types = SaveImage.INPUT_TYPES()
field_spec = input_types["optional"]["field_0"]
assert field_spec[0] == "DATA_FIELD"
assert field_spec[1]["accepted_types"] == ["IMAGE", "ANNOTATION_SOURCE"]
field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
@@ -1729,6 +1734,9 @@ def test_preview_image():
from backend.data_types import ImageData
from backend.execution_context import active_node, execution_callbacks
node = PreviewImage()
preview_input = PreviewImage.INPUT_TYPES()["optional"]["input"]
assert preview_input[0] == "ANNOTATION_SOURCE"
assert preview_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
# Set up a capture for the broadcast
captured = []
@@ -1794,6 +1802,9 @@ def test_annotations():
node = Annotations()
font_node = Font()
annotation_input = Annotations.INPUT_TYPES()["required"]["input"]
assert annotation_input[0] == "ANNOTATION_SOURCE"
assert annotation_input[1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
warnings = []
field = DataField(
data=np.linspace(0.0, 1.0, 64 * 64, dtype=np.float64).reshape(64, 64),
@@ -1920,6 +1931,7 @@ def test_markup():
assert _preview_markup_stroke_width(5, 128, 128) == 5
assert _preview_markup_stroke_width(5, 2048, 2048) > 5
assert required_inputs["input"][1]["accepted_types"] == ["DATA_FIELD", "IMAGE"]
assert required_inputs["shape"][1]["default"] == "arrow"
assert required_inputs["stroke_color"][1]["default"] == "#ff0000"
@@ -1987,6 +1999,10 @@ def test_print_table():
from backend.nodes.print_table import PrintTable
node = PrintTable()
table_spec = PrintTable.INPUT_TYPES()["required"]["table"]
assert table_spec[0] == "MEASURE_TABLE"
assert table_spec[1]["accepted_types"] == ["RECORD_TABLE"]
captured = []
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
PrintTable._current_node_id = "test"
@@ -2005,6 +2021,10 @@ def test_value_display():
from backend.nodes.value_display import ValueDisplay
node = ValueDisplay()
value_spec = ValueDisplay.INPUT_TYPES()["required"]["value"]
assert value_spec[0] == "FLOAT"
assert value_spec[1]["accepted_types"] == ["MEASURE_TABLE"]
captured = []
ValueDisplay._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
ValueDisplay._current_node_id = "test"
@@ -2599,6 +2619,9 @@ def test_line_cursors():
from backend.nodes.cursors import Cursors
node = Cursors()
line_spec = Cursors.INPUT_TYPES()["required"]["line"]
assert line_spec[0] == "LINE"
assert line_spec[1]["accepted_types"] == ["DATA_FIELD"]
# Create a simple linear ramp
line = np.linspace(0, 10, 100).astype(np.float64)
@@ -2814,6 +2837,10 @@ def test_stats():
from backend.nodes.stats import Stats
node = Stats()
input_spec = Stats.INPUT_TYPES()["required"]["input"]
assert input_spec[0] == "DATA_FIELD"
assert input_spec[1]["accepted_types"] == ["IMAGE", "LINE", "RECORD_TABLE"]
captured = []
Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
Stats._current_node_id = "test"
@@ -2998,6 +3025,17 @@ def test_save_generic():
from PIL import Image as PILImage
node = Save()
value_spec = node.INPUT_TYPES()["required"]["value"]
assert value_spec[0] == "DATA_FIELD"
assert value_spec[1]["accepted_types"] == [
"IMAGE",
"ANNOTATION_SOURCE",
"LINE",
"MEASURE_TABLE",
"RECORD_TABLE",
"MESH_MODEL",
"FLOAT",
]
format_choices = node.INPUT_TYPES()["required"]["format"][1]["choices_by_source_type"]
assert format_choices["ANNOTATION_SOURCE"] == format_choices["IMAGE"]