caching nodes to improve performance

This commit is contained in:
2026-03-26 21:43:08 -07:00
parent 0429f39a8d
commit 8be53e9e6d
2 changed files with 360 additions and 7 deletions

View File

@@ -21,9 +21,13 @@ The engine:
""" """
from __future__ import annotations from __future__ import annotations
import hashlib
import json
import uuid import uuid
from copy import deepcopy
from collections import defaultdict, deque from collections import defaultdict, deque
from math import isfinite from math import isfinite
from threading import RLock
from time import perf_counter from time import perf_counter
from typing import Any, Callable from typing import Any, Callable
@@ -43,6 +47,10 @@ def _is_link(value: Any) -> bool:
class ExecutionEngine: class ExecutionEngine:
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code.""" """Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
def __init__(self) -> None:
self._node_cache: dict[str, dict[str, Any]] = {}
self._cache_lock = RLock()
def execute( def execute(
self, self,
prompt: dict[str, dict], prompt: dict[str, dict],
@@ -75,6 +83,7 @@ class ExecutionEngine:
""" """
order = self._topological_sort(prompt) order = self._topological_sort(prompt)
node_outputs: dict[str, tuple] = {} node_outputs: dict[str, tuple] = {}
node_output_signatures: dict[str, tuple[str, ...]] = {}
# Inject display callbacks before execution # Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning) self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
@@ -90,10 +99,16 @@ class ExecutionEngine:
raw_inputs = node_def.get("inputs", {}) raw_inputs = node_def.get("inputs", {})
input_types = cls.INPUT_TYPES() input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types) inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
# Let display nodes know their node_id so they can tag WS messages # Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id) self._set_node_id_on_display(cls, node_id)
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
if cache_entry is not None:
result = self._clone_cached_outputs(cache_entry["outputs"])
elapsed_ms = 0.0
else:
if on_node_start: if on_node_start:
on_node_start(node_id) on_node_start(node_id)
@@ -108,6 +123,17 @@ class ExecutionEngine:
result = (result,) result = (result,)
node_outputs[node_id] = result node_outputs[node_id] = result
output_signatures = tuple(self._fingerprint_value(value) for value in result)
node_output_signatures[node_id] = output_signatures
if cache_entry is None and self._node_cacheable(cls):
self._store_cache_entry(
node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
# Auto-preview: broadcast a thumbnail for any DATA_FIELD, # Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result. # IMAGE, or table-like output so every node shows its result.
@@ -208,6 +234,193 @@ class ExecutionEngine:
return value return value
def _node_cacheable(self, cls: type) -> bool:
return not bool(getattr(cls, "manual_trigger", False))
def _get_cached_entry(self, node_id: str, class_name: str, input_signature: str) -> dict[str, Any] | None:
if not self._node_cacheable(NODE_CLASS_MAPPINGS[class_name]):
return None
with self._cache_lock:
entry = self._node_cache.get(node_id)
if not entry:
return None
if entry.get("class_name") != class_name:
return None
if entry.get("input_signature") != input_signature:
return None
return entry
def _store_cache_entry(
self,
*,
node_id: str,
class_name: str,
input_signature: str,
output_signatures: tuple[str, ...],
outputs: tuple,
) -> None:
with self._cache_lock:
self._node_cache[node_id] = {
"class_name": class_name,
"input_signature": input_signature,
"output_signatures": output_signatures,
"outputs": outputs,
}
def _build_input_signature(
self,
class_name: str,
raw_inputs: dict[str, Any],
node_output_signatures: dict[str, tuple[str, ...]],
) -> str:
normalized_inputs: dict[str, Any] = {}
for key in sorted(raw_inputs):
value = raw_inputs[key]
if _is_link(value):
src_id, slot = value[0], int(value[1])
source_signatures = node_output_signatures.get(src_id)
if source_signatures is None:
raise KeyError(f"Node '{src_id}' has no output signature yet — dependency ordering bug?")
if slot >= len(source_signatures):
raise IndexError(
f"Node '{src_id}' only has {len(source_signatures)} output signatures, "
f"but slot {slot} was requested."
)
normalized_inputs[key] = {
"kind": "link",
"source": src_id,
"slot": slot,
"signature": source_signatures[slot],
}
else:
normalized_inputs[key] = {
"kind": "value",
"signature": self._fingerprint_value(value),
}
return self._fingerprint_value({
"class_type": class_name,
"inputs": normalized_inputs,
})
def _fingerprint_value(self, value: Any) -> str:
return hashlib.blake2b(self._fingerprint_bytes(value), digest_size=16).hexdigest()
def _fingerprint_bytes(self, value: Any) -> bytes:
import numpy as np
from backend.data_types import DataField, ImageData, LineData, MeasureTable, MeshModel, RecordTable
if value is None:
return b"null"
if isinstance(value, bool):
return b"bool:1" if value else b"bool:0"
if isinstance(value, int) and not isinstance(value, bool):
return f"int:{value}".encode()
if isinstance(value, float):
return json.dumps(float(value), sort_keys=True, separators=(",", ":")).encode()
if isinstance(value, str):
return ("str:" + value).encode("utf-8", errors="surrogatepass")
if isinstance(value, DataField):
return b"|".join([
b"DataField",
self._fingerprint_bytes(value.data),
f"xreal:{value.xreal}".encode(),
f"yreal:{value.yreal}".encode(),
f"xoff:{value.xoff}".encode(),
f"yoff:{value.yoff}".encode(),
("ux:" + value.si_unit_xy).encode(),
("uz:" + value.si_unit_z).encode(),
("domain:" + value.domain).encode(),
self._fingerprint_bytes(value.colormap),
f"display_offset:{value.display_offset}".encode(),
f"display_scale:{value.display_scale}".encode(),
self._fingerprint_bytes(value.overlays),
])
if isinstance(value, LineData):
return b"|".join([
b"LineData",
self._fingerprint_bytes(value.data),
self._fingerprint_bytes(value.x_axis.tolist() if value.x_axis is not None else None),
("x_unit:" + value.x_unit).encode(),
("y_unit:" + value.y_unit).encode(),
])
if isinstance(value, MeshModel):
return b"|".join([
b"MeshModel",
self._fingerprint_bytes(value.vertices),
self._fingerprint_bytes(value.faces),
self._fingerprint_bytes(value.colors if value.colors is not None else None),
])
if isinstance(value, ImageData):
return b"|".join([
b"ImageData",
self._fingerprint_bytes(np.asarray(value)),
self._fingerprint_bytes(getattr(value, "metadata", {})),
])
if isinstance(value, np.ndarray):
array = np.ascontiguousarray(value)
header = json.dumps(
{"dtype": str(array.dtype), "shape": list(array.shape)},
sort_keys=True,
separators=(",", ":"),
).encode()
return b"|".join([b"ndarray", header, memoryview(array).tobytes()])
if isinstance(value, (MeasureTable, RecordTable, list)):
return b"[" + b",".join(self._fingerprint_bytes(item) for item in value) + b"]"
if isinstance(value, tuple):
return b"(" + b",".join(self._fingerprint_bytes(item) for item in value) + b")"
if isinstance(value, dict):
items = []
for key in sorted(value):
items.append(
self._fingerprint_bytes(str(key))
+ b":"
+ self._fingerprint_bytes(value[key])
)
return b"{" + b",".join(items) + b"}"
return repr(value).encode("utf-8", errors="surrogatepass")
def _clone_cached_outputs(self, outputs: tuple) -> tuple:
return tuple(self._clone_cached_value(value) for value in outputs)
def _clone_cached_value(self, value: Any) -> Any:
import numpy as np
from backend.data_types import DataField, ImageData, LineData, MeshModel
if isinstance(value, DataField):
return value.copy()
if isinstance(value, LineData):
return LineData(
data=value.data.copy(),
x_axis=value.x_axis.copy() if value.x_axis is not None else None,
x_unit=value.x_unit,
y_unit=value.y_unit,
)
if isinstance(value, MeshModel):
return MeshModel(
vertices=value.vertices.copy(),
faces=value.faces.copy(),
colors=value.colors.copy() if value.colors is not None else None,
)
if isinstance(value, ImageData):
return value.copy_with_metadata(data=np.asarray(value).copy())
if isinstance(value, np.ndarray):
return value.copy()
if isinstance(value, tuple):
return tuple(self._clone_cached_value(item) for item in value)
if isinstance(value, (list, dict)):
return deepcopy(value)
return value
def _inject_display_callbacks( def _inject_display_callbacks(
self, self,
on_preview: Callable | None, on_preview: Callable | None,

View File

@@ -1759,6 +1759,146 @@ def test_execution_engine_numeric_socket_coercion():
print(" PASS\n") print(" PASS\n")
def test_execution_engine_caches_unchanged_nodes():
print("=== Test: ExecutionEngine caches unchanged nodes ===")
from backend.execution import ExecutionEngine
from backend.node_registry import register_node
@register_node(display_name="Test Cache Source")
class TestCacheSource:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestCacheSource.calls += 1
return (float(value),)
@register_node(display_name="Test Cache Downstream")
class TestCacheDownstream:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestCacheDownstream.calls += 1
return (float(value) * 2.0,)
TestCacheSource.calls = 0
TestCacheDownstream.calls = 0
engine = ExecutionEngine()
prompt = {
"1": {
"class_type": "TestCacheSource",
"inputs": {"value": 2.5},
},
"2": {
"class_type": "TestCacheDownstream",
"inputs": {"value": ["1", 0]},
},
}
first_timings = []
first_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: first_timings.append((node_id, elapsed_ms)))
second_timings = []
second_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: second_timings.append((node_id, elapsed_ms)))
assert first_outputs["2"] == (5.0,)
assert second_outputs["2"] == (5.0,)
assert TestCacheSource.calls == 1
assert TestCacheDownstream.calls == 1
assert {node_id for node_id, _ in second_timings} == {"1", "2"}
assert all(elapsed_ms == 0.0 for _, elapsed_ms in second_timings)
print(" PASS\n")
def test_execution_engine_only_propagates_real_output_changes():
print("=== Test: ExecutionEngine propagates only real upstream output changes ===")
from backend.execution import ExecutionEngine
from backend.node_registry import register_node
@register_node(display_name="Test Quantized Source")
class TestQuantizedSource:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("FLOAT",)}}
RETURN_TYPES = ("INT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestQuantizedSource.calls += 1
return (int(round(float(value))),)
@register_node(display_name="Test Quantized Downstream")
class TestQuantizedDownstream:
calls = 0
@classmethod
def INPUT_TYPES(cls):
return {"required": {"value": ("INT",)}}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "tests"
def process(self, value):
TestQuantizedDownstream.calls += 1
return (float(value) + 0.5,)
TestQuantizedSource.calls = 0
TestQuantizedDownstream.calls = 0
engine = ExecutionEngine()
prompt = {
"1": {
"class_type": "TestQuantizedSource",
"inputs": {"value": 1.2},
},
"2": {
"class_type": "TestQuantizedDownstream",
"inputs": {"value": ["1", 0]},
},
}
outputs_first = engine.execute(prompt)
assert outputs_first["2"] == (1.5,)
prompt["1"]["inputs"]["value"] = 1.3
outputs_second = engine.execute(prompt)
assert outputs_second["2"] == (1.5,)
prompt["1"]["inputs"]["value"] = 2.2
outputs_third = engine.execute(prompt)
assert outputs_third["2"] == (2.5,)
assert TestQuantizedSource.calls == 3
assert TestQuantizedDownstream.calls == 2
print(" PASS\n")
# ========================================================================= # =========================================================================
# Analysis — Cursors # Analysis — Cursors
# ========================================================================= # =========================================================================