caching nodes to improve performance

This commit is contained in:
2026-03-26 21:43:08 -07:00
parent 0429f39a8d
commit 8be53e9e6d
2 changed files with 360 additions and 7 deletions

View File

@@ -21,9 +21,13 @@ The engine:
"""
from __future__ import annotations
import hashlib
import json
import uuid
from copy import deepcopy
from collections import defaultdict, deque
from math import isfinite
from threading import RLock
from time import perf_counter
from typing import Any, Callable
@@ -43,6 +47,10 @@ def _is_link(value: Any) -> bool:
class ExecutionEngine:
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
def __init__(self) -> None:
self._node_cache: dict[str, dict[str, Any]] = {}
self._cache_lock = RLock()
def execute(
self,
prompt: dict[str, dict],
@@ -75,6 +83,7 @@ class ExecutionEngine:
"""
order = self._topological_sort(prompt)
node_outputs: dict[str, tuple] = {}
node_output_signatures: dict[str, tuple[str, ...]] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning)
@@ -90,24 +99,41 @@ class ExecutionEngine:
raw_inputs = node_def.get("inputs", {})
input_types = cls.INPUT_TYPES()
inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types)
input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures)
# Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id)
if on_node_start:
on_node_start(node_id)
cache_entry = self._get_cached_entry(node_id, class_name, input_signature)
if cache_entry is not None:
result = self._clone_cached_outputs(cache_entry["outputs"])
elapsed_ms = 0.0
else:
if on_node_start:
on_node_start(node_id)
instance = cls()
func = getattr(instance, cls.FUNCTION)
start_time = perf_counter()
result = func(**inputs)
elapsed_ms = (perf_counter() - start_time) * 1000.0
instance = cls()
func = getattr(instance, cls.FUNCTION)
start_time = perf_counter()
result = func(**inputs)
elapsed_ms = (perf_counter() - start_time) * 1000.0
# Nodes must return a tuple; coerce single values just in case
if not isinstance(result, tuple):
result = (result,)
node_outputs[node_id] = result
output_signatures = tuple(self._fingerprint_value(value) for value in result)
node_output_signatures[node_id] = output_signatures
if cache_entry is None and self._node_cacheable(cls):
self._store_cache_entry(
node_id=node_id,
class_name=class_name,
input_signature=input_signature,
output_signatures=output_signatures,
outputs=self._clone_cached_outputs(result),
)
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or table-like output so every node shows its result.
@@ -208,6 +234,193 @@ class ExecutionEngine:
return value
def _node_cacheable(self, cls: type) -> bool:
return not bool(getattr(cls, "manual_trigger", False))
def _get_cached_entry(self, node_id: str, class_name: str, input_signature: str) -> dict[str, Any] | None:
if not self._node_cacheable(NODE_CLASS_MAPPINGS[class_name]):
return None
with self._cache_lock:
entry = self._node_cache.get(node_id)
if not entry:
return None
if entry.get("class_name") != class_name:
return None
if entry.get("input_signature") != input_signature:
return None
return entry
def _store_cache_entry(
self,
*,
node_id: str,
class_name: str,
input_signature: str,
output_signatures: tuple[str, ...],
outputs: tuple,
) -> None:
with self._cache_lock:
self._node_cache[node_id] = {
"class_name": class_name,
"input_signature": input_signature,
"output_signatures": output_signatures,
"outputs": outputs,
}
def _build_input_signature(
self,
class_name: str,
raw_inputs: dict[str, Any],
node_output_signatures: dict[str, tuple[str, ...]],
) -> str:
normalized_inputs: dict[str, Any] = {}
for key in sorted(raw_inputs):
value = raw_inputs[key]
if _is_link(value):
src_id, slot = value[0], int(value[1])
source_signatures = node_output_signatures.get(src_id)
if source_signatures is None:
raise KeyError(f"Node '{src_id}' has no output signature yet — dependency ordering bug?")
if slot >= len(source_signatures):
raise IndexError(
f"Node '{src_id}' only has {len(source_signatures)} output signatures, "
f"but slot {slot} was requested."
)
normalized_inputs[key] = {
"kind": "link",
"source": src_id,
"slot": slot,
"signature": source_signatures[slot],
}
else:
normalized_inputs[key] = {
"kind": "value",
"signature": self._fingerprint_value(value),
}
return self._fingerprint_value({
"class_type": class_name,
"inputs": normalized_inputs,
})
def _fingerprint_value(self, value: Any) -> str:
return hashlib.blake2b(self._fingerprint_bytes(value), digest_size=16).hexdigest()
def _fingerprint_bytes(self, value: Any) -> bytes:
import numpy as np
from backend.data_types import DataField, ImageData, LineData, MeasureTable, MeshModel, RecordTable
if value is None:
return b"null"
if isinstance(value, bool):
return b"bool:1" if value else b"bool:0"
if isinstance(value, int) and not isinstance(value, bool):
return f"int:{value}".encode()
if isinstance(value, float):
return json.dumps(float(value), sort_keys=True, separators=(",", ":")).encode()
if isinstance(value, str):
return ("str:" + value).encode("utf-8", errors="surrogatepass")
if isinstance(value, DataField):
return b"|".join([
b"DataField",
self._fingerprint_bytes(value.data),
f"xreal:{value.xreal}".encode(),
f"yreal:{value.yreal}".encode(),
f"xoff:{value.xoff}".encode(),
f"yoff:{value.yoff}".encode(),
("ux:" + value.si_unit_xy).encode(),
("uz:" + value.si_unit_z).encode(),
("domain:" + value.domain).encode(),
self._fingerprint_bytes(value.colormap),
f"display_offset:{value.display_offset}".encode(),
f"display_scale:{value.display_scale}".encode(),
self._fingerprint_bytes(value.overlays),
])
if isinstance(value, LineData):
return b"|".join([
b"LineData",
self._fingerprint_bytes(value.data),
self._fingerprint_bytes(value.x_axis.tolist() if value.x_axis is not None else None),
("x_unit:" + value.x_unit).encode(),
("y_unit:" + value.y_unit).encode(),
])
if isinstance(value, MeshModel):
return b"|".join([
b"MeshModel",
self._fingerprint_bytes(value.vertices),
self._fingerprint_bytes(value.faces),
self._fingerprint_bytes(value.colors if value.colors is not None else None),
])
if isinstance(value, ImageData):
return b"|".join([
b"ImageData",
self._fingerprint_bytes(np.asarray(value)),
self._fingerprint_bytes(getattr(value, "metadata", {})),
])
if isinstance(value, np.ndarray):
array = np.ascontiguousarray(value)
header = json.dumps(
{"dtype": str(array.dtype), "shape": list(array.shape)},
sort_keys=True,
separators=(",", ":"),
).encode()
return b"|".join([b"ndarray", header, memoryview(array).tobytes()])
if isinstance(value, (MeasureTable, RecordTable, list)):
return b"[" + b",".join(self._fingerprint_bytes(item) for item in value) + b"]"
if isinstance(value, tuple):
return b"(" + b",".join(self._fingerprint_bytes(item) for item in value) + b")"
if isinstance(value, dict):
items = []
for key in sorted(value):
items.append(
self._fingerprint_bytes(str(key))
+ b":"
+ self._fingerprint_bytes(value[key])
)
return b"{" + b",".join(items) + b"}"
return repr(value).encode("utf-8", errors="surrogatepass")
def _clone_cached_outputs(self, outputs: tuple) -> tuple:
return tuple(self._clone_cached_value(value) for value in outputs)
def _clone_cached_value(self, value: Any) -> Any:
import numpy as np
from backend.data_types import DataField, ImageData, LineData, MeshModel
if isinstance(value, DataField):
return value.copy()
if isinstance(value, LineData):
return LineData(
data=value.data.copy(),
x_axis=value.x_axis.copy() if value.x_axis is not None else None,
x_unit=value.x_unit,
y_unit=value.y_unit,
)
if isinstance(value, MeshModel):
return MeshModel(
vertices=value.vertices.copy(),
faces=value.faces.copy(),
colors=value.colors.copy() if value.colors is not None else None,
)
if isinstance(value, ImageData):
return value.copy_with_metadata(data=np.asarray(value).copy())
if isinstance(value, np.ndarray):
return value.copy()
if isinstance(value, tuple):
return tuple(self._clone_cached_value(item) for item in value)
if isinstance(value, (list, dict)):
return deepcopy(value)
return value
def _inject_display_callbacks(
self,
on_preview: Callable | None,