caching nodes to improve performance
This commit is contained in:
@@ -21,9 +21,13 @@ The engine:
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict, deque
|
||||
from math import isfinite
|
||||
from threading import RLock
|
||||
from time import perf_counter
|
||||
from typing import Any, Callable
|
||||
|
||||
@@ -43,6 +47,10 @@ def _is_link(value: Any) -> bool:
|
||||
class ExecutionEngine:
|
||||
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._node_cache: dict[str, dict[str, Any]] = {}
|
||||
self._cache_lock = RLock()
|
||||
|
||||
def execute(
|
||||
self,
|
||||
prompt: dict[str, dict],
|
||||
@@ -75,6 +83,7 @@ class ExecutionEngine:
|
||||
"""
|
||||
order = self._topological_sort(prompt)
|
||||
node_outputs: dict[str, tuple] = {}
|
||||
node_output_signatures: dict[str, tuple[str, ...]] = {}
|
||||
|
||||
# Inject display callbacks before execution
|
||||
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
|
||||
@@ -90,10 +99,16 @@ class ExecutionEngine:
|
||||
raw_inputs = node_def.get("inputs", {})
|
||||
input_types = cls.INPUT_TYPES()
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
|
||||
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
|
||||
|
||||
# Let display nodes know their node_id so they can tag WS messages
|
||||
self._set_node_id_on_display(cls, node_id)
|
||||
|
||||
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
|
||||
if cache_entry is not None:
|
||||
result = self._clone_cached_outputs(cache_entry["outputs"])
|
||||
elapsed_ms = 0.0
|
||||
else:
|
||||
if on_node_start:
|
||||
on_node_start(node_id)
|
||||
|
||||
@@ -108,6 +123,17 @@ class ExecutionEngine:
|
||||
result = (result,)
|
||||
|
||||
node_outputs[node_id] = result
|
||||
output_signatures = tuple(self._fingerprint_value(value) for value in result)
|
||||
node_output_signatures[node_id] = output_signatures
|
||||
|
||||
if cache_entry is None and self._node_cacheable(cls):
|
||||
self._store_cache_entry(
|
||||
node_id=node_id,
|
||||
class_name=class_name,
|
||||
input_signature=input_signature,
|
||||
output_signatures=output_signatures,
|
||||
outputs=self._clone_cached_outputs(result),
|
||||
)
|
||||
|
||||
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
|
||||
# IMAGE, or table-like output so every node shows its result.
|
||||
@@ -208,6 +234,193 @@ class ExecutionEngine:
|
||||
|
||||
return value
|
||||
|
||||
def _node_cacheable(self, cls: type) -> bool:
|
||||
return not bool(getattr(cls, "manual_trigger", False))
|
||||
|
||||
def _get_cached_entry(self, node_id: str, class_name: str, input_signature: str) -> dict[str, Any] | None:
|
||||
if not self._node_cacheable(NODE_CLASS_MAPPINGS[class_name]):
|
||||
return None
|
||||
with self._cache_lock:
|
||||
entry = self._node_cache.get(node_id)
|
||||
if not entry:
|
||||
return None
|
||||
if entry.get("class_name") != class_name:
|
||||
return None
|
||||
if entry.get("input_signature") != input_signature:
|
||||
return None
|
||||
return entry
|
||||
|
||||
def _store_cache_entry(
|
||||
self,
|
||||
*,
|
||||
node_id: str,
|
||||
class_name: str,
|
||||
input_signature: str,
|
||||
output_signatures: tuple[str, ...],
|
||||
outputs: tuple,
|
||||
) -> None:
|
||||
with self._cache_lock:
|
||||
self._node_cache[node_id] = {
|
||||
"class_name": class_name,
|
||||
"input_signature": input_signature,
|
||||
"output_signatures": output_signatures,
|
||||
"outputs": outputs,
|
||||
}
|
||||
|
||||
def _build_input_signature(
|
||||
self,
|
||||
class_name: str,
|
||||
raw_inputs: dict[str, Any],
|
||||
node_output_signatures: dict[str, tuple[str, ...]],
|
||||
) -> str:
|
||||
normalized_inputs: dict[str, Any] = {}
|
||||
for key in sorted(raw_inputs):
|
||||
value = raw_inputs[key]
|
||||
if _is_link(value):
|
||||
src_id, slot = value[0], int(value[1])
|
||||
source_signatures = node_output_signatures.get(src_id)
|
||||
if source_signatures is None:
|
||||
raise KeyError(f"Node '{src_id}' has no output signature yet — dependency ordering bug?")
|
||||
if slot >= len(source_signatures):
|
||||
raise IndexError(
|
||||
f"Node '{src_id}' only has {len(source_signatures)} output signatures, "
|
||||
f"but slot {slot} was requested."
|
||||
)
|
||||
normalized_inputs[key] = {
|
||||
"kind": "link",
|
||||
"source": src_id,
|
||||
"slot": slot,
|
||||
"signature": source_signatures[slot],
|
||||
}
|
||||
else:
|
||||
normalized_inputs[key] = {
|
||||
"kind": "value",
|
||||
"signature": self._fingerprint_value(value),
|
||||
}
|
||||
|
||||
return self._fingerprint_value({
|
||||
"class_type": class_name,
|
||||
"inputs": normalized_inputs,
|
||||
})
|
||||
|
||||
def _fingerprint_value(self, value: Any) -> str:
|
||||
return hashlib.blake2b(self._fingerprint_bytes(value), digest_size=16).hexdigest()
|
||||
|
||||
def _fingerprint_bytes(self, value: Any) -> bytes:
|
||||
import numpy as np
|
||||
from backend.data_types import DataField, ImageData, LineData, MeasureTable, MeshModel, RecordTable
|
||||
|
||||
if value is None:
|
||||
return b"null"
|
||||
if isinstance(value, bool):
|
||||
return b"bool:1" if value else b"bool:0"
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return f"int:{value}".encode()
|
||||
if isinstance(value, float):
|
||||
return json.dumps(float(value), sort_keys=True, separators=(",", ":")).encode()
|
||||
if isinstance(value, str):
|
||||
return ("str:" + value).encode("utf-8", errors="surrogatepass")
|
||||
|
||||
if isinstance(value, DataField):
|
||||
return b"|".join([
|
||||
b"DataField",
|
||||
self._fingerprint_bytes(value.data),
|
||||
f"xreal:{value.xreal}".encode(),
|
||||
f"yreal:{value.yreal}".encode(),
|
||||
f"xoff:{value.xoff}".encode(),
|
||||
f"yoff:{value.yoff}".encode(),
|
||||
("ux:" + value.si_unit_xy).encode(),
|
||||
("uz:" + value.si_unit_z).encode(),
|
||||
("domain:" + value.domain).encode(),
|
||||
self._fingerprint_bytes(value.colormap),
|
||||
f"display_offset:{value.display_offset}".encode(),
|
||||
f"display_scale:{value.display_scale}".encode(),
|
||||
self._fingerprint_bytes(value.overlays),
|
||||
])
|
||||
|
||||
if isinstance(value, LineData):
|
||||
return b"|".join([
|
||||
b"LineData",
|
||||
self._fingerprint_bytes(value.data),
|
||||
self._fingerprint_bytes(value.x_axis.tolist() if value.x_axis is not None else None),
|
||||
("x_unit:" + value.x_unit).encode(),
|
||||
("y_unit:" + value.y_unit).encode(),
|
||||
])
|
||||
|
||||
if isinstance(value, MeshModel):
|
||||
return b"|".join([
|
||||
b"MeshModel",
|
||||
self._fingerprint_bytes(value.vertices),
|
||||
self._fingerprint_bytes(value.faces),
|
||||
self._fingerprint_bytes(value.colors if value.colors is not None else None),
|
||||
])
|
||||
|
||||
if isinstance(value, ImageData):
|
||||
return b"|".join([
|
||||
b"ImageData",
|
||||
self._fingerprint_bytes(np.asarray(value)),
|
||||
self._fingerprint_bytes(getattr(value, "metadata", {})),
|
||||
])
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
array = np.ascontiguousarray(value)
|
||||
header = json.dumps(
|
||||
{"dtype": str(array.dtype), "shape": list(array.shape)},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
).encode()
|
||||
return b"|".join([b"ndarray", header, memoryview(array).tobytes()])
|
||||
|
||||
if isinstance(value, (MeasureTable, RecordTable, list)):
|
||||
return b"[" + b",".join(self._fingerprint_bytes(item) for item in value) + b"]"
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return b"(" + b",".join(self._fingerprint_bytes(item) for item in value) + b")"
|
||||
|
||||
if isinstance(value, dict):
|
||||
items = []
|
||||
for key in sorted(value):
|
||||
items.append(
|
||||
self._fingerprint_bytes(str(key))
|
||||
+ b":"
|
||||
+ self._fingerprint_bytes(value[key])
|
||||
)
|
||||
return b"{" + b",".join(items) + b"}"
|
||||
|
||||
return repr(value).encode("utf-8", errors="surrogatepass")
|
||||
|
||||
def _clone_cached_outputs(self, outputs: tuple) -> tuple:
|
||||
return tuple(self._clone_cached_value(value) for value in outputs)
|
||||
|
||||
def _clone_cached_value(self, value: Any) -> Any:
|
||||
import numpy as np
|
||||
from backend.data_types import DataField, ImageData, LineData, MeshModel
|
||||
|
||||
if isinstance(value, DataField):
|
||||
return value.copy()
|
||||
if isinstance(value, LineData):
|
||||
return LineData(
|
||||
data=value.data.copy(),
|
||||
x_axis=value.x_axis.copy() if value.x_axis is not None else None,
|
||||
x_unit=value.x_unit,
|
||||
y_unit=value.y_unit,
|
||||
)
|
||||
if isinstance(value, MeshModel):
|
||||
return MeshModel(
|
||||
vertices=value.vertices.copy(),
|
||||
faces=value.faces.copy(),
|
||||
colors=value.colors.copy() if value.colors is not None else None,
|
||||
)
|
||||
if isinstance(value, ImageData):
|
||||
return value.copy_with_metadata(data=np.asarray(value).copy())
|
||||
if isinstance(value, np.ndarray):
|
||||
return value.copy()
|
||||
if isinstance(value, tuple):
|
||||
return tuple(self._clone_cached_value(item) for item in value)
|
||||
if isinstance(value, (list, dict)):
|
||||
return deepcopy(value)
|
||||
return value
|
||||
|
||||
def _inject_display_callbacks(
|
||||
self,
|
||||
on_preview: Callable | None,
|
||||
|
||||
@@ -1759,6 +1759,146 @@ def test_execution_engine_numeric_socket_coercion():
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_execution_engine_caches_unchanged_nodes():
|
||||
print("=== Test: ExecutionEngine caches unchanged nodes ===")
|
||||
from backend.execution import ExecutionEngine
|
||||
from backend.node_registry import register_node
|
||||
|
||||
@register_node(display_name="Test Cache Source")
|
||||
class TestCacheSource:
|
||||
calls = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"value": ("FLOAT",)}}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
|
||||
def process(self, value):
|
||||
TestCacheSource.calls += 1
|
||||
return (float(value),)
|
||||
|
||||
@register_node(display_name="Test Cache Downstream")
|
||||
class TestCacheDownstream:
|
||||
calls = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"value": ("FLOAT",)}}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
|
||||
def process(self, value):
|
||||
TestCacheDownstream.calls += 1
|
||||
return (float(value) * 2.0,)
|
||||
|
||||
TestCacheSource.calls = 0
|
||||
TestCacheDownstream.calls = 0
|
||||
|
||||
engine = ExecutionEngine()
|
||||
prompt = {
|
||||
"1": {
|
||||
"class_type": "TestCacheSource",
|
||||
"inputs": {"value": 2.5},
|
||||
},
|
||||
"2": {
|
||||
"class_type": "TestCacheDownstream",
|
||||
"inputs": {"value": ["1", 0]},
|
||||
},
|
||||
}
|
||||
|
||||
first_timings = []
|
||||
first_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: first_timings.append((node_id, elapsed_ms)))
|
||||
second_timings = []
|
||||
second_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: second_timings.append((node_id, elapsed_ms)))
|
||||
|
||||
assert first_outputs["2"] == (5.0,)
|
||||
assert second_outputs["2"] == (5.0,)
|
||||
assert TestCacheSource.calls == 1
|
||||
assert TestCacheDownstream.calls == 1
|
||||
assert {node_id for node_id, _ in second_timings} == {"1", "2"}
|
||||
assert all(elapsed_ms == 0.0 for _, elapsed_ms in second_timings)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_execution_engine_only_propagates_real_output_changes():
|
||||
print("=== Test: ExecutionEngine propagates only real upstream output changes ===")
|
||||
from backend.execution import ExecutionEngine
|
||||
from backend.node_registry import register_node
|
||||
|
||||
@register_node(display_name="Test Quantized Source")
|
||||
class TestQuantizedSource:
|
||||
calls = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"value": ("FLOAT",)}}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
|
||||
def process(self, value):
|
||||
TestQuantizedSource.calls += 1
|
||||
return (int(round(float(value))),)
|
||||
|
||||
@register_node(display_name="Test Quantized Downstream")
|
||||
class TestQuantizedDownstream:
|
||||
calls = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"value": ("INT",)}}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
RETURN_NAMES = ("value",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "tests"
|
||||
|
||||
def process(self, value):
|
||||
TestQuantizedDownstream.calls += 1
|
||||
return (float(value) + 0.5,)
|
||||
|
||||
TestQuantizedSource.calls = 0
|
||||
TestQuantizedDownstream.calls = 0
|
||||
|
||||
engine = ExecutionEngine()
|
||||
prompt = {
|
||||
"1": {
|
||||
"class_type": "TestQuantizedSource",
|
||||
"inputs": {"value": 1.2},
|
||||
},
|
||||
"2": {
|
||||
"class_type": "TestQuantizedDownstream",
|
||||
"inputs": {"value": ["1", 0]},
|
||||
},
|
||||
}
|
||||
|
||||
outputs_first = engine.execute(prompt)
|
||||
assert outputs_first["2"] == (1.5,)
|
||||
|
||||
prompt["1"]["inputs"]["value"] = 1.3
|
||||
outputs_second = engine.execute(prompt)
|
||||
assert outputs_second["2"] == (1.5,)
|
||||
|
||||
prompt["1"]["inputs"]["value"] = 2.2
|
||||
outputs_third = engine.execute(prompt)
|
||||
assert outputs_third["2"] == (2.5,)
|
||||
|
||||
assert TestQuantizedSource.calls == 3
|
||||
assert TestQuantizedDownstream.calls == 2
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis — Cursors
|
||||
# =========================================================================
|
||||
|
||||
Reference in New Issue
Block a user