caching nodes to improve performance

This commit is contained in:
2026-03-26 21:43:08 -07:00
parent 0429f39a8d
commit 8be53e9e6d
2 changed files with 360 additions and 7 deletions

View File

@@ -21,9 +21,13 @@ The engine:
"""
from __future__ import annotations
import hashlib
import json
import uuid
from copy import deepcopy
from collections import defaultdict, deque
from math import isfinite
from threading import RLock
from time import perf_counter
from typing import Any, Callable
@@ -43,6 +47,10 @@ def _is_link(value: Any) -> bool:
class ExecutionEngine:
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
def __init__(self) -> None:
self._node_cache: dict[str, dict[str, Any]] = {}
self._cache_lock = RLock()
def execute(
self,
prompt: dict[str, dict],
@@ -75,6 +83,7 @@ class ExecutionEngine:
"""
order = self._topological_sort(prompt)
node_outputs: dict[str, tuple] = {}
node_output_signatures: dict[str, tuple[str, ...]] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
@@ -90,10 +99,16 @@ class ExecutionEngine:
raw_inputs = node_def.get("inputs", {})
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
# Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id)
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
if cache_entry is not None:
result = self._clone_cached_outputs(cache_entry["outputs"])
elapsed_ms = 0.0
else:
if on_node_start:
on_node_start(node_id)
@@ -108,6 +123,17 @@ class ExecutionEngine:
result = (result,)
node_outputs[node_id] = result
output_signatures = tuple(self._fingerprint_value(value) for value in result)
node_output_signatures[node_id] = output_signatures
if cache_entry is None and self._node_cacheable(cls):
self._store_cache_entry(
node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
@@ -208,6 +234,193 @@ class ExecutionEngine:
return value
def _node_cacheable(self, cls: type) -> bool:
return not bool(getattr(cls, "manual_trigger", False))
def _get_cached_entry(self, node_id: str, class_name: str, input_signature: str) -> dict[str, Any] | None:
if not self._node_cacheable(NODE_CLASS_MAPPINGS[class_name]):
return None
with self._cache_lock:
entry = self._node_cache.get(node_id)
if not entry:
return None
if entry.get("class_name") != class_name:
return None
if entry.get("input_signature") != input_signature:
return None
return entry
def _store_cache_entry(
self,
*,
node_id: str,
class_name: str,
input_signature: str,
output_signatures: tuple[str, ...],
outputs: tuple,
) -> None:
with self._cache_lock:
self._node_cache[node_id] = {
"class_name": class_name,
"input_signature": input_signature,
"output_signatures": output_signatures,
"outputs": outputs,
}
def _build_input_signature(
self,
class_name: str,
raw_inputs: dict[str, Any],
node_output_signatures: dict[str, tuple[str, ...]],
) -> str:
normalized_inputs: dict[str, Any] = {}
for key in sorted(raw_inputs):
value = raw_inputs[key]
if _is_link(value):
src_id, slot = value[0], int(value[1])
source_signatures = node_output_signatures.get(src_id)
if source_signatures is None:
raise KeyError(f"Node '{src_id}' has no output signature yet — dependency ordering bug?")
if slot >= len(source_signatures):
raise IndexError(
f"Node '{src_id}' only has {len(source_signatures)} output signatures, "
f"but slot {slot} was requested."
)
normalized_inputs[key] = {
"kind": "link",
"source": src_id,
"slot": slot,
"signature": source_signatures[slot],
}
else:
normalized_inputs[key] = {
"kind": "value",
"signature": self._fingerprint_value(value),
}
return self._fingerprint_value({
"class_type": class_name,
"inputs": normalized_inputs,
})
def _fingerprint_value(self, value: Any) -> str:
return hashlib.blake2b(self._fingerprint_bytes(value), digest_size=16).hexdigest()
def _fingerprint_bytes(self, value: Any) -> bytes:
import numpy as np
from backend.data_types import DataField, ImageData, LineData, MeasureTable, MeshModel, RecordTable
if value is None:
return b"null"
if isinstance(value, bool):
return b"bool:1" if value else b"bool:0"
if isinstance(value, int) and not isinstance(value, bool):
return f"int:{value}".encode()
if isinstance(value, float):
return json.dumps(float(value), sort_keys=True, separators=(",", ":")).encode()
if isinstance(value, str):
return ("str:" + value).encode("utf-8", errors="surrogatepass")
if isinstance(value, DataField):
return b"|".join([
b"DataField",
self._fingerprint_bytes(value.data),
f"xreal:{value.xreal}".encode(),
f"yreal:{value.yreal}".encode(),
f"xoff:{value.xoff}".encode(),
f"yoff:{value.yoff}".encode(),
("ux:" + value.si_unit_xy).encode(),
("uz:" + value.si_unit_z).encode(),
("domain:" + value.domain).encode(),
self._fingerprint_bytes(value.colormap),
f"display_offset:{value.display_offset}".encode(),
f"display_scale:{value.display_scale}".encode(),
self._fingerprint_bytes(value.overlays),
])
if isinstance(value, LineData):
return b"|".join([
b"LineData",
self._fingerprint_bytes(value.data),
self._fingerprint_bytes(value.x_axis.tolist() if value.x_axis is not None else None),
("x_unit:" + value.x_unit).encode(),
("y_unit:" + value.y_unit).encode(),
])
if isinstance(value, MeshModel):
return b"|".join([
b"MeshModel",
self._fingerprint_bytes(value.vertices),
self._fingerprint_bytes(value.faces),
self._fingerprint_bytes(value.colors if value.colors is not None else None),
])
if isinstance(value, ImageData):
return b"|".join([
b"ImageData",
self._fingerprint_bytes(np.asarray(value)),
self._fingerprint_bytes(getattr(value, "metadata", {})),
])
if isinstance(value, np.ndarray):
array = np.ascontiguousarray(value)
header = json.dumps(
{"dtype": str(array.dtype), "shape": list(array.shape)},
sort_keys=True,
separators=(",", ":"),
).encode()
return b"|".join([b"ndarray", header, memoryview(array).tobytes()])
if isinstance(value, (MeasureTable, RecordTable, list)):
return b"[" + b",".join(self._fingerprint_bytes(item) for item in value) + b"]"
if isinstance(value, tuple):
return b"(" + b",".join(self._fingerprint_bytes(item) for item in value) + b")"
if isinstance(value, dict):
items = []
for key in sorted(value):
items.append(
self._fingerprint_bytes(str(key))
+ b":"
+ self._fingerprint_bytes(value[key])
)
return b"{" + b",".join(items) + b"}"
return repr(value).encode("utf-8", errors="surrogatepass")
def _clone_cached_outputs(self, outputs: tuple) -> tuple:
return tuple(self._clone_cached_value(value) for value in outputs)
def _clone_cached_value(self, value: Any) -> Any:
import numpy as np
from backend.data_types import DataField, ImageData, LineData, MeshModel
if isinstance(value, DataField):
return value.copy()
if isinstance(value, LineData):
return LineData(
data=value.data.copy(),
x_axis=value.x_axis.copy() if value.x_axis is not None else None,
x_unit=value.x_unit,
y_unit=value.y_unit,
)
if isinstance(value, MeshModel):
return MeshModel(
vertices=value.vertices.copy(),
faces=value.faces.copy(),
colors=value.colors.copy() if value.colors is not None else None,
)
if isinstance(value, ImageData):
return value.copy_with_metadata(data=np.asarray(value).copy())
if isinstance(value, np.ndarray):
return value.copy()
if isinstance(value, tuple):
return tuple(self._clone_cached_value(item) for item in value)
if isinstance(value, (list, dict)):
return deepcopy(value)
return value
def _inject_display_callbacks(
self,
on_preview: Callable | None,

View File

@@ -1759,6 +1759,146 @@ def test_execution_engine_numeric_socket_coercion():
print(" PASS\n")
def test_execution_engine_caches_unchanged_nodes():
print("=== Test: ExecutionEngine caches unchanged nodes ===")
from backend.execution import ExecutionEngine
from backend.node_registry import register_node
@register_node(display_name="Test Cache Source")
class TestCacheSource:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestCacheSource.calls += 1
return (float(value),)
@register_node(display_name="Test Cache Downstream")
class TestCacheDownstream:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestCacheDownstream.calls += 1
return (float(value) * 2.0,)
TestCacheSource.calls = 0
TestCacheDownstream.calls = 0
engine = ExecutionEngine()
prompt = {
"1": {
"class_type": "TestCacheSource",
"inputs": {"value": 2.5},
},
"2": {
"class_type": "TestCacheDownstream",
"inputs": {"value": ["1", 0]},
},
}
first_timings = []
first_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: first_timings.append((node_id, elapsed_ms)))
second_timings = []
second_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: second_timings.append((node_id, elapsed_ms)))
assert first_outputs["2"] == (5.0,)
assert second_outputs["2"] == (5.0,)
assert TestCacheSource.calls == 1
assert TestCacheDownstream.calls == 1
assert {node_id for node_id, _ in second_timings} == {"1", "2"}
assert all(elapsed_ms == 0.0 for _, elapsed_ms in second_timings)
print(" PASS\n")
def test_execution_engine_only_propagates_real_output_changes():
print("=== Test: ExecutionEngine propagates only real upstream output changes ===")
from backend.execution import ExecutionEngine
from backend.node_registry import register_node
@register_node(display_name="Test Quantized Source")
class TestQuantizedSource:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
RETURN_TYPES = ("INT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestQuantizedSource.calls += 1
return (int(round(float(value))),)
@register_node(display_name="Test Quantized Downstream")
class TestQuantizedDownstream:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("INT",)}}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestQuantizedDownstream.calls += 1
return (float(value) + 0.5,)
TestQuantizedSource.calls = 0
TestQuantizedDownstream.calls = 0
engine = ExecutionEngine()
prompt = {
"1": {
"class_type": "TestQuantizedSource",
"inputs": {"value": 1.2},
},
"2": {
"class_type": "TestQuantizedDownstream",
"inputs": {"value": ["1", 0]},
},
}
outputs_first = engine.execute(prompt)
assert outputs_first["2"] == (1.5,)
prompt["1"]["inputs"]["value"] = 1.3
outputs_second = engine.execute(prompt)
assert outputs_second["2"] == (1.5,)
prompt["1"]["inputs"]["value"] = 2.2
outputs_third = engine.execute(prompt)
assert outputs_third["2"] == (2.5,)
assert TestQuantizedSource.calls == 3
assert TestQuantizedDownstream.calls == 2
print(" PASS\n")
# =========================================================================
# Analysis — Cursors
# =========================================================================