initial commit
This commit is contained in:
294
backend/execution.py
Normal file
294
backend/execution.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
Graph execution engine for argonode.
|
||||
|
||||
Prompt format (same as ComfyUI):
|
||||
{
|
||||
"node_id": {
|
||||
"class_type": "GaussianFilter",
|
||||
"inputs": {
|
||||
"field": ["upstream_node_id", 0], # link: [src_id, output_slot]
|
||||
"sigma": 2.0 # constant widget value
|
||||
}
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
The engine:
|
||||
1. Topologically sorts nodes (Kahn's algorithm).
|
||||
2. Resolves input links to actual Python objects from earlier outputs.
|
||||
3. Calls each node's FUNCTION method.
|
||||
4. Emits progress callbacks after each node.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.node_registry import NODE_CLASS_MAPPINGS
|
||||
|
||||
|
||||
def _is_link(value: Any) -> bool:
|
||||
"""A value is a link if it's a [node_id_str, slot_int] pair."""
|
||||
return (
|
||||
isinstance(value, (list, tuple))
|
||||
and len(value) == 2
|
||||
and isinstance(value[0], str)
|
||||
and isinstance(value[1], int)
|
||||
)
|
||||
|
||||
|
||||
class ExecutionEngine:
|
||||
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
prompt: dict[str, dict],
|
||||
on_node_start: Callable[[str], None] | None = None,
|
||||
on_node_done: Callable[[str], None] | None = None,
|
||||
on_preview: Callable[[str, str], None] | None = None,
|
||||
on_table: Callable[[str, list], None] | None = None,
|
||||
on_mesh: Callable[[str, dict], None] | None = None,
|
||||
on_overlay: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, tuple]:
|
||||
"""
|
||||
Execute the workflow described by `prompt`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prompt : workflow dict (node_id → {class_type, inputs})
|
||||
on_node_start : called with node_id just before a node executes
|
||||
on_node_done : called with node_id just after a node executes
|
||||
on_preview : called with (node_id, data_uri) when a display node runs
|
||||
on_table : called with (node_id, table_list) when PrintTable runs
|
||||
on_overlay : called with (node_id, data_uri) for interactive overlays
|
||||
|
||||
Returns
|
||||
-------
|
||||
node_outputs : {node_id → tuple-of-outputs} for every executed node
|
||||
"""
|
||||
order = self._topological_sort(prompt)
|
||||
node_outputs: dict[str, tuple] = {}
|
||||
|
||||
# Inject display callbacks before execution
|
||||
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay)
|
||||
|
||||
for node_id in order:
|
||||
node_def = prompt[node_id]
|
||||
class_name = node_def["class_type"]
|
||||
|
||||
if class_name not in NODE_CLASS_MAPPINGS:
|
||||
raise ValueError(f"Unknown node type: '{class_name}'")
|
||||
|
||||
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||
raw_inputs = node_def.get("inputs", {})
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs)
|
||||
|
||||
# Let display nodes know their node_id so they can tag WS messages
|
||||
self._set_node_id_on_display(cls, node_id)
|
||||
|
||||
if on_node_start:
|
||||
on_node_start(node_id)
|
||||
|
||||
instance = cls()
|
||||
func = getattr(instance, cls.FUNCTION)
|
||||
result = func(**inputs)
|
||||
|
||||
# Nodes must return a tuple; coerce single values just in case
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
|
||||
node_outputs[node_id] = result
|
||||
|
||||
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
|
||||
# IMAGE, or TABLE output so every node shows its result.
|
||||
if on_preview or on_table:
|
||||
self._auto_preview(cls, node_id, result, on_preview, on_table)
|
||||
|
||||
if on_node_done:
|
||||
on_node_done(node_id)
|
||||
|
||||
return node_outputs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _topological_sort(self, prompt: dict) -> list[str]:
|
||||
"""Kahn's algorithm — returns node IDs in dependency order."""
|
||||
in_degree: dict[str, int] = {nid: 0 for nid in prompt}
|
||||
dependents: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for node_id, node_def in prompt.items():
|
||||
for value in node_def.get("inputs", {}).values():
|
||||
if _is_link(value):
|
||||
src_id = value[0]
|
||||
if src_id in prompt:
|
||||
in_degree[node_id] += 1
|
||||
dependents[src_id].append(node_id)
|
||||
|
||||
queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0)
|
||||
order: list[str] = []
|
||||
|
||||
while queue:
|
||||
nid = queue.popleft()
|
||||
order.append(nid)
|
||||
for dep in dependents[nid]:
|
||||
in_degree[dep] -= 1
|
||||
if in_degree[dep] == 0:
|
||||
queue.append(dep)
|
||||
|
||||
if len(order) != len(prompt):
|
||||
raise ValueError("Cycle detected in workflow graph — cannot execute.")
|
||||
|
||||
return order
|
||||
|
||||
def _resolve_inputs(
|
||||
self,
|
||||
raw_inputs: dict[str, Any],
|
||||
node_outputs: dict[str, tuple],
|
||||
) -> dict[str, Any]:
|
||||
"""Replace [src_id, slot] links with actual output values."""
|
||||
resolved = {}
|
||||
for key, value in raw_inputs.items():
|
||||
if _is_link(value):
|
||||
src_id, slot = value[0], int(value[1])
|
||||
if src_id not in node_outputs:
|
||||
raise KeyError(
|
||||
f"Node '{src_id}' has no output yet — dependency ordering bug?"
|
||||
)
|
||||
outputs = node_outputs[src_id]
|
||||
if slot >= len(outputs):
|
||||
raise IndexError(
|
||||
f"Node '{src_id}' only has {len(outputs)} outputs, "
|
||||
f"but slot {slot} was requested."
|
||||
)
|
||||
resolved[key] = outputs[slot]
|
||||
else:
|
||||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
def _inject_display_callbacks(
|
||||
self,
|
||||
on_preview: Callable | None,
|
||||
on_table: Callable | None,
|
||||
on_mesh: Callable | None = None,
|
||||
on_overlay: Callable | None = None,
|
||||
) -> None:
|
||||
"""Wire up broadcast callbacks on display node classes."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection
|
||||
from backend.nodes.io import SaveImage
|
||||
|
||||
PreviewImage._broadcast_fn = on_preview
|
||||
View3D._broadcast_mesh_fn = on_mesh
|
||||
PrintTable._broadcast_table_fn = on_table
|
||||
CrossSection._broadcast_overlay_fn = on_overlay
|
||||
SaveImage._broadcast_preview = (
|
||||
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
|
||||
)
|
||||
|
||||
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||
"""Inform display nodes of their current node_id for WS tagging."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection
|
||||
if cls in (PreviewImage, PrintTable, View3D, CrossSection):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
def _auto_preview(
|
||||
self,
|
||||
cls: type,
|
||||
node_id: str,
|
||||
result: tuple,
|
||||
on_preview: Callable | None,
|
||||
on_table: Callable | None,
|
||||
) -> None:
|
||||
"""
|
||||
After every node executes, inspect its outputs and broadcast
|
||||
a preview for the first DATA_FIELD, IMAGE, or TABLE found.
|
||||
"""
|
||||
import numpy as np
|
||||
from backend.data_types import (
|
||||
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
|
||||
)
|
||||
|
||||
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||
|
||||
for slot, type_name in enumerate(return_types):
|
||||
if slot >= len(result):
|
||||
break
|
||||
value = result[slot]
|
||||
|
||||
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
|
||||
arr = datafield_to_uint8(value, "viridis")
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return # one preview per node is enough
|
||||
|
||||
if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview:
|
||||
arr = image_to_uint8(value)
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return
|
||||
|
||||
if type_name == "LINE" and isinstance(value, np.ndarray) and on_preview:
|
||||
preview = self._render_line_preview(cls, slot, result)
|
||||
if preview:
|
||||
on_preview(node_id, preview)
|
||||
return
|
||||
|
||||
if type_name == "TABLE" and isinstance(value, list) and on_table:
|
||||
on_table(node_id, value)
|
||||
return
|
||||
|
||||
|
||||
def _render_line_preview(
|
||||
self,
|
||||
cls: type,
|
||||
slot: int,
|
||||
result: tuple,
|
||||
) -> str | None:
|
||||
"""Render a LINE output as a small matplotlib plot, returned as a data URI."""
|
||||
import numpy as np
|
||||
import base64
|
||||
import io as _io
|
||||
|
||||
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||
|
||||
# Find the y-values (current slot) and try to find an x-axis
|
||||
y = result[slot]
|
||||
x = None
|
||||
# If the next output is also LINE, use it as x-axis
|
||||
if slot + 1 < len(return_types) and return_types[slot + 1] == "LINE":
|
||||
x = result[slot + 1]
|
||||
# Or if slot > 0 and previous is LINE, this slot is the x-axis — skip
|
||||
if slot > 0 and return_types[slot - 1] == "LINE":
|
||||
return None # the first LINE already plotted both
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
|
||||
fig.patch.set_facecolor("#1e293b")
|
||||
ax.set_facecolor("#0f172a")
|
||||
if x is not None:
|
||||
ax.plot(x, y, color="#ff9800", linewidth=1.2)
|
||||
else:
|
||||
ax.plot(y, color="#ff9800", linewidth=1.2)
|
||||
ax.tick_params(colors="#94a3b8", labelsize=7)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#334155")
|
||||
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
|
||||
fig.tight_layout(pad=0.4)
|
||||
|
||||
buf = _io.BytesIO()
|
||||
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
|
||||
plt.close(fig)
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return f"data:image/png;base64,{b64}"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def new_prompt_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
Reference in New Issue
Block a user