caching nodes to improve performance
This commit is contained in:
@@ -1759,6 +1759,146 @@ def test_execution_engine_numeric_socket_coercion():
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_execution_engine_caches_unchanged_nodes():
|
||||
print("=== Test: ExecutionEngine caches unchanged nodes ===")
|
||||
from backend.execution import ExecutionEngine
|
||||
from backend.node_registry import register_node
|
||||
|
||||
@register_node(display_name="Test Cache Source")
|
||||
class TestCacheSource:
|
||||
calls = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"value": ("FLOAT",)}}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
|
||||
def process(self, value):
|
||||
TestCacheSource.calls += 1
|
||||
return (float(value),)
|
||||
|
||||
@register_node(display_name="Test Cache Downstream")
|
||||
class TestCacheDownstream:
|
||||
calls = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"value": ("FLOAT",)}}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
|
||||
def process(self, value):
|
||||
TestCacheDownstream.calls += 1
|
||||
return (float(value) * 2.0,)
|
||||
|
||||
TestCacheSource.calls = 0
|
||||
TestCacheDownstream.calls = 0
|
||||
|
||||
engine = ExecutionEngine()
|
||||
prompt = {
|
||||
"1": {
|
||||
"class_type": "TestCacheSource",
|
||||
"inputs": {"value": 2.5},
|
||||
},
|
||||
"2": {
|
||||
"class_type": "TestCacheDownstream",
|
||||
"inputs": {"value": ["1", 0]},
|
||||
},
|
||||
}
|
||||
|
||||
first_timings = []
|
||||
first_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: first_timings.append((node_id, elapsed_ms)))
|
||||
second_timings = []
|
||||
second_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: second_timings.append((node_id, elapsed_ms)))
|
||||
|
||||
assert first_outputs["2"] == (5.0,)
|
||||
assert second_outputs["2"] == (5.0,)
|
||||
assert TestCacheSource.calls == 1
|
||||
assert TestCacheDownstream.calls == 1
|
||||
assert {node_id for node_id, _ in second_timings} == {"1", "2"}
|
||||
assert all(elapsed_ms == 0.0 for _, elapsed_ms in second_timings)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_execution_engine_only_propagates_real_output_changes():
|
||||
print("=== Test: ExecutionEngine propagates only real upstream output changes ===")
|
||||
from backend.execution import ExecutionEngine
|
||||
from backend.node_registry import register_node
|
||||
|
||||
@register_node(display_name="Test Quantized Source")
|
||||
class TestQuantizedSource:
|
||||
calls = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"value": ("FLOAT",)}}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
|
||||
def process(self, value):
|
||||
TestQuantizedSource.calls += 1
|
||||
return (int(round(float(value))),)
|
||||
|
||||
@register_node(display_name="Test Quantized Downstream")
|
||||
class TestQuantizedDownstream:
|
||||
calls = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"value": ("INT",)}}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
|
||||
def process(self, value):
|
||||
TestQuantizedDownstream.calls += 1
|
||||
return (float(value) + 0.5,)
|
||||
|
||||
TestQuantizedSource.calls = 0
|
||||
TestQuantizedDownstream.calls = 0
|
||||
|
||||
engine = ExecutionEngine()
|
||||
prompt = {
|
||||
"1": {
|
||||
"class_type": "TestQuantizedSource",
|
||||
"inputs": {"value": 1.2},
|
||||
},
|
||||
"2": {
|
||||
"class_type": "TestQuantizedDownstream",
|
||||
"inputs": {"value": ["1", 0]},
|
||||
},
|
||||
}
|
||||
|
||||
outputs_first = engine.execute(prompt)
|
||||
assert outputs_first["2"] == (1.5,)
|
||||
|
||||
prompt["1"]["inputs"]["value"] = 1.3
|
||||
outputs_second = engine.execute(prompt)
|
||||
assert outputs_second["2"] == (1.5,)
|
||||
|
||||
prompt["1"]["inputs"]["value"] = 2.2
|
||||
outputs_third = engine.execute(prompt)
|
||||
assert outputs_third["2"] == (2.5,)
|
||||
|
||||
assert TestQuantizedSource.calls == 3
|
||||
assert TestQuantizedDownstream.calls == 2
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis — Cursors
|
||||
# =========================================================================
|
||||
|
||||
Reference in New Issue
Block a user