From 8be53e9e6d2c66925f9d046ec564ec3bfbeac39e Mon Sep 17 00:00:00 2001 From: matei jordache Date: Thu, 26 Mar 2026 21:43:08 -0700 Subject: [PATCH] caching nodes to improve performance --- backend/execution.py | 227 +++++++++++++++++++++++++++++++++++++++++-- tests/test_nodes.py | 140 ++++++++++++++++++++++++++ 2 files changed, 360 insertions(+), 7 deletions(-) diff --git a/backend/execution.py b/backend/execution.py index ccb4596..0d14ed0 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -21,9 +21,13 @@ The engine: """ from __future__ import annotations +import hashlib +import json import uuid +from copy import deepcopy from collections import defaultdict, deque from math import isfinite +from threading import RLock from time import perf_counter from typing import Any, Callable @@ -43,6 +47,10 @@ def _is_link(value: Any) -> bool: class ExecutionEngine: """Synchronous (blocking) graph executor. Run inside a thread pool from async code.""" + def __init__(self) -> None: + self._node_cache: dict[str, dict[str, Any]] = {} + self._cache_lock = RLock() + def execute( self, prompt: dict[str, dict], @@ -75,6 +83,7 @@ class ExecutionEngine: """ order = self._topological_sort(prompt) node_outputs: dict[str, tuple] = {} + node_output_signatures: dict[str, tuple[str, ...]] = {} # Inject display callbacks before execution self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_value, on_warning) @@ -90,24 +99,41 @@ class ExecutionEngine: raw_inputs = node_def.get("inputs", {}) input_types = cls.INPUT_TYPES() inputs = self._resolve_inputs(raw_inputs, node_outputs, input_types) + input_signature = self._build_input_signature(class_name, raw_inputs, node_output_signatures) # Let display nodes know their node_id so they can tag WS messages self._set_node_id_on_display(cls, node_id) - if on_node_start: - on_node_start(node_id) + cache_entry = self._get_cached_entry(node_id, class_name, input_signature) + if cache_entry is not None: + result = self._clone_cached_outputs(cache_entry["outputs"]) + elapsed_ms = 0.0 + else: + if on_node_start: + on_node_start(node_id) - instance = cls() - func = getattr(instance, cls.FUNCTION) - start_time = perf_counter() - result = func(**inputs) - elapsed_ms = (perf_counter() - start_time) * 1000.0 + instance = cls() + func = getattr(instance, cls.FUNCTION) + start_time = perf_counter() + result = func(**inputs) + elapsed_ms = (perf_counter() - start_time) * 1000.0 # Nodes must return a tuple; coerce single values just in case if not isinstance(result, tuple): result = (result,) node_outputs[node_id] = result + output_signatures = tuple(self._fingerprint_value(value) for value in result) + node_output_signatures[node_id] = output_signatures + + if cache_entry is None and self._node_cacheable(cls): + self._store_cache_entry( + node_id=node_id, + class_name=class_name, + input_signature=input_signature, + output_signatures=output_signatures, + outputs=self._clone_cached_outputs(result), + ) # Auto-preview: broadcast a thumbnail for any DATA_FIELD, # IMAGE, or table-like output so every node shows its result. @@ -208,6 +234,193 @@ class ExecutionEngine: return value + def _node_cacheable(self, cls: type) -> bool: + return not bool(getattr(cls, "manual_trigger", False)) + + def _get_cached_entry(self, node_id: str, class_name: str, input_signature: str) -> dict[str, Any] | None: + if not self._node_cacheable(NODE_CLASS_MAPPINGS[class_name]): + return None + with self._cache_lock: + entry = self._node_cache.get(node_id) + if not entry: + return None + if entry.get("class_name") != class_name: + return None + if entry.get("input_signature") != input_signature: + return None + return entry + + def _store_cache_entry( + self, + *, + node_id: str, + class_name: str, + input_signature: str, + output_signatures: tuple[str, ...], + outputs: tuple, + ) -> None: + with self._cache_lock: + self._node_cache[node_id] = { + "class_name": class_name, + "input_signature": input_signature, + "output_signatures": output_signatures, + "outputs": outputs, + } + + def _build_input_signature( + self, + class_name: str, + raw_inputs: dict[str, Any], + node_output_signatures: dict[str, tuple[str, ...]], + ) -> str: + normalized_inputs: dict[str, Any] = {} + for key in sorted(raw_inputs): + value = raw_inputs[key] + if _is_link(value): + src_id, slot = value[0], int(value[1]) + source_signatures = node_output_signatures.get(src_id) + if source_signatures is None: + raise KeyError(f"Node '{src_id}' has no output signature yet — dependency ordering bug?") + if slot >= len(source_signatures): + raise IndexError( + f"Node '{src_id}' only has {len(source_signatures)} output signatures, " + f"but slot {slot} was requested." + ) + normalized_inputs[key] = { + "kind": "link", + "source": src_id, + "slot": slot, + "signature": source_signatures[slot], + } + else: + normalized_inputs[key] = { + "kind": "value", + "signature": self._fingerprint_value(value), + } + + return self._fingerprint_value({ + "class_type": class_name, + "inputs": normalized_inputs, + }) + + def _fingerprint_value(self, value: Any) -> str: + return hashlib.blake2b(self._fingerprint_bytes(value), digest_size=16).hexdigest() + + def _fingerprint_bytes(self, value: Any) -> bytes: + import numpy as np + from backend.data_types import DataField, ImageData, LineData, MeasureTable, MeshModel, RecordTable + + if value is None: + return b"null" + if isinstance(value, bool): + return b"bool:1" if value else b"bool:0" + if isinstance(value, int) and not isinstance(value, bool): + return f"int:{value}".encode() + if isinstance(value, float): + return json.dumps(float(value), sort_keys=True, separators=(",", ":")).encode() + if isinstance(value, str): + return ("str:" + value).encode("utf-8", errors="surrogatepass") + + if isinstance(value, DataField): + return b"|".join([ + b"DataField", + self._fingerprint_bytes(value.data), + f"xreal:{value.xreal}".encode(), + f"yreal:{value.yreal}".encode(), + f"xoff:{value.xoff}".encode(), + f"yoff:{value.yoff}".encode(), + ("ux:" + value.si_unit_xy).encode(), + ("uz:" + value.si_unit_z).encode(), + ("domain:" + value.domain).encode(), + self._fingerprint_bytes(value.colormap), + f"display_offset:{value.display_offset}".encode(), + f"display_scale:{value.display_scale}".encode(), + self._fingerprint_bytes(value.overlays), + ]) + + if isinstance(value, LineData): + return b"|".join([ + b"LineData", + self._fingerprint_bytes(value.data), + self._fingerprint_bytes(value.x_axis.tolist() if value.x_axis is not None else None), + ("x_unit:" + value.x_unit).encode(), + ("y_unit:" + value.y_unit).encode(), + ]) + + if isinstance(value, MeshModel): + return b"|".join([ + b"MeshModel", + self._fingerprint_bytes(value.vertices), + self._fingerprint_bytes(value.faces), + self._fingerprint_bytes(value.colors if value.colors is not None else None), + ]) + + if isinstance(value, ImageData): + return b"|".join([ + b"ImageData", + self._fingerprint_bytes(np.asarray(value)), + self._fingerprint_bytes(getattr(value, "metadata", {})), + ]) + + if isinstance(value, np.ndarray): + array = np.ascontiguousarray(value) + header = json.dumps( + {"dtype": str(array.dtype), "shape": list(array.shape)}, + sort_keys=True, + separators=(",", ":"), + ).encode() + return b"|".join([b"ndarray", header, memoryview(array).tobytes()]) + + if isinstance(value, (MeasureTable, RecordTable, list)): + return b"[" + b",".join(self._fingerprint_bytes(item) for item in value) + b"]" + + if isinstance(value, tuple): + return b"(" + b",".join(self._fingerprint_bytes(item) for item in value) + b")" + + if isinstance(value, dict): + items = [] + for key in sorted(value): + items.append( + self._fingerprint_bytes(str(key)) + + b":" + + self._fingerprint_bytes(value[key]) + ) + return b"{" + b",".join(items) + b"}" + + return repr(value).encode("utf-8", errors="surrogatepass") + + def _clone_cached_outputs(self, outputs: tuple) -> tuple: + return tuple(self._clone_cached_value(value) for value in outputs) + + def _clone_cached_value(self, value: Any) -> Any: + import numpy as np + from backend.data_types import DataField, ImageData, LineData, MeshModel + + if isinstance(value, DataField): + return value.copy() + if isinstance(value, LineData): + return LineData( + data=value.data.copy(), + x_axis=value.x_axis.copy() if value.x_axis is not None else None, + x_unit=value.x_unit, + y_unit=value.y_unit, + ) + if isinstance(value, MeshModel): + return MeshModel( + vertices=value.vertices.copy(), + faces=value.faces.copy(), + colors=value.colors.copy() if value.colors is not None else None, + ) + if isinstance(value, ImageData): + return value.copy_with_metadata(data=np.asarray(value).copy()) + if isinstance(value, np.ndarray): + return value.copy() + if isinstance(value, tuple): + return tuple(self._clone_cached_value(item) for item in value) + if isinstance(value, (list, dict)): + return deepcopy(value) + return value + def _inject_display_callbacks( self, on_preview: Callable | None, diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 114a17e..7f4926c 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1759,6 +1759,146 @@ def test_execution_engine_numeric_socket_coercion(): print(" PASS\n") +def test_execution_engine_caches_unchanged_nodes(): + print("=== Test: ExecutionEngine caches unchanged nodes ===") + from backend.execution import ExecutionEngine + from backend.node_registry import register_node + + @register_node(display_name="Test Cache Source") + class TestCacheSource: + calls = 0 + + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("FLOAT",)}} + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + CATEGORY = "tests" + + def process(self, value): + TestCacheSource.calls += 1 + return (float(value),) + + @register_node(display_name="Test Cache Downstream") + class TestCacheDownstream: + calls = 0 + + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("FLOAT",)}} + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + CATEGORY = "tests" + + def process(self, value): + TestCacheDownstream.calls += 1 + return (float(value) * 2.0,) + + TestCacheSource.calls = 0 + TestCacheDownstream.calls = 0 + + engine = ExecutionEngine() + prompt = { + "1": { + "class_type": "TestCacheSource", + "inputs": {"value": 2.5}, + }, + "2": { + "class_type": "TestCacheDownstream", + "inputs": {"value": ["1", 0]}, + }, + } + + first_timings = [] + first_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: first_timings.append((node_id, elapsed_ms))) + second_timings = [] + second_outputs = engine.execute(prompt, on_node_done=lambda node_id, elapsed_ms: second_timings.append((node_id, elapsed_ms))) + + assert first_outputs["2"] == (5.0,) + assert second_outputs["2"] == (5.0,) + assert TestCacheSource.calls == 1 + assert TestCacheDownstream.calls == 1 + assert {node_id for node_id, _ in second_timings} == {"1", "2"} + assert all(elapsed_ms == 0.0 for _, elapsed_ms in second_timings) + + print(" PASS\n") + + +def test_execution_engine_only_propagates_real_output_changes(): + print("=== Test: ExecutionEngine propagates only real upstream output changes ===") + from backend.execution import ExecutionEngine + from backend.node_registry import register_node + + @register_node(display_name="Test Quantized Source") + class TestQuantizedSource: + calls = 0 + + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("FLOAT",)}} + + RETURN_TYPES = ("INT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + CATEGORY = "tests" + + def process(self, value): + TestQuantizedSource.calls += 1 + return (int(round(float(value))),) + + @register_node(display_name="Test Quantized Downstream") + class TestQuantizedDownstream: + calls = 0 + + @classmethod + def INPUT_TYPES(cls): + return {"required": {"value": ("INT",)}} + + RETURN_TYPES = ("FLOAT",) + RETURN_NAMES = ("value",) + FUNCTION = "process" + CATEGORY = "tests" + + def process(self, value): + TestQuantizedDownstream.calls += 1 + return (float(value) + 0.5,) + + TestQuantizedSource.calls = 0 + TestQuantizedDownstream.calls = 0 + + engine = ExecutionEngine() + prompt = { + "1": { + "class_type": "TestQuantizedSource", + "inputs": {"value": 1.2}, + }, + "2": { + "class_type": "TestQuantizedDownstream", + "inputs": {"value": ["1", 0]}, + }, + } + + outputs_first = engine.execute(prompt) + assert outputs_first["2"] == (1.5,) + + prompt["1"]["inputs"]["value"] = 1.3 + outputs_second = engine.execute(prompt) + assert outputs_second["2"] == (1.5,) + + prompt["1"]["inputs"]["value"] = 2.2 + outputs_third = engine.execute(prompt) + assert outputs_third["2"] == (2.5,) + + assert TestQuantizedSource.calls == 3 + assert TestQuantizedDownstream.calls == 2 + + print(" PASS\n") + + # ========================================================================= # Analysis — Cursors # =========================================================================