initial commit
This commit is contained in:
0
backend/__init__.py
Normal file
0
backend/__init__.py
Normal file
134
backend/data_types.py
Normal file
134
backend/data_types.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Core data types for argonode.
|
||||
|
||||
DataField mirrors Gwyddion's GwyDataField structure:
|
||||
xres, yres – pixel dimensions
|
||||
xreal, yreal – physical dimensions in metres
|
||||
xoff, yoff – position offset in metres
|
||||
si_unit_xy – lateral unit string (e.g. "m", "nm")
|
||||
si_unit_z – height/value unit string (e.g. "m", "V", "A")
|
||||
domain – "spatial" or "frequency" (set by FFT nodes)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataField:
|
||||
data: np.ndarray # shape (yres, xres), dtype float64
|
||||
xres: int = 0
|
||||
yres: int = 0
|
||||
xreal: float = 1e-6 # physical width in metres
|
||||
yreal: float = 1e-6 # physical height in metres
|
||||
xoff: float = 0.0
|
||||
yoff: float = 0.0
|
||||
si_unit_xy: str = "m"
|
||||
si_unit_z: str = "m"
|
||||
domain: str = "spatial" # "spatial" or "frequency"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.data = np.asarray(self.data, dtype=np.float64)
|
||||
if self.data.ndim != 2:
|
||||
raise ValueError(f"DataField.data must be 2-D, got shape {self.data.shape}")
|
||||
self.yres, self.xres = self.data.shape
|
||||
|
||||
def copy(self) -> "DataField":
|
||||
"""Return a deep copy with independent data array."""
|
||||
return DataField(
|
||||
data=self.data.copy(),
|
||||
xres=self.xres,
|
||||
yres=self.yres,
|
||||
xreal=self.xreal,
|
||||
yreal=self.yreal,
|
||||
xoff=self.xoff,
|
||||
yoff=self.yoff,
|
||||
si_unit_xy=self.si_unit_xy,
|
||||
si_unit_z=self.si_unit_z,
|
||||
domain=self.domain,
|
||||
)
|
||||
|
||||
def replace(self, **kwargs) -> "DataField":
|
||||
"""Return a copy with selected fields replaced. data is deep-copied unless provided."""
|
||||
base = {
|
||||
"data": self.data.copy(),
|
||||
"xres": self.xres,
|
||||
"yres": self.yres,
|
||||
"xreal": self.xreal,
|
||||
"yreal": self.yreal,
|
||||
"xoff": self.xoff,
|
||||
"yoff": self.yoff,
|
||||
"si_unit_xy": self.si_unit_xy,
|
||||
"si_unit_z": self.si_unit_z,
|
||||
"domain": self.domain,
|
||||
}
|
||||
base.update(kwargs)
|
||||
return DataField(**base)
|
||||
|
||||
@property
|
||||
def dx(self) -> float:
|
||||
"""Physical pixel size in x (metres)."""
|
||||
return self.xreal / self.xres if self.xres else 1.0
|
||||
|
||||
@property
|
||||
def dy(self) -> float:
|
||||
"""Physical pixel size in y (metres)."""
|
||||
return self.yreal / self.yres if self.yres else 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility helpers shared across nodes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
|
||||
"""
|
||||
Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap.
|
||||
Returns shape (H, W, 3) uint8.
|
||||
"""
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.colors as mcolors
|
||||
|
||||
data = df.data
|
||||
dmin, dmax = data.min(), data.max()
|
||||
if dmax > dmin:
|
||||
normalized = (data - dmin) / (dmax - dmin)
|
||||
else:
|
||||
normalized = np.zeros_like(data)
|
||||
|
||||
cmap = cm.get_cmap(colormap)
|
||||
rgba = cmap(normalized) # (H, W, 4) float [0,1]
|
||||
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
return rgb
|
||||
|
||||
|
||||
def image_to_uint8(image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert an IMAGE (float or uint8, 2-D or 3-D) to uint8 (H,W,3) or (H,W) for PIL.
|
||||
"""
|
||||
if image.dtype == np.uint8:
|
||||
return image
|
||||
# float — normalize to [0, 255]
|
||||
imin, imax = image.min(), image.max()
|
||||
if imax > imin:
|
||||
out = (image - imin) / (imax - imin) * 255.0
|
||||
else:
|
||||
out = np.zeros_like(image)
|
||||
return out.astype(np.uint8)
|
||||
|
||||
|
||||
def encode_preview(arr: np.ndarray) -> str:
|
||||
"""
|
||||
Encode a uint8 numpy array as a base64 data URI (PNG).
|
||||
arr: (H, W) grayscale or (H, W, 3) RGB, uint8.
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
img = Image.fromarray(arr)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return f"data:image/png;base64,{b64}"
|
||||
294
backend/execution.py
Normal file
294
backend/execution.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
Graph execution engine for argonode.
|
||||
|
||||
Prompt format (same as ComfyUI):
|
||||
{
|
||||
"node_id": {
|
||||
"class_type": "GaussianFilter",
|
||||
"inputs": {
|
||||
"field": ["upstream_node_id", 0], # link: [src_id, output_slot]
|
||||
"sigma": 2.0 # constant widget value
|
||||
}
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
The engine:
|
||||
1. Topologically sorts nodes (Kahn's algorithm).
|
||||
2. Resolves input links to actual Python objects from earlier outputs.
|
||||
3. Calls each node's FUNCTION method.
|
||||
4. Emits progress callbacks after each node.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.node_registry import NODE_CLASS_MAPPINGS
|
||||
|
||||
|
||||
def _is_link(value: Any) -> bool:
|
||||
"""A value is a link if it's a [node_id_str, slot_int] pair."""
|
||||
return (
|
||||
isinstance(value, (list, tuple))
|
||||
and len(value) == 2
|
||||
and isinstance(value[0], str)
|
||||
and isinstance(value[1], int)
|
||||
)
|
||||
|
||||
|
||||
class ExecutionEngine:
|
||||
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
prompt: dict[str, dict],
|
||||
on_node_start: Callable[[str], None] | None = None,
|
||||
on_node_done: Callable[[str], None] | None = None,
|
||||
on_preview: Callable[[str, str], None] | None = None,
|
||||
on_table: Callable[[str, list], None] | None = None,
|
||||
on_mesh: Callable[[str, dict], None] | None = None,
|
||||
on_overlay: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, tuple]:
|
||||
"""
|
||||
Execute the workflow described by `prompt`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prompt : workflow dict (node_id → {class_type, inputs})
|
||||
on_node_start : called with node_id just before a node executes
|
||||
on_node_done : called with node_id just after a node executes
|
||||
on_preview : called with (node_id, data_uri) when a display node runs
|
||||
on_table : called with (node_id, table_list) when PrintTable runs
|
||||
on_overlay : called with (node_id, data_uri) for interactive overlays
|
||||
|
||||
Returns
|
||||
-------
|
||||
node_outputs : {node_id → tuple-of-outputs} for every executed node
|
||||
"""
|
||||
order = self._topological_sort(prompt)
|
||||
node_outputs: dict[str, tuple] = {}
|
||||
|
||||
# Inject display callbacks before execution
|
||||
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay)
|
||||
|
||||
for node_id in order:
|
||||
node_def = prompt[node_id]
|
||||
class_name = node_def["class_type"]
|
||||
|
||||
if class_name not in NODE_CLASS_MAPPINGS:
|
||||
raise ValueError(f"Unknown node type: '{class_name}'")
|
||||
|
||||
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||
raw_inputs = node_def.get("inputs", {})
|
||||
inputs = self._resolve_inputs(raw_inputs, node_outputs)
|
||||
|
||||
# Let display nodes know their node_id so they can tag WS messages
|
||||
self._set_node_id_on_display(cls, node_id)
|
||||
|
||||
if on_node_start:
|
||||
on_node_start(node_id)
|
||||
|
||||
instance = cls()
|
||||
func = getattr(instance, cls.FUNCTION)
|
||||
result = func(**inputs)
|
||||
|
||||
# Nodes must return a tuple; coerce single values just in case
|
||||
if not isinstance(result, tuple):
|
||||
result = (result,)
|
||||
|
||||
node_outputs[node_id] = result
|
||||
|
||||
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
|
||||
# IMAGE, or TABLE output so every node shows its result.
|
||||
if on_preview or on_table:
|
||||
self._auto_preview(cls, node_id, result, on_preview, on_table)
|
||||
|
||||
if on_node_done:
|
||||
on_node_done(node_id)
|
||||
|
||||
return node_outputs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _topological_sort(self, prompt: dict) -> list[str]:
|
||||
"""Kahn's algorithm — returns node IDs in dependency order."""
|
||||
in_degree: dict[str, int] = {nid: 0 for nid in prompt}
|
||||
dependents: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for node_id, node_def in prompt.items():
|
||||
for value in node_def.get("inputs", {}).values():
|
||||
if _is_link(value):
|
||||
src_id = value[0]
|
||||
if src_id in prompt:
|
||||
in_degree[node_id] += 1
|
||||
dependents[src_id].append(node_id)
|
||||
|
||||
queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0)
|
||||
order: list[str] = []
|
||||
|
||||
while queue:
|
||||
nid = queue.popleft()
|
||||
order.append(nid)
|
||||
for dep in dependents[nid]:
|
||||
in_degree[dep] -= 1
|
||||
if in_degree[dep] == 0:
|
||||
queue.append(dep)
|
||||
|
||||
if len(order) != len(prompt):
|
||||
raise ValueError("Cycle detected in workflow graph — cannot execute.")
|
||||
|
||||
return order
|
||||
|
||||
def _resolve_inputs(
|
||||
self,
|
||||
raw_inputs: dict[str, Any],
|
||||
node_outputs: dict[str, tuple],
|
||||
) -> dict[str, Any]:
|
||||
"""Replace [src_id, slot] links with actual output values."""
|
||||
resolved = {}
|
||||
for key, value in raw_inputs.items():
|
||||
if _is_link(value):
|
||||
src_id, slot = value[0], int(value[1])
|
||||
if src_id not in node_outputs:
|
||||
raise KeyError(
|
||||
f"Node '{src_id}' has no output yet — dependency ordering bug?"
|
||||
)
|
||||
outputs = node_outputs[src_id]
|
||||
if slot >= len(outputs):
|
||||
raise IndexError(
|
||||
f"Node '{src_id}' only has {len(outputs)} outputs, "
|
||||
f"but slot {slot} was requested."
|
||||
)
|
||||
resolved[key] = outputs[slot]
|
||||
else:
|
||||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
def _inject_display_callbacks(
|
||||
self,
|
||||
on_preview: Callable | None,
|
||||
on_table: Callable | None,
|
||||
on_mesh: Callable | None = None,
|
||||
on_overlay: Callable | None = None,
|
||||
) -> None:
|
||||
"""Wire up broadcast callbacks on display node classes."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection
|
||||
from backend.nodes.io import SaveImage
|
||||
|
||||
PreviewImage._broadcast_fn = on_preview
|
||||
View3D._broadcast_mesh_fn = on_mesh
|
||||
PrintTable._broadcast_table_fn = on_table
|
||||
CrossSection._broadcast_overlay_fn = on_overlay
|
||||
SaveImage._broadcast_preview = (
|
||||
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
|
||||
)
|
||||
|
||||
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||
"""Inform display nodes of their current node_id for WS tagging."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection
|
||||
if cls in (PreviewImage, PrintTable, View3D, CrossSection):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
def _auto_preview(
|
||||
self,
|
||||
cls: type,
|
||||
node_id: str,
|
||||
result: tuple,
|
||||
on_preview: Callable | None,
|
||||
on_table: Callable | None,
|
||||
) -> None:
|
||||
"""
|
||||
After every node executes, inspect its outputs and broadcast
|
||||
a preview for the first DATA_FIELD, IMAGE, or TABLE found.
|
||||
"""
|
||||
import numpy as np
|
||||
from backend.data_types import (
|
||||
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
|
||||
)
|
||||
|
||||
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||
|
||||
for slot, type_name in enumerate(return_types):
|
||||
if slot >= len(result):
|
||||
break
|
||||
value = result[slot]
|
||||
|
||||
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
|
||||
arr = datafield_to_uint8(value, "viridis")
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return # one preview per node is enough
|
||||
|
||||
if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview:
|
||||
arr = image_to_uint8(value)
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return
|
||||
|
||||
if type_name == "LINE" and isinstance(value, np.ndarray) and on_preview:
|
||||
preview = self._render_line_preview(cls, slot, result)
|
||||
if preview:
|
||||
on_preview(node_id, preview)
|
||||
return
|
||||
|
||||
if type_name == "TABLE" and isinstance(value, list) and on_table:
|
||||
on_table(node_id, value)
|
||||
return
|
||||
|
||||
|
||||
def _render_line_preview(
|
||||
self,
|
||||
cls: type,
|
||||
slot: int,
|
||||
result: tuple,
|
||||
) -> str | None:
|
||||
"""Render a LINE output as a small matplotlib plot, returned as a data URI."""
|
||||
import numpy as np
|
||||
import base64
|
||||
import io as _io
|
||||
|
||||
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||
|
||||
# Find the y-values (current slot) and try to find an x-axis
|
||||
y = result[slot]
|
||||
x = None
|
||||
# If the next output is also LINE, use it as x-axis
|
||||
if slot + 1 < len(return_types) and return_types[slot + 1] == "LINE":
|
||||
x = result[slot + 1]
|
||||
# Or if slot > 0 and previous is LINE, this slot is the x-axis — skip
|
||||
if slot > 0 and return_types[slot - 1] == "LINE":
|
||||
return None # the first LINE already plotted both
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
|
||||
fig.patch.set_facecolor("#1e293b")
|
||||
ax.set_facecolor("#0f172a")
|
||||
if x is not None:
|
||||
ax.plot(x, y, color="#ff9800", linewidth=1.2)
|
||||
else:
|
||||
ax.plot(y, color="#ff9800", linewidth=1.2)
|
||||
ax.tick_params(colors="#94a3b8", labelsize=7)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#334155")
|
||||
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
|
||||
fig.tight_layout(pad=0.4)
|
||||
|
||||
buf = _io.BytesIO()
|
||||
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
|
||||
plt.close(fig)
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return f"data:image/png;base64,{b64}"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def new_prompt_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
47
backend/main.py
Normal file
47
backend/main.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Entry point for argonode.
|
||||
|
||||
Run with:
|
||||
python -m backend.main
|
||||
or simply:
|
||||
python backend/main.py
|
||||
from the argonode/ directory.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Allow running as `python backend/main.py` from the project root
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from aiohttp import web
|
||||
from backend.server import create_app
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s — %(message)s",
|
||||
)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
HOST = "127.0.0.1"
|
||||
PORT = 8188
|
||||
|
||||
|
||||
def main() -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
app = create_app(loop)
|
||||
|
||||
log.info("=" * 60)
|
||||
log.info(" Argonode — Node-based image analysis")
|
||||
log.info(" Open your browser at http://%s:%d", HOST, PORT)
|
||||
log.info("=" * 60)
|
||||
|
||||
web.run_app(app, host=HOST, port=PORT, loop=loop, access_log=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
56
backend/node_registry.py
Normal file
56
backend/node_registry.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Node registry for argonode.
|
||||
|
||||
Nodes are plain Python classes decorated with @register_node.
|
||||
NODE_CLASS_MAPPINGS is the single source of truth consumed by
|
||||
the execution engine and the /nodes REST endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
NODE_CLASS_MAPPINGS: dict[str, type] = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {}
|
||||
|
||||
|
||||
def register_node(display_name: str | None = None):
|
||||
"""
|
||||
Class decorator that registers a node class into NODE_CLASS_MAPPINGS.
|
||||
|
||||
Usage:
|
||||
@register_node(display_name="Gaussian Filter")
|
||||
class GaussianFilter:
|
||||
...
|
||||
"""
|
||||
def decorator(cls: type) -> type:
|
||||
name = cls.__name__
|
||||
NODE_CLASS_MAPPINGS[name] = cls
|
||||
NODE_DISPLAY_NAME_MAPPINGS[name] = display_name or name
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
|
||||
def get_node_info(class_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
Return a JSON-serialisable dict describing a node — consumed by GET /nodes.
|
||||
Shape is compatible with what LiteGraph.js expects from the frontend.
|
||||
"""
|
||||
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||
input_types: dict = cls.INPUT_TYPES()
|
||||
|
||||
return {
|
||||
"name": class_name,
|
||||
"display_name": NODE_DISPLAY_NAME_MAPPINGS.get(class_name, class_name),
|
||||
"category": getattr(cls, "CATEGORY", "uncategorized"),
|
||||
"input": input_types,
|
||||
"input_order": {k: list(v.keys()) for k, v in input_types.items()},
|
||||
"output": list(cls.RETURN_TYPES),
|
||||
"output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)),
|
||||
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
|
||||
"description": getattr(cls, "DESCRIPTION", ""),
|
||||
}
|
||||
|
||||
|
||||
def get_all_node_info() -> dict[str, dict[str, Any]]:
|
||||
"""Return info dicts for every registered node."""
|
||||
return {name: get_node_info(name) for name in NODE_CLASS_MAPPINGS}
|
||||
2
backend/nodes/__init__.py
Normal file
2
backend/nodes/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Import all node modules to trigger @register_node decorators.
|
||||
from . import io, filters, level, analysis, grains, display
|
||||
471
backend/nodes/analysis.py
Normal file
471
backend/nodes/analysis.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Analysis nodes — statistics, histograms, FFT, cross sections.
|
||||
|
||||
Gwyddion equivalents:
|
||||
StatisticsNode → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h)
|
||||
HeightHistogram → DH (height distribution), gwy_data_field_dh
|
||||
FFT2D → gwy_data_field_2dfft + gwy_data_field_2dpsdf
|
||||
CrossSection → gwy_data_field_get_profile (libprocess/datafield.c)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StatisticsNode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Statistics")
|
||||
class StatisticsNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TABLE",)
|
||||
RETURN_NAMES = ("stats",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute basic surface statistics: min, max, mean, RMS roughness, median, "
|
||||
"and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms."
|
||||
)
|
||||
|
||||
def process(self, field: DataField) -> tuple:
|
||||
d = field.data
|
||||
mean = float(d.mean())
|
||||
rms = float(np.sqrt(np.mean((d - mean) ** 2)))
|
||||
skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0
|
||||
kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0
|
||||
|
||||
table = [
|
||||
{"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z},
|
||||
{"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z},
|
||||
{"quantity": "mean", "value": mean, "unit": field.si_unit_z},
|
||||
{"quantity": "RMS", "value": rms, "unit": field.si_unit_z},
|
||||
{"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z},
|
||||
{"quantity": "skewness", "value": skewness, "unit": ""},
|
||||
{"quantity": "kurtosis", "value": kurtosis, "unit": ""},
|
||||
{"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z},
|
||||
]
|
||||
return (table,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HeightHistogram
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Height Histogram")
|
||||
class HeightHistogram:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LINE", "LINE")
|
||||
RETURN_NAMES = ("counts", "bin_centers")
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute the height distribution histogram (DH). "
|
||||
"Equivalent to gwy_data_field_dh."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, n_bins: int) -> tuple:
|
||||
counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
|
||||
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
|
||||
return (counts.astype(np.float64), bin_centers)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FFT2D
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="2D FFT")
|
||||
class FFT2D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"windowing": (["hann", "hamming", "blackman", "none"],),
|
||||
"level": (["mean", "plane", "none"],),
|
||||
"output": (["log_magnitude", "magnitude", "phase", "psdf"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("spectrum",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
|
||||
"Output can be log magnitude, magnitude, phase, or PSDF. "
|
||||
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, windowing: str, level: str, output: str) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
# Level subtraction (Gwyddion-style, before windowing)
|
||||
if level == "mean":
|
||||
data -= data.mean()
|
||||
elif level == "plane":
|
||||
# Fit and subtract a plane: z = a + b*x + c*y
|
||||
yy, xx = np.mgrid[0:yres, 0:xres]
|
||||
xx_f = xx.ravel().astype(np.float64)
|
||||
yy_f = yy.ravel().astype(np.float64)
|
||||
zz_f = data.ravel()
|
||||
A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f])
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None)
|
||||
plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy)
|
||||
data -= plane
|
||||
|
||||
# Windowing (Gwyddion uses (i+0.5)/n centred formulation)
|
||||
if windowing != "none":
|
||||
t_y = (np.arange(yres) + 0.5) / yres
|
||||
t_x = (np.arange(xres) + 0.5) / xres
|
||||
if windowing == "hann":
|
||||
wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y)
|
||||
wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x)
|
||||
elif windowing == "hamming":
|
||||
wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y)
|
||||
wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x)
|
||||
elif windowing == "blackman":
|
||||
wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y)
|
||||
wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x)
|
||||
else:
|
||||
wy = np.ones(yres)
|
||||
wx = np.ones(xres)
|
||||
data *= np.outer(wy, wx)
|
||||
|
||||
# 2D FFT, shifted so DC is at centre
|
||||
F = np.fft.fftshift(np.fft.fft2(data))
|
||||
n = xres * yres
|
||||
|
||||
if output == "log_magnitude":
|
||||
mag = np.abs(F)
|
||||
# Log scale with floor to avoid log(0)
|
||||
result = np.log1p(mag)
|
||||
elif output == "magnitude":
|
||||
result = np.abs(F)
|
||||
elif output == "phase":
|
||||
result = np.angle(F)
|
||||
elif output == "psdf":
|
||||
# Gwyddion-equivalent PSDF: |F|^2 * dx * dy / (n * 4π²)
|
||||
dx = field.xreal / xres
|
||||
dy = field.yreal / yres
|
||||
result = (np.abs(F) ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
|
||||
else:
|
||||
result = np.abs(F)
|
||||
|
||||
# Calibrate the output field in spatial-frequency units
|
||||
if output == "psdf":
|
||||
# Gwyddion uses angular frequency: 2π/dx, 2π/dy
|
||||
freq_xreal = 2.0 * np.pi * xres / field.xreal
|
||||
freq_yreal = 2.0 * np.pi * yres / field.yreal
|
||||
z_unit = f"({field.si_unit_z})^2 m^2"
|
||||
else:
|
||||
freq_xreal = xres / field.xreal
|
||||
freq_yreal = yres / field.yreal
|
||||
z_unit = field.si_unit_z
|
||||
|
||||
out_field = DataField(
|
||||
data=result,
|
||||
xreal=freq_xreal,
|
||||
yreal=freq_yreal,
|
||||
si_unit_xy="1/m",
|
||||
si_unit_z=z_unit,
|
||||
domain="frequency",
|
||||
)
|
||||
return (out_field,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CrossSection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extend_to_edges(x1, y1, x2, y2):
|
||||
"""
|
||||
Extend the line through (x1,y1)-(x2,y2) to the boundaries of [0,1]x[0,1].
|
||||
Returns the two intersection points (clipped to the unit square).
|
||||
"""
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Collect parametric t values where line hits each boundary
|
||||
t_candidates = []
|
||||
if abs(dx) > 1e-12:
|
||||
for bx in (0.0, 1.0):
|
||||
t = (bx - x1) / dx
|
||||
y_at_t = y1 + t * dy
|
||||
if -1e-9 <= y_at_t <= 1.0 + 1e-9:
|
||||
t_candidates.append(t)
|
||||
if abs(dy) > 1e-12:
|
||||
for by in (0.0, 1.0):
|
||||
t = (by - y1) / dy
|
||||
x_at_t = x1 + t * dx
|
||||
if -1e-9 <= x_at_t <= 1.0 + 1e-9:
|
||||
t_candidates.append(t)
|
||||
|
||||
if len(t_candidates) < 2:
|
||||
return x1, y1, x2, y2
|
||||
|
||||
t_min = min(t_candidates)
|
||||
t_max = max(t_candidates)
|
||||
|
||||
return (
|
||||
np.clip(x1 + t_min * dx, 0, 1),
|
||||
np.clip(y1 + t_min * dy, 0, 1),
|
||||
np.clip(x1 + t_max * dx, 0, 1),
|
||||
np.clip(y1 + t_max * dy, 0, 1),
|
||||
)
|
||||
|
||||
|
||||
@register_node(display_name="Cross Section")
|
||||
class CrossSection:
|
||||
"""Extract a 1-D height profile along an arbitrary line across the image."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||
"extend": (["none", "to_edges"],),
|
||||
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"point_a": ("COORD",),
|
||||
"point_b": ("COORD",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LINE",)
|
||||
RETURN_NAMES = ("profile",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Extract a cross-section profile along a line between two points. "
|
||||
"Drag the markers on the image to set the line endpoints. "
|
||||
"Equivalent to gwy_data_field_get_profile."
|
||||
)
|
||||
|
||||
_broadcast_overlay_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(
|
||||
self, field: DataField,
|
||||
x1: float, y1: float, x2: float, y2: float,
|
||||
extend: str, n_samples: int,
|
||||
point_a=None, point_b=None,
|
||||
) -> tuple:
|
||||
from scipy.ndimage import map_coordinates
|
||||
import io, base64
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
# COORD inputs override widget values
|
||||
if point_a is not None:
|
||||
x1, y1 = float(point_a[0]), float(point_a[1])
|
||||
if point_b is not None:
|
||||
x2, y2 = float(point_b[0]), float(point_b[1])
|
||||
|
||||
# Remember marker positions (before extend)
|
||||
marker_x1, marker_y1 = float(x1), float(y1)
|
||||
marker_x2, marker_y2 = float(x2), float(y2)
|
||||
|
||||
xres, yres = field.xres, field.yres
|
||||
|
||||
if extend == "to_edges":
|
||||
x1, y1, x2, y2 = _extend_to_edges(
|
||||
float(x1), float(y1), float(x2), float(y2),
|
||||
)
|
||||
|
||||
# Convert fractional [0,1] to pixel indices [0, res-1]
|
||||
px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1)
|
||||
px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1)
|
||||
|
||||
# Number of sample points
|
||||
line_len_px = np.hypot(px2 - px1, py2 - py1)
|
||||
if n_samples <= 0:
|
||||
n_samples = max(2, int(np.ceil(line_len_px)))
|
||||
|
||||
# Sample coordinates along the line
|
||||
t = np.linspace(0, 1, n_samples)
|
||||
coords_y = py1 + t * (py2 - py1)
|
||||
coords_x = px1 + t * (px2 - px1)
|
||||
|
||||
# Interpolate values along the line (cubic spline)
|
||||
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
|
||||
|
||||
# Broadcast overlay image with marker positions
|
||||
if CrossSection._broadcast_overlay_fn is not None:
|
||||
fig = Figure(figsize=(3, 3), dpi=100)
|
||||
ax = fig.add_axes([0, 0, 1, 1])
|
||||
ax.imshow(field.data, cmap="viridis", aspect="auto")
|
||||
ax.axis("off")
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
||||
buf.seek(0)
|
||||
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
|
||||
|
||||
CrossSection._broadcast_overlay_fn(
|
||||
CrossSection._current_node_id,
|
||||
{
|
||||
"image": image_uri,
|
||||
"x1": marker_x1, "y1": marker_y1,
|
||||
"x2": marker_x2, "y2": marker_y2,
|
||||
"a_locked": point_a is not None,
|
||||
"b_locked": point_b is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return (profile.astype(np.float64),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LineMath — single scalar measurement from a LINE profile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_rq(d):
|
||||
"""RMS of deviations from mean."""
|
||||
return float(np.sqrt(np.mean(d * d)))
|
||||
|
||||
# Registry: name → (function(z) → float, unit_label)
|
||||
# All functions receive the raw 1-D profile as float64.
|
||||
LINE_OPS: dict[str, tuple] = {}
|
||||
|
||||
|
||||
def _line_op(name, unit=""):
|
||||
"""Decorator to register a LINE operation."""
|
||||
def decorator(fn):
|
||||
LINE_OPS[name] = (fn, unit)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
|
||||
# ── Basic statistics ──────────────────────────────────────────────────────
|
||||
|
||||
@_line_op("min")
|
||||
def _op_min(z):
|
||||
return float(z.min())
|
||||
|
||||
@_line_op("max")
|
||||
def _op_max(z):
|
||||
return float(z.max())
|
||||
|
||||
@_line_op("mean")
|
||||
def _op_mean(z):
|
||||
return float(z.mean())
|
||||
|
||||
@_line_op("median")
|
||||
def _op_median(z):
|
||||
return float(np.median(z))
|
||||
|
||||
@_line_op("sum")
|
||||
def _op_sum(z):
|
||||
return float(z.sum())
|
||||
|
||||
@_line_op("range")
|
||||
def _op_range(z):
|
||||
return float(z.max() - z.min())
|
||||
|
||||
@_line_op("length", unit="pts")
|
||||
def _op_length(z):
|
||||
return float(len(z))
|
||||
|
||||
@_line_op("rms")
|
||||
def _op_rms(z):
|
||||
return float(np.sqrt(np.mean(z * z)))
|
||||
|
||||
|
||||
# ── Roughness parameters ──────────────────────────
|
||||
|
||||
@_line_op("Ra")
|
||||
def _op_ra(z):
|
||||
return float(np.mean(np.abs(z - z.mean())))
|
||||
|
||||
@_line_op("Rq")
|
||||
def _op_rq(z):
|
||||
d = z - z.mean()
|
||||
return _safe_rq(d)
|
||||
|
||||
@_line_op("Rsk")
|
||||
def _op_rsk(z):
|
||||
d = z - z.mean()
|
||||
rq = _safe_rq(d)
|
||||
return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0
|
||||
|
||||
@_line_op("Rku")
|
||||
def _op_rku(z):
|
||||
d = z - z.mean()
|
||||
rq = _safe_rq(d)
|
||||
return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0
|
||||
|
||||
@_line_op("Rp")
|
||||
def _op_rp(z):
|
||||
return float((z - z.mean()).max())
|
||||
|
||||
@_line_op("Rv")
|
||||
def _op_rv(z):
|
||||
return float(-(z - z.mean()).min())
|
||||
|
||||
@_line_op("Rt")
|
||||
def _op_rt(z):
|
||||
d = z - z.mean()
|
||||
return float(d.max() - d.min())
|
||||
|
||||
@_line_op("Dq")
|
||||
def _op_dq(z):
|
||||
"""RMS slope (first derivative RMS)."""
|
||||
dz = np.diff(z)
|
||||
return float(np.sqrt(np.mean(dz * dz)))
|
||||
|
||||
@_line_op("Da")
|
||||
def _op_da(z):
|
||||
"""Mean absolute slope."""
|
||||
return float(np.mean(np.abs(np.diff(z))))
|
||||
|
||||
|
||||
@register_node(display_name="Line Math")
|
||||
class LineMath:
|
||||
"""Compute a single scalar value from a LINE profile."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"line": ("LINE",),
|
||||
"operation": (list(LINE_OPS.keys()),),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TABLE",)
|
||||
RETURN_NAMES = ("result",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute a single scalar measurement from a LINE profile. "
|
||||
"Includes basic stats and Gwyddion-convention roughness parameters."
|
||||
)
|
||||
|
||||
def process(self, line, operation: str) -> tuple:
|
||||
z = np.asarray(line, dtype=np.float64).ravel()
|
||||
fn, unit = LINE_OPS[operation]
|
||||
value = fn(z)
|
||||
table = [{"quantity": operation, "value": value, "unit": unit}]
|
||||
return (table,)
|
||||
165
backend/nodes/display.py
Normal file
165
backend/nodes/display.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Display / output nodes.
|
||||
|
||||
Preview accepts both DATA_FIELD and IMAGE via optional inputs —
|
||||
connect whichever type you have. The server injects _broadcast_fn
|
||||
before execution begins.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview
|
||||
|
||||
|
||||
@register_node(display_name="Preview")
|
||||
class PreviewImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],),
|
||||
},
|
||||
"optional": {
|
||||
"image": ("IMAGE",),
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "preview"
|
||||
CATEGORY = "display"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
|
||||
|
||||
_broadcast_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
|
||||
# Prefer field if both are connected; accept whichever is provided
|
||||
if field is not None:
|
||||
arr_u8 = datafield_to_uint8(field, colormap)
|
||||
elif image is not None:
|
||||
if image.dtype != np.uint8:
|
||||
imin, imax = image.min(), image.max()
|
||||
if imax > imin:
|
||||
norm = (image - imin) / (imax - imin)
|
||||
else:
|
||||
norm = np.zeros_like(image)
|
||||
arr_u8 = (norm * 255).astype(np.uint8)
|
||||
else:
|
||||
arr_u8 = image
|
||||
|
||||
if arr_u8.ndim == 2 and colormap != "gray":
|
||||
import matplotlib.cm as cm
|
||||
cmap = cm.get_cmap(colormap)
|
||||
rgba = cmap(arr_u8.astype(np.float32) / 255.0)
|
||||
arr_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
else:
|
||||
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
|
||||
|
||||
data_uri = encode_preview(arr_u8)
|
||||
|
||||
if PreviewImage._broadcast_fn is not None:
|
||||
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
|
||||
|
||||
return ()
|
||||
|
||||
|
||||
@register_node(display_name="3D View")
|
||||
class View3D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],),
|
||||
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
|
||||
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "render"
|
||||
CATEGORY = "display"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = (
|
||||
"Interactive 3D surface view of a DATA_FIELD. "
|
||||
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
|
||||
)
|
||||
|
||||
_broadcast_mesh_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def render(
|
||||
self, field: DataField,
|
||||
colormap: str, z_scale: float, resolution: int,
|
||||
) -> tuple:
|
||||
import matplotlib.cm as cm
|
||||
import base64
|
||||
|
||||
data = field.data
|
||||
yres, xres = data.shape
|
||||
|
||||
# Downsample if larger than resolution
|
||||
step_y = max(1, yres // resolution)
|
||||
step_x = max(1, xres // resolution)
|
||||
z = data[::step_y, ::step_x].astype(np.float32)
|
||||
ny, nx = z.shape
|
||||
|
||||
# Normalize for colormap
|
||||
zmin, zmax = float(z.min()), float(z.max())
|
||||
if zmax > zmin:
|
||||
z_norm = (z - zmin) / (zmax - zmin)
|
||||
else:
|
||||
z_norm = np.zeros_like(z)
|
||||
|
||||
cmap = cm.get_cmap(colormap)
|
||||
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
|
||||
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
|
||||
# Base64-encode arrays for efficient WS transport
|
||||
z_b64 = base64.b64encode(z.tobytes()).decode()
|
||||
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
|
||||
|
||||
mesh_data = {
|
||||
"width": nx,
|
||||
"height": ny,
|
||||
"z_data": z_b64,
|
||||
"colors": colors_b64,
|
||||
"z_min": zmin,
|
||||
"z_max": zmax,
|
||||
"z_scale": float(z_scale),
|
||||
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
|
||||
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
|
||||
}
|
||||
|
||||
if View3D._broadcast_mesh_fn is not None:
|
||||
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
|
||||
|
||||
return ()
|
||||
|
||||
|
||||
@register_node(display_name="Print Table")
|
||||
class PrintTable:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"table": ("TABLE",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "print_table"
|
||||
CATEGORY = "display"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Send a TABLE to the browser as a WebSocket message for display."
|
||||
|
||||
_broadcast_table_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def print_table(self, table: list) -> tuple:
|
||||
if PrintTable._broadcast_table_fn is not None:
|
||||
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
|
||||
return ()
|
||||
115
backend/nodes/filters.py
Normal file
115
backend/nodes/filters.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Filter nodes — Gwyddion-equivalent image filters.
|
||||
|
||||
Gwyddion equivalents:
|
||||
GaussianFilter → gwy_data_field_filter_gaussian
|
||||
MedianFilter → gwy_data_field_filter_median
|
||||
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GaussianFilter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Gaussian Filter")
|
||||
class GaussianFilter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("filtered",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "filters"
|
||||
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
|
||||
|
||||
def process(self, field: DataField, sigma: float) -> tuple:
|
||||
from scipy.ndimage import gaussian_filter
|
||||
data = gaussian_filter(field.data.copy(), sigma=float(sigma))
|
||||
return (field.replace(data=data),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MedianFilter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Median Filter")
|
||||
class MedianFilter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("filtered",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "filters"
|
||||
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
|
||||
|
||||
def process(self, field: DataField, size: int) -> tuple:
|
||||
from scipy.ndimage import median_filter
|
||||
size = max(1, int(size))
|
||||
data = median_filter(field.data.copy(), size=size)
|
||||
return (field.replace(data=data),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EdgeDetect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Edge Detect")
|
||||
class EdgeDetect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["sobel", "prewitt", "laplacian", "log"],),
|
||||
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("edges",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "filters"
|
||||
DESCRIPTION = (
|
||||
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
|
||||
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, method: str, sigma: float) -> tuple:
|
||||
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
|
||||
data = field.data.copy()
|
||||
|
||||
if method == "sobel":
|
||||
sx = sobel(data, axis=1)
|
||||
sy = sobel(data, axis=0)
|
||||
result = np.hypot(sx, sy)
|
||||
elif method == "prewitt":
|
||||
px = prewitt(data, axis=1)
|
||||
py = prewitt(data, axis=0)
|
||||
result = np.hypot(px, py)
|
||||
elif method == "laplacian":
|
||||
result = laplace(data)
|
||||
elif method == "log":
|
||||
result = gaussian_laplace(data, sigma=float(sigma))
|
||||
else:
|
||||
raise ValueError(f"Unknown edge detection method: {method}")
|
||||
|
||||
return (field.replace(data=result),)
|
||||
127
backend/nodes/grains.py
Normal file
127
backend/nodes/grains.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Grain/feature detection nodes.
|
||||
|
||||
Gwyddion equivalents:
|
||||
ThresholdMask → threshold.c / otsu_threshold.c
|
||||
GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ThresholdMask
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Threshold Mask")
|
||||
class ThresholdMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["otsu", "absolute", "relative"],),
|
||||
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
|
||||
"direction": (["above", "below"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "grains"
|
||||
DESCRIPTION = (
|
||||
"Create a binary mask by thresholding data. "
|
||||
"Otsu automatically finds the optimal threshold. "
|
||||
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
|
||||
data = field.data
|
||||
|
||||
if method == "otsu":
|
||||
from skimage.filters import threshold_otsu
|
||||
t = threshold_otsu(data)
|
||||
elif method == "absolute":
|
||||
t = float(threshold)
|
||||
elif method == "relative":
|
||||
# threshold is a fraction [0, 1] of the data range
|
||||
dmin, dmax = data.min(), data.max()
|
||||
t = dmin + float(threshold) * (dmax - dmin)
|
||||
else:
|
||||
raise ValueError(f"Unknown threshold method: {method}")
|
||||
|
||||
if direction == "above":
|
||||
mask = (data >= t).astype(np.uint8) * 255
|
||||
else:
|
||||
mask = (data < t).astype(np.uint8) * 255
|
||||
|
||||
return (mask,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GrainAnalysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Grain Analysis")
|
||||
class GrainAnalysis:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"mask": ("IMAGE",),
|
||||
"min_size": ("INT", {"default": 10, "min": 1, "max": 100000, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TABLE",)
|
||||
RETURN_NAMES = ("grain_stats",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "grains"
|
||||
DESCRIPTION = (
|
||||
"Label connected grain regions in a binary mask and compute per-grain statistics: "
|
||||
"area, equivalent diameter, mean/max height, bounding box. "
|
||||
"Equivalent to gwy_data_field_grains_get_values."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:
|
||||
from scipy.ndimage import label, find_objects
|
||||
|
||||
binary = (mask > 127).astype(np.int32)
|
||||
labeled, n_grains = label(binary)
|
||||
|
||||
pixel_area = field.dx * field.dy # m^2 per pixel
|
||||
|
||||
rows = []
|
||||
for grain_id in range(1, n_grains + 1):
|
||||
grain_pixels = labeled == grain_id
|
||||
area_px = int(grain_pixels.sum())
|
||||
if area_px < min_size:
|
||||
continue
|
||||
|
||||
area_m2 = area_px * pixel_area
|
||||
equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi))
|
||||
|
||||
heights = field.data[grain_pixels]
|
||||
mean_h = float(heights.mean())
|
||||
max_h = float(heights.max())
|
||||
|
||||
# Bounding box
|
||||
ys, xs = np.where(grain_pixels)
|
||||
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"
|
||||
|
||||
rows.append({
|
||||
"grain_id": grain_id,
|
||||
"area_px": area_px,
|
||||
"area_m2": area_m2,
|
||||
"equiv_diam_m": equiv_diam,
|
||||
"mean_height": mean_h,
|
||||
"max_height": max_h,
|
||||
"bbox": bbox,
|
||||
})
|
||||
|
||||
return (rows,)
|
||||
277
backend/nodes/io.py
Normal file
277
backend/nodes/io.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
I/O nodes: load and save images and SPM data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, encode_preview, image_to_uint8
|
||||
|
||||
# Resolved at server startup so nodes know where to look
|
||||
INPUT_DIR = Path(__file__).parent.parent.parent / "input"
|
||||
OUTPUT_DIR = Path(__file__).parent.parent.parent / "output"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadImage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Load Image")
|
||||
class LoadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("FILE_PICKER", {"default": ""}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
|
||||
RETURN_NAMES = ("image", "field")
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD."
|
||||
|
||||
def load(self, filename: str):
|
||||
# Accept absolute paths or filenames relative to input/
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = INPUT_DIR / filename
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
if ext in (".npy",):
|
||||
arr = np.load(str(path)).astype(np.float64)
|
||||
elif ext in (".npz",):
|
||||
npz = np.load(str(path))
|
||||
key = list(npz.files)[0]
|
||||
arr = npz[key].astype(np.float64)
|
||||
else:
|
||||
from PIL import Image
|
||||
img = Image.open(str(path))
|
||||
arr = np.array(img)
|
||||
if arr.dtype != np.uint8:
|
||||
arr = arr.astype(np.float64)
|
||||
|
||||
# Convert to float64 grayscale for the DATA_FIELD output
|
||||
if arr.ndim == 3:
|
||||
gray = np.mean(arr.astype(np.float64), axis=2)
|
||||
else:
|
||||
gray = arr.astype(np.float64)
|
||||
|
||||
field = DataField(data=gray)
|
||||
return (arr, field)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadSPM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Load SPM File")
|
||||
class LoadSPM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("FILE_PICKER", {"default": ""}),
|
||||
"channel": ("STRING", {"default": "Z"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("field",)
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField."
|
||||
|
||||
def load(self, filename: str, channel: str = "Z"):
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = INPUT_DIR / filename
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
|
||||
if ext == ".gwy":
|
||||
return (self._load_gwy(path, channel),)
|
||||
elif ext == ".sxm":
|
||||
return (self._load_sxm(path, channel),)
|
||||
elif ext in (".ibw",):
|
||||
return (self._load_ibw(path),)
|
||||
elif ext in (".npy",):
|
||||
data = np.load(str(path)).astype(np.float64)
|
||||
return (DataField(data=data),)
|
||||
elif ext in (".npz",):
|
||||
npz = np.load(str(path))
|
||||
key = list(npz.files)[0]
|
||||
return (DataField(data=npz[key].astype(np.float64)),)
|
||||
else:
|
||||
raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz")
|
||||
|
||||
def _load_gwy(self, path: Path, channel: str) -> DataField:
|
||||
try:
|
||||
import gwyfile
|
||||
except ImportError:
|
||||
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
|
||||
|
||||
obj = gwyfile.load(str(path))
|
||||
channels = gwyfile.util.get_datafields(obj)
|
||||
if not channels:
|
||||
raise ValueError(f"No data channels found in {path.name}")
|
||||
|
||||
# Try requested channel name, fall back to first available
|
||||
ch = None
|
||||
for key, df in channels.items():
|
||||
if channel.lower() in key.lower():
|
||||
ch = df
|
||||
break
|
||||
if ch is None:
|
||||
ch = next(iter(channels.values()))
|
||||
|
||||
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
|
||||
return DataField(
|
||||
data=data,
|
||||
xreal=float(ch.xreal),
|
||||
yreal=float(ch.yreal),
|
||||
xoff=float(getattr(ch, "xoff", 0.0)),
|
||||
yoff=float(getattr(ch, "yoff", 0.0)),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
)
|
||||
|
||||
def _load_sxm(self, path: Path, channel: str) -> DataField:
|
||||
try:
|
||||
import nanonispy as nap
|
||||
except ImportError:
|
||||
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
|
||||
|
||||
sxm = nap.read.Scan(str(path))
|
||||
signals = sxm.signals
|
||||
|
||||
# Pick channel
|
||||
ch_key = None
|
||||
for key in signals:
|
||||
if channel.upper() in key.upper():
|
||||
ch_key = key
|
||||
break
|
||||
if ch_key is None:
|
||||
ch_key = next(iter(signals))
|
||||
|
||||
data = signals[ch_key].get("forward", list(signals[ch_key].values())[0])
|
||||
data = np.asarray(data, dtype=np.float64)
|
||||
if data.ndim != 2:
|
||||
data = data.reshape(data.shape[-2], data.shape[-1])
|
||||
|
||||
header = sxm.header
|
||||
scan_range = header.get("scan_range", [1e-6, 1e-6])
|
||||
return DataField(
|
||||
data=data,
|
||||
xreal=float(scan_range[0]),
|
||||
yreal=float(scan_range[1]),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
)
|
||||
|
||||
def _load_ibw(self, path: Path) -> DataField:
|
||||
try:
|
||||
import igor.igorpy as igorpy
|
||||
wave = igorpy.load(str(path))
|
||||
data = wave.wave["wData"].squeeze().astype(np.float64)
|
||||
except ImportError:
|
||||
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
|
||||
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
elif data.ndim != 2:
|
||||
data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1)
|
||||
|
||||
return DataField(data=data, si_unit_xy="m", si_unit_z="m")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coordinate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Coordinate")
|
||||
class Coordinate:
|
||||
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("COORD",)
|
||||
RETURN_NAMES = ("point",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
|
||||
|
||||
def process(self, x: float, y: float) -> tuple:
|
||||
return ((float(x), float(y)),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SaveImage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Save Image")
|
||||
class SaveImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"filename_prefix": ("STRING", {"default": "output"}),
|
||||
"format": (["PNG", "TIFF", "NPY"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
CATEGORY = "io"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Save an image or array to the output folder."
|
||||
|
||||
# Injected by server.py before execution begins
|
||||
_broadcast_preview = None
|
||||
|
||||
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"):
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Find next available filename
|
||||
idx = 1
|
||||
while True:
|
||||
name = f"{filename_prefix}_{idx:04d}"
|
||||
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
|
||||
if not candidate.exists():
|
||||
break
|
||||
idx += 1
|
||||
|
||||
if format == "NPY":
|
||||
np.save(str(OUTPUT_DIR / f"{name}.npy"), image)
|
||||
else:
|
||||
from PIL import Image
|
||||
arr = image_to_uint8(image)
|
||||
if arr.ndim == 2:
|
||||
pil_img = Image.fromarray(arr, mode="L")
|
||||
else:
|
||||
pil_img = Image.fromarray(arr, mode="RGB")
|
||||
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
|
||||
|
||||
# Emit preview over WebSocket if callback is set
|
||||
if SaveImage._broadcast_preview is not None:
|
||||
arr_u8 = image_to_uint8(image)
|
||||
data_uri = encode_preview(arr_u8)
|
||||
SaveImage._broadcast_preview(data_uri)
|
||||
|
||||
return ()
|
||||
150
backend/nodes/level.py
Normal file
150
backend/nodes/level.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Leveling nodes — background removal and zero correction.
|
||||
|
||||
Gwyddion equivalents:
|
||||
PlaneLevelField → gwy_data_field_fit_plane + gwy_data_field_plane_level
|
||||
PolyLevelField → gwy_data_field_fit_polynom (via level.c polylevel module)
|
||||
FixZero → fix_zero in level.c
|
||||
|
||||
Plane-fit algorithm follows Gwyddion's level.h definition:
|
||||
z_fit = pa + pbx * x + pby * y (least-squares over all pixels)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PlaneLevelField
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Plane Level")
|
||||
class PlaneLevelField:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("leveled",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "level"
|
||||
DESCRIPTION = (
|
||||
"Fit and subtract a least-squares plane from the data. "
|
||||
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
|
||||
)
|
||||
|
||||
def process(self, field: DataField) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
# Normalised coordinate grids in [0, 1]
|
||||
x = np.linspace(0.0, 1.0, xres)
|
||||
y = np.linspace(0.0, 1.0, yres)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
# Design matrix: [1, x, y] shape (N, 3)
|
||||
A = np.column_stack([
|
||||
np.ones(xres * yres),
|
||||
xx.ravel(),
|
||||
yy.ravel(),
|
||||
])
|
||||
z = data.ravel()
|
||||
|
||||
# Least-squares: solve A @ [pa, pbx, pby] = z
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
|
||||
pa, pbx, pby = coeffs
|
||||
|
||||
plane = (pa + pbx * xx + pby * yy)
|
||||
return (field.replace(data=data - plane),)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PolyLevelField
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Polynomial Level")
|
||||
class PolyLevelField:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
|
||||
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
|
||||
RETURN_NAMES = ("leveled", "background")
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "level"
|
||||
DESCRIPTION = (
|
||||
"Fit and subtract a polynomial background of given degree in x and y. "
|
||||
"Equivalent to gwy_data_field_fit_polynom."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
x = np.linspace(0.0, 1.0, xres)
|
||||
y = np.linspace(0.0, 1.0, yres)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
# Build Vandermonde-style design matrix with all monomials x^i * y^j
|
||||
cols = []
|
||||
for i in range(degree_x + 1):
|
||||
for j in range(degree_y + 1):
|
||||
cols.append((xx ** i * yy ** j).ravel())
|
||||
A = np.column_stack(cols)
|
||||
z = data.ravel()
|
||||
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
|
||||
|
||||
background = (A @ coeffs).reshape(yres, xres)
|
||||
leveled = data - background
|
||||
|
||||
return (field.replace(data=leveled), field.replace(data=background))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FixZero
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Fix Zero")
|
||||
class FixZero:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["min", "mean", "median"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("zeroed",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "level"
|
||||
DESCRIPTION = (
|
||||
"Shift data so that the minimum (or mean/median) is zero. "
|
||||
"Equivalent to fix_zero in Gwyddion's level.c."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, method: str) -> tuple:
|
||||
data = field.data.copy()
|
||||
if method == "min":
|
||||
data -= data.min()
|
||||
elif method == "mean":
|
||||
data -= data.mean()
|
||||
elif method == "median":
|
||||
data -= np.median(data)
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
return (field.replace(data=data),)
|
||||
267
backend/server.py
Normal file
267
backend/server.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
aiohttp web server for argonode.
|
||||
|
||||
Routes
|
||||
------
|
||||
GET / → serve frontend/index.html
|
||||
GET /static/{path} → serve frontend JS/CSS
|
||||
GET /nodes → JSON dict of all registered node definitions
|
||||
POST /upload → multipart file upload to input/
|
||||
POST /prompt → submit a workflow; returns {prompt_id}
|
||||
GET /ws → WebSocket upgrade
|
||||
|
||||
WebSocket message types sent to clients
|
||||
----------------------------------------
|
||||
{"type": "execution_start", "data": {"prompt_id": "..."}}
|
||||
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
|
||||
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
|
||||
{"type": "table", "data": {"node_id": "...", "rows": [...]}}
|
||||
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
|
||||
{"type": "execution_complete", "data": {"prompt_id": "..."}}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from aiohttp import web, WSMsgType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FRONTEND_DIR = Path(__file__).parent.parent / "frontend"
|
||||
DIST_DIR = FRONTEND_DIR / "dist"
|
||||
INPUT_DIR = Path(__file__).parent.parent / "input"
|
||||
OUTPUT_DIR = Path(__file__).parent.parent / "output"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON helper — numpy scalars are not serialisable by default
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _SafeEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
import numpy as np
|
||||
if isinstance(obj, (np.integer,)):
|
||||
return int(obj)
|
||||
if isinstance(obj, (np.floating,)):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def _dumps(obj) -> str:
|
||||
return json.dumps(obj, cls=_SafeEncoder)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Application factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
# Import nodes to trigger registration decorators
|
||||
import backend.nodes # noqa: F401
|
||||
from backend.node_registry import get_all_node_info
|
||||
from backend.execution import ExecutionEngine, new_prompt_id
|
||||
|
||||
INPUT_DIR.mkdir(exist_ok=True)
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
engine = ExecutionEngine()
|
||||
websockets: set[web.WebSocketResponse] = set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket broadcast helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def broadcast(msg: dict) -> None:
|
||||
"""Schedule a broadcast to all connected WebSocket clients."""
|
||||
payload = _dumps(msg)
|
||||
for ws in list(websockets):
|
||||
if not ws.closed:
|
||||
asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop)
|
||||
|
||||
def on_preview(node_id: str, data_uri: str) -> None:
|
||||
broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
|
||||
|
||||
def on_table(node_id: str, rows: list) -> None:
|
||||
broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}})
|
||||
|
||||
def on_mesh(node_id: str, mesh_data: dict) -> None:
|
||||
broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
|
||||
|
||||
def on_overlay(node_id: str, overlay_data) -> None:
|
||||
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def index(request: web.Request) -> web.Response:
|
||||
# Serve Vite build output if available, else raw frontend
|
||||
if (DIST_DIR / "index.html").exists():
|
||||
return web.FileResponse(DIST_DIR / "index.html")
|
||||
return web.FileResponse(FRONTEND_DIR / "index.html")
|
||||
|
||||
async def get_nodes(request: web.Request) -> web.Response:
|
||||
info = get_all_node_info()
|
||||
return web.Response(
|
||||
text=_dumps(info),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
async def list_files(request: web.Request) -> web.Response:
|
||||
"""List files in the input/ directory for the file picker widget."""
|
||||
files = sorted(
|
||||
f.name for f in INPUT_DIR.iterdir()
|
||||
if f.is_file() and not f.name.startswith(".")
|
||||
) if INPUT_DIR.exists() else []
|
||||
return web.Response(text=_dumps(files), content_type="application/json")
|
||||
|
||||
async def browse_dir(request: web.Request) -> web.Response:
|
||||
"""
|
||||
Server-side directory browser for local file picking.
|
||||
GET /browse?dir=/some/path → {parent, dirs[], files[]}
|
||||
"""
|
||||
dir_path = request.query.get("dir", str(Path.home()))
|
||||
p = Path(dir_path).expanduser().resolve()
|
||||
|
||||
if not p.is_dir():
|
||||
raise web.HTTPBadRequest(reason=f"Not a directory: {p}")
|
||||
|
||||
dirs = []
|
||||
files = []
|
||||
try:
|
||||
for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()):
|
||||
if entry.name.startswith("."):
|
||||
continue
|
||||
if entry.is_dir():
|
||||
dirs.append(entry.name)
|
||||
elif entry.is_file():
|
||||
files.append(entry.name)
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
return web.Response(
|
||||
text=_dumps({
|
||||
"path": str(p),
|
||||
"parent": str(p.parent) if p.parent != p else None,
|
||||
"dirs": dirs,
|
||||
"files": files,
|
||||
}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
async def upload_file(request: web.Request) -> web.Response:
|
||||
reader = await request.multipart()
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "file":
|
||||
raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body")
|
||||
|
||||
filename = Path(field.filename).name # strip any path traversal
|
||||
dest = INPUT_DIR / filename
|
||||
with open(dest, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(65536)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
return web.Response(text=_dumps({"filename": filename}), content_type="application/json")
|
||||
|
||||
async def submit_prompt(request: web.Request) -> web.Response:
|
||||
body = await request.json()
|
||||
prompt = body.get("prompt")
|
||||
if not isinstance(prompt, dict) or not prompt:
|
||||
raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict")
|
||||
|
||||
prompt_id = new_prompt_id()
|
||||
|
||||
# Run execution in a thread pool so scipy doesn't block the event loop
|
||||
async def run():
|
||||
broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}})
|
||||
|
||||
def on_start(node_id: str) -> None:
|
||||
broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: engine.execute(
|
||||
prompt,
|
||||
on_node_start=on_start,
|
||||
on_preview=on_preview,
|
||||
on_table=on_table,
|
||||
on_mesh=on_mesh,
|
||||
on_overlay=on_overlay,
|
||||
),
|
||||
)
|
||||
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}})
|
||||
except Exception as exc:
|
||||
log.exception("Execution error")
|
||||
broadcast({
|
||||
"type": "execution_error",
|
||||
"data": {"node_id": "", "message": str(exc)},
|
||||
})
|
||||
|
||||
asyncio.ensure_future(run())
|
||||
return web.Response(
|
||||
text=_dumps({"prompt_id": prompt_id}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
websockets.add(ws)
|
||||
log.info("WebSocket client connected (%d total)", len(websockets))
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
pass # clients don't need to send anything currently
|
||||
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
|
||||
break
|
||||
finally:
|
||||
websockets.discard(ws)
|
||||
log.info("WebSocket client disconnected (%d total)", len(websockets))
|
||||
return ws
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# App assembly
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
app = web.Application()
|
||||
|
||||
app.router.add_get("/", index)
|
||||
app.router.add_get("/nodes", get_nodes)
|
||||
app.router.add_get("/files", list_files)
|
||||
app.router.add_get("/browse", browse_dir)
|
||||
app.router.add_post("/upload", upload_file)
|
||||
app.router.add_post("/prompt", submit_prompt)
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
|
||||
# Serve frontend static files (Vite build or raw)
|
||||
if DIST_DIR.exists():
|
||||
app.router.add_static("/assets", DIST_DIR / "assets")
|
||||
app.router.add_static("/static", FRONTEND_DIR)
|
||||
|
||||
# CORS — allow any origin (local dev only)
|
||||
async def _cors_middleware(app_, handler):
|
||||
async def middleware(request):
|
||||
if request.method == "OPTIONS":
|
||||
return web.Response(headers={
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
})
|
||||
response = await handler(request)
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
return response
|
||||
return middleware
|
||||
|
||||
app.middlewares.append(_cors_middleware)
|
||||
|
||||
return app
|
||||
Reference in New Issue
Block a user