From 961b5d08c81137b7ff27b75e42b6216908f8ef22 Mon Sep 17 00:00:00 2001 From: matei jordache Date: Sun, 29 Mar 2026 22:48:29 -0700 Subject: [PATCH] implement plugin system --- backend/node_menu.py | 14 ++-- backend/node_registry.py | 2 +- backend/plugin_loader.py | 121 +++++++++++++++++++++++++++++++ backend/runtime_paths.py | 25 ++++++- backend/server.py | 64 ++++++++++++++++- frontend/src/App.jsx | 6 ++ plugins/example_normalize.py | 135 +++++++++++++++++++++++++++++++++++ 7 files changed, 358 insertions(+), 9 deletions(-) create mode 100644 backend/plugin_loader.py create mode 100644 plugins/example_normalize.py diff --git a/backend/node_menu.py b/backend/node_menu.py index 74940f5..ae7c916 100644 --- a/backend/node_menu.py +++ b/backend/node_menu.py @@ -124,18 +124,22 @@ for category, class_names in MENU_LAYOUT.items(): }) -def get_menu_metadata(class_name: str) -> dict[str, Any]: +def get_menu_metadata(class_name: str, cls: type | None = None) -> dict[str, Any]: metadata = _NODE_METADATA.get(class_name) if metadata is not None: return dict(metadata) + # Nodes not listed in MENU_LAYOUT (e.g. plugins) can declare their own + # menu category via a CATEGORY class attribute. Falls back to "Unsorted". + category = getattr(cls, "CATEGORY", "Unsorted") if cls else "Unsorted" + order = len(_CATEGORY_ORDER) return { - "category": "Unsorted", - "category_order": len(_CATEGORY_ORDER), + "category": category, + "category_order": order, "menu_order": 10_000, "menu_categories": [{ - "category": "Unsorted", - "category_order": len(_CATEGORY_ORDER), + "category": category, + "category_order": order, "menu_order": 10_000, }], } diff --git a/backend/node_registry.py b/backend/node_registry.py index c32e6a4..8455ecf 100644 --- a/backend/node_registry.py +++ b/backend/node_registry.py @@ -74,7 +74,7 @@ def get_node_info(class_name: str) -> dict[str, Any]: """ cls = NODE_CLASS_MAPPINGS[class_name] input_types: dict = cls.INPUT_TYPES() - menu_metadata = get_menu_metadata(class_name) + menu_metadata = get_menu_metadata(class_name, cls) return { "name": class_name, diff --git a/backend/plugin_loader.py b/backend/plugin_loader.py new file mode 100644 index 0000000..d0958a2 --- /dev/null +++ b/backend/plugin_loader.py @@ -0,0 +1,121 @@ +""" +Plugin loader for argonode. + +Scans a plugins directory for .py files and packages (directories containing +__init__.py), imports each one, and lets their @register_node decorators +self-register into NODE_CLASS_MAPPINGS. Errors are logged as warnings and +never crash the server. + +Plugin authors write a single .py file dropped into the plugins/ directory: + + from backend.node_registry import register_node + from backend.data_types import DataField + + @register_node(display_name="My Filter") + class MyFilter: + CATEGORY = "Plugins" + + @classmethod + def INPUT_TYPES(cls): + return {"required": {"field": ("DATA_FIELD",)}} + + OUTPUTS = (("DATA_FIELD", "result"),) + FUNCTION = "process" + + def process(self, field: DataField) -> tuple: + ... + return (field.replace(data=result),) + +Multi-file plugins: place a directory with __init__.py in the plugins folder. +Files and directories whose names start with '_' are skipped (private helpers). +""" +from __future__ import annotations + +import importlib.util +import logging +import sys +import traceback +from pathlib import Path + +log = logging.getLogger(__name__) + + +def load_plugins(plugins_dir: Path) -> list[tuple[str, str]]: + """ + Import every plugin found in *plugins_dir*. + + Returns a list of ``(plugin_name, error_traceback)`` for each plugin that + failed to load. An empty list means all plugins loaded without error. + Plugins that fail do not block subsequent plugins from loading. + """ + if not plugins_dir.exists(): + return [] + + candidates = _discover(plugins_dir) + if not candidates: + return [] + + log.info("Loading plugins from %s", plugins_dir) + errors: list[tuple[str, str]] = [] + + for name, path in candidates: + try: + _import_plugin(name, path) + log.info("Plugin loaded: %s", name) + except Exception: + msg = traceback.format_exc() + log.warning("Plugin %r failed to load:\n%s", name, msg) + errors.append((name, msg)) + + return errors + + +def _discover(plugins_dir: Path) -> list[tuple[str, Path]]: + """ + Return ``(name, path)`` pairs for every importable plugin entry. + + A plugin is either: + - a ``.py`` file (excluding ``__init__.py`` and ``_``-prefixed files), or + - a sub-directory that contains ``__init__.py`` (a package plugin). + + Both kinds must not have a leading ``_`` in their name so that private + helper modules placed alongside plugins are not mistakenly imported. + """ + found: list[tuple[str, Path]] = [] + for entry in sorted(plugins_dir.iterdir()): + if entry.name.startswith("_"): + continue + if entry.is_file() and entry.suffix == ".py": + found.append((entry.stem, entry)) + elif entry.is_dir() and (entry / "__init__.py").exists(): + found.append((entry.name, entry / "__init__.py")) + return found + + +def _import_plugin(name: str, path: Path) -> None: + """ + Import a single plugin file (or package ``__init__.py``) via importlib. + + The module is registered under ``argonode_plugins.`` in + ``sys.modules``. This namespace: + - avoids collisions with any PyPI package of the same name, and + - makes package-style plugins (with sub-modules) work correctly, because + their relative imports resolve against the ``argonode_plugins.*`` parent. + + If the module was previously imported (e.g. on a hot-reload call after an + upload), it is deleted from ``sys.modules`` first so the file is re-executed + and any updated ``@register_node`` decorators take effect. + """ + module_name = f"argonode_plugins.{name}" + + # Remove stale module to support hot-reload after /upload-plugin. + if module_name in sys.modules: + del sys.modules[module_name] + + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot create a module spec for {path}") + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) # @register_node decorators fire here diff --git a/backend/runtime_paths.py b/backend/runtime_paths.py index 7ddbe72..31cd5b4 100644 --- a/backend/runtime_paths.py +++ b/backend/runtime_paths.py @@ -62,6 +62,29 @@ def output_dir() -> Path: return app_data_dir() / "output" -def ensure_runtime_dirs() -> None: +def plugins_dir() -> Path: + return app_data_dir() / "plugins" + + +def plugins_enabled(*, native: bool) -> bool: + """ + Return True when the plugin system should be active. + + Default behaviour: enabled on native/desktop builds, disabled for web. + Override with the ARGONODE_PLUGINS environment variable: + ARGONODE_PLUGINS=1 – force on (useful for testing plugins via main.py) + ARGONODE_PLUGINS=0 – force off (disable even on native builds) + """ + env = os.getenv("ARGONODE_PLUGINS", "").strip().lower() + if env in ("1", "true", "yes"): + return True + if env in ("0", "false", "no"): + return False + return native + + +def ensure_runtime_dirs(*, with_plugins: bool = False) -> None: input_dir().mkdir(parents=True, exist_ok=True) output_dir().mkdir(parents=True, exist_ok=True) + if with_plugins: + plugins_dir().mkdir(parents=True, exist_ok=True) diff --git a/backend/server.py b/backend/server.py index 095194f..31ca912 100644 --- a/backend/server.py +++ b/backend/server.py @@ -40,7 +40,7 @@ from pathlib import Path from aiohttp import web, WSMsgType from backend.frontend_build import FrontendBuildError, ensure_frontend_dist_ready -from backend.runtime_paths import ensure_runtime_dirs, frontend_dir, frontend_dist_dir, project_root +from backend.runtime_paths import ensure_runtime_dirs, frontend_dir, frontend_dist_dir, plugins_dir, plugins_enabled, project_root from backend.session_runtime import ( PATH_INPUT_TYPES, SESSION_HEADER, @@ -110,10 +110,16 @@ def create_app( allow_local_filesystem: bool = False, ) -> web.Application: import backend.nodes # noqa: F401 + + _plugins_on = plugins_enabled(native=allow_local_filesystem) + if _plugins_on: + from backend.plugin_loader import load_plugins + load_plugins(plugins_dir()) + from backend.execution import ExecutionEngine, new_prompt_id from backend.node_registry import NODE_CLASS_MAPPINGS, get_all_node_info - ensure_runtime_dirs() + ensure_runtime_dirs(with_plugins=_plugins_on) session_engines: dict[str, ExecutionEngine] = {} session_websockets: dict[str, set[web.WebSocketResponse]] = defaultdict(set) @@ -343,6 +349,58 @@ def create_app( content_type="application/json", ) + async def upload_plugin(request: web.Request) -> web.Response: + """ + Accept a .py plugin file, save it to plugins_dir(), hot-reload all + plugins, and notify every connected WebSocket client to refresh /nodes. + + Warning: uploading Python files is equivalent to remote code execution. + This endpoint is intentionally unrestricted because argonode is a + local-first application; do not expose it on a public network. + """ + reader = await request.multipart() + filename = "" + file_bytes = None + + while True: + part = await reader.next() + if part is None: + break + if part.name == "file": + filename = Path(part.filename or "plugin.py").name + chunks = [] + while True: + chunk = await part.read_chunk(65536) + if not chunk: + break + chunks.append(chunk) + file_bytes = b"".join(chunks) + + if file_bytes is None: + raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body") + if not filename.endswith(".py"): + raise web.HTTPBadRequest(reason="Only .py plugin files are accepted") + + dest = plugins_dir() / filename + dest.write_bytes(file_bytes) + + # Hot-reload: re-run the loader (handles re-import of changed files). + load_plugins(plugins_dir()) + + # Tell every connected frontend to re-fetch GET /nodes. + msg = _dumps({"type": "nodes_updated"}) + for ws_set in session_websockets.values(): + for ws in list(ws_set): + try: + await ws.send_str(msg) + except Exception: + pass + + return web.Response( + text=_dumps({"filename": filename, "loaded": True}), + content_type="application/json", + ) + async def download_file(request: web.Request) -> web.Response: body = await request.read() filename = request.query.get("filename", "workflow.png") @@ -469,6 +527,8 @@ def create_app( app.router.add_get("/folder-files", get_folder_files) app.router.add_post("/upload-folder", create_upload_folder) app.router.add_post("/upload", upload_file) + if _plugins_on: + app.router.add_post("/upload-plugin", upload_plugin) app.router.add_post("/download", download_file) app.router.add_post("/save-workflow-png", save_workflow_png) app.router.add_get("/channels", get_channels) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 583cb48..0aac854 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -1328,6 +1328,12 @@ function Flow() { case 'node_warning': updateNodeData(msg.data.node_id, { warning: msg.data.message }); break; + case 'nodes_updated': + api.getNodes().then((defs) => { + nodeDefsRef.current = defs; + setStatus({ text: `Plugin loaded — ${Object.keys(defs).length} nodes available.`, level: 'info' }); + }).catch(() => {}); + break; } }); api.initWS(); diff --git a/plugins/example_normalize.py b/plugins/example_normalize.py new file mode 100644 index 0000000..1fd2893 --- /dev/null +++ b/plugins/example_normalize.py @@ -0,0 +1,135 @@ +""" +Example argonode plugin: Normalize Z Range + +Drop any .py file into this plugins/ folder and restart argonode (or upload it +via POST /upload-plugin) — the node will appear in the Add Node menu immediately. + +─── What you need to import ───────────────────────────────────────────────── + + from backend.node_registry import register_node ← the decorator + from backend.data_types import DataField ← the main SPM data type + +Other available types (import from backend.data_types as needed): + LineData - 1-D profile data (data, x_axis arrays + units) + MeshModel - 3-D triangle mesh (vertices, faces, colors arrays) + RecordTable - measurement table (list of dicts with schema) + IMAGE - uint8 numpy array (masks, greyscale, RGB images) + +─── Input types you can declare in INPUT_TYPES ────────────────────────────── + + ("DATA_FIELD",) - SPM height/signal field + ("IMAGE",) - mask or image (uint8 ndarray) + ("LINE",) - 1-D line/profile data + ("FLOAT", {...options...}) - float number widget + ("INT", {...options...}) - integer number widget + (["choice_a", "choice_b"],) - dropdown menu + ("STRING", {...}) - text input + +─── Output types you can declare in OUTPUTS ───────────────────────────────── + + ("DATA_FIELD", "name") - SPM field + ("IMAGE", "name") - mask / image + ("LINE", "name") - 1-D data + ("FLOAT", "name") - scalar number + ("RECORD_TABLE","name") - measurement table + +─── Inputs are passed as keyword arguments to your process() method ───────── +─── Outputs must be returned as a tuple, one item per OUTPUTS entry ───────── +""" + +import numpy as np +from backend.node_registry import register_node +from backend.data_types import DataField, RecordTable + + +@register_node(display_name="Normalize Z Range") +class NormalizeZRange: + """Rescale height values so the full range maps to [low, high].""" + + # Menu category shown in the Add Node popup. + # Any string works; nodes sharing a category are grouped together. + CATEGORY = "Plugins" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + # DATA_FIELD is the standard SPM field type. + "field": ("DATA_FIELD",), + + # FLOAT widget with default, min, and max. + "low": ("FLOAT", {"default": 0.0}), + "high": ("FLOAT", {"default": 1.0}), + }, + # Optional inputs don't need to be connected. + "optional": { + # A mask (uint8, 0 or 255) can restrict which pixels are + # used to compute the min/max for normalisation. + "mask": ("IMAGE",), + }, + } + + # Each entry is (output_type, output_name). + # The tuple length must match the tuple returned by process(). + OUTPUTS = ( + ("DATA_FIELD", "normalized"), + # RECORD_TABLE outputs appear as a "Print Table" connector and can be + # wired to the PrintTable display node or the Save node (CSV/JSON). + # The table is a RecordTable — a plain list of dicts, each with the + # keys "quantity", "value", and "unit". + ("RECORD_TABLE", "stats"), + ) + + # Name of the method to call when the node executes. + FUNCTION = "process" + + DESCRIPTION = ( + "Linearly rescale the Z values so the full data range maps to " + "[low, high]. If a mask is connected, only masked pixels are used " + "to compute the source min/max (unmasked pixels are still rescaled). " + "Also outputs a measurement table with the source range statistics." + ) + + def process( + self, + field: DataField, + low: float, + high: float, + mask=None, # optional: uint8 ndarray or None + ) -> tuple: + data = field.data.astype(np.float64) + + # Determine the source range from masked pixels if a mask was provided, + # otherwise use the full field. + if mask is not None and mask.shape == data.shape: + active = data[mask > 0] + else: + active = data.ravel() + + src_min = float(active.min()) if active.size > 0 else float(data.min()) + src_max = float(active.max()) if active.size > 0 else float(data.max()) + + span = src_max - src_min + if span == 0.0: + # Flat field: fill with low. + result = np.full_like(data, low) + else: + result = low + (data - src_min) / span * (high - low) + + # field.replace() copies all metadata (size, units, offsets) and + # substitutes a new data array. Always use this instead of building + # a DataField from scratch, so physical dimensions are preserved. + + # Build a RECORD_TABLE: a list of {"quantity", "value", "unit"} dicts. + # Use field.si_unit_z for the physical Z unit stored on the field + # (e.g. "m" for height data). Plain dimensionless numbers get "". + table = RecordTable([ + {"quantity": "Source min", "value": src_min, "unit": field.si_unit_z}, + {"quantity": "Source max", "value": src_max, "unit": field.si_unit_z}, + {"quantity": "Source span", "value": src_max - src_min, "unit": field.si_unit_z}, + {"quantity": "Output low", "value": low, "unit": ""}, + {"quantity": "Output high", "value": high, "unit": ""}, + ]) + + # Return one value per OUTPUTS entry, in the same order. + return (field.replace(data=result), table)