initial commit

This commit is contained in:
2026-03-23 00:35:30 -07:00
parent 5ecc913e28
commit 87b6905fba
48 changed files with 7012 additions and 1 deletions

0
backend/__init__.py Normal file
View File

134
backend/data_types.py Normal file
View File

@@ -0,0 +1,134 @@
"""
Core data types for argonode.
DataField mirrors Gwyddion's GwyDataField structure:
xres, yres pixel dimensions
xreal, yreal physical dimensions in metres
xoff, yoff position offset in metres
si_unit_xy lateral unit string (e.g. "m", "nm")
si_unit_z height/value unit string (e.g. "m", "V", "A")
domain "spatial" or "frequency" (set by FFT nodes)
"""
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
@dataclass
class DataField:
data: np.ndarray # shape (yres, xres), dtype float64
xres: int = 0
yres: int = 0
xreal: float = 1e-6 # physical width in metres
yreal: float = 1e-6 # physical height in metres
xoff: float = 0.0
yoff: float = 0.0
si_unit_xy: str = "m"
si_unit_z: str = "m"
domain: str = "spatial" # "spatial" or "frequency"
def __post_init__(self) -> None:
self.data = np.asarray(self.data, dtype=np.float64)
if self.data.ndim != 2:
raise ValueError(f"DataField.data must be 2-D, got shape {self.data.shape}")
self.yres, self.xres = self.data.shape
def copy(self) -> "DataField":
"""Return a deep copy with independent data array."""
return DataField(
data=self.data.copy(),
xres=self.xres,
yres=self.yres,
xreal=self.xreal,
yreal=self.yreal,
xoff=self.xoff,
yoff=self.yoff,
si_unit_xy=self.si_unit_xy,
si_unit_z=self.si_unit_z,
domain=self.domain,
)
def replace(self, **kwargs) -> "DataField":
"""Return a copy with selected fields replaced. data is deep-copied unless provided."""
base = {
"data": self.data.copy(),
"xres": self.xres,
"yres": self.yres,
"xreal": self.xreal,
"yreal": self.yreal,
"xoff": self.xoff,
"yoff": self.yoff,
"si_unit_xy": self.si_unit_xy,
"si_unit_z": self.si_unit_z,
"domain": self.domain,
}
base.update(kwargs)
return DataField(**base)
@property
def dx(self) -> float:
"""Physical pixel size in x (metres)."""
return self.xreal / self.xres if self.xres else 1.0
@property
def dy(self) -> float:
"""Physical pixel size in y (metres)."""
return self.yreal / self.yres if self.yres else 1.0
# ---------------------------------------------------------------------------
# Utility helpers shared across nodes
# ---------------------------------------------------------------------------
def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
"""
Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap.
Returns shape (H, W, 3) uint8.
"""
import matplotlib.cm as cm
import matplotlib.colors as mcolors
data = df.data
dmin, dmax = data.min(), data.max()
if dmax > dmin:
normalized = (data - dmin) / (dmax - dmin)
else:
normalized = np.zeros_like(data)
cmap = cm.get_cmap(colormap)
rgba = cmap(normalized) # (H, W, 4) float [0,1]
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
return rgb
def image_to_uint8(image: np.ndarray) -> np.ndarray:
"""
Convert an IMAGE (float or uint8, 2-D or 3-D) to uint8 (H,W,3) or (H,W) for PIL.
"""
if image.dtype == np.uint8:
return image
# float — normalize to [0, 255]
imin, imax = image.min(), image.max()
if imax > imin:
out = (image - imin) / (imax - imin) * 255.0
else:
out = np.zeros_like(image)
return out.astype(np.uint8)
def encode_preview(arr: np.ndarray) -> str:
"""
Encode a uint8 numpy array as a base64 data URI (PNG).
arr: (H, W) grayscale or (H, W, 3) RGB, uint8.
"""
import base64
import io
from PIL import Image
img = Image.fromarray(arr)
buf = io.BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"

294
backend/execution.py Normal file
View File

@@ -0,0 +1,294 @@
"""
Graph execution engine for argonode.
Prompt format (same as ComfyUI):
{
"node_id": {
"class_type": "GaussianFilter",
"inputs": {
"field": ["upstream_node_id", 0], # link: [src_id, output_slot]
"sigma": 2.0 # constant widget value
}
},
...
}
The engine:
1. Topologically sorts nodes (Kahn's algorithm).
2. Resolves input links to actual Python objects from earlier outputs.
3. Calls each node's FUNCTION method.
4. Emits progress callbacks after each node.
"""
from __future__ import annotations
import uuid
from collections import defaultdict, deque
from typing import Any, Callable
from backend.node_registry import NODE_CLASS_MAPPINGS
def _is_link(value: Any) -> bool:
"""A value is a link if it's a [node_id_str, slot_int] pair."""
return (
isinstance(value, (list, tuple))
and len(value) == 2
and isinstance(value[0], str)
and isinstance(value[1], int)
)
class ExecutionEngine:
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
def execute(
self,
prompt: dict[str, dict],
on_node_start: Callable[[str], None] | None = None,
on_node_done: Callable[[str], None] | None = None,
on_preview: Callable[[str, str], None] | None = None,
on_table: Callable[[str, list], None] | None = None,
on_mesh: Callable[[str, dict], None] | None = None,
on_overlay: Callable[[str, str], None] | None = None,
) -> dict[str, tuple]:
"""
Execute the workflow described by `prompt`.
Parameters
----------
prompt : workflow dict (node_id → {class_type, inputs})
on_node_start : called with node_id just before a node executes
on_node_done : called with node_id just after a node executes
on_preview : called with (node_id, data_uri) when a display node runs
on_table : called with (node_id, table_list) when PrintTable runs
on_overlay : called with (node_id, data_uri) for interactive overlays
Returns
-------
node_outputs : {node_id → tuple-of-outputs} for every executed node
"""
order = self._topological_sort(prompt)
node_outputs: dict[str, tuple] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay)
for node_id in order:
node_def = prompt[node_id]
class_name = node_def["class_type"]
if class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Unknown node type: '{class_name}'")
cls = NODE_CLASS_MAPPINGS[class_name]
raw_inputs = node_def.get("inputs", {})
inputs = self._resolve_inputs(raw_inputs, node_outputs)
# Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id)
if on_node_start:
on_node_start(node_id)
instance = cls()
func = getattr(instance, cls.FUNCTION)
result = func(**inputs)
# Nodes must return a tuple; coerce single values just in case
if not isinstance(result, tuple):
result = (result,)
node_outputs[node_id] = result
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or TABLE output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table)
if on_node_done:
on_node_done(node_id)
return node_outputs
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _topological_sort(self, prompt: dict) -> list[str]:
"""Kahn's algorithm — returns node IDs in dependency order."""
in_degree: dict[str, int] = {nid: 0 for nid in prompt}
dependents: dict[str, list[str]] = defaultdict(list)
for node_id, node_def in prompt.items():
for value in node_def.get("inputs", {}).values():
if _is_link(value):
src_id = value[0]
if src_id in prompt:
in_degree[node_id] += 1
dependents[src_id].append(node_id)
queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0)
order: list[str] = []
while queue:
nid = queue.popleft()
order.append(nid)
for dep in dependents[nid]:
in_degree[dep] -= 1
if in_degree[dep] == 0:
queue.append(dep)
if len(order) != len(prompt):
raise ValueError("Cycle detected in workflow graph — cannot execute.")
return order
def _resolve_inputs(
self,
raw_inputs: dict[str, Any],
node_outputs: dict[str, tuple],
) -> dict[str, Any]:
"""Replace [src_id, slot] links with actual output values."""
resolved = {}
for key, value in raw_inputs.items():
if _is_link(value):
src_id, slot = value[0], int(value[1])
if src_id not in node_outputs:
raise KeyError(
f"Node '{src_id}' has no output yet — dependency ordering bug?"
)
outputs = node_outputs[src_id]
if slot >= len(outputs):
raise IndexError(
f"Node '{src_id}' only has {len(outputs)} outputs, "
f"but slot {slot} was requested."
)
resolved[key] = outputs[slot]
else:
resolved[key] = value
return resolved
def _inject_display_callbacks(
self,
on_preview: Callable | None,
on_table: Callable | None,
on_mesh: Callable | None = None,
on_overlay: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection
from backend.nodes.io import SaveImage
PreviewImage._broadcast_fn = on_preview
View3D._broadcast_mesh_fn = on_mesh
PrintTable._broadcast_table_fn = on_table
CrossSection._broadcast_overlay_fn = on_overlay
SaveImage._broadcast_preview = (
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
)
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection
if cls in (PreviewImage, PrintTable, View3D, CrossSection):
cls._current_node_id = node_id
def _auto_preview(
self,
cls: type,
node_id: str,
result: tuple,
on_preview: Callable | None,
on_table: Callable | None,
) -> None:
"""
After every node executes, inspect its outputs and broadcast
a preview for the first DATA_FIELD, IMAGE, or TABLE found.
"""
import numpy as np
from backend.data_types import (
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
)
return_types = getattr(cls, "RETURN_TYPES", ())
for slot, type_name in enumerate(return_types):
if slot >= len(result):
break
value = result[slot]
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
arr = datafield_to_uint8(value, "viridis")
on_preview(node_id, encode_preview(arr))
return # one preview per node is enough
if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview:
arr = image_to_uint8(value)
on_preview(node_id, encode_preview(arr))
return
if type_name == "LINE" and isinstance(value, np.ndarray) and on_preview:
preview = self._render_line_preview(cls, slot, result)
if preview:
on_preview(node_id, preview)
return
if type_name == "TABLE" and isinstance(value, list) and on_table:
on_table(node_id, value)
return
def _render_line_preview(
self,
cls: type,
slot: int,
result: tuple,
) -> str | None:
"""Render a LINE output as a small matplotlib plot, returned as a data URI."""
import numpy as np
import base64
import io as _io
return_types = getattr(cls, "RETURN_TYPES", ())
# Find the y-values (current slot) and try to find an x-axis
y = result[slot]
x = None
# If the next output is also LINE, use it as x-axis
if slot + 1 < len(return_types) and return_types[slot + 1] == "LINE":
x = result[slot + 1]
# Or if slot > 0 and previous is LINE, this slot is the x-axis — skip
if slot > 0 and return_types[slot - 1] == "LINE":
return None # the first LINE already plotted both
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
fig.patch.set_facecolor("#1e293b")
ax.set_facecolor("#0f172a")
if x is not None:
ax.plot(x, y, color="#ff9800", linewidth=1.2)
else:
ax.plot(y, color="#ff9800", linewidth=1.2)
ax.tick_params(colors="#94a3b8", labelsize=7)
for spine in ax.spines.values():
spine.set_color("#334155")
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
fig.tight_layout(pad=0.4)
buf = _io.BytesIO()
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
plt.close(fig)
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
except Exception:
return None
def new_prompt_id() -> str:
return str(uuid.uuid4())

47
backend/main.py Normal file
View File

@@ -0,0 +1,47 @@
"""
Entry point for argonode.
Run with:
python -m backend.main
or simply:
python backend/main.py
from the argonode/ directory.
"""
import asyncio
import logging
import sys
from pathlib import Path
# Allow running as `python backend/main.py` from the project root
sys.path.insert(0, str(Path(__file__).parent.parent))
from aiohttp import web
from backend.server import create_app
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(name)s%(message)s",
)
log = logging.getLogger(__name__)
HOST = "127.0.0.1"
PORT = 8188
def main() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
app = create_app(loop)
log.info("=" * 60)
log.info(" Argonode — Node-based image analysis")
log.info(" Open your browser at http://%s:%d", HOST, PORT)
log.info("=" * 60)
web.run_app(app, host=HOST, port=PORT, loop=loop, access_log=None)
if __name__ == "__main__":
main()

56
backend/node_registry.py Normal file
View File

@@ -0,0 +1,56 @@
"""
Node registry for argonode.
Nodes are plain Python classes decorated with @register_node.
NODE_CLASS_MAPPINGS is the single source of truth consumed by
the execution engine and the /nodes REST endpoint.
"""
from __future__ import annotations
from typing import Any
NODE_CLASS_MAPPINGS: dict[str, type] = {}
NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {}
def register_node(display_name: str | None = None):
"""
Class decorator that registers a node class into NODE_CLASS_MAPPINGS.
Usage:
@register_node(display_name="Gaussian Filter")
class GaussianFilter:
...
"""
def decorator(cls: type) -> type:
name = cls.__name__
NODE_CLASS_MAPPINGS[name] = cls
NODE_DISPLAY_NAME_MAPPINGS[name] = display_name or name
return cls
return decorator
def get_node_info(class_name: str) -> dict[str, Any]:
"""
Return a JSON-serialisable dict describing a node — consumed by GET /nodes.
Shape is compatible with what LiteGraph.js expects from the frontend.
"""
cls = NODE_CLASS_MAPPINGS[class_name]
input_types: dict = cls.INPUT_TYPES()
return {
"name": class_name,
"display_name": NODE_DISPLAY_NAME_MAPPINGS.get(class_name, class_name),
"category": getattr(cls, "CATEGORY", "uncategorized"),
"input": input_types,
"input_order": {k: list(v.keys()) for k, v in input_types.items()},
"output": list(cls.RETURN_TYPES),
"output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)),
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
"description": getattr(cls, "DESCRIPTION", ""),
}
def get_all_node_info() -> dict[str, dict[str, Any]]:
"""Return info dicts for every registered node."""
return {name: get_node_info(name) for name in NODE_CLASS_MAPPINGS}

View File

@@ -0,0 +1,2 @@
# Import all node modules to trigger @register_node decorators.
from . import io, filters, level, analysis, grains, display

471
backend/nodes/analysis.py Normal file
View File

@@ -0,0 +1,471 @@
"""
Analysis nodes — statistics, histograms, FFT, cross sections.
Gwyddion equivalents:
StatisticsNode → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h)
HeightHistogram → DH (height distribution), gwy_data_field_dh
FFT2D → gwy_data_field_2dfft + gwy_data_field_2dpsdf
CrossSection → gwy_data_field_get_profile (libprocess/datafield.c)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# StatisticsNode
# ---------------------------------------------------------------------------
@register_node(display_name="Statistics")
class StatisticsNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("stats",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute basic surface statistics: min, max, mean, RMS roughness, median, "
"and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms."
)
def process(self, field: DataField) -> tuple:
d = field.data
mean = float(d.mean())
rms = float(np.sqrt(np.mean((d - mean) ** 2)))
skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0
kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0
table = [
{"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z},
{"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z},
{"quantity": "mean", "value": mean, "unit": field.si_unit_z},
{"quantity": "RMS", "value": rms, "unit": field.si_unit_z},
{"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z},
{"quantity": "skewness", "value": skewness, "unit": ""},
{"quantity": "kurtosis", "value": kurtosis, "unit": ""},
{"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z},
]
return (table,)
# ---------------------------------------------------------------------------
# HeightHistogram
# ---------------------------------------------------------------------------
@register_node(display_name="Height Histogram")
class HeightHistogram:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
}
}
RETURN_TYPES = ("LINE", "LINE")
RETURN_NAMES = ("counts", "bin_centers")
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute the height distribution histogram (DH). "
"Equivalent to gwy_data_field_dh."
)
def process(self, field: DataField, n_bins: int) -> tuple:
counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
return (counts.astype(np.float64), bin_centers)
# ---------------------------------------------------------------------------
# FFT2D
# ---------------------------------------------------------------------------
@register_node(display_name="2D FFT")
class FFT2D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"windowing": (["hann", "hamming", "blackman", "none"],),
"level": (["mean", "plane", "none"],),
"output": (["log_magnitude", "magnitude", "phase", "psdf"],),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("spectrum",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
"Output can be log magnitude, magnitude, phase, or PSDF. "
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
)
def process(self, field: DataField, windowing: str, level: str, output: str) -> tuple:
data = field.data.copy()
yres, xres = data.shape
# Level subtraction (Gwyddion-style, before windowing)
if level == "mean":
data -= data.mean()
elif level == "plane":
# Fit and subtract a plane: z = a + b*x + c*y
yy, xx = np.mgrid[0:yres, 0:xres]
xx_f = xx.ravel().astype(np.float64)
yy_f = yy.ravel().astype(np.float64)
zz_f = data.ravel()
A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f])
coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None)
plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy)
data -= plane
# Windowing (Gwyddion uses (i+0.5)/n centred formulation)
if windowing != "none":
t_y = (np.arange(yres) + 0.5) / yres
t_x = (np.arange(xres) + 0.5) / xres
if windowing == "hann":
wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y)
wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x)
elif windowing == "hamming":
wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y)
wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x)
elif windowing == "blackman":
wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y)
wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x)
else:
wy = np.ones(yres)
wx = np.ones(xres)
data *= np.outer(wy, wx)
# 2D FFT, shifted so DC is at centre
F = np.fft.fftshift(np.fft.fft2(data))
n = xres * yres
if output == "log_magnitude":
mag = np.abs(F)
# Log scale with floor to avoid log(0)
result = np.log1p(mag)
elif output == "magnitude":
result = np.abs(F)
elif output == "phase":
result = np.angle(F)
elif output == "psdf":
# Gwyddion-equivalent PSDF: |F|^2 * dx * dy / (n * 4π²)
dx = field.xreal / xres
dy = field.yreal / yres
result = (np.abs(F) ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
else:
result = np.abs(F)
# Calibrate the output field in spatial-frequency units
if output == "psdf":
# Gwyddion uses angular frequency: 2π/dx, 2π/dy
freq_xreal = 2.0 * np.pi * xres / field.xreal
freq_yreal = 2.0 * np.pi * yres / field.yreal
z_unit = f"({field.si_unit_z})^2 m^2"
else:
freq_xreal = xres / field.xreal
freq_yreal = yres / field.yreal
z_unit = field.si_unit_z
out_field = DataField(
data=result,
xreal=freq_xreal,
yreal=freq_yreal,
si_unit_xy="1/m",
si_unit_z=z_unit,
domain="frequency",
)
return (out_field,)
# ---------------------------------------------------------------------------
# CrossSection
# ---------------------------------------------------------------------------
def _extend_to_edges(x1, y1, x2, y2):
"""
Extend the line through (x1,y1)-(x2,y2) to the boundaries of [0,1]x[0,1].
Returns the two intersection points (clipped to the unit square).
"""
dx = x2 - x1
dy = y2 - y1
# Collect parametric t values where line hits each boundary
t_candidates = []
if abs(dx) > 1e-12:
for bx in (0.0, 1.0):
t = (bx - x1) / dx
y_at_t = y1 + t * dy
if -1e-9 <= y_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if abs(dy) > 1e-12:
for by in (0.0, 1.0):
t = (by - y1) / dy
x_at_t = x1 + t * dx
if -1e-9 <= x_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if len(t_candidates) < 2:
return x1, y1, x2, y2
t_min = min(t_candidates)
t_max = max(t_candidates)
return (
np.clip(x1 + t_min * dx, 0, 1),
np.clip(y1 + t_min * dy, 0, 1),
np.clip(x1 + t_max * dx, 0, 1),
np.clip(y1 + t_max * dy, 0, 1),
)
@register_node(display_name="Cross Section")
class CrossSection:
"""Extract a 1-D height profile along an arbitrary line across the image."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"extend": (["none", "to_edges"],),
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
},
"optional": {
"point_a": ("COORD",),
"point_b": ("COORD",),
},
}
RETURN_TYPES = ("LINE",)
RETURN_NAMES = ("profile",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Extract a cross-section profile along a line between two points. "
"Drag the markers on the image to set the line endpoints. "
"Equivalent to gwy_data_field_get_profile."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self, field: DataField,
x1: float, y1: float, x2: float, y2: float,
extend: str, n_samples: int,
point_a=None, point_b=None,
) -> tuple:
from scipy.ndimage import map_coordinates
import io, base64
from matplotlib.figure import Figure
# COORD inputs override widget values
if point_a is not None:
x1, y1 = float(point_a[0]), float(point_a[1])
if point_b is not None:
x2, y2 = float(point_b[0]), float(point_b[1])
# Remember marker positions (before extend)
marker_x1, marker_y1 = float(x1), float(y1)
marker_x2, marker_y2 = float(x2), float(y2)
xres, yres = field.xres, field.yres
if extend == "to_edges":
x1, y1, x2, y2 = _extend_to_edges(
float(x1), float(y1), float(x2), float(y2),
)
# Convert fractional [0,1] to pixel indices [0, res-1]
px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1)
px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1)
# Number of sample points
line_len_px = np.hypot(px2 - px1, py2 - py1)
if n_samples <= 0:
n_samples = max(2, int(np.ceil(line_len_px)))
# Sample coordinates along the line
t = np.linspace(0, 1, n_samples)
coords_y = py1 + t * (py2 - py1)
coords_x = px1 + t * (px2 - px1)
# Interpolate values along the line (cubic spline)
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
# Broadcast overlay image with marker positions
if CrossSection._broadcast_overlay_fn is not None:
fig = Figure(figsize=(3, 3), dpi=100)
ax = fig.add_axes([0, 0, 1, 1])
ax.imshow(field.data, cmap="viridis", aspect="auto")
ax.axis("off")
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
CrossSection._broadcast_overlay_fn(
CrossSection._current_node_id,
{
"image": image_uri,
"x1": marker_x1, "y1": marker_y1,
"x2": marker_x2, "y2": marker_y2,
"a_locked": point_a is not None,
"b_locked": point_b is not None,
},
)
return (profile.astype(np.float64),)
# ---------------------------------------------------------------------------
# LineMath — single scalar measurement from a LINE profile
# ---------------------------------------------------------------------------
def _safe_rq(d):
"""RMS of deviations from mean."""
return float(np.sqrt(np.mean(d * d)))
# Registry: name → (function(z) → float, unit_label)
# All functions receive the raw 1-D profile as float64.
LINE_OPS: dict[str, tuple] = {}
def _line_op(name, unit=""):
"""Decorator to register a LINE operation."""
def decorator(fn):
LINE_OPS[name] = (fn, unit)
return fn
return decorator
# ── Basic statistics ──────────────────────────────────────────────────────
@_line_op("min")
def _op_min(z):
return float(z.min())
@_line_op("max")
def _op_max(z):
return float(z.max())
@_line_op("mean")
def _op_mean(z):
return float(z.mean())
@_line_op("median")
def _op_median(z):
return float(np.median(z))
@_line_op("sum")
def _op_sum(z):
return float(z.sum())
@_line_op("range")
def _op_range(z):
return float(z.max() - z.min())
@_line_op("length", unit="pts")
def _op_length(z):
return float(len(z))
@_line_op("rms")
def _op_rms(z):
return float(np.sqrt(np.mean(z * z)))
# ── Roughness parameters ──────────────────────────
@_line_op("Ra")
def _op_ra(z):
return float(np.mean(np.abs(z - z.mean())))
@_line_op("Rq")
def _op_rq(z):
d = z - z.mean()
return _safe_rq(d)
@_line_op("Rsk")
def _op_rsk(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0
@_line_op("Rku")
def _op_rku(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0
@_line_op("Rp")
def _op_rp(z):
return float((z - z.mean()).max())
@_line_op("Rv")
def _op_rv(z):
return float(-(z - z.mean()).min())
@_line_op("Rt")
def _op_rt(z):
d = z - z.mean()
return float(d.max() - d.min())
@_line_op("Dq")
def _op_dq(z):
"""RMS slope (first derivative RMS)."""
dz = np.diff(z)
return float(np.sqrt(np.mean(dz * dz)))
@_line_op("Da")
def _op_da(z):
"""Mean absolute slope."""
return float(np.mean(np.abs(np.diff(z))))
@register_node(display_name="Line Math")
class LineMath:
"""Compute a single scalar value from a LINE profile."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("LINE",),
"operation": (list(LINE_OPS.keys()),),
}
}
RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("result",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute a single scalar measurement from a LINE profile. "
"Includes basic stats and Gwyddion-convention roughness parameters."
)
def process(self, line, operation: str) -> tuple:
z = np.asarray(line, dtype=np.float64).ravel()
fn, unit = LINE_OPS[operation]
value = fn(z)
table = [{"quantity": operation, "value": value, "unit": unit}]
return (table,)

165
backend/nodes/display.py Normal file
View File

@@ -0,0 +1,165 @@
"""
Display / output nodes.
Preview accepts both DATA_FIELD and IMAGE via optional inputs —
connect whichever type you have. The server injects _broadcast_fn
before execution begins.
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview
@register_node(display_name="Preview")
class PreviewImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],),
},
"optional": {
"image": ("IMAGE",),
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ()
FUNCTION = "preview"
CATEGORY = "display"
OUTPUT_NODE = True
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
_broadcast_fn = None
_current_node_id: str = ""
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
# Prefer field if both are connected; accept whichever is provided
if field is not None:
arr_u8 = datafield_to_uint8(field, colormap)
elif image is not None:
if image.dtype != np.uint8:
imin, imax = image.min(), image.max()
if imax > imin:
norm = (image - imin) / (imax - imin)
else:
norm = np.zeros_like(image)
arr_u8 = (norm * 255).astype(np.uint8)
else:
arr_u8 = image
if arr_u8.ndim == 2 and colormap != "gray":
import matplotlib.cm as cm
cmap = cm.get_cmap(colormap)
rgba = cmap(arr_u8.astype(np.float32) / 255.0)
arr_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
else:
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
data_uri = encode_preview(arr_u8)
if PreviewImage._broadcast_fn is not None:
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
return ()
@register_node(display_name="3D View")
class View3D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],),
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
}
}
RETURN_TYPES = ()
FUNCTION = "render"
CATEGORY = "display"
OUTPUT_NODE = True
DESCRIPTION = (
"Interactive 3D surface view of a DATA_FIELD. "
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
)
_broadcast_mesh_fn = None
_current_node_id: str = ""
def render(
self, field: DataField,
colormap: str, z_scale: float, resolution: int,
) -> tuple:
import matplotlib.cm as cm
import base64
data = field.data
yres, xres = data.shape
# Downsample if larger than resolution
step_y = max(1, yres // resolution)
step_x = max(1, xres // resolution)
z = data[::step_y, ::step_x].astype(np.float32)
ny, nx = z.shape
# Normalize for colormap
zmin, zmax = float(z.min()), float(z.max())
if zmax > zmin:
z_norm = (z - zmin) / (zmax - zmin)
else:
z_norm = np.zeros_like(z)
cmap = cm.get_cmap(colormap)
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
# Base64-encode arrays for efficient WS transport
z_b64 = base64.b64encode(z.tobytes()).decode()
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
mesh_data = {
"width": nx,
"height": ny,
"z_data": z_b64,
"colors": colors_b64,
"z_min": zmin,
"z_max": zmax,
"z_scale": float(z_scale),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}
if View3D._broadcast_mesh_fn is not None:
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
return ()
@register_node(display_name="Print Table")
class PrintTable:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"table": ("TABLE",),
}
}
RETURN_TYPES = ()
FUNCTION = "print_table"
CATEGORY = "display"
OUTPUT_NODE = True
DESCRIPTION = "Send a TABLE to the browser as a WebSocket message for display."
_broadcast_table_fn = None
_current_node_id: str = ""
def print_table(self, table: list) -> tuple:
if PrintTable._broadcast_table_fn is not None:
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
return ()

115
backend/nodes/filters.py Normal file
View File

@@ -0,0 +1,115 @@
"""
Filter nodes — Gwyddion-equivalent image filters.
Gwyddion equivalents:
GaussianFilter → gwy_data_field_filter_gaussian
MedianFilter → gwy_data_field_filter_median
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# GaussianFilter
# ---------------------------------------------------------------------------
@register_node(display_name="Gaussian Filter")
class GaussianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
CATEGORY = "filters"
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
def process(self, field: DataField, sigma: float) -> tuple:
from scipy.ndimage import gaussian_filter
data = gaussian_filter(field.data.copy(), sigma=float(sigma))
return (field.replace(data=data),)
# ---------------------------------------------------------------------------
# MedianFilter
# ---------------------------------------------------------------------------
@register_node(display_name="Median Filter")
class MedianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
CATEGORY = "filters"
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
def process(self, field: DataField, size: int) -> tuple:
from scipy.ndimage import median_filter
size = max(1, int(size))
data = median_filter(field.data.copy(), size=size)
return (field.replace(data=data),)
# ---------------------------------------------------------------------------
# EdgeDetect
# ---------------------------------------------------------------------------
@register_node(display_name="Edge Detect")
class EdgeDetect:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["sobel", "prewitt", "laplacian", "log"],),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("edges",)
FUNCTION = "process"
CATEGORY = "filters"
DESCRIPTION = (
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
)
def process(self, field: DataField, method: str, sigma: float) -> tuple:
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
data = field.data.copy()
if method == "sobel":
sx = sobel(data, axis=1)
sy = sobel(data, axis=0)
result = np.hypot(sx, sy)
elif method == "prewitt":
px = prewitt(data, axis=1)
py = prewitt(data, axis=0)
result = np.hypot(px, py)
elif method == "laplacian":
result = laplace(data)
elif method == "log":
result = gaussian_laplace(data, sigma=float(sigma))
else:
raise ValueError(f"Unknown edge detection method: {method}")
return (field.replace(data=result),)

127
backend/nodes/grains.py Normal file
View File

@@ -0,0 +1,127 @@
"""
Grain/feature detection nodes.
Gwyddion equivalents:
ThresholdMask → threshold.c / otsu_threshold.c
GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# ThresholdMask
# ---------------------------------------------------------------------------
@register_node(display_name="Threshold Mask")
class ThresholdMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["otsu", "absolute", "relative"],),
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
"direction": (["above", "below"],),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "grains"
DESCRIPTION = (
"Create a binary mask by thresholding data. "
"Otsu automatically finds the optimal threshold. "
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
)
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
data = field.data
if method == "otsu":
from skimage.filters import threshold_otsu
t = threshold_otsu(data)
elif method == "absolute":
t = float(threshold)
elif method == "relative":
# threshold is a fraction [0, 1] of the data range
dmin, dmax = data.min(), data.max()
t = dmin + float(threshold) * (dmax - dmin)
else:
raise ValueError(f"Unknown threshold method: {method}")
if direction == "above":
mask = (data >= t).astype(np.uint8) * 255
else:
mask = (data < t).astype(np.uint8) * 255
return (mask,)
# ---------------------------------------------------------------------------
# GrainAnalysis
# ---------------------------------------------------------------------------
@register_node(display_name="Grain Analysis")
class GrainAnalysis:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"mask": ("IMAGE",),
"min_size": ("INT", {"default": 10, "min": 1, "max": 100000, "step": 1}),
}
}
RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("grain_stats",)
FUNCTION = "process"
CATEGORY = "grains"
DESCRIPTION = (
"Label connected grain regions in a binary mask and compute per-grain statistics: "
"area, equivalent diameter, mean/max height, bounding box. "
"Equivalent to gwy_data_field_grains_get_values."
)
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:
from scipy.ndimage import label, find_objects
binary = (mask > 127).astype(np.int32)
labeled, n_grains = label(binary)
pixel_area = field.dx * field.dy # m^2 per pixel
rows = []
for grain_id in range(1, n_grains + 1):
grain_pixels = labeled == grain_id
area_px = int(grain_pixels.sum())
if area_px < min_size:
continue
area_m2 = area_px * pixel_area
equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi))
heights = field.data[grain_pixels]
mean_h = float(heights.mean())
max_h = float(heights.max())
# Bounding box
ys, xs = np.where(grain_pixels)
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"
rows.append({
"grain_id": grain_id,
"area_px": area_px,
"area_m2": area_m2,
"equiv_diam_m": equiv_diam,
"mean_height": mean_h,
"max_height": max_h,
"bbox": bbox,
})
return (rows,)

277
backend/nodes/io.py Normal file
View File

@@ -0,0 +1,277 @@
"""
I/O nodes: load and save images and SPM data.
"""
from __future__ import annotations
import os
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview, image_to_uint8
# Resolved at server startup so nodes know where to look
INPUT_DIR = Path(__file__).parent.parent.parent / "input"
OUTPUT_DIR = Path(__file__).parent.parent.parent / "output"
# ---------------------------------------------------------------------------
# LoadImage
# ---------------------------------------------------------------------------
@register_node(display_name="Load Image")
class LoadImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
}
}
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
RETURN_NAMES = ("image", "field")
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD."
def load(self, filename: str):
# Accept absolute paths or filenames relative to input/
path = Path(filename)
if not path.is_absolute():
path = INPUT_DIR / filename
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
ext = path.suffix.lower()
if ext in (".npy",):
arr = np.load(str(path)).astype(np.float64)
elif ext in (".npz",):
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image
img = Image.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
# Convert to float64 grayscale for the DATA_FIELD output
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
field = DataField(data=gray)
return (arr, field)
# ---------------------------------------------------------------------------
# LoadSPM
# ---------------------------------------------------------------------------
@register_node(display_name="Load SPM File")
class LoadSPM:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
"channel": ("STRING", {"default": "Z"}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField."
def load(self, filename: str, channel: str = "Z"):
path = Path(filename)
if not path.is_absolute():
path = INPUT_DIR / filename
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
ext = path.suffix.lower()
if ext == ".gwy":
return (self._load_gwy(path, channel),)
elif ext == ".sxm":
return (self._load_sxm(path, channel),)
elif ext in (".ibw",):
return (self._load_ibw(path),)
elif ext in (".npy",):
data = np.load(str(path)).astype(np.float64)
return (DataField(data=data),)
elif ext in (".npz",):
npz = np.load(str(path))
key = list(npz.files)[0]
return (DataField(data=npz[key].astype(np.float64)),)
else:
raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz")
def _load_gwy(self, path: Path, channel: str) -> DataField:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
# Try requested channel name, fall back to first available
ch = None
for key, df in channels.items():
if channel.lower() in key.lower():
ch = df
break
if ch is None:
ch = next(iter(channels.values()))
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
return DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
)
def _load_sxm(self, path: Path, channel: str) -> DataField:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
# Pick channel
ch_key = None
for key in signals:
if channel.upper() in key.upper():
ch_key = key
break
if ch_key is None:
ch_key = next(iter(signals))
data = signals[ch_key].get("forward", list(signals[ch_key].values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
return DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
)
def _load_ibw(self, path: Path) -> DataField:
try:
import igor.igorpy as igorpy
wave = igorpy.load(str(path))
data = wave.wave["wData"].squeeze().astype(np.float64)
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
if data.ndim == 1:
data = data.reshape(1, -1)
elif data.ndim != 2:
data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1)
return DataField(data=data, si_unit_xy="m", si_unit_z="m")
# ---------------------------------------------------------------------------
# Coordinate
# ---------------------------------------------------------------------------
@register_node(display_name="Coordinate")
class Coordinate:
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("COORD",)
RETURN_NAMES = ("point",)
FUNCTION = "process"
CATEGORY = "io"
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
def process(self, x: float, y: float) -> tuple:
return ((float(x), float(y)),)
# ---------------------------------------------------------------------------
# SaveImage
# ---------------------------------------------------------------------------
@register_node(display_name="Save Image")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "output"}),
"format": (["PNG", "TIFF", "NPY"],),
}
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "io"
OUTPUT_NODE = True
DESCRIPTION = "Save an image or array to the output folder."
# Injected by server.py before execution begins
_broadcast_preview = None
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"):
OUTPUT_DIR.mkdir(exist_ok=True)
# Find next available filename
idx = 1
while True:
name = f"{filename_prefix}_{idx:04d}"
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
if not candidate.exists():
break
idx += 1
if format == "NPY":
np.save(str(OUTPUT_DIR / f"{name}.npy"), image)
else:
from PIL import Image
arr = image_to_uint8(image)
if arr.ndim == 2:
pil_img = Image.fromarray(arr, mode="L")
else:
pil_img = Image.fromarray(arr, mode="RGB")
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
# Emit preview over WebSocket if callback is set
if SaveImage._broadcast_preview is not None:
arr_u8 = image_to_uint8(image)
data_uri = encode_preview(arr_u8)
SaveImage._broadcast_preview(data_uri)
return ()

150
backend/nodes/level.py Normal file
View File

@@ -0,0 +1,150 @@
"""
Leveling nodes — background removal and zero correction.
Gwyddion equivalents:
PlaneLevelField → gwy_data_field_fit_plane + gwy_data_field_plane_level
PolyLevelField → gwy_data_field_fit_polynom (via level.c polylevel module)
FixZero → fix_zero in level.c
Plane-fit algorithm follows Gwyddion's level.h definition:
z_fit = pa + pbx * x + pby * y (least-squares over all pixels)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# PlaneLevelField
# ---------------------------------------------------------------------------
@register_node(display_name="Plane Level")
class PlaneLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("leveled",)
FUNCTION = "process"
CATEGORY = "level"
DESCRIPTION = (
"Fit and subtract a least-squares plane from the data. "
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
)
def process(self, field: DataField) -> tuple:
data = field.data.copy()
yres, xres = data.shape
# Normalised coordinate grids in [0, 1]
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
# Design matrix: [1, x, y] shape (N, 3)
A = np.column_stack([
np.ones(xres * yres),
xx.ravel(),
yy.ravel(),
])
z = data.ravel()
# Least-squares: solve A @ [pa, pbx, pby] = z
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
pa, pbx, pby = coeffs
plane = (pa + pbx * xx + pby * yy)
return (field.replace(data=data - plane),)
# ---------------------------------------------------------------------------
# PolyLevelField
# ---------------------------------------------------------------------------
@register_node(display_name="Polynomial Level")
class PolyLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
RETURN_NAMES = ("leveled", "background")
FUNCTION = "process"
CATEGORY = "level"
DESCRIPTION = (
"Fit and subtract a polynomial background of given degree in x and y. "
"Equivalent to gwy_data_field_fit_polynom."
)
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
data = field.data.copy()
yres, xres = data.shape
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
# Build Vandermonde-style design matrix with all monomials x^i * y^j
cols = []
for i in range(degree_x + 1):
for j in range(degree_y + 1):
cols.append((xx ** i * yy ** j).ravel())
A = np.column_stack(cols)
z = data.ravel()
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
background = (A @ coeffs).reshape(yres, xres)
leveled = data - background
return (field.replace(data=leveled), field.replace(data=background))
# ---------------------------------------------------------------------------
# FixZero
# ---------------------------------------------------------------------------
@register_node(display_name="Fix Zero")
class FixZero:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["min", "mean", "median"],),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("zeroed",)
FUNCTION = "process"
CATEGORY = "level"
DESCRIPTION = (
"Shift data so that the minimum (or mean/median) is zero. "
"Equivalent to fix_zero in Gwyddion's level.c."
)
def process(self, field: DataField, method: str) -> tuple:
data = field.data.copy()
if method == "min":
data -= data.min()
elif method == "mean":
data -= data.mean()
elif method == "median":
data -= np.median(data)
else:
raise ValueError(f"Unknown method: {method}")
return (field.replace(data=data),)

267
backend/server.py Normal file
View File

@@ -0,0 +1,267 @@
"""
aiohttp web server for argonode.
Routes
------
GET / → serve frontend/index.html
GET /static/{path} → serve frontend JS/CSS
GET /nodes → JSON dict of all registered node definitions
POST /upload → multipart file upload to input/
POST /prompt → submit a workflow; returns {prompt_id}
GET /ws → WebSocket upgrade
WebSocket message types sent to clients
----------------------------------------
{"type": "execution_start", "data": {"prompt_id": "..."}}
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
{"type": "table", "data": {"node_id": "...", "rows": [...]}}
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
{"type": "execution_complete", "data": {"prompt_id": "..."}}
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from pathlib import Path
from aiohttp import web, WSMsgType
log = logging.getLogger(__name__)
FRONTEND_DIR = Path(__file__).parent.parent / "frontend"
DIST_DIR = FRONTEND_DIR / "dist"
INPUT_DIR = Path(__file__).parent.parent / "input"
OUTPUT_DIR = Path(__file__).parent.parent / "output"
# ---------------------------------------------------------------------------
# JSON helper — numpy scalars are not serialisable by default
# ---------------------------------------------------------------------------
class _SafeEncoder(json.JSONEncoder):
def default(self, obj):
import numpy as np
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
def _dumps(obj) -> str:
return json.dumps(obj, cls=_SafeEncoder)
# ---------------------------------------------------------------------------
# Application factory
# ---------------------------------------------------------------------------
def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
# Import nodes to trigger registration decorators
import backend.nodes # noqa: F401
from backend.node_registry import get_all_node_info
from backend.execution import ExecutionEngine, new_prompt_id
INPUT_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)
engine = ExecutionEngine()
websockets: set[web.WebSocketResponse] = set()
# ------------------------------------------------------------------
# WebSocket broadcast helpers
# ------------------------------------------------------------------
def broadcast(msg: dict) -> None:
"""Schedule a broadcast to all connected WebSocket clients."""
payload = _dumps(msg)
for ws in list(websockets):
if not ws.closed:
asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop)
def on_preview(node_id: str, data_uri: str) -> None:
broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
def on_table(node_id: str, rows: list) -> None:
broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}})
def on_mesh(node_id: str, mesh_data: dict) -> None:
broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
def on_overlay(node_id: str, overlay_data) -> None:
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
# ------------------------------------------------------------------
# Route handlers
# ------------------------------------------------------------------
async def index(request: web.Request) -> web.Response:
# Serve Vite build output if available, else raw frontend
if (DIST_DIR / "index.html").exists():
return web.FileResponse(DIST_DIR / "index.html")
return web.FileResponse(FRONTEND_DIR / "index.html")
async def get_nodes(request: web.Request) -> web.Response:
info = get_all_node_info()
return web.Response(
text=_dumps(info),
content_type="application/json",
)
async def list_files(request: web.Request) -> web.Response:
"""List files in the input/ directory for the file picker widget."""
files = sorted(
f.name for f in INPUT_DIR.iterdir()
if f.is_file() and not f.name.startswith(".")
) if INPUT_DIR.exists() else []
return web.Response(text=_dumps(files), content_type="application/json")
async def browse_dir(request: web.Request) -> web.Response:
"""
Server-side directory browser for local file picking.
GET /browse?dir=/some/path → {parent, dirs[], files[]}
"""
dir_path = request.query.get("dir", str(Path.home()))
p = Path(dir_path).expanduser().resolve()
if not p.is_dir():
raise web.HTTPBadRequest(reason=f"Not a directory: {p}")
dirs = []
files = []
try:
for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()):
if entry.name.startswith("."):
continue
if entry.is_dir():
dirs.append(entry.name)
elif entry.is_file():
files.append(entry.name)
except PermissionError:
pass
return web.Response(
text=_dumps({
"path": str(p),
"parent": str(p.parent) if p.parent != p else None,
"dirs": dirs,
"files": files,
}),
content_type="application/json",
)
async def upload_file(request: web.Request) -> web.Response:
reader = await request.multipart()
field = await reader.next()
if field is None or field.name != "file":
raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body")
filename = Path(field.filename).name # strip any path traversal
dest = INPUT_DIR / filename
with open(dest, "wb") as f:
while True:
chunk = await field.read_chunk(65536)
if not chunk:
break
f.write(chunk)
return web.Response(text=_dumps({"filename": filename}), content_type="application/json")
async def submit_prompt(request: web.Request) -> web.Response:
body = await request.json()
prompt = body.get("prompt")
if not isinstance(prompt, dict) or not prompt:
raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict")
prompt_id = new_prompt_id()
# Run execution in a thread pool so scipy doesn't block the event loop
async def run():
broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}})
def on_start(node_id: str) -> None:
broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
try:
await loop.run_in_executor(
None,
lambda: engine.execute(
prompt,
on_node_start=on_start,
on_preview=on_preview,
on_table=on_table,
on_mesh=on_mesh,
on_overlay=on_overlay,
),
)
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}})
except Exception as exc:
log.exception("Execution error")
broadcast({
"type": "execution_error",
"data": {"node_id": "", "message": str(exc)},
})
asyncio.ensure_future(run())
return web.Response(
text=_dumps({"prompt_id": prompt_id}),
content_type="application/json",
)
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
websockets.add(ws)
log.info("WebSocket client connected (%d total)", len(websockets))
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
pass # clients don't need to send anything currently
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
break
finally:
websockets.discard(ws)
log.info("WebSocket client disconnected (%d total)", len(websockets))
return ws
# ------------------------------------------------------------------
# App assembly
# ------------------------------------------------------------------
app = web.Application()
app.router.add_get("/", index)
app.router.add_get("/nodes", get_nodes)
app.router.add_get("/files", list_files)
app.router.add_get("/browse", browse_dir)
app.router.add_post("/upload", upload_file)
app.router.add_post("/prompt", submit_prompt)
app.router.add_get("/ws", websocket_handler)
# Serve frontend static files (Vite build or raw)
if DIST_DIR.exists():
app.router.add_static("/assets", DIST_DIR / "assets")
app.router.add_static("/static", FRONTEND_DIR)
# CORS — allow any origin (local dev only)
async def _cors_middleware(app_, handler):
async def middleware(request):
if request.method == "OPTIONS":
return web.Response(headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
})
response = await handler(request)
response.headers["Access-Control-Allow-Origin"] = "*"
return response
return middleware
app.middlewares.append(_cors_middleware)
return app