Files
tono/tests/node_tests/execution.py
2026-03-29 16:42:55 -07:00

139 lines
4.6 KiB
Python

import backend.nodes # noqa: F401 — registers all nodes
from backend.execution import ExecutionEngine
from backend.node_registry import register_node
def test_execution_engine_numeric_socket_coercion():
@register_node(display_name="Test Echo Int")
class TestEchoInt:
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("INT",)}}
OUTPUTS = (('INT', 'value'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
return (value,)
@register_node(display_name="Test Echo Float")
class TestEchoFloat:
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
OUTPUTS = (('FLOAT', 'value'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
return (value,)
engine = ExecutionEngine()
prompt = {
"1": {"class_type": "Number", "inputs": {"value": 3.6}},
"2": {"class_type": "TestEchoInt", "inputs": {"value": ["1", 0]}},
"3": {"class_type": "TestEchoFloat", "inputs": {"value": ["1", 0]}},
}
outputs = engine.execute(prompt)
assert outputs["2"] == (4,)
assert outputs["3"] == (3.6,)
def test_execution_engine_caches_unchanged_nodes():
@register_node(display_name="Test Cache Source")
class TestCacheSource:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
OUTPUTS = (('FLOAT', 'value'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestCacheSource.calls += 1
return (float(value),)
@register_node(display_name="Test Cache Downstream")
class TestCacheDownstream:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
OUTPUTS = (('FLOAT', 'value'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestCacheDownstream.calls += 1
return (float(value) * 2.0,)
TestCacheSource.calls = 0
TestCacheDownstream.calls = 0
engine = ExecutionEngine()
prompt = {
"1": {"class_type": "TestCacheSource", "inputs": {"value": 2.5}},
"2": {"class_type": "TestCacheDownstream", "inputs": {"value": ["1", 0]}},
}
first_timings = []
first_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: first_timings.append((node_id, elapsed_ms)))
second_timings = []
second_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: second_timings.append((node_id, elapsed_ms)))
assert first_outputs["2"] == (5.0,)
assert second_outputs["2"] == (5.0,)
assert TestCacheSource.calls == 1
assert TestCacheDownstream.calls == 1
assert {node_id for node_id, _ in second_timings} == {"1", "2"}
assert all(elapsed_ms == 0.0 for _, elapsed_ms in second_timings)
def test_execution_engine_only_propagates_real_output_changes():
@register_node(display_name="Test Quantized Source")
class TestQuantizedSource:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
OUTPUTS = (('INT', 'value'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestQuantizedSource.calls += 1
return (int(round(float(value))),)
@register_node(display_name="Test Quantized Downstream")
class TestQuantizedDownstream:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("INT",)}}
OUTPUTS = (('FLOAT', 'value'),)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestQuantizedDownstream.calls += 1
return (float(value) + 0.5,)
TestQuantizedSource.calls = 0
TestQuantizedDownstream.calls = 0
engine = ExecutionEngine()
prompt = {
"1": {"class_type": "TestQuantizedSource", "inputs": {"value": 1.2}},
"2": {"class_type": "TestQuantizedDownstream", "inputs": {"value": ["1", 0]}},
}
outputs_first = engine.execute(prompt)
assert outputs_first["2"] == (1.5,)
prompt["1"]["inputs"]["value"] = 1.3
outputs_second = engine.execute(prompt)
assert outputs_second["2"] == (1.5,)
prompt["1"]["inputs"]["value"] = 2.2
outputs_third = engine.execute(prompt)
assert outputs_third["2"] == (2.5,)
assert TestQuantizedSource.calls == 3
assert TestQuantizedDownstream.calls == 2