initial commit

This commit is contained in:
2026-03-23 00:35:30 -07:00
parent 5ecc913e28
commit 87b6905fba
48 changed files with 7012 additions and 1 deletions

2
.gitignore vendored
View File

@@ -1 +1,3 @@
*__pycache__*
frontend/node_modules/
frontend/dist/

87
GWYDDION_FEATURE_GAP.md Normal file
View File

@@ -0,0 +1,87 @@
# Gwyddion Features Not Yet in Argonode
Reference for future implementation. Grouped by value to typical SPM workflows.
---
## High Value
| # | Feature | Gwyddion Source | Description |
|---|---------|---------------|-------------|
| 1 | Line Correction | linecorrect.c, linematch.c | Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts. |
| 2 | Scar Removal | scars.c | Detect and interpolate scan-line defects (horizontal streaks). |
| 3 | Facet Leveling | facet-level.c | Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces. |
| 4 | Morphological Mask Ops | mask_morph.c | Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks. |
| 5 | 1D FFT Filter | fft_filter_1d.c | Bandpass/lowpass/highpass filtering of LINE profiles. |
| 6 | 2D FFT Filter | fft_filter_2d.c | Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.). |
| 7 | Autocorrelation (ACF) | acf2d.c | 2D autocorrelation function. Reveals periodic structures and correlation lengths. |
| 8 | PSDF | psdf2d.c | Radial/2D power spectral density function. Complementary to ACF for roughness characterization. |
| 9 | Fractal Dimension | fractal.c | Multiple methods: partitioning, cube counting, triangulation, PSDF, HHCF. Quantifies surface complexity. |
| 10 | Curvature | curvature.c | Local mean/Gaussian curvature maps. Useful for feature identification. |
| 11 | Grain Distance Transform | mask_edt.c | Euclidean distance from grain boundaries. Useful for spatial distribution analysis. |
| 12 | Watershed Segmentation | grain_wshed.c | Automatic grain detection without manual threshold. More robust than simple thresholding. |
| 13 | Rotate / Flip | rotate.c, basicops.c | Basic geometric transforms (90°, arbitrary angle, mirror). |
| 14 | Crop | crop.c | Extract sub-region of a field. |
## Medium Value
| # | Feature | Gwyddion Source | Description |
|---|---------|---------------|-------------|
| 15 | Correlation / Pattern Matching | crosscor.c, maskcor.c | Find repeated features or align images via cross-correlation. |
| 16 | Slope Distribution | slope_dist.c | Angular histogram of surface slopes. Characterizes surface texture directionality. |
| 17 | Grain Filtering | grain_filter.c | Remove grains by size, height, or border contact. Refine grain masks post-detection. |
| 18 | Field Arithmetic | arithmetic.c | Add/subtract/multiply/divide two DATA_FIELDs. Useful for difference maps, normalization. |
| 19 | Spot Removal | spotremove.c | Interpolate over selected point defects (dust, spikes). |
| 20 | Tip Modeling / Deconvolution | tip_blind.c, tip_model.c | Estimate tip shape from image, deconvolve to recover true surface. |
| 21 | Radial Profile | rprofile tool | Azimuthally averaged profile from a center point. Good for circular features. |
| 22 | Wavelet Transform | dwt.c, cwt.c | Discrete/continuous wavelet analysis. Multi-scale roughness decomposition. |
| 23 | Scale / Resample | scale.c, resample.c | Resize fields with interpolation. |
| 24 | Gradient | gradient.c | Compute x/y gradient magnitude maps. |
| 25 | Custom Convolution | convolution_filter.c | User-defined kernel convolution. |
| 26 | Local Contrast Enhancement | local_contrast.c | Enhance visibility of local features in images. |
## Lower Priority
| # | Feature | Gwyddion Source | Description |
|---|---------|---------------|-------------|
| 27 | Drift Correction | drift.c | Compensate for thermal/piezo drift between scan lines. |
| 28 | Affine / Perspective Correction | correct_affine.c, correct_perspective.c | Fix geometric distortions from scanner nonlinearity. |
| 29 | MFM Analysis | mfm_*.c | Magnetic force microscopy: field calculation, shift finding. |
| 30 | Lattice Measurement | measure_lattice.c | Detect and measure periodic lattice structures from ACF/FFT. |
| 31 | Hough Transform | hough.c | Detect lines and circles in images. |
| 32 | Image Stitching / Merging | merge.c, stitch.c | Combine multiple overlapping scans into one image. |
| 33 | Facet Analysis | facet_analysis.c | Orientation distribution of surface facets (stereographic projection). |
| 34 | Shape Fitting | fit-shape.c | Fit geometric primitives: sphere, paraboloid, cylinder, etc. |
| 35 | Synthetic Surface Generation | *_synth.c (~20 modules) | Generate test surfaces: FBM, noise, lattice, waves, particles, fibers, etc. |
| 36 | Entropy | entropy.c | Information entropy of height distribution. |
| 37 | Indentation Analysis | indent_analyze.c, hertz.c | Nanoindentation curve fitting (Hertz model). |
| 38 | Deconvolution | deconvolve.c | Blind/regularized deconvolution for image restoration. |
| 39 | Canny / Harris Detection | filters.c | Corner and edge feature detection beyond basic Sobel/Prewitt. |
| 40 | Kuwahara Filter | filters.c | Edge-preserving smoothing filter. |
---
## Already Implemented in Argonode
For reference, these Gwyddion equivalents are already covered:
| Argonode Node | Category | Gwyddion Equivalent |
|--------------|----------|-------------------|
| Load Image / Load SPM File | io | File import (gwy, sxm, ibw) |
| Save Image | io | File export |
| Coordinate | io | — |
| Plane Level | level | level.c |
| Polynomial Level | level | polylevel.c |
| Fix Zero | level | level.c (fix_zero) |
| Gaussian Filter | filters | filters.c (gaussian) |
| Median Filter | filters | filters.c (median) |
| Edge Detect | filters | edge.c (sobel, prewitt, laplacian, LoG) |
| Statistics | analysis | stats.c |
| Height Histogram | analysis | linestats.c (dh) |
| 2D FFT | analysis | fft.c |
| Cross Section | analysis | profile tool |
| Profile Roughness | analysis | roughness.c (Ra, Rq, Rsk, Rku, Rp, Rv, Rt) |
| Line Math | analysis | linestats.c |
| Threshold Mask | grains | threshold.c, otsu_threshold.c |
| Grain Analysis | grains | grain_stat.c |
| Preview / 3D View / Print Table | display | Presentation, 3D view |

0
backend/__init__.py Normal file
View File

134
backend/data_types.py Normal file
View File

@@ -0,0 +1,134 @@
"""
Core data types for argonode.
DataField mirrors Gwyddion's GwyDataField structure:
xres, yres pixel dimensions
xreal, yreal physical dimensions in metres
xoff, yoff position offset in metres
si_unit_xy lateral unit string (e.g. "m", "nm")
si_unit_z height/value unit string (e.g. "m", "V", "A")
domain "spatial" or "frequency" (set by FFT nodes)
"""
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
@dataclass
class DataField:
data: np.ndarray # shape (yres, xres), dtype float64
xres: int = 0
yres: int = 0
xreal: float = 1e-6 # physical width in metres
yreal: float = 1e-6 # physical height in metres
xoff: float = 0.0
yoff: float = 0.0
si_unit_xy: str = "m"
si_unit_z: str = "m"
domain: str = "spatial" # "spatial" or "frequency"
def __post_init__(self) -> None:
self.data = np.asarray(self.data, dtype=np.float64)
if self.data.ndim != 2:
raise ValueError(f"DataField.data must be 2-D, got shape {self.data.shape}")
self.yres, self.xres = self.data.shape
def copy(self) -> "DataField":
"""Return a deep copy with independent data array."""
return DataField(
data=self.data.copy(),
xres=self.xres,
yres=self.yres,
xreal=self.xreal,
yreal=self.yreal,
xoff=self.xoff,
yoff=self.yoff,
si_unit_xy=self.si_unit_xy,
si_unit_z=self.si_unit_z,
domain=self.domain,
)
def replace(self, **kwargs) -> "DataField":
"""Return a copy with selected fields replaced. data is deep-copied unless provided."""
base = {
"data": self.data.copy(),
"xres": self.xres,
"yres": self.yres,
"xreal": self.xreal,
"yreal": self.yreal,
"xoff": self.xoff,
"yoff": self.yoff,
"si_unit_xy": self.si_unit_xy,
"si_unit_z": self.si_unit_z,
"domain": self.domain,
}
base.update(kwargs)
return DataField(**base)
@property
def dx(self) -> float:
"""Physical pixel size in x (metres)."""
return self.xreal / self.xres if self.xres else 1.0
@property
def dy(self) -> float:
"""Physical pixel size in y (metres)."""
return self.yreal / self.yres if self.yres else 1.0
# ---------------------------------------------------------------------------
# Utility helpers shared across nodes
# ---------------------------------------------------------------------------
def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
"""
Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap.
Returns shape (H, W, 3) uint8.
"""
import matplotlib.cm as cm
import matplotlib.colors as mcolors
data = df.data
dmin, dmax = data.min(), data.max()
if dmax > dmin:
normalized = (data - dmin) / (dmax - dmin)
else:
normalized = np.zeros_like(data)
cmap = cm.get_cmap(colormap)
rgba = cmap(normalized) # (H, W, 4) float [0,1]
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
return rgb
def image_to_uint8(image: np.ndarray) -> np.ndarray:
"""
Convert an IMAGE (float or uint8, 2-D or 3-D) to uint8 (H,W,3) or (H,W) for PIL.
"""
if image.dtype == np.uint8:
return image
# float — normalize to [0, 255]
imin, imax = image.min(), image.max()
if imax > imin:
out = (image - imin) / (imax - imin) * 255.0
else:
out = np.zeros_like(image)
return out.astype(np.uint8)
def encode_preview(arr: np.ndarray) -> str:
"""
Encode a uint8 numpy array as a base64 data URI (PNG).
arr: (H, W) grayscale or (H, W, 3) RGB, uint8.
"""
import base64
import io
from PIL import Image
img = Image.fromarray(arr)
buf = io.BytesIO()
img.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"

294
backend/execution.py Normal file
View File

@@ -0,0 +1,294 @@
"""
Graph execution engine for argonode.
Prompt format (same as ComfyUI):
{
"node_id": {
"class_type": "GaussianFilter",
"inputs": {
"field": ["upstream_node_id", 0], # link: [src_id, output_slot]
"sigma": 2.0 # constant widget value
}
},
...
}
The engine:
1. Topologically sorts nodes (Kahn's algorithm).
2. Resolves input links to actual Python objects from earlier outputs.
3. Calls each node's FUNCTION method.
4. Emits progress callbacks after each node.
"""
from __future__ import annotations
import uuid
from collections import defaultdict, deque
from typing import Any, Callable
from backend.node_registry import NODE_CLASS_MAPPINGS
def _is_link(value: Any) -> bool:
"""A value is a link if it's a [node_id_str, slot_int] pair."""
return (
isinstance(value, (list, tuple))
and len(value) == 2
and isinstance(value[0], str)
and isinstance(value[1], int)
)
class ExecutionEngine:
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
def execute(
self,
prompt: dict[str, dict],
on_node_start: Callable[[str], None] | None = None,
on_node_done: Callable[[str], None] | None = None,
on_preview: Callable[[str, str], None] | None = None,
on_table: Callable[[str, list], None] | None = None,
on_mesh: Callable[[str, dict], None] | None = None,
on_overlay: Callable[[str, str], None] | None = None,
) -> dict[str, tuple]:
"""
Execute the workflow described by `prompt`.
Parameters
----------
prompt : workflow dict (node_id → {class_type, inputs})
on_node_start : called with node_id just before a node executes
on_node_done : called with node_id just after a node executes
on_preview : called with (node_id, data_uri) when a display node runs
on_table : called with (node_id, table_list) when PrintTable runs
on_overlay : called with (node_id, data_uri) for interactive overlays
Returns
-------
node_outputs : {node_id → tuple-of-outputs} for every executed node
"""
order = self._topological_sort(prompt)
node_outputs: dict[str, tuple] = {}
# Inject display callbacks before execution
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay)
for node_id in order:
node_def = prompt[node_id]
class_name = node_def["class_type"]
if class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Unknown node type: '{class_name}'")
cls = NODE_CLASS_MAPPINGS[class_name]
raw_inputs = node_def.get("inputs", {})
inputs = self._resolve_inputs(raw_inputs, node_outputs)
# Let display nodes know their node_id so they can tag WS messages
self._set_node_id_on_display(cls, node_id)
if on_node_start:
on_node_start(node_id)
instance = cls()
func = getattr(instance, cls.FUNCTION)
result = func(**inputs)
# Nodes must return a tuple; coerce single values just in case
if not isinstance(result, tuple):
result = (result,)
node_outputs[node_id] = result
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
# IMAGE, or TABLE output so every node shows its result.
if on_preview or on_table:
self._auto_preview(cls, node_id, result, on_preview, on_table)
if on_node_done:
on_node_done(node_id)
return node_outputs
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _topological_sort(self, prompt: dict) -> list[str]:
"""Kahn's algorithm — returns node IDs in dependency order."""
in_degree: dict[str, int] = {nid: 0 for nid in prompt}
dependents: dict[str, list[str]] = defaultdict(list)
for node_id, node_def in prompt.items():
for value in node_def.get("inputs", {}).values():
if _is_link(value):
src_id = value[0]
if src_id in prompt:
in_degree[node_id] += 1
dependents[src_id].append(node_id)
queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0)
order: list[str] = []
while queue:
nid = queue.popleft()
order.append(nid)
for dep in dependents[nid]:
in_degree[dep] -= 1
if in_degree[dep] == 0:
queue.append(dep)
if len(order) != len(prompt):
raise ValueError("Cycle detected in workflow graph — cannot execute.")
return order
def _resolve_inputs(
self,
raw_inputs: dict[str, Any],
node_outputs: dict[str, tuple],
) -> dict[str, Any]:
"""Replace [src_id, slot] links with actual output values."""
resolved = {}
for key, value in raw_inputs.items():
if _is_link(value):
src_id, slot = value[0], int(value[1])
if src_id not in node_outputs:
raise KeyError(
f"Node '{src_id}' has no output yet — dependency ordering bug?"
)
outputs = node_outputs[src_id]
if slot >= len(outputs):
raise IndexError(
f"Node '{src_id}' only has {len(outputs)} outputs, "
f"but slot {slot} was requested."
)
resolved[key] = outputs[slot]
else:
resolved[key] = value
return resolved
def _inject_display_callbacks(
self,
on_preview: Callable | None,
on_table: Callable | None,
on_mesh: Callable | None = None,
on_overlay: Callable | None = None,
) -> None:
"""Wire up broadcast callbacks on display node classes."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection
from backend.nodes.io import SaveImage
PreviewImage._broadcast_fn = on_preview
View3D._broadcast_mesh_fn = on_mesh
PrintTable._broadcast_table_fn = on_table
CrossSection._broadcast_overlay_fn = on_overlay
SaveImage._broadcast_preview = (
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
)
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection
if cls in (PreviewImage, PrintTable, View3D, CrossSection):
cls._current_node_id = node_id
def _auto_preview(
self,
cls: type,
node_id: str,
result: tuple,
on_preview: Callable | None,
on_table: Callable | None,
) -> None:
"""
After every node executes, inspect its outputs and broadcast
a preview for the first DATA_FIELD, IMAGE, or TABLE found.
"""
import numpy as np
from backend.data_types import (
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
)
return_types = getattr(cls, "RETURN_TYPES", ())
for slot, type_name in enumerate(return_types):
if slot >= len(result):
break
value = result[slot]
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
arr = datafield_to_uint8(value, "viridis")
on_preview(node_id, encode_preview(arr))
return # one preview per node is enough
if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview:
arr = image_to_uint8(value)
on_preview(node_id, encode_preview(arr))
return
if type_name == "LINE" and isinstance(value, np.ndarray) and on_preview:
preview = self._render_line_preview(cls, slot, result)
if preview:
on_preview(node_id, preview)
return
if type_name == "TABLE" and isinstance(value, list) and on_table:
on_table(node_id, value)
return
def _render_line_preview(
self,
cls: type,
slot: int,
result: tuple,
) -> str | None:
"""Render a LINE output as a small matplotlib plot, returned as a data URI."""
import numpy as np
import base64
import io as _io
return_types = getattr(cls, "RETURN_TYPES", ())
# Find the y-values (current slot) and try to find an x-axis
y = result[slot]
x = None
# If the next output is also LINE, use it as x-axis
if slot + 1 < len(return_types) and return_types[slot + 1] == "LINE":
x = result[slot + 1]
# Or if slot > 0 and previous is LINE, this slot is the x-axis — skip
if slot > 0 and return_types[slot - 1] == "LINE":
return None # the first LINE already plotted both
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
fig.patch.set_facecolor("#1e293b")
ax.set_facecolor("#0f172a")
if x is not None:
ax.plot(x, y, color="#ff9800", linewidth=1.2)
else:
ax.plot(y, color="#ff9800", linewidth=1.2)
ax.tick_params(colors="#94a3b8", labelsize=7)
for spine in ax.spines.values():
spine.set_color("#334155")
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
fig.tight_layout(pad=0.4)
buf = _io.BytesIO()
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
plt.close(fig)
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
except Exception:
return None
def new_prompt_id() -> str:
return str(uuid.uuid4())

47
backend/main.py Normal file
View File

@@ -0,0 +1,47 @@
"""
Entry point for argonode.
Run with:
python -m backend.main
or simply:
python backend/main.py
from the argonode/ directory.
"""
import asyncio
import logging
import sys
from pathlib import Path
# Allow running as `python backend/main.py` from the project root
sys.path.insert(0, str(Path(__file__).parent.parent))
from aiohttp import web
from backend.server import create_app
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(name)s%(message)s",
)
log = logging.getLogger(__name__)
HOST = "127.0.0.1"
PORT = 8188
def main() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
app = create_app(loop)
log.info("=" * 60)
log.info(" Argonode — Node-based image analysis")
log.info(" Open your browser at http://%s:%d", HOST, PORT)
log.info("=" * 60)
web.run_app(app, host=HOST, port=PORT, loop=loop, access_log=None)
if __name__ == "__main__":
main()

56
backend/node_registry.py Normal file
View File

@@ -0,0 +1,56 @@
"""
Node registry for argonode.
Nodes are plain Python classes decorated with @register_node.
NODE_CLASS_MAPPINGS is the single source of truth consumed by
the execution engine and the /nodes REST endpoint.
"""
from __future__ import annotations
from typing import Any
NODE_CLASS_MAPPINGS: dict[str, type] = {}
NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {}
def register_node(display_name: str | None = None):
"""
Class decorator that registers a node class into NODE_CLASS_MAPPINGS.
Usage:
@register_node(display_name="Gaussian Filter")
class GaussianFilter:
...
"""
def decorator(cls: type) -> type:
name = cls.__name__
NODE_CLASS_MAPPINGS[name] = cls
NODE_DISPLAY_NAME_MAPPINGS[name] = display_name or name
return cls
return decorator
def get_node_info(class_name: str) -> dict[str, Any]:
"""
Return a JSON-serialisable dict describing a node — consumed by GET /nodes.
Shape is compatible with what LiteGraph.js expects from the frontend.
"""
cls = NODE_CLASS_MAPPINGS[class_name]
input_types: dict = cls.INPUT_TYPES()
return {
"name": class_name,
"display_name": NODE_DISPLAY_NAME_MAPPINGS.get(class_name, class_name),
"category": getattr(cls, "CATEGORY", "uncategorized"),
"input": input_types,
"input_order": {k: list(v.keys()) for k, v in input_types.items()},
"output": list(cls.RETURN_TYPES),
"output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)),
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
"description": getattr(cls, "DESCRIPTION", ""),
}
def get_all_node_info() -> dict[str, dict[str, Any]]:
"""Return info dicts for every registered node."""
return {name: get_node_info(name) for name in NODE_CLASS_MAPPINGS}

View File

@@ -0,0 +1,2 @@
# Import all node modules to trigger @register_node decorators.
from . import io, filters, level, analysis, grains, display

471
backend/nodes/analysis.py Normal file
View File

@@ -0,0 +1,471 @@
"""
Analysis nodes — statistics, histograms, FFT, cross sections.
Gwyddion equivalents:
StatisticsNode → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h)
HeightHistogram → DH (height distribution), gwy_data_field_dh
FFT2D → gwy_data_field_2dfft + gwy_data_field_2dpsdf
CrossSection → gwy_data_field_get_profile (libprocess/datafield.c)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# StatisticsNode
# ---------------------------------------------------------------------------
@register_node(display_name="Statistics")
class StatisticsNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("stats",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute basic surface statistics: min, max, mean, RMS roughness, median, "
"and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms."
)
def process(self, field: DataField) -> tuple:
d = field.data
mean = float(d.mean())
rms = float(np.sqrt(np.mean((d - mean) ** 2)))
skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0
kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0
table = [
{"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z},
{"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z},
{"quantity": "mean", "value": mean, "unit": field.si_unit_z},
{"quantity": "RMS", "value": rms, "unit": field.si_unit_z},
{"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z},
{"quantity": "skewness", "value": skewness, "unit": ""},
{"quantity": "kurtosis", "value": kurtosis, "unit": ""},
{"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z},
]
return (table,)
# ---------------------------------------------------------------------------
# HeightHistogram
# ---------------------------------------------------------------------------
@register_node(display_name="Height Histogram")
class HeightHistogram:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
}
}
RETURN_TYPES = ("LINE", "LINE")
RETURN_NAMES = ("counts", "bin_centers")
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute the height distribution histogram (DH). "
"Equivalent to gwy_data_field_dh."
)
def process(self, field: DataField, n_bins: int) -> tuple:
counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
return (counts.astype(np.float64), bin_centers)
# ---------------------------------------------------------------------------
# FFT2D
# ---------------------------------------------------------------------------
@register_node(display_name="2D FFT")
class FFT2D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"windowing": (["hann", "hamming", "blackman", "none"],),
"level": (["mean", "plane", "none"],),
"output": (["log_magnitude", "magnitude", "phase", "psdf"],),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("spectrum",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
"Output can be log magnitude, magnitude, phase, or PSDF. "
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
)
def process(self, field: DataField, windowing: str, level: str, output: str) -> tuple:
data = field.data.copy()
yres, xres = data.shape
# Level subtraction (Gwyddion-style, before windowing)
if level == "mean":
data -= data.mean()
elif level == "plane":
# Fit and subtract a plane: z = a + b*x + c*y
yy, xx = np.mgrid[0:yres, 0:xres]
xx_f = xx.ravel().astype(np.float64)
yy_f = yy.ravel().astype(np.float64)
zz_f = data.ravel()
A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f])
coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None)
plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy)
data -= plane
# Windowing (Gwyddion uses (i+0.5)/n centred formulation)
if windowing != "none":
t_y = (np.arange(yres) + 0.5) / yres
t_x = (np.arange(xres) + 0.5) / xres
if windowing == "hann":
wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y)
wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x)
elif windowing == "hamming":
wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y)
wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x)
elif windowing == "blackman":
wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y)
wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x)
else:
wy = np.ones(yres)
wx = np.ones(xres)
data *= np.outer(wy, wx)
# 2D FFT, shifted so DC is at centre
F = np.fft.fftshift(np.fft.fft2(data))
n = xres * yres
if output == "log_magnitude":
mag = np.abs(F)
# Log scale with floor to avoid log(0)
result = np.log1p(mag)
elif output == "magnitude":
result = np.abs(F)
elif output == "phase":
result = np.angle(F)
elif output == "psdf":
# Gwyddion-equivalent PSDF: |F|^2 * dx * dy / (n * 4π²)
dx = field.xreal / xres
dy = field.yreal / yres
result = (np.abs(F) ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
else:
result = np.abs(F)
# Calibrate the output field in spatial-frequency units
if output == "psdf":
# Gwyddion uses angular frequency: 2π/dx, 2π/dy
freq_xreal = 2.0 * np.pi * xres / field.xreal
freq_yreal = 2.0 * np.pi * yres / field.yreal
z_unit = f"({field.si_unit_z})^2 m^2"
else:
freq_xreal = xres / field.xreal
freq_yreal = yres / field.yreal
z_unit = field.si_unit_z
out_field = DataField(
data=result,
xreal=freq_xreal,
yreal=freq_yreal,
si_unit_xy="1/m",
si_unit_z=z_unit,
domain="frequency",
)
return (out_field,)
# ---------------------------------------------------------------------------
# CrossSection
# ---------------------------------------------------------------------------
def _extend_to_edges(x1, y1, x2, y2):
"""
Extend the line through (x1,y1)-(x2,y2) to the boundaries of [0,1]x[0,1].
Returns the two intersection points (clipped to the unit square).
"""
dx = x2 - x1
dy = y2 - y1
# Collect parametric t values where line hits each boundary
t_candidates = []
if abs(dx) > 1e-12:
for bx in (0.0, 1.0):
t = (bx - x1) / dx
y_at_t = y1 + t * dy
if -1e-9 <= y_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if abs(dy) > 1e-12:
for by in (0.0, 1.0):
t = (by - y1) / dy
x_at_t = x1 + t * dx
if -1e-9 <= x_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if len(t_candidates) < 2:
return x1, y1, x2, y2
t_min = min(t_candidates)
t_max = max(t_candidates)
return (
np.clip(x1 + t_min * dx, 0, 1),
np.clip(y1 + t_min * dy, 0, 1),
np.clip(x1 + t_max * dx, 0, 1),
np.clip(y1 + t_max * dy, 0, 1),
)
@register_node(display_name="Cross Section")
class CrossSection:
"""Extract a 1-D height profile along an arbitrary line across the image."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
"extend": (["none", "to_edges"],),
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
},
"optional": {
"point_a": ("COORD",),
"point_b": ("COORD",),
},
}
RETURN_TYPES = ("LINE",)
RETURN_NAMES = ("profile",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Extract a cross-section profile along a line between two points. "
"Drag the markers on the image to set the line endpoints. "
"Equivalent to gwy_data_field_get_profile."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(
self, field: DataField,
x1: float, y1: float, x2: float, y2: float,
extend: str, n_samples: int,
point_a=None, point_b=None,
) -> tuple:
from scipy.ndimage import map_coordinates
import io, base64
from matplotlib.figure import Figure
# COORD inputs override widget values
if point_a is not None:
x1, y1 = float(point_a[0]), float(point_a[1])
if point_b is not None:
x2, y2 = float(point_b[0]), float(point_b[1])
# Remember marker positions (before extend)
marker_x1, marker_y1 = float(x1), float(y1)
marker_x2, marker_y2 = float(x2), float(y2)
xres, yres = field.xres, field.yres
if extend == "to_edges":
x1, y1, x2, y2 = _extend_to_edges(
float(x1), float(y1), float(x2), float(y2),
)
# Convert fractional [0,1] to pixel indices [0, res-1]
px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1)
px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1)
# Number of sample points
line_len_px = np.hypot(px2 - px1, py2 - py1)
if n_samples <= 0:
n_samples = max(2, int(np.ceil(line_len_px)))
# Sample coordinates along the line
t = np.linspace(0, 1, n_samples)
coords_y = py1 + t * (py2 - py1)
coords_x = px1 + t * (px2 - px1)
# Interpolate values along the line (cubic spline)
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
# Broadcast overlay image with marker positions
if CrossSection._broadcast_overlay_fn is not None:
fig = Figure(figsize=(3, 3), dpi=100)
ax = fig.add_axes([0, 0, 1, 1])
ax.imshow(field.data, cmap="viridis", aspect="auto")
ax.axis("off")
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
CrossSection._broadcast_overlay_fn(
CrossSection._current_node_id,
{
"image": image_uri,
"x1": marker_x1, "y1": marker_y1,
"x2": marker_x2, "y2": marker_y2,
"a_locked": point_a is not None,
"b_locked": point_b is not None,
},
)
return (profile.astype(np.float64),)
# ---------------------------------------------------------------------------
# LineMath — single scalar measurement from a LINE profile
# ---------------------------------------------------------------------------
def _safe_rq(d):
"""RMS of deviations from mean."""
return float(np.sqrt(np.mean(d * d)))
# Registry: name → (function(z) → float, unit_label)
# All functions receive the raw 1-D profile as float64.
LINE_OPS: dict[str, tuple] = {}
def _line_op(name, unit=""):
"""Decorator to register a LINE operation."""
def decorator(fn):
LINE_OPS[name] = (fn, unit)
return fn
return decorator
# ── Basic statistics ──────────────────────────────────────────────────────
@_line_op("min")
def _op_min(z):
return float(z.min())
@_line_op("max")
def _op_max(z):
return float(z.max())
@_line_op("mean")
def _op_mean(z):
return float(z.mean())
@_line_op("median")
def _op_median(z):
return float(np.median(z))
@_line_op("sum")
def _op_sum(z):
return float(z.sum())
@_line_op("range")
def _op_range(z):
return float(z.max() - z.min())
@_line_op("length", unit="pts")
def _op_length(z):
return float(len(z))
@_line_op("rms")
def _op_rms(z):
return float(np.sqrt(np.mean(z * z)))
# ── Roughness parameters ──────────────────────────
@_line_op("Ra")
def _op_ra(z):
return float(np.mean(np.abs(z - z.mean())))
@_line_op("Rq")
def _op_rq(z):
d = z - z.mean()
return _safe_rq(d)
@_line_op("Rsk")
def _op_rsk(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0
@_line_op("Rku")
def _op_rku(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0
@_line_op("Rp")
def _op_rp(z):
return float((z - z.mean()).max())
@_line_op("Rv")
def _op_rv(z):
return float(-(z - z.mean()).min())
@_line_op("Rt")
def _op_rt(z):
d = z - z.mean()
return float(d.max() - d.min())
@_line_op("Dq")
def _op_dq(z):
"""RMS slope (first derivative RMS)."""
dz = np.diff(z)
return float(np.sqrt(np.mean(dz * dz)))
@_line_op("Da")
def _op_da(z):
"""Mean absolute slope."""
return float(np.mean(np.abs(np.diff(z))))
@register_node(display_name="Line Math")
class LineMath:
"""Compute a single scalar value from a LINE profile."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"line": ("LINE",),
"operation": (list(LINE_OPS.keys()),),
}
}
RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("result",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute a single scalar measurement from a LINE profile. "
"Includes basic stats and Gwyddion-convention roughness parameters."
)
def process(self, line, operation: str) -> tuple:
z = np.asarray(line, dtype=np.float64).ravel()
fn, unit = LINE_OPS[operation]
value = fn(z)
table = [{"quantity": operation, "value": value, "unit": unit}]
return (table,)

165
backend/nodes/display.py Normal file
View File

@@ -0,0 +1,165 @@
"""
Display / output nodes.
Preview accepts both DATA_FIELD and IMAGE via optional inputs —
connect whichever type you have. The server injects _broadcast_fn
before execution begins.
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview
@register_node(display_name="Preview")
class PreviewImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],),
},
"optional": {
"image": ("IMAGE",),
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ()
FUNCTION = "preview"
CATEGORY = "display"
OUTPUT_NODE = True
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
_broadcast_fn = None
_current_node_id: str = ""
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
# Prefer field if both are connected; accept whichever is provided
if field is not None:
arr_u8 = datafield_to_uint8(field, colormap)
elif image is not None:
if image.dtype != np.uint8:
imin, imax = image.min(), image.max()
if imax > imin:
norm = (image - imin) / (imax - imin)
else:
norm = np.zeros_like(image)
arr_u8 = (norm * 255).astype(np.uint8)
else:
arr_u8 = image
if arr_u8.ndim == 2 and colormap != "gray":
import matplotlib.cm as cm
cmap = cm.get_cmap(colormap)
rgba = cmap(arr_u8.astype(np.float32) / 255.0)
arr_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
else:
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
data_uri = encode_preview(arr_u8)
if PreviewImage._broadcast_fn is not None:
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
return ()
@register_node(display_name="3D View")
class View3D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],),
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
}
}
RETURN_TYPES = ()
FUNCTION = "render"
CATEGORY = "display"
OUTPUT_NODE = True
DESCRIPTION = (
"Interactive 3D surface view of a DATA_FIELD. "
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
)
_broadcast_mesh_fn = None
_current_node_id: str = ""
def render(
self, field: DataField,
colormap: str, z_scale: float, resolution: int,
) -> tuple:
import matplotlib.cm as cm
import base64
data = field.data
yres, xres = data.shape
# Downsample if larger than resolution
step_y = max(1, yres // resolution)
step_x = max(1, xres // resolution)
z = data[::step_y, ::step_x].astype(np.float32)
ny, nx = z.shape
# Normalize for colormap
zmin, zmax = float(z.min()), float(z.max())
if zmax > zmin:
z_norm = (z - zmin) / (zmax - zmin)
else:
z_norm = np.zeros_like(z)
cmap = cm.get_cmap(colormap)
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
# Base64-encode arrays for efficient WS transport
z_b64 = base64.b64encode(z.tobytes()).decode()
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
mesh_data = {
"width": nx,
"height": ny,
"z_data": z_b64,
"colors": colors_b64,
"z_min": zmin,
"z_max": zmax,
"z_scale": float(z_scale),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}
if View3D._broadcast_mesh_fn is not None:
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
return ()
@register_node(display_name="Print Table")
class PrintTable:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"table": ("TABLE",),
}
}
RETURN_TYPES = ()
FUNCTION = "print_table"
CATEGORY = "display"
OUTPUT_NODE = True
DESCRIPTION = "Send a TABLE to the browser as a WebSocket message for display."
_broadcast_table_fn = None
_current_node_id: str = ""
def print_table(self, table: list) -> tuple:
if PrintTable._broadcast_table_fn is not None:
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
return ()

115
backend/nodes/filters.py Normal file
View File

@@ -0,0 +1,115 @@
"""
Filter nodes — Gwyddion-equivalent image filters.
Gwyddion equivalents:
GaussianFilter → gwy_data_field_filter_gaussian
MedianFilter → gwy_data_field_filter_median
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# GaussianFilter
# ---------------------------------------------------------------------------
@register_node(display_name="Gaussian Filter")
class GaussianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
CATEGORY = "filters"
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
def process(self, field: DataField, sigma: float) -> tuple:
from scipy.ndimage import gaussian_filter
data = gaussian_filter(field.data.copy(), sigma=float(sigma))
return (field.replace(data=data),)
# ---------------------------------------------------------------------------
# MedianFilter
# ---------------------------------------------------------------------------
@register_node(display_name="Median Filter")
class MedianFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("filtered",)
FUNCTION = "process"
CATEGORY = "filters"
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
def process(self, field: DataField, size: int) -> tuple:
from scipy.ndimage import median_filter
size = max(1, int(size))
data = median_filter(field.data.copy(), size=size)
return (field.replace(data=data),)
# ---------------------------------------------------------------------------
# EdgeDetect
# ---------------------------------------------------------------------------
@register_node(display_name="Edge Detect")
class EdgeDetect:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["sobel", "prewitt", "laplacian", "log"],),
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("edges",)
FUNCTION = "process"
CATEGORY = "filters"
DESCRIPTION = (
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
)
def process(self, field: DataField, method: str, sigma: float) -> tuple:
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
data = field.data.copy()
if method == "sobel":
sx = sobel(data, axis=1)
sy = sobel(data, axis=0)
result = np.hypot(sx, sy)
elif method == "prewitt":
px = prewitt(data, axis=1)
py = prewitt(data, axis=0)
result = np.hypot(px, py)
elif method == "laplacian":
result = laplace(data)
elif method == "log":
result = gaussian_laplace(data, sigma=float(sigma))
else:
raise ValueError(f"Unknown edge detection method: {method}")
return (field.replace(data=result),)

127
backend/nodes/grains.py Normal file
View File

@@ -0,0 +1,127 @@
"""
Grain/feature detection nodes.
Gwyddion equivalents:
ThresholdMask → threshold.c / otsu_threshold.c
GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# ThresholdMask
# ---------------------------------------------------------------------------
@register_node(display_name="Threshold Mask")
class ThresholdMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["otsu", "absolute", "relative"],),
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
"direction": (["above", "below"],),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "grains"
DESCRIPTION = (
"Create a binary mask by thresholding data. "
"Otsu automatically finds the optimal threshold. "
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
)
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
data = field.data
if method == "otsu":
from skimage.filters import threshold_otsu
t = threshold_otsu(data)
elif method == "absolute":
t = float(threshold)
elif method == "relative":
# threshold is a fraction [0, 1] of the data range
dmin, dmax = data.min(), data.max()
t = dmin + float(threshold) * (dmax - dmin)
else:
raise ValueError(f"Unknown threshold method: {method}")
if direction == "above":
mask = (data >= t).astype(np.uint8) * 255
else:
mask = (data < t).astype(np.uint8) * 255
return (mask,)
# ---------------------------------------------------------------------------
# GrainAnalysis
# ---------------------------------------------------------------------------
@register_node(display_name="Grain Analysis")
class GrainAnalysis:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"mask": ("IMAGE",),
"min_size": ("INT", {"default": 10, "min": 1, "max": 100000, "step": 1}),
}
}
RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("grain_stats",)
FUNCTION = "process"
CATEGORY = "grains"
DESCRIPTION = (
"Label connected grain regions in a binary mask and compute per-grain statistics: "
"area, equivalent diameter, mean/max height, bounding box. "
"Equivalent to gwy_data_field_grains_get_values."
)
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:
from scipy.ndimage import label, find_objects
binary = (mask > 127).astype(np.int32)
labeled, n_grains = label(binary)
pixel_area = field.dx * field.dy # m^2 per pixel
rows = []
for grain_id in range(1, n_grains + 1):
grain_pixels = labeled == grain_id
area_px = int(grain_pixels.sum())
if area_px < min_size:
continue
area_m2 = area_px * pixel_area
equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi))
heights = field.data[grain_pixels]
mean_h = float(heights.mean())
max_h = float(heights.max())
# Bounding box
ys, xs = np.where(grain_pixels)
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"
rows.append({
"grain_id": grain_id,
"area_px": area_px,
"area_m2": area_m2,
"equiv_diam_m": equiv_diam,
"mean_height": mean_h,
"max_height": max_h,
"bbox": bbox,
})
return (rows,)

277
backend/nodes/io.py Normal file
View File

@@ -0,0 +1,277 @@
"""
I/O nodes: load and save images and SPM data.
"""
from __future__ import annotations
import os
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, encode_preview, image_to_uint8
# Resolved at server startup so nodes know where to look
INPUT_DIR = Path(__file__).parent.parent.parent / "input"
OUTPUT_DIR = Path(__file__).parent.parent.parent / "output"
# ---------------------------------------------------------------------------
# LoadImage
# ---------------------------------------------------------------------------
@register_node(display_name="Load Image")
class LoadImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
}
}
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
RETURN_NAMES = ("image", "field")
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD."
def load(self, filename: str):
# Accept absolute paths or filenames relative to input/
path = Path(filename)
if not path.is_absolute():
path = INPUT_DIR / filename
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
ext = path.suffix.lower()
if ext in (".npy",):
arr = np.load(str(path)).astype(np.float64)
elif ext in (".npz",):
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image
img = Image.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
# Convert to float64 grayscale for the DATA_FIELD output
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
field = DataField(data=gray)
return (arr, field)
# ---------------------------------------------------------------------------
# LoadSPM
# ---------------------------------------------------------------------------
@register_node(display_name="Load SPM File")
class LoadSPM:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
"channel": ("STRING", {"default": "Z"}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField."
def load(self, filename: str, channel: str = "Z"):
path = Path(filename)
if not path.is_absolute():
path = INPUT_DIR / filename
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
ext = path.suffix.lower()
if ext == ".gwy":
return (self._load_gwy(path, channel),)
elif ext == ".sxm":
return (self._load_sxm(path, channel),)
elif ext in (".ibw",):
return (self._load_ibw(path),)
elif ext in (".npy",):
data = np.load(str(path)).astype(np.float64)
return (DataField(data=data),)
elif ext in (".npz",):
npz = np.load(str(path))
key = list(npz.files)[0]
return (DataField(data=npz[key].astype(np.float64)),)
else:
raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz")
def _load_gwy(self, path: Path, channel: str) -> DataField:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
# Try requested channel name, fall back to first available
ch = None
for key, df in channels.items():
if channel.lower() in key.lower():
ch = df
break
if ch is None:
ch = next(iter(channels.values()))
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
return DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
)
def _load_sxm(self, path: Path, channel: str) -> DataField:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
# Pick channel
ch_key = None
for key in signals:
if channel.upper() in key.upper():
ch_key = key
break
if ch_key is None:
ch_key = next(iter(signals))
data = signals[ch_key].get("forward", list(signals[ch_key].values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
return DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
)
def _load_ibw(self, path: Path) -> DataField:
try:
import igor.igorpy as igorpy
wave = igorpy.load(str(path))
data = wave.wave["wData"].squeeze().astype(np.float64)
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
if data.ndim == 1:
data = data.reshape(1, -1)
elif data.ndim != 2:
data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1)
return DataField(data=data, si_unit_xy="m", si_unit_z="m")
# ---------------------------------------------------------------------------
# Coordinate
# ---------------------------------------------------------------------------
@register_node(display_name="Coordinate")
class Coordinate:
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("COORD",)
RETURN_NAMES = ("point",)
FUNCTION = "process"
CATEGORY = "io"
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
def process(self, x: float, y: float) -> tuple:
return ((float(x), float(y)),)
# ---------------------------------------------------------------------------
# SaveImage
# ---------------------------------------------------------------------------
@register_node(display_name="Save Image")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "output"}),
"format": (["PNG", "TIFF", "NPY"],),
}
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "io"
OUTPUT_NODE = True
DESCRIPTION = "Save an image or array to the output folder."
# Injected by server.py before execution begins
_broadcast_preview = None
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"):
OUTPUT_DIR.mkdir(exist_ok=True)
# Find next available filename
idx = 1
while True:
name = f"{filename_prefix}_{idx:04d}"
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
if not candidate.exists():
break
idx += 1
if format == "NPY":
np.save(str(OUTPUT_DIR / f"{name}.npy"), image)
else:
from PIL import Image
arr = image_to_uint8(image)
if arr.ndim == 2:
pil_img = Image.fromarray(arr, mode="L")
else:
pil_img = Image.fromarray(arr, mode="RGB")
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
# Emit preview over WebSocket if callback is set
if SaveImage._broadcast_preview is not None:
arr_u8 = image_to_uint8(image)
data_uri = encode_preview(arr_u8)
SaveImage._broadcast_preview(data_uri)
return ()

150
backend/nodes/level.py Normal file
View File

@@ -0,0 +1,150 @@
"""
Leveling nodes — background removal and zero correction.
Gwyddion equivalents:
PlaneLevelField → gwy_data_field_fit_plane + gwy_data_field_plane_level
PolyLevelField → gwy_data_field_fit_polynom (via level.c polylevel module)
FixZero → fix_zero in level.c
Plane-fit algorithm follows Gwyddion's level.h definition:
z_fit = pa + pbx * x + pby * y (least-squares over all pixels)
"""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
# ---------------------------------------------------------------------------
# PlaneLevelField
# ---------------------------------------------------------------------------
@register_node(display_name="Plane Level")
class PlaneLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("leveled",)
FUNCTION = "process"
CATEGORY = "level"
DESCRIPTION = (
"Fit and subtract a least-squares plane from the data. "
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
)
def process(self, field: DataField) -> tuple:
data = field.data.copy()
yres, xres = data.shape
# Normalised coordinate grids in [0, 1]
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
# Design matrix: [1, x, y] shape (N, 3)
A = np.column_stack([
np.ones(xres * yres),
xx.ravel(),
yy.ravel(),
])
z = data.ravel()
# Least-squares: solve A @ [pa, pbx, pby] = z
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
pa, pbx, pby = coeffs
plane = (pa + pbx * xx + pby * yy)
return (field.replace(data=data - plane),)
# ---------------------------------------------------------------------------
# PolyLevelField
# ---------------------------------------------------------------------------
@register_node(display_name="Polynomial Level")
class PolyLevelField:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
}
}
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
RETURN_NAMES = ("leveled", "background")
FUNCTION = "process"
CATEGORY = "level"
DESCRIPTION = (
"Fit and subtract a polynomial background of given degree in x and y. "
"Equivalent to gwy_data_field_fit_polynom."
)
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
data = field.data.copy()
yres, xres = data.shape
x = np.linspace(0.0, 1.0, xres)
y = np.linspace(0.0, 1.0, yres)
xx, yy = np.meshgrid(x, y)
# Build Vandermonde-style design matrix with all monomials x^i * y^j
cols = []
for i in range(degree_x + 1):
for j in range(degree_y + 1):
cols.append((xx ** i * yy ** j).ravel())
A = np.column_stack(cols)
z = data.ravel()
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
background = (A @ coeffs).reshape(yres, xres)
leveled = data - background
return (field.replace(data=leveled), field.replace(data=background))
# ---------------------------------------------------------------------------
# FixZero
# ---------------------------------------------------------------------------
@register_node(display_name="Fix Zero")
class FixZero:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"method": (["min", "mean", "median"],),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("zeroed",)
FUNCTION = "process"
CATEGORY = "level"
DESCRIPTION = (
"Shift data so that the minimum (or mean/median) is zero. "
"Equivalent to fix_zero in Gwyddion's level.c."
)
def process(self, field: DataField, method: str) -> tuple:
data = field.data.copy()
if method == "min":
data -= data.min()
elif method == "mean":
data -= data.mean()
elif method == "median":
data -= np.median(data)
else:
raise ValueError(f"Unknown method: {method}")
return (field.replace(data=data),)

267
backend/server.py Normal file
View File

@@ -0,0 +1,267 @@
"""
aiohttp web server for argonode.
Routes
------
GET / → serve frontend/index.html
GET /static/{path} → serve frontend JS/CSS
GET /nodes → JSON dict of all registered node definitions
POST /upload → multipart file upload to input/
POST /prompt → submit a workflow; returns {prompt_id}
GET /ws → WebSocket upgrade
WebSocket message types sent to clients
----------------------------------------
{"type": "execution_start", "data": {"prompt_id": "..."}}
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
{"type": "table", "data": {"node_id": "...", "rows": [...]}}
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
{"type": "execution_complete", "data": {"prompt_id": "..."}}
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from pathlib import Path
from aiohttp import web, WSMsgType
log = logging.getLogger(__name__)
FRONTEND_DIR = Path(__file__).parent.parent / "frontend"
DIST_DIR = FRONTEND_DIR / "dist"
INPUT_DIR = Path(__file__).parent.parent / "input"
OUTPUT_DIR = Path(__file__).parent.parent / "output"
# ---------------------------------------------------------------------------
# JSON helper — numpy scalars are not serialisable by default
# ---------------------------------------------------------------------------
class _SafeEncoder(json.JSONEncoder):
def default(self, obj):
import numpy as np
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
def _dumps(obj) -> str:
return json.dumps(obj, cls=_SafeEncoder)
# ---------------------------------------------------------------------------
# Application factory
# ---------------------------------------------------------------------------
def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
# Import nodes to trigger registration decorators
import backend.nodes # noqa: F401
from backend.node_registry import get_all_node_info
from backend.execution import ExecutionEngine, new_prompt_id
INPUT_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)
engine = ExecutionEngine()
websockets: set[web.WebSocketResponse] = set()
# ------------------------------------------------------------------
# WebSocket broadcast helpers
# ------------------------------------------------------------------
def broadcast(msg: dict) -> None:
"""Schedule a broadcast to all connected WebSocket clients."""
payload = _dumps(msg)
for ws in list(websockets):
if not ws.closed:
asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop)
def on_preview(node_id: str, data_uri: str) -> None:
broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
def on_table(node_id: str, rows: list) -> None:
broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}})
def on_mesh(node_id: str, mesh_data: dict) -> None:
broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
def on_overlay(node_id: str, overlay_data) -> None:
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
# ------------------------------------------------------------------
# Route handlers
# ------------------------------------------------------------------
async def index(request: web.Request) -> web.Response:
# Serve Vite build output if available, else raw frontend
if (DIST_DIR / "index.html").exists():
return web.FileResponse(DIST_DIR / "index.html")
return web.FileResponse(FRONTEND_DIR / "index.html")
async def get_nodes(request: web.Request) -> web.Response:
info = get_all_node_info()
return web.Response(
text=_dumps(info),
content_type="application/json",
)
async def list_files(request: web.Request) -> web.Response:
"""List files in the input/ directory for the file picker widget."""
files = sorted(
f.name for f in INPUT_DIR.iterdir()
if f.is_file() and not f.name.startswith(".")
) if INPUT_DIR.exists() else []
return web.Response(text=_dumps(files), content_type="application/json")
async def browse_dir(request: web.Request) -> web.Response:
"""
Server-side directory browser for local file picking.
GET /browse?dir=/some/path → {parent, dirs[], files[]}
"""
dir_path = request.query.get("dir", str(Path.home()))
p = Path(dir_path).expanduser().resolve()
if not p.is_dir():
raise web.HTTPBadRequest(reason=f"Not a directory: {p}")
dirs = []
files = []
try:
for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()):
if entry.name.startswith("."):
continue
if entry.is_dir():
dirs.append(entry.name)
elif entry.is_file():
files.append(entry.name)
except PermissionError:
pass
return web.Response(
text=_dumps({
"path": str(p),
"parent": str(p.parent) if p.parent != p else None,
"dirs": dirs,
"files": files,
}),
content_type="application/json",
)
async def upload_file(request: web.Request) -> web.Response:
reader = await request.multipart()
field = await reader.next()
if field is None or field.name != "file":
raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body")
filename = Path(field.filename).name # strip any path traversal
dest = INPUT_DIR / filename
with open(dest, "wb") as f:
while True:
chunk = await field.read_chunk(65536)
if not chunk:
break
f.write(chunk)
return web.Response(text=_dumps({"filename": filename}), content_type="application/json")
async def submit_prompt(request: web.Request) -> web.Response:
body = await request.json()
prompt = body.get("prompt")
if not isinstance(prompt, dict) or not prompt:
raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict")
prompt_id = new_prompt_id()
# Run execution in a thread pool so scipy doesn't block the event loop
async def run():
broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}})
def on_start(node_id: str) -> None:
broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
try:
await loop.run_in_executor(
None,
lambda: engine.execute(
prompt,
on_node_start=on_start,
on_preview=on_preview,
on_table=on_table,
on_mesh=on_mesh,
on_overlay=on_overlay,
),
)
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}})
except Exception as exc:
log.exception("Execution error")
broadcast({
"type": "execution_error",
"data": {"node_id": "", "message": str(exc)},
})
asyncio.ensure_future(run())
return web.Response(
text=_dumps({"prompt_id": prompt_id}),
content_type="application/json",
)
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
websockets.add(ws)
log.info("WebSocket client connected (%d total)", len(websockets))
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
pass # clients don't need to send anything currently
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
break
finally:
websockets.discard(ws)
log.info("WebSocket client disconnected (%d total)", len(websockets))
return ws
# ------------------------------------------------------------------
# App assembly
# ------------------------------------------------------------------
app = web.Application()
app.router.add_get("/", index)
app.router.add_get("/nodes", get_nodes)
app.router.add_get("/files", list_files)
app.router.add_get("/browse", browse_dir)
app.router.add_post("/upload", upload_file)
app.router.add_post("/prompt", submit_prompt)
app.router.add_get("/ws", websocket_handler)
# Serve frontend static files (Vite build or raw)
if DIST_DIR.exists():
app.router.add_static("/assets", DIST_DIR / "assets")
app.router.add_static("/static", FRONTEND_DIR)
# CORS — allow any origin (local dev only)
async def _cors_middleware(app_, handler):
async def middleware(request):
if request.method == "OPTIONS":
return web.Response(headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
})
response = await handler(request)
response.headers["Access-Control-Allow-Origin"] = "*"
return response
return middleware
app.middlewares.append(_cors_middleware)
return app

12
frontend/index.html Normal file
View File

@@ -0,0 +1,12 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Argonode — Image Analysis</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.jsx"></script>
</body>
</html>

1951
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

20
frontend/package.json Normal file
View File

@@ -0,0 +1,20 @@
{
"name": "argonode-frontend",
"private": true,
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview"
},
"dependencies": {
"@xyflow/react": "^12.0.0",
"react": "^18.3.0",
"react-dom": "^18.3.0",
"three": "^0.183.2"
},
"devDependencies": {
"@vitejs/plugin-react": "^4.3.0",
"vite": "^5.4.0"
}
}

601
frontend/src/App.jsx Normal file
View File

@@ -0,0 +1,601 @@
import React, {
useState, useCallback, useEffect, useRef, useMemo,
} from 'react';
import {
ReactFlow, Background, Controls, MiniMap,
useNodesState, useEdgesState, addEdge, useReactFlow,
ReactFlowProvider,
} from '@xyflow/react';
import '@xyflow/react/dist/style.css';
import CustomNode, { NodeContext } from './CustomNode';
import FileBrowser from './FileBrowser';
import * as api from './api';
// ── Constants ─────────────────────────────────────────────────────────
const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']);
const TYPE_COLORS = {
DATA_FIELD: '#3a7abf',
IMAGE: '#4caf50',
LINE: '#ff9800',
TABLE: '#fdd835',
COORD: '#e91e63',
};
const NODE_TYPES = { custom: CustomNode };
// ── Handle ID helpers ─────────────────────────────────────────────────
function getHandleType(handleId) {
return handleId.split('::')[2];
}
function getInputName(handleId) {
return handleId.split('::')[1];
}
function getOutputSlot(handleId) {
return parseInt(handleId.split('::')[1], 10);
}
// ── Graph serialisation → backend prompt format ───────────────────────
function serializeGraph(nodes, edges) {
const prompt = {};
for (const node of nodes) {
const { className, definition, widgetValues } = node.data;
if (!definition) continue;
const inputs = {};
// Widget (scalar) values
const required = definition.input.required || {};
for (const [name, spec] of Object.entries(required)) {
const [type] = Array.isArray(spec) ? spec : [spec];
if (DATA_TYPES.has(type)) continue; // socket, handled via edges
if (widgetValues[name] !== undefined) {
inputs[name] = widgetValues[name];
}
}
// Connected (socket) inputs from edges
const incoming = edges.filter((e) => e.target === node.id);
for (const edge of incoming) {
const inputName = getInputName(edge.targetHandle);
const outputSlot = getOutputSlot(edge.sourceHandle);
inputs[inputName] = [edge.source, outputSlot];
}
prompt[node.id] = { class_type: className, inputs };
}
return prompt;
}
// ── Context menu component ────────────────────────────────────────────
function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection }) {
// Group by category, optionally filtering to compatible nodes
const categories = {};
for (const [className, def] of Object.entries(nodeDefs)) {
// If filtering: only show nodes with a matching input or output
if (filterType && filterDirection) {
if (filterDirection === 'source') {
// Dragged from an output — show nodes that have a matching INPUT
const req = def.input.required || {};
const opt = def.input.optional || {};
const allInputs = { ...req, ...opt };
const hasMatch = Object.values(allInputs).some((spec) => {
const [type] = Array.isArray(spec) ? spec : [spec];
return type === filterType;
});
if (!hasMatch) continue;
} else {
// Dragged from an input — show nodes that have a matching OUTPUT
if (!def.output.includes(filterType)) continue;
}
}
const cat = def.category || 'uncategorized';
if (!categories[cat]) categories[cat] = [];
categories[cat].push({ className, def });
}
if (Object.keys(categories).length === 0) {
return (
<div className="context-menu" style={{ left: x, top: y }} onClick={(e) => e.stopPropagation()}>
<div className="context-item" style={{ color: '#64748b' }}>No compatible nodes</div>
</div>
);
}
return (
<div
className="context-menu"
style={{ left: x, top: y }}
onClick={(e) => e.stopPropagation()}
>
{Object.entries(categories).map(([cat, items]) => (
<div key={cat}>
<div className="context-category">{cat}</div>
{items.map(({ className, def }) => (
<div
key={className}
className="context-item"
onClick={() => { onAdd(className, def); onClose(); }}
>
{def.display_name || className}
</div>
))}
</div>
))}
</div>
);
}
// ── Main flow component (needs ReactFlowProvider ancestor) ────────────
function Flow() {
const [nodes, setNodes, onNodesChange] = useNodesState([]);
const [edges, setEdges, onEdgesChange] = useEdgesState([]);
const [status, setStatus] = useState({ text: 'Connecting…', level: 'info' });
const [contextMenu, setContextMenu] = useState(null);
const [fileBrowserCb, setFileBrowserCb] = useState(null);
const nodeDefsRef = useRef({});
const nextIdRef = useRef(1);
const autoRunTimer = useRef(null);
const autoRunRef = useRef(null);
const reactFlow = useReactFlow();
// ── Load node definitions ───────────────────────────────────────────
useEffect(() => {
api.getNodes().then((defs) => {
nodeDefsRef.current = defs;
setStatus({ text: `Loaded ${Object.keys(defs).length} nodes.`, level: 'info' });
}).catch((err) => {
setStatus({ text: 'Failed to load nodes: ' + err.message, level: 'error' });
});
}, []);
// ── WebSocket ───────────────────────────────────────────────────────
const updateNodeData = useCallback((nodeId, patch) => {
setNodes((ns) => ns.map((n) =>
n.id !== nodeId ? n : { ...n, data: { ...n.data, ...patch } }
));
}, [setNodes]);
useEffect(() => {
api.setMessageHandler((msg) => {
console.log('[argonode] WS:', msg.type, msg.data?.node_id || msg.data?.node || '');
switch (msg.type) {
case 'execution_start':
setStatus({ text: 'Running workflow…', level: 'info' });
break;
case 'executing':
setStatus({ text: `Executing node ${msg.data.node}`, level: 'info' });
break;
case 'execution_complete':
setStatus({ text: 'Done.', level: 'info' });
break;
case 'execution_error':
setStatus({ text: 'Error: ' + msg.data.message, level: 'error' });
console.error('[argonode] execution error', msg.data);
break;
case 'preview':
updateNodeData(msg.data.node_id, { previewImage: msg.data.image });
break;
case 'table':
updateNodeData(msg.data.node_id, { tableRows: msg.data.rows });
break;
case 'mesh3d':
updateNodeData(msg.data.node_id, { meshData: msg.data.mesh });
break;
case 'overlay':
updateNodeData(msg.data.node_id, { overlay: msg.data.overlay });
break;
}
});
api.initWS();
return () => api.closeWS();
}, [updateNodeData]);
// ── Connection handling ─────────────────────────────────────────────
const isValidConnection = useCallback((connection) => {
const srcType = getHandleType(connection.sourceHandle);
const tgtType = getHandleType(connection.targetHandle);
return srcType === tgtType;
}, []);
const onConnect = useCallback((params) => {
const type = getHandleType(params.sourceHandle);
const color = TYPE_COLORS[type] || '#999';
setEdges((eds) => {
// Enforce single connection per input handle
const filtered = eds.filter(
(e) => !(e.target === params.target && e.targetHandle === params.targetHandle)
);
return addEdge(
{ ...params, style: { stroke: color, strokeWidth: 2 } },
filtered
);
});
scheduleAutoRun();
}, [setEdges]);
// ── Drop-on-blank: open filtered context menu ──────────────────────
const onConnectEnd = useCallback((event, connectionState) => {
// If the connection was completed (dropped on a valid handle), do nothing
if (connectionState.isValid) return;
const fromHandle = connectionState.fromHandle;
if (!fromHandle || !fromHandle.id) return;
const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event;
const handleType = getHandleType(fromHandle.id);
setContextMenu({
x: clientX,
y: clientY,
filterType: handleType,
filterDirection: fromHandle.type,
pendingNodeId: fromHandle.nodeId,
pendingHandleId: fromHandle.id,
pendingHandleType: fromHandle.type,
});
}, []);
// ── Widget change callback ──────────────────────────────────────────
const onWidgetChange = useCallback((nodeId, name, value) => {
setNodes((ns) => ns.map((n) => {
if (n.id !== nodeId) return n;
return {
...n,
data: {
...n.data,
widgetValues: { ...n.data.widgetValues, [name]: value },
},
};
}));
scheduleAutoRun();
}, [setNodes]); // scheduleAutoRun is stable (no deps)
// ── File browser ────────────────────────────────────────────────────
const openFileBrowser = useCallback((callback) => {
setFileBrowserCb(() => callback);
}, []);
// ── Node context value (stable) ─────────────────────────────────────
const contextValue = useMemo(() => ({
onWidgetChange,
openFileBrowser,
}), [onWidgetChange, openFileBrowser]);
// ── Add node from context menu ──────────────────────────────────────
const addNode = useCallback((className, def) => {
if (!contextMenu) return;
const position = reactFlow.screenToFlowPosition({
x: contextMenu.x,
y: contextMenu.y,
});
// Build default widget values
const widgetValues = {};
const required = def.input.required || {};
for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (DATA_TYPES.has(type)) continue;
if (Array.isArray(type)) {
widgetValues[name] = type[0]; // combo default = first option
} else {
widgetValues[name] = opts?.default ?? '';
}
}
const newNodeId = String(nextIdRef.current++);
const newNode = {
id: newNodeId,
type: 'custom',
position,
dragHandle: '.drag-handle',
data: {
label: def.display_name || className,
className,
definition: def,
widgetValues,
previewImage: null,
tableRows: null,
meshData: null,
overlay: null,
},
};
setNodes((ns) => [...ns, newNode]);
// Auto-connect if this was triggered by dropping a connection on blank space
if (contextMenu.pendingHandleId) {
const filterType = contextMenu.filterType;
if (contextMenu.pendingHandleType === 'source') {
// Dragged from an output → connect to the first matching input on the new node
const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) };
const inputName = Object.entries(allInputs).find(([, spec]) => {
const [type] = Array.isArray(spec) ? spec : [spec];
return type === filterType;
})?.[0];
if (inputName) {
const targetHandle = `input::${inputName}::${filterType}`;
const color = TYPE_COLORS[filterType] || '#999';
setEdges((eds) => addEdge({
source: contextMenu.pendingNodeId,
sourceHandle: contextMenu.pendingHandleId,
target: newNodeId,
targetHandle,
style: { stroke: color, strokeWidth: 2 },
}, eds));
}
} else {
// Dragged from an input → connect from the first matching output on the new node
const outputIdx = def.output.indexOf(filterType);
if (outputIdx !== -1) {
const sourceHandle = `output::${outputIdx}::${filterType}`;
const color = TYPE_COLORS[filterType] || '#999';
setEdges((eds) => addEdge({
source: newNodeId,
sourceHandle,
target: contextMenu.pendingNodeId,
targetHandle: contextMenu.pendingHandleId,
style: { stroke: color, strokeWidth: 2 },
}, eds));
}
}
}
setContextMenu(null);
scheduleAutoRun();
}, [contextMenu, reactFlow, setNodes, setEdges]);
// ── Toolbar actions ─────────────────────────────────────────────────
const runWorkflow = useCallback(async () => {
// Read current state via functional ref to avoid stale closure
const currentNodes = reactFlow.getNodes();
const currentEdges = reactFlow.getEdges();
const prompt = serializeGraph(currentNodes, currentEdges);
if (!prompt || Object.keys(prompt).length === 0) {
setStatus({ text: 'Graph is empty — add some nodes first.', level: 'error' });
return;
}
setStatus({ text: 'Running…', level: 'info' });
try {
await api.runPrompt(prompt);
} catch (err) {
setStatus({ text: 'Failed: ' + err.message, level: 'error' });
}
}, [reactFlow]);
// Debounced auto-run via ref to avoid dependency chains
autoRunRef.current = () => {
const currentNodes = reactFlow.getNodes();
const currentEdges = reactFlow.getEdges();
// Don't run if any node has unconnected required data inputs
for (const node of currentNodes) {
const def = node.data?.definition;
if (!def) continue;
const required = def.input.required || {};
for (const [name, spec] of Object.entries(required)) {
const [type] = Array.isArray(spec) ? spec : [spec];
if (!DATA_TYPES.has(type)) continue;
const hasEdge = currentEdges.some(
(e) => e.target === node.id && getInputName(e.targetHandle) === name
);
if (!hasEdge) return; // incomplete graph, skip auto-run
}
}
const prompt = serializeGraph(currentNodes, currentEdges);
if (!prompt || Object.keys(prompt).length === 0) return;
setStatus({ text: 'Running…', level: 'info' });
api.runPrompt(prompt).catch((err) => {
setStatus({ text: 'Failed: ' + err.message, level: 'error' });
});
};
const scheduleAutoRun = useCallback(() => {
clearTimeout(autoRunTimer.current);
autoRunTimer.current = setTimeout(() => autoRunRef.current?.(), 300);
}, []);
const clearGraph = useCallback(() => {
setNodes([]);
setEdges([]);
nextIdRef.current = 1;
setStatus({ text: 'Graph cleared.', level: 'info' });
}, [setNodes, setEdges]);
const saveWorkflow = useCallback(() => {
const currentNodes = reactFlow.getNodes().map((n) => ({
...n,
data: { ...n.data, previewImage: null, tableRows: null, meshData: null, overlay: null },
}));
const data = { version: 1, nodes: currentNodes, edges: reactFlow.getEdges() };
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' });
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
a.download = 'workflow.json';
a.click();
}, [reactFlow]);
const loadWorkflow = useCallback(() => {
const input = document.createElement('input');
input.type = 'file';
input.accept = '.json';
input.onchange = async (e) => {
const file = e.target.files[0];
if (!file) return;
const text = await file.text();
try {
const data = JSON.parse(text);
const loadedNodes = data.nodes || [];
const loadedEdges = data.edges || [];
// Re-populate definitions from current nodeDefs
const defs = nodeDefsRef.current;
const hydrated = loadedNodes.map((n) => ({
...n,
data: {
...n.data,
definition: defs[n.data.className] || n.data.definition,
previewImage: null,
tableRows: null,
meshData: null,
overlay: null,
},
}));
setNodes(hydrated);
setEdges(loadedEdges);
// Update ID counter to avoid collisions
const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
nextIdRef.current = maxId + 1;
setStatus({ text: 'Workflow loaded.', level: 'info' });
} catch {
setStatus({ text: 'Invalid workflow JSON.', level: 'error' });
}
};
input.click();
}, [setNodes, setEdges]);
// ── Keyboard shortcut ───────────────────────────────────────────────
useEffect(() => {
const handler = (e) => {
if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') {
e.preventDefault();
runWorkflow();
}
};
window.addEventListener('keydown', handler);
return () => window.removeEventListener('keydown', handler);
}, [runWorkflow]);
// ── Context menu ────────────────────────────────────────────────────
const onPaneContextMenu = useCallback((event) => {
event.preventDefault();
setContextMenu({ x: event.clientX, y: event.clientY });
}, []);
// ── Render ──────────────────────────────────────────────────────────
return (
<NodeContext.Provider value={contextValue}>
<div className="app-container">
{/* Toolbar */}
<div id="toolbar">
<span id="app-title">Argonode</span>
<div className="toolbar-group">
<button className="btn btn-primary" onClick={runWorkflow} title="Run workflow (Ctrl+Enter)">
Run
</button>
<button className="btn" onClick={clearGraph} title="Clear graph">
Clear
</button>
</div>
<div className="toolbar-group">
<button className="btn" onClick={saveWorkflow} title="Save workflow JSON">
Save
</button>
<button className="btn" onClick={loadWorkflow} title="Load workflow JSON">
Load
</button>
</div>
<div className={`status-bar ${status.level}`}>{status.text}</div>
</div>
{/* React Flow canvas */}
<div className="flow-container" onMouseDown={(e) => {
if (!e.target.closest('.context-menu')) setContextMenu(null);
}}>
<ReactFlow
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
isValidConnection={isValidConnection}
nodeTypes={NODE_TYPES}
onPaneContextMenu={onPaneContextMenu}
colorMode="dark"
fitView
deleteKeyCode={['Backspace', 'Delete']}
defaultEdgeOptions={{ type: 'default' }}
>
<Background />
<Controls />
<MiniMap
nodeColor={(n) => {
const cat = n.data?.definition?.category;
const colors = {
io: '#37474f', filters: '#1a237e', level: '#1b5e20',
analysis: '#4a148c', grains: '#bf360c', display: '#212121',
};
return colors[cat] || '#333';
}}
/>
</ReactFlow>
{contextMenu && (
<ContextMenu
x={contextMenu.x}
y={contextMenu.y}
nodeDefs={nodeDefsRef.current}
onAdd={addNode}
onClose={() => setContextMenu(null)}
filterType={contextMenu.filterType}
filterDirection={contextMenu.filterDirection}
/>
)}
</div>
{/* File browser modal */}
{fileBrowserCb && (
<FileBrowser
onSelect={(path) => { fileBrowserCb(path); setFileBrowserCb(null); }}
onClose={() => setFileBrowserCb(null)}
/>
)}
</div>
</NodeContext.Provider>
);
}
// ── App wrapper with ReactFlowProvider ────────────────────────────────
export default function App() {
return (
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
);
}

View File

@@ -0,0 +1,86 @@
import React, { useRef, useState, useCallback } from 'react';
/**
* Image preview with two endpoint markers for cross-section line control.
* Markers are draggable when unlocked (no COORD input connected),
* and fixed when locked (COORD input provides the position).
*
* Marker positions are driven by widget values (immediate React state),
* not by backend overlay coords, so they move instantly during drag.
*/
export default function CrossSectionOverlay({
image, x1, y1, x2, y2,
aLocked, bLocked,
nodeId, onWidgetChange,
}) {
const containerRef = useRef(null);
const [dragging, setDragging] = useState(null); // 'p1' or 'p2'
const getCoords = useCallback((e) => {
const rect = containerRef.current.getBoundingClientRect();
return {
fx: Math.max(0, Math.min(1, (e.clientX - rect.left) / rect.width)),
fy: Math.max(0, Math.min(1, (e.clientY - rect.top) / rect.height)),
};
}, []);
const onPointerDown = useCallback((point) => (e) => {
if (point === 'p1' && aLocked) return;
if (point === 'p2' && bLocked) return;
e.stopPropagation();
e.preventDefault();
e.target.setPointerCapture(e.pointerId);
setDragging(point);
}, [aLocked, bLocked]);
const onPointerMove = useCallback((e) => {
if (!dragging || !containerRef.current) return;
const { fx, fy } = getCoords(e);
const vx = parseFloat(fx.toFixed(3));
const vy = parseFloat(fy.toFixed(3));
if (dragging === 'p1') {
onWidgetChange(nodeId, 'x1', vx);
onWidgetChange(nodeId, 'y1', vy);
} else {
onWidgetChange(nodeId, 'x2', vx);
onWidgetChange(nodeId, 'y2', vy);
}
}, [dragging, nodeId, onWidgetChange, getCoords]);
const onPointerUp = useCallback(() => {
setDragging(null);
}, []);
return (
<div
ref={containerRef}
className="nodrag nowheel cs-overlay"
onPointerMove={onPointerMove}
onPointerUp={onPointerUp}
onLostPointerCapture={onPointerUp}
>
<img src={image} alt="field" draggable={false} className="cs-image" />
{/* Line connecting the two markers */}
<svg className="cs-svg">
<line
x1={`${x1 * 100}%`} y1={`${y1 * 100}%`}
x2={`${x2 * 100}%`} y2={`${y2 * 100}%`}
stroke="#ffd700" strokeWidth="2" strokeDasharray="6 3"
/>
</svg>
{/* Endpoint markers — locked markers get a different style */}
<div
className={`cs-marker ${aLocked ? 'cs-marker-locked' : ''}`}
style={{ left: `${x1 * 100}%`, top: `${y1 * 100}%` }}
onPointerDown={onPointerDown('p1')}
/>
<div
className={`cs-marker ${bLocked ? 'cs-marker-locked' : ''}`}
style={{ left: `${x2 * 100}%`, top: `${y2 * 100}%` }}
onPointerDown={onPointerDown('p2')}
/>
</div>
);
}

396
frontend/src/CustomNode.jsx Normal file
View File

@@ -0,0 +1,396 @@
import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react';
import { Handle, Position } from '@xyflow/react';
const SurfaceView = lazy(() => import('./SurfaceView'));
const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay'));
// ── Constants ─────────────────────────────────────────────────────────
const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']);
const TYPE_COLORS = {
DATA_FIELD: '#3a7abf',
IMAGE: '#4caf50',
LINE: '#ff9800',
TABLE: '#fdd835',
COORD: '#e91e63',
};
const CAT_COLORS = {
io: '#37474f',
filters: '#1a237e',
level: '#1b5e20',
analysis: '#4a148c',
grains: '#bf360c',
display: '#212121',
};
// ── Context (provided by App) ─────────────────────────────────────────
export const NodeContext = React.createContext(null);
// ── Draggable number input ────────────────────────────────────────────
function DraggableNumber({ value, step, min, max, precision, onChange }) {
const [editing, setEditing] = useState(false);
const [editText, setEditText] = useState('');
const dragState = useRef(null);
const elRef = useRef(null);
const display = precision != null ? Number(value).toFixed(precision) : String(value);
const clamp = useCallback((v) => {
if (min != null && v < min) v = min;
if (max != null && v > max) v = max;
return v;
}, [min, max]);
const onPointerDown = useCallback((e) => {
if (editing) return;
e.preventDefault();
dragState.current = { startX: e.clientX, startVal: Number(value) };
elRef.current?.setPointerCapture(e.pointerId);
}, [editing, value]);
const onPointerMove = useCallback((e) => {
if (!dragState.current) return;
const dx = e.clientX - dragState.current.startX;
const delta = dx * (step || 0.01);
const raw = dragState.current.startVal + delta;
const rounded = precision != null
? parseFloat(raw.toFixed(precision))
: Math.round(raw);
onChange(clamp(rounded));
}, [step, precision, clamp, onChange]);
const onPointerUp = useCallback((e) => {
if (!dragState.current) return;
const dx = Math.abs(e.clientX - dragState.current.startX);
dragState.current = null;
// If barely moved, enter text-edit mode
if (dx < 3) {
setEditText(display);
setEditing(true);
}
}, [display]);
const commitEdit = useCallback(() => {
setEditing(false);
const parsed = parseFloat(editText);
if (!isNaN(parsed)) onChange(clamp(precision != null ? parseFloat(parsed.toFixed(precision)) : Math.round(parsed)));
}, [editText, precision, clamp, onChange]);
if (editing) {
return (
<input
className="nodrag drag-number-edit"
type="text"
autoFocus
value={editText}
onChange={(e) => setEditText(e.target.value)}
onBlur={commitEdit}
onKeyDown={(e) => { if (e.key === 'Enter') commitEdit(); if (e.key === 'Escape') setEditing(false); }}
/>
);
}
return (
<div
ref={elRef}
className="nodrag drag-number"
onPointerDown={onPointerDown}
onPointerMove={onPointerMove}
onPointerUp={onPointerUp}
>
<span className="drag-number-val">{display}</span>
</div>
);
}
// ── Collapsible section ───────────────────────────────────────────────
function CollapsibleSection({ title, defaultOpen, children }) {
const [open, setOpen] = useState(defaultOpen);
return (
<div className="collapsible">
<button
className="nodrag collapsible-toggle"
onClick={() => setOpen((o) => !o)}
>
<span className="collapsible-arrow">{open ? '▾' : '▸'}</span>
{title}
</button>
{open && children}
</div>
);
}
// ── CustomNode component ──────────────────────────────────────────────
function CustomNode({ id, data }) {
const ctx = useContext(NodeContext);
const def = data.definition;
// Parse inputs into data handles and widgets
const required = def.input.required || {};
const optional = def.input.optional || {};
const dataInputs = [];
const widgets = [];
const hiddenWidgets = new Set();
for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (DATA_TYPES.has(type)) {
dataInputs.push({ name, type });
} else if (opts?.hidden) {
hiddenWidgets.add(name);
} else {
widgets.push({ name, type, opts: opts || {} });
}
}
for (const [name, spec] of Object.entries(optional)) {
const [type] = Array.isArray(spec) ? spec : [spec];
dataInputs.push({ name, type });
}
const outputs = def.output.map((type, i) => ({
name: def.output_name[i] || type,
type,
slot: i,
}));
const catColor = CAT_COLORS[def.category] || '#333';
const maxIORows = Math.max(dataInputs.length, outputs.length);
return (
<div className="custom-node">
{/* Title */}
<div className="node-title drag-handle" style={{ background: catColor }}>
{data.label}
</div>
<div className="node-body">
{/* I/O rows — pair inputs[i] with outputs[i] */}
{Array.from({ length: maxIORows }, (_, i) => {
const inp = dataInputs[i];
const out = outputs[i];
return (
<div className="io-row" key={`io-${i}`}>
<div className="io-left">
{inp && (
<>
<Handle
type="target"
position={Position.Left}
id={`input::${inp.name}::${inp.type}`}
className="typed-handle"
style={{ background: TYPE_COLORS[inp.type] || '#999' }}
/>
<span className="io-label">{inp.name}</span>
</>
)}
</div>
<div className="io-right">
{out && (
<>
<span className="io-label">{out.name}</span>
<Handle
type="source"
position={Position.Right}
id={`output::${out.slot}::${out.type}`}
className="typed-handle"
style={{ background: TYPE_COLORS[out.type] || '#999' }}
/>
</>
)}
</div>
</div>
);
})}
{/* Widget rows */}
{widgets.map((w) => (
<div className="widget-row" key={w.name}>
<WidgetControl
widget={w}
nodeId={id}
value={data.widgetValues[w.name]}
onChange={ctx.onWidgetChange}
openFileBrowser={ctx.openFileBrowser}
/>
</div>
))}
{/* Interactive 3D surface view */}
{data.meshData && (
<CollapsibleSection title="3D View" defaultOpen={true}>
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading 3D...</div>}>
<SurfaceView meshData={data.meshData} />
</Suspense>
</CollapsibleSection>
)}
{/* Collapsible preview image */}
{data.previewImage && (
<CollapsibleSection title="Preview" defaultOpen={true}>
<div className="node-preview">
<img src={data.previewImage} alt="preview" draggable={false} />
</div>
</CollapsibleSection>
)}
{/* Interactive cross-section overlay */}
{data.overlay && hiddenWidgets.has('x1') && (
<CollapsibleSection title="Cross Section" defaultOpen={true}>
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading...</div>}>
<CrossSectionOverlay
image={data.overlay.image}
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
y1={data.overlay.a_locked ? data.overlay.y1 : (data.widgetValues.y1 ?? data.overlay.y1)}
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
y2={data.overlay.b_locked ? data.overlay.y2 : (data.widgetValues.y2 ?? data.overlay.y2)}
aLocked={data.overlay.a_locked}
bLocked={data.overlay.b_locked}
nodeId={id}
onWidgetChange={ctx.onWidgetChange}
/>
</Suspense>
</CollapsibleSection>
)}
{/* Collapsible table data */}
{data.tableRows && data.tableRows.length > 0 && (
<CollapsibleSection title="Table" defaultOpen={true}>
<div className="node-table">
{data.tableRows.map((row, i) => {
let line;
if (row.quantity !== undefined) {
const val = typeof row.value === 'number' ? row.value.toExponential(3) : row.value;
line = `${row.quantity}: ${val} ${row.unit || ''}`;
} else {
line = Object.entries(row)
.slice(0, 3)
.map(([k, v]) => `${k}: ${typeof v === 'number' ? v.toExponential(2) : v}`)
.join(' ');
}
return <div key={i} className="table-line">{line}</div>;
})}
</div>
</CollapsibleSection>
)}
</div>
</div>
);
}
// ── Widget renderer ───────────────────────────────────────────────────
function WidgetControl({ widget, nodeId, value, onChange, openFileBrowser }) {
const { name, type, opts } = widget;
const val = value ?? opts?.default ?? '';
// Combo / enum — type itself is the array of options
if (Array.isArray(type)) {
return (
<>
<label>{name}</label>
<select
className="nodrag"
value={val || type[0]}
onChange={(e) => onChange(nodeId, name, e.target.value)}
>
{type.map((opt) => (
<option key={opt} value={opt}>{opt}</option>
))}
</select>
</>
);
}
if (type === 'FILE_PICKER') {
return (
<>
<label>{name}</label>
<div className="file-picker-row">
<input
className="nodrag"
type="text"
value={val}
onChange={(e) => onChange(nodeId, name, e.target.value)}
placeholder="Select file…"
/>
<button
className="nodrag browse-btn"
onClick={() => openFileBrowser((path) => onChange(nodeId, name, path))}
>
Browse
</button>
</div>
</>
);
}
if (type === 'FLOAT') {
return (
<>
<label>{name}</label>
<DraggableNumber
value={val || 0}
step={opts?.step ?? 0.01}
min={opts?.min}
max={opts?.max}
precision={4}
onChange={(v) => onChange(nodeId, name, v)}
/>
</>
);
}
if (type === 'INT') {
return (
<>
<label>{name}</label>
<DraggableNumber
value={val || 0}
step={opts?.step ?? 1}
min={opts?.min}
max={opts?.max}
precision={0}
onChange={(v) => onChange(nodeId, name, v)}
/>
</>
);
}
if (type === 'BOOLEAN') {
return (
<>
<label>{name}</label>
<input
className="nodrag"
type="checkbox"
checked={!!val}
onChange={(e) => onChange(nodeId, name, e.target.checked)}
/>
</>
);
}
// STRING and anything else
return (
<>
<label>{name}</label>
<input
className="nodrag"
type="text"
value={val}
onChange={(e) => onChange(nodeId, name, e.target.value)}
/>
</>
);
}
export default memo(CustomNode);

View File

@@ -0,0 +1,94 @@
import React, { useState, useEffect, useCallback } from 'react';
import * as api from './api';
/**
* Server-side file browser modal.
*
* Props:
* onSelect(absolutePath) — called when user picks a file
* onClose() — called when user dismisses the dialog
*/
export default function FileBrowser({ onSelect, onClose }) {
const [path, setPath] = useState('');
const [parent, setParent] = useState(null);
const [dirs, setDirs] = useState([]);
const [files, setFiles] = useState([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState(null);
const navigate = useCallback(async (dir) => {
setLoading(true);
setError(null);
try {
const data = await api.browse(dir);
setPath(data.path);
setParent(data.parent);
setDirs(data.dirs);
setFiles(data.files);
} catch (err) {
setError(err.message);
} finally {
setLoading(false);
}
}, []);
// Start at home directory on mount
useEffect(() => {
navigate(null);
}, [navigate]);
return (
<div className="fb-backdrop" onClick={(e) => { if (e.target === e.currentTarget) onClose(); }}>
<div className="fb-dialog">
{/* Header */}
<div className="fb-header">
<span className="fb-path">{path}</span>
<button className="fb-close" onClick={onClose}></button>
</div>
{/* File list */}
<div className="fb-list">
{loading && <div className="fb-loading">Loading</div>}
{error && <div className="fb-loading">Error: {error}</div>}
{!loading && !error && (
<>
{/* Parent directory */}
{parent && (
<div className="fb-entry fb-dir" onClick={() => navigate(parent)}>
..
</div>
)}
{/* Directories */}
{dirs.map((d) => (
<div
key={d}
className="fb-entry fb-dir"
onClick={() => navigate(path + '/' + d)}
>
📁 {d}
</div>
))}
{/* Files */}
{files.map((f) => (
<div
key={f}
className="fb-entry fb-file"
onClick={() => { onSelect(path + '/' + f); onClose(); }}
>
{f}
</div>
))}
{dirs.length === 0 && files.length === 0 && (
<div className="fb-loading">Empty directory</div>
)}
</>
)}
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,183 @@
import React, { useRef, useEffect, useCallback } from 'react';
import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
/**
* Interactive 3D surface viewer using Three.js.
* Props:
* meshData: { width, height, z_data (b64 float32), colors (b64 uint8 RGB),
* z_min, z_max, z_scale, x_range, y_range }
*/
export default function SurfaceView({ meshData }) {
const containerRef = useRef(null);
const threeRef = useRef(null); // { renderer, scene, camera, controls, mesh }
// Decode base64 to typed arrays
const decode = useCallback((b64, ArrayType) => {
const bin = atob(b64);
const bytes = new Uint8Array(bin.length);
for (let i = 0; i < bin.length; i++) bytes[i] = bin.charCodeAt(i);
return new ArrayType(bytes.buffer);
}, []);
// Initialize Three.js scene once
useEffect(() => {
const container = containerRef.current;
if (!container || threeRef.current) return;
const width = container.clientWidth;
const height = width; // 1:1 aspect
const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: false });
renderer.setSize(width, height);
renderer.setPixelRatio(window.devicePixelRatio);
renderer.setClearColor(0x0f172a);
container.appendChild(renderer.domElement);
const scene = new THREE.Scene();
const camera = new THREE.PerspectiveCamera(45, 1, 0.01, 1000);
camera.position.set(1.2, 0.8, 1.2);
const controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.1;
controls.minDistance = 0.3;
controls.maxDistance = 10;
// Lighting
const ambient = new THREE.AmbientLight(0xffffff, 0.4);
scene.add(ambient);
const dir = new THREE.DirectionalLight(0xffffff, 0.8);
dir.position.set(1, 2, 1.5);
scene.add(dir);
const dir2 = new THREE.DirectionalLight(0xffffff, 0.3);
dir2.position.set(-1, 0.5, -1);
scene.add(dir2);
// Animation loop
let animId;
const animate = () => {
animId = requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
};
animate();
threeRef.current = { renderer, scene, camera, controls, mesh: null, animId };
// Resize observer to maintain 1:1 aspect when node width changes
const ro = new ResizeObserver((entries) => {
const entry = entries[0];
if (!entry || !threeRef.current) return;
const w = entry.contentRect.width;
if (w < 1) return;
const { renderer: r, camera: c } = threeRef.current;
r.setSize(w, w);
c.aspect = 1;
c.updateProjectionMatrix();
});
ro.observe(container);
return () => {
ro.disconnect();
cancelAnimationFrame(animId);
controls.dispose();
renderer.dispose();
if (container.contains(renderer.domElement)) {
container.removeChild(renderer.domElement);
}
threeRef.current = null;
};
}, []);
// Update mesh when data changes
useEffect(() => {
if (!threeRef.current || !meshData) return;
const { scene, camera, controls } = threeRef.current;
const { width: nx, height: ny, z_data, colors, z_min, z_max, z_scale, x_range, y_range } = meshData;
// Decode arrays
const zArr = decode(z_data, Float32Array);
const colArr = decode(colors, Uint8Array);
// Remove old mesh
if (threeRef.current.mesh) {
scene.remove(threeRef.current.mesh);
threeRef.current.mesh.geometry.dispose();
threeRef.current.mesh.material.dispose();
}
// Build geometry
const geom = new THREE.BufferGeometry();
const positions = new Float32Array(nx * ny * 3);
const colorAttr = new Float32Array(nx * ny * 3);
// Normalize coordinates to roughly [-0.5, 0.5] for good camera framing
const zRange = z_max - z_min || 1;
for (let iy = 0; iy < ny; iy++) {
for (let ix = 0; ix < nx; ix++) {
const idx = iy * nx + ix;
const px = ix / (nx - 1) - 0.5; // [-0.5, 0.5]
const py = iy / (ny - 1) - 0.5;
const pz = ((zArr[idx] - z_min) / zRange - 0.5) * z_scale;
positions[idx * 3] = px;
positions[idx * 3 + 1] = pz; // height on Y axis
positions[idx * 3 + 2] = py;
colorAttr[idx * 3] = colArr[idx * 3] / 255;
colorAttr[idx * 3 + 1] = colArr[idx * 3 + 1] / 255;
colorAttr[idx * 3 + 2] = colArr[idx * 3 + 2] / 255;
}
}
geom.setAttribute('position', new THREE.BufferAttribute(positions, 3));
geom.setAttribute('color', new THREE.BufferAttribute(colorAttr, 3));
// Build index (triangles from grid)
const indices = [];
for (let iy = 0; iy < ny - 1; iy++) {
for (let ix = 0; ix < nx - 1; ix++) {
const a = iy * nx + ix;
const b = a + 1;
const c = a + nx;
const d = c + 1;
indices.push(a, c, b);
indices.push(b, c, d);
}
}
geom.setIndex(indices);
geom.computeVertexNormals();
const mat = new THREE.MeshPhongMaterial({
vertexColors: true,
side: THREE.DoubleSide,
shininess: 30,
flatShading: false,
});
const mesh = new THREE.Mesh(geom, mat);
scene.add(mesh);
threeRef.current.mesh = mesh;
// Reset camera target to center of mesh
controls.target.set(0, 0, 0);
controls.update();
}, [meshData, decode]);
// Prevent scroll events from propagating to React Flow
const onWheel = useCallback((e) => {
e.stopPropagation();
}, []);
return (
<div
ref={containerRef}
className="nodrag nowheel surface-view-container"
onWheelCapture={onWheel}
/>
);
}

93
frontend/src/api.js Normal file
View File

@@ -0,0 +1,93 @@
/**
* api.js — REST + WebSocket client for argonode backend.
*
* Uses relative URLs so the Vite dev proxy (port 5173 → 8188)
* and production same-origin serving both work transparently.
*/
// ── REST helpers ──────────────────────────────────────────────────────
export async function getNodes() {
const r = await fetch('/nodes');
if (!r.ok) throw new Error(`GET /nodes failed: ${r.status}`);
return r.json();
}
export async function getFiles() {
const r = await fetch('/files');
if (!r.ok) return [];
return r.json();
}
export async function browse(dir) {
const url = dir ? `/browse?dir=${encodeURIComponent(dir)}` : '/browse';
const r = await fetch(url);
if (!r.ok) throw new Error(`Browse failed: ${r.status}`);
return r.json();
}
export async function uploadFile(file) {
const fd = new FormData();
fd.append('file', file);
const r = await fetch('/upload', { method: 'POST', body: fd });
if (!r.ok) throw new Error(`Upload failed: ${r.status}`);
return r.json();
}
export async function runPrompt(prompt) {
const r = await fetch('/prompt', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ prompt }),
});
if (!r.ok) {
const text = await r.text();
throw new Error(`POST /prompt failed (${r.status}): ${text}`);
}
return r.json();
}
// ── WebSocket ─────────────────────────────────────────────────────────
let _ws = null;
let _handler = null;
let _reconnectTimer = null;
export function setMessageHandler(fn) {
_handler = fn;
}
export function initWS() {
if (_ws && _ws.readyState < 2) return; // already open or connecting
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
_ws = new WebSocket(`${protocol}//${window.location.host}/ws`);
_ws.onopen = () => {
console.log('[argonode] WebSocket connected');
};
_ws.onclose = () => {
console.log('[argonode] WebSocket closed, reconnecting in 3s…');
clearTimeout(_reconnectTimer);
_reconnectTimer = setTimeout(() => initWS(), 3000);
};
_ws.onerror = (e) => {
console.error('[argonode] WebSocket error', e);
};
_ws.onmessage = (e) => {
try {
const msg = JSON.parse(e.data);
if (_handler) _handler(msg);
} catch {
// ignore malformed messages
}
};
}
export function closeWS() {
clearTimeout(_reconnectTimer);
if (_ws) _ws.close();
}

6
frontend/src/main.jsx Normal file
View File

@@ -0,0 +1,6 @@
import React from 'react';
import { createRoot } from 'react-dom/client';
import App from './App';
import './styles.css';
createRoot(document.getElementById('root')).render(<App />);

509
frontend/src/styles.css Normal file
View File

@@ -0,0 +1,509 @@
/* ── Reset & base ──────────────────────────────────────────────────── */
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
html, body, #root {
width: 100%;
height: 100%;
background: #1a1a2e;
color: #e0e0e0;
font-family: "Inter", "Segoe UI", system-ui, sans-serif;
font-size: 13px;
overflow: hidden;
}
.app-container {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
}
/* ── Toolbar ───────────────────────────────────────────────────────── */
#toolbar {
height: 44px;
background: #16213e;
border-bottom: 1px solid #0f3460;
display: flex;
align-items: center;
padding: 0 12px;
gap: 10px;
z-index: 100;
user-select: none;
flex-shrink: 0;
}
#app-title {
font-size: 15px;
font-weight: 700;
letter-spacing: 0.5px;
color: #e94560;
margin-right: 8px;
flex-shrink: 0;
}
.toolbar-group {
display: flex;
gap: 6px;
flex-shrink: 0;
}
/* ── Buttons ───────────────────────────────────────────────────────── */
.btn {
padding: 5px 12px;
border: 1px solid #0f3460;
border-radius: 5px;
background: #0f3460;
color: #e0e0e0;
font-size: 12px;
cursor: pointer;
transition: background 0.15s, border-color 0.15s;
white-space: nowrap;
}
.btn:hover {
background: #1a4a8a;
border-color: #3a7abf;
}
.btn:active {
background: #0a2040;
}
.btn-primary {
background: #e94560;
border-color: #e94560;
font-weight: 600;
}
.btn-primary:hover {
background: #ff6b81;
border-color: #ff6b81;
}
/* ── Status bar ────────────────────────────────────────────────────── */
.status-bar {
margin-left: auto;
padding: 4px 10px;
border-radius: 4px;
font-size: 11px;
max-width: 400px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
flex-shrink: 1;
}
.status-bar.info { color: #90caf9; }
.status-bar.error { color: #ef9a9a; background: rgba(183,28,28,0.2); }
/* ── React Flow container ──────────────────────────────────────────── */
.flow-container {
flex: 1;
position: relative;
}
/* ── React Flow dark overrides ─────────────────────────────────────── */
.react-flow {
background: #0d1117 !important;
}
/* ── Custom node ───────────────────────────────────────────────────── */
.custom-node {
background: #1e293b;
border: 1px solid #334155;
border-radius: 6px;
font-size: 11px;
color: #e0e0e0;
width: 200px;
min-width: 160px;
resize: horizontal;
overflow: hidden;
}
/* Let React Flow node wrapper fit to the custom-node's size */
.react-flow__node-custom {
width: auto !important;
height: auto !important;
}
/* Title bar is the drag handle for moving the node */
.drag-handle {
cursor: grab;
}
.drag-handle:active {
cursor: grabbing;
}
.custom-node.selected {
border-color: #90caf9;
}
.node-title {
padding: 5px 10px;
font-weight: 600;
font-size: 12px;
color: white;
border-radius: 5px 5px 0 0;
border-bottom: 1px solid rgba(0, 0, 0, 0.3);
}
.node-body {
padding: 4px 0;
}
/* ── I/O rows ──────────────────────────────────────────────────────── */
.io-row {
display: flex;
justify-content: space-between;
align-items: center;
padding: 3px 12px;
min-height: 22px;
position: relative;
}
.io-left, .io-right {
display: flex;
align-items: center;
gap: 4px;
}
.io-label {
font-size: 10px;
color: #94a3b8;
}
/* ── Handles ───────────────────────────────────────────────────────── */
.typed-handle {
width: 10px !important;
height: 10px !important;
border: 2px solid #1e293b !important;
border-radius: 50% !important;
}
/* ── Widget rows ───────────────────────────────────────────────────── */
.widget-row {
padding: 3px 10px;
display: flex;
align-items: center;
gap: 6px;
}
.widget-row label {
font-size: 10px;
color: #64748b;
min-width: 40px;
flex-shrink: 0;
}
.widget-row input[type="text"],
.widget-row input[type="number"],
.widget-row select {
background: #0f172a;
color: #e0e0e0;
border: 1px solid #334155;
border-radius: 3px;
padding: 2px 5px;
font-size: 11px;
flex: 1;
min-width: 0;
}
.widget-row input[type="checkbox"] {
accent-color: #3a7abf;
}
.widget-row input:focus,
.widget-row select:focus {
outline: none;
border-color: #3a7abf;
}
.file-picker-row {
display: flex;
gap: 4px;
flex: 1;
min-width: 0;
}
.file-picker-row input {
flex: 1;
min-width: 0;
}
/* ── Draggable number ──────────────────────────────────────────────── */
.drag-number {
flex: 1;
min-width: 0;
background: #0f172a;
border: 1px solid #334155;
border-radius: 3px;
padding: 2px 6px;
cursor: ew-resize;
user-select: none;
text-align: center;
font-size: 11px;
color: #e0e0e0;
touch-action: none;
}
.drag-number:hover {
border-color: #3a7abf;
}
.drag-number-val {
pointer-events: none;
}
.drag-number-edit {
flex: 1;
min-width: 0;
background: #0f172a;
border: 1px solid #3a7abf;
border-radius: 3px;
padding: 2px 5px;
font-size: 11px;
color: #e0e0e0;
text-align: center;
outline: none;
}
.browse-btn {
background: #0f3460;
color: #e0e0e0;
border: 1px solid #334155;
border-radius: 3px;
padding: 2px 6px;
font-size: 10px;
cursor: pointer;
white-space: nowrap;
}
.browse-btn:hover {
background: #1a4a8a;
}
/* ── Collapsible section ───────────────────────────────────────────── */
.collapsible {
border-top: 1px solid #334155;
margin-top: 4px;
}
.collapsible-toggle {
display: flex;
align-items: center;
gap: 4px;
width: 100%;
background: none;
border: none;
color: #64748b;
font-size: 10px;
padding: 3px 10px;
cursor: pointer;
text-align: left;
}
.collapsible-toggle:hover {
color: #94a3b8;
}
.collapsible-arrow {
font-size: 9px;
}
/* ── Node preview ──────────────────────────────────────────────────── */
.node-preview {
overflow: hidden;
}
.node-preview img {
width: 100%;
max-width: 100%;
display: block;
}
/* ── Cross-section overlay ────────────────────────────────────────── */
.cs-overlay {
position: relative;
user-select: none;
touch-action: none;
overflow: hidden;
}
.cs-image {
width: 100%;
display: block;
}
.cs-svg {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
pointer-events: none;
}
.cs-marker {
position: absolute;
width: 14px;
height: 14px;
border-radius: 50%;
background: #ffd700;
border: 2px solid #fff;
transform: translate(-50%, -50%);
cursor: grab;
box-shadow: 0 0 4px rgba(0,0,0,0.6);
z-index: 1;
}
.cs-marker:active:not(.cs-marker-locked) {
cursor: grabbing;
background: #ffeb3b;
transform: translate(-50%, -50%) scale(1.2);
}
.cs-marker-locked {
background: #e91e63;
border-color: #e91e63;
cursor: default;
opacity: 0.9;
}
/* ── 3D surface view ──────────────────────────────────────────────── */
.surface-view-container {
width: 100%;
aspect-ratio: 1 / 1;
cursor: grab;
overflow: hidden;
}
.surface-view-container:active {
cursor: grabbing;
}
.surface-view-container canvas {
display: block;
}
/* ── Node table ────────────────────────────────────────────────────── */
.node-table {
padding: 4px 10px;
font-family: "SF Mono", "Fira Code", monospace;
font-size: 10px;
color: #cbd5e1;
}
.table-line {
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
line-height: 1.5;
}
/* ── Node resize handles ───────────────────────────────────────────── */
.node-resize-line {
border-color: #90caf9 !important;
}
.node-resize-handle {
background: #90caf9 !important;
width: 8px !important;
height: 8px !important;
}
/* ── Context menu ──────────────────────────────────────────────────── */
.context-menu {
position: fixed;
z-index: 1000;
background: #16213e;
border: 1px solid #0f3460;
border-radius: 6px;
min-width: 180px;
max-height: 60vh;
overflow-y: auto;
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.5);
padding: 4px 0;
}
.context-category {
padding: 6px 12px 3px;
font-size: 10px;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.5px;
color: #64748b;
border-top: 1px solid #0f3460;
}
.context-category:first-child {
border-top: none;
}
.context-item {
padding: 5px 20px;
font-size: 12px;
cursor: pointer;
color: #e0e0e0;
}
.context-item:hover {
background: #0f3460;
}
/* ── File browser dialog ──────────────────────────────────────────── */
.fb-backdrop {
position: fixed;
inset: 0;
background: rgba(0, 0, 0, 0.6);
z-index: 2000;
display: flex;
align-items: center;
justify-content: center;
}
.fb-dialog {
background: #16213e;
border: 1px solid #0f3460;
border-radius: 8px;
width: 520px;
max-height: 70vh;
display: flex;
flex-direction: column;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5);
}
.fb-header {
display: flex;
align-items: center;
justify-content: space-between;
padding: 10px 14px;
border-bottom: 1px solid #0f3460;
}
.fb-path {
font-size: 12px;
color: #90caf9;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
flex: 1;
margin-right: 10px;
}
.fb-close {
background: none;
border: none;
color: #e0e0e0;
font-size: 16px;
cursor: pointer;
padding: 2px 6px;
}
.fb-close:hover { color: #e94560; }
.fb-list {
overflow-y: auto;
padding: 6px 0;
flex: 1;
}
.fb-entry {
padding: 6px 14px;
cursor: pointer;
font-size: 13px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.fb-entry:hover { background: #0f3460; }
.fb-dir { color: #90caf9; }
.fb-file { color: #e0e0e0; }
.fb-loading {
padding: 16px;
text-align: center;
color: #607d8b;
}
/* ── Scrollbar styling ─────────────────────────────────────────────── */
::-webkit-scrollbar { width: 6px; height: 6px; }
::-webkit-scrollbar-track { background: #1a1a2e; }
::-webkit-scrollbar-thumb { background: #0f3460; border-radius: 3px; }
::-webkit-scrollbar-thumb:hover { background: #3a7abf; }
/* ── React Flow MiniMap ────────────────────────────────────────────── */
.react-flow__minimap {
background: #16213e !important;
border: 1px solid #0f3460 !important;
border-radius: 4px !important;
}

23
frontend/vite.config.js Normal file
View File

@@ -0,0 +1,23 @@
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
export default defineConfig({
plugins: [react()],
server: {
port: 5173,
proxy: {
'/nodes': 'http://127.0.0.1:8188',
'/files': 'http://127.0.0.1:8188',
'/browse': 'http://127.0.0.1:8188',
'/upload': 'http://127.0.0.1:8188',
'/prompt': 'http://127.0.0.1:8188',
'/ws': {
target: 'http://127.0.0.1:8188',
ws: true,
},
},
},
build: {
outDir: 'dist',
},
});

0
tests/__init__.py Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 981 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

258
tests/test_fft.py Normal file
View File

@@ -0,0 +1,258 @@
"""
Test the FFT2D node against known inputs and Gwyddion-equivalent results.
Run from project root:
python -m tests.test_fft
"""
import sys
import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField
from backend.nodes.analysis import FFT2D
def make_field(data, xreal=1e-6, yreal=1e-6):
"""Create a DataField from a 2D array."""
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
def test_dc_removal():
"""A constant image should produce near-zero FFT after mean subtraction."""
print("=== Test: DC removal ===")
data = np.ones((64, 64)) * 42.0
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
peak = result.data.max()
print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}")
assert peak < 1e-10, f"Expected ~0, got {peak}"
print(" PASS\n")
def test_single_frequency():
"""A pure sine wave should produce two peaks at the known frequency."""
print("=== Test: Single frequency detection ===")
N = 128
xreal = 1e-6 # 1 micron
freq_cycles = 10 # 10 cycles across the image
x = np.linspace(0, 1, N, endpoint=False)
data = np.sin(2 * np.pi * freq_cycles * x)[np.newaxis, :] * np.ones((N, 1))
field = make_field(data, xreal=xreal, yreal=xreal)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
# The peak should be at column offset = freq_cycles from center
mag = result.data
cy, cx = N // 2, N // 2 # center (DC)
# Find the peak location (excluding DC which should be ~0 after mean sub)
mag_copy = mag.copy()
mag_copy[cy, cx] = 0
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
peak_col_offset = abs(peak_idx[1] - cx)
print(f" Image: {N}x{N}, {freq_cycles} horizontal cycles")
print(f" Expected peak at column offset {freq_cycles} from center")
print(f" Found peak at {peak_idx} (offset {peak_col_offset})")
print(f" DC value: {mag[cy, cx]:.2e}")
print(f" Peak value: {mag[peak_idx]:.2e}")
assert peak_col_offset == freq_cycles, f"Expected offset {freq_cycles}, got {peak_col_offset}"
assert peak_idx[0] == cy, f"Expected peak on center row, got row {peak_idx[0]}"
print(" PASS\n")
def test_2d_frequency():
"""A 2D sine should produce peaks at the correct (kx, ky) position."""
print("=== Test: 2D frequency detection ===")
N = 128
fx, fy = 8, 5 # cycles in x and y
y, x = np.mgrid[0:N, 0:N] / N
data = np.sin(2 * np.pi * (fx * x + fy * y))
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="none", level="mean", output="magnitude")
mag = result.data
cy, cx = N // 2, N // 2
mag_copy = mag.copy()
mag_copy[cy, cx] = 0
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
dx = abs(peak_idx[1] - cx)
dy = abs(peak_idx[0] - cy)
print(f" Input: sin(2π({fx}x + {fy}y))")
print(f" Expected peak offset: ({fy}, {fx}) from center")
print(f" Found peak at {peak_idx} (offset dy={dy}, dx={dx})")
assert dx == fx and dy == fy, f"Expected ({fy},{fx}), got ({dy},{dx})"
print(" PASS\n")
def test_psdf_normalization():
"""
PSDF of white noise should integrate to the variance.
Parseval's theorem: sum of PSDF * dk_x * dk_y ≈ variance of the signal.
"""
print("=== Test: PSDF normalization (Parseval) ===")
N = 256
xreal = 1e-6
rng = np.random.default_rng(42)
data = rng.standard_normal((N, N))
variance = data.var()
field = make_field(data, xreal=xreal, yreal=xreal)
node = FFT2D()
result, = node.process(field, windowing="none", level="none", output="psdf")
psdf = result.data
# Integrate: sum of PSDF * dk_x * dk_y
# Our output field has xreal = 2π*N/xreal (angular freq range)
dk_x = result.xreal / N
dk_y = result.yreal / N
integral = psdf.sum() * dk_x * dk_y
ratio = integral / variance
print(f" Signal variance: {variance:.6f}")
print(f" PSDF integral: {integral:.6f}")
print(f" Ratio (should be ~1.0): {ratio:.4f}")
# Allow 20% tolerance for finite-size effects
assert 0.8 < ratio < 1.2, f"Parseval's theorem violated: ratio = {ratio}"
print(" PASS\n")
def test_windowing_reduces_leakage():
"""Windowing should reduce spectral leakage from a non-integer frequency."""
print("=== Test: Windowing reduces leakage ===")
N = 128
freq = 10.5 # non-integer → spectral leakage without windowing
x = np.linspace(0, 1, N, endpoint=False)
data = np.sin(2 * np.pi * freq * x)[np.newaxis, :] * np.ones((N, 1))
field = make_field(data)
node = FFT2D()
# Without windowing
r_none, = node.process(field, windowing="none", level="mean", output="magnitude")
mag_none = r_none.data[N // 2, :] # center row
# With Hann windowing
r_hann, = node.process(field, windowing="hann", level="mean", output="magnitude")
mag_hann = r_hann.data[N // 2, :]
# Measure leakage: ratio of energy far from peak vs total
peak_col = np.argmax(mag_none)
far_mask = np.ones(N, dtype=bool)
far_mask[max(0, peak_col - 3):peak_col + 4] = False
# Also mask the symmetric peak
sym_col = N - peak_col
far_mask[max(0, sym_col - 3):sym_col + 4] = False
leakage_none = mag_none[far_mask].sum() / mag_none.sum()
leakage_hann = mag_hann[far_mask].sum() / mag_hann.sum()
print(f" Non-integer frequency: {freq}")
print(f" Leakage without windowing: {leakage_none:.4f}")
print(f" Leakage with Hann window: {leakage_hann:.4f}")
assert leakage_hann < leakage_none, "Hann window should reduce leakage"
print(" PASS\n")
def test_plane_subtraction():
"""Plane subtraction should remove linear gradients."""
print("=== Test: Plane subtraction ===")
N = 64
y, x = np.mgrid[0:N, 0:N] / N
# Tilted plane + sine wave
data = 100 * x + 50 * y + np.sin(2 * np.pi * 8 * x)
field = make_field(data)
node = FFT2D()
# Without leveling — huge DC and low-freq energy
r_none, = node.process(field, windowing="none", level="none", output="magnitude")
dc_none = r_none.data[N // 2, N // 2]
# With mean subtraction — DC removed but gradient leaks
r_mean, = node.process(field, windowing="none", level="mean", output="magnitude")
dc_mean = r_mean.data[N // 2, N // 2]
# With plane subtraction — gradient removed
r_plane, = node.process(field, windowing="none", level="plane", output="magnitude")
dc_plane = r_plane.data[N // 2, N // 2]
# With plane subtraction, check the low-freq energy near DC is reduced
# (plane subtraction removes gradients that leak into low frequencies)
r = 3 # radius around DC to check
cy, cx = N // 2, N // 2
lowfreq_none = r_none.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
lowfreq_plane = r_plane.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
print(f" DC magnitude (no leveling): {dc_none:.2e}")
print(f" DC magnitude (mean subtract): {dc_mean:.2e}")
print(f" DC magnitude (plane subtract): {dc_plane:.2e}")
print(f" Low-freq energy (no level): {lowfreq_none:.2e}")
print(f" Low-freq energy (plane sub): {lowfreq_plane:.2e}")
assert dc_mean < dc_none, "Mean subtraction should reduce DC"
assert lowfreq_plane < lowfreq_none * 0.01, "Plane subtraction should reduce low-freq energy"
print(" PASS\n")
def test_non_square():
"""FFT should work on non-square, non-power-of-2 images."""
print("=== Test: Non-square image ===")
data = np.random.default_rng(99).standard_normal((100, 150))
field = make_field(data, xreal=1.5e-6, yreal=1.0e-6)
node = FFT2D()
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}"
assert np.all(np.isfinite(result.data)), "Non-finite values in output"
print(f" Shape: {result.data.shape}")
print(f" Output range: [{result.data.min():.4f}, {result.data.max():.4f}]")
print(" PASS\n")
def test_log_magnitude_visual_range():
"""Log magnitude should produce a reasonable dynamic range for display."""
print("=== Test: Log magnitude visual range ===")
N = 128
x = np.linspace(0, 1, N, endpoint=False)
# Multi-frequency test image
y, x = np.mgrid[0:N, 0:N] / N
data = (np.sin(2 * np.pi * 5 * x) +
0.5 * np.sin(2 * np.pi * 15 * x + 2 * np.pi * 10 * y) +
0.1 * np.random.default_rng(7).standard_normal((N, N)))
field = make_field(data)
node = FFT2D()
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
vmin, vmax = result.data.min(), result.data.max()
dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30)
print(f" Log magnitude range: [{vmin:.4f}, {vmax:.4f}]")
print(f" Dynamic range: {dynamic_range:.2f}")
assert vmax > vmin, "Log magnitude should have nonzero range"
assert np.all(np.isfinite(result.data)), "Non-finite values in log magnitude"
print(" PASS\n")
if __name__ == "__main__":
test_dc_removal()
test_single_frequency()
test_2d_frequency()
test_psdf_normalization()
test_windowing_reduces_leakage()
test_plane_subtraction()
test_non_square()
test_log_magnitude_visual_range()
print("All tests passed!")

97
tests/test_fft_visual.py Normal file
View File

@@ -0,0 +1,97 @@
"""
Generate test images and their FFT outputs for visual comparison with Gwyddion.
Saves PNG files to tests/output/.
Run: .venv/bin/python -m tests.test_fft_visual
"""
import sys
import os
import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField, datafield_to_uint8, encode_preview
from backend.nodes.analysis import FFT2D
OUT_DIR = os.path.join(os.path.dirname(__file__), "output")
os.makedirs(OUT_DIR, exist_ok=True)
def save_field(field, name, colormap="viridis"):
"""Save a DataField as a PNG for visual inspection."""
from PIL import Image
arr = datafield_to_uint8(field, colormap)
img = Image.fromarray(arr)
path = os.path.join(OUT_DIR, f"{name}.png")
img.save(path)
print(f" Saved {path} (range: [{field.data.min():.4g}, {field.data.max():.4g}])")
def make_field(data, xreal=1e-6, yreal=1e-6):
return DataField(data=data, xreal=xreal, yreal=yreal)
def main():
node = FFT2D()
N = 256
# --- Test 1: Multi-frequency sine waves ---
print("Test 1: Multi-frequency sine waves")
y, x = np.mgrid[0:N, 0:N] / N
data = (np.sin(2 * np.pi * 10 * x)
+ 0.7 * np.sin(2 * np.pi * 25 * y)
+ 0.3 * np.sin(2 * np.pi * (15 * x + 8 * y)))
field = make_field(data)
save_field(field, "01_sines_input")
for output_mode in ["log_magnitude", "magnitude", "psdf"]:
result, = node.process(field, windowing="hann", level="mean", output=output_mode)
save_field(result, f"01_sines_{output_mode}")
# --- Test 2: Real-world-like surface with noise + tilt ---
print("\nTest 2: Tilted surface with features")
rng = np.random.default_rng(42)
data = (50 * x + 30 * y # tilt
+ np.sin(2 * np.pi * 20 * x) # periodic feature
+ 0.5 * rng.standard_normal((N, N))) # noise
field = make_field(data)
save_field(field, "02_surface_input")
for level_mode in ["none", "mean", "plane"]:
result, = node.process(field, windowing="hann", level=level_mode, output="log_magnitude")
save_field(result, f"02_surface_fft_level_{level_mode}")
# --- Test 3: Checkerboard pattern ---
print("\nTest 3: Checkerboard")
freq = 16
data = np.sign(np.sin(2 * np.pi * freq * x) * np.sin(2 * np.pi * freq * y))
field = make_field(data)
save_field(field, "03_checker_input")
result, = node.process(field, windowing="none", level="mean", output="log_magnitude")
save_field(result, "03_checker_fft")
# --- Test 4: Concentric rings (radial frequency) ---
print("\nTest 4: Concentric rings")
r = np.sqrt((x - 0.5)**2 + (y - 0.5)**2)
data = np.sin(2 * np.pi * 30 * r)
field = make_field(data)
save_field(field, "04_rings_input")
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
save_field(result, "04_rings_fft")
# --- Test 5: Compare windowing effects ---
print("\nTest 5: Windowing comparison")
data = np.sin(2 * np.pi * 10.5 * x) + 0.5 * np.sin(2 * np.pi * 30.3 * y)
field = make_field(data)
save_field(field, "05_window_input")
for win in ["none", "hann", "hamming", "blackman"]:
result, = node.process(field, windowing=win, level="mean", output="log_magnitude")
save_field(result, f"05_window_{win}")
print(f"\nAll outputs saved to {OUT_DIR}/")
if __name__ == "__main__":
main()

488
tests/test_nodes.py Normal file
View File

@@ -0,0 +1,488 @@
"""
Tests for all argonode backend nodes (excluding FFT2D which has its own test file).
Run from project root:
.venv/bin/python -m tests.test_nodes
"""
import sys
import os
import tempfile
import numpy as np
sys.path.insert(0, ".")
from backend.data_types import DataField
def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
"""Create a DataField, optionally from given data or a random field."""
if data is None:
data = np.random.default_rng(42).standard_normal(shape)
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
# =========================================================================
# Filters
# =========================================================================
def test_gaussian_filter():
print("=== Test: GaussianFilter ===")
from backend.nodes.filters import GaussianFilter
node = GaussianFilter()
field = make_field()
result, = node.process(field, sigma=2.0)
assert result.data.shape == field.data.shape
assert result.xreal == field.xreal
assert result.si_unit_z == field.si_unit_z
# Gaussian blur should reduce variance
assert result.data.std() < field.data.std()
# With very small sigma, output should be nearly unchanged
result_tiny, = node.process(field, sigma=0.01)
assert np.allclose(result_tiny.data, field.data, atol=1e-6)
print(" PASS\n")
def test_median_filter():
print("=== Test: MedianFilter ===")
from backend.nodes.filters import MedianFilter
node = MedianFilter()
# Median filter should remove salt-and-pepper noise
data = np.zeros((64, 64))
rng = np.random.default_rng(7)
noise_idx = rng.choice(64 * 64, size=100, replace=False)
data.ravel()[noise_idx] = 1.0
field = make_field(data=data)
result, = node.process(field, size=3)
assert result.data.shape == field.data.shape
# Should remove most impulse noise
assert result.data.sum() < field.data.sum()
# Size=1 should be identity
result_1, = node.process(field, size=1)
assert np.array_equal(result_1.data, field.data)
print(" PASS\n")
def test_edge_detect():
print("=== Test: EdgeDetect ===")
from backend.nodes.filters import EdgeDetect
node = EdgeDetect()
# Create an image with a sharp vertical edge
data = np.zeros((64, 64))
data[:, 32:] = 1.0
field = make_field(data=data)
for method in ["sobel", "prewitt", "laplacian", "log"]:
result, = node.process(field, method=method, sigma=1.0)
assert result.data.shape == field.data.shape
# Edge response should be strongest near column 32
col_energy = np.abs(result.data).sum(axis=0)
peak_col = np.argmax(col_energy)
assert abs(peak_col - 32) <= 2, f"{method}: peak at col {peak_col}, expected ~32"
print(" PASS\n")
# =========================================================================
# Level
# =========================================================================
def test_plane_level():
print("=== Test: PlaneLevelField ===")
from backend.nodes.level import PlaneLevelField
node = PlaneLevelField()
# Create a tilted plane + small signal
N = 64
y, x = np.mgrid[0:N, 0:N] / N
signal = np.sin(2 * np.pi * 5 * x)
data = 100 * x + 50 * y + signal
field = make_field(data=data)
result, = node.process(field)
assert result.data.shape == field.data.shape
# After plane leveling, mean should be near zero
assert abs(result.data.mean()) < 1e-10
# The signal should remain (correlation with original sine)
corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1]
assert corr > 0.98, f"Signal correlation after leveling: {corr}"
print(" PASS\n")
def test_poly_level():
print("=== Test: PolyLevelField ===")
from backend.nodes.level import PolyLevelField
node = PolyLevelField()
N = 64
y, x = np.mgrid[0:N, 0:N] / N
# Quadratic background + signal
background = 50 * x**2 + 30 * y**2 + 10 * x * y
signal = np.sin(2 * np.pi * 8 * x)
data = background + signal
field = make_field(data=data)
leveled, bg = node.process(field, degree_x=2, degree_y=2)
assert leveled.data.shape == field.data.shape
assert bg.data.shape == field.data.shape
# leveled + bg should reconstruct original
assert np.allclose(leveled.data + bg.data, field.data, atol=1e-10)
# Signal should be preserved after leveling
corr = np.corrcoef(leveled.data.ravel(), signal.ravel())[0, 1]
assert corr > 0.95, f"Signal correlation after poly leveling: {corr}"
# Degree 0 should just subtract the mean
leveled_0, bg_0 = node.process(field, degree_x=0, degree_y=0)
assert abs(leveled_0.data.mean()) < 1e-10
print(" PASS\n")
def test_fix_zero():
print("=== Test: FixZero ===")
from backend.nodes.level import FixZero
node = FixZero()
field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64))
result_min, = node.process(field, method="min")
assert result_min.data.min() == 0.0
assert result_min.data.max() == 30.0
result_mean, = node.process(field, method="mean")
assert abs(result_mean.data.mean()) < 1e-10
result_median, = node.process(field, method="median")
assert abs(np.median(result_median.data)) < 1e-10
print(" PASS\n")
# =========================================================================
# Analysis (non-FFT)
# =========================================================================
def test_statistics():
print("=== Test: StatisticsNode ===")
from backend.nodes.analysis import StatisticsNode
node = StatisticsNode()
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
field = make_field(data=data)
table, = node.process(field)
stats = {row["quantity"]: row["value"] for row in table}
assert stats["min"] == 1.0
assert stats["max"] == 4.0
assert stats["mean"] == 2.5
assert stats["median"] == 2.5
assert stats["range"] == 3.0
# RMS = sqrt(mean((x - mean)^2))
expected_rms = np.sqrt(np.mean((data - 2.5) ** 2))
assert abs(stats["RMS"] - expected_rms) < 1e-10
# Constant data should have RMS=0, skewness=0, kurtosis=0
const_field = make_field(data=np.ones((4, 4)) * 5.0)
table_const, = node.process(const_field)
const_stats = {row["quantity"]: row["value"] for row in table_const}
assert const_stats["RMS"] == 0.0
assert const_stats["skewness"] == 0.0
assert const_stats["kurtosis"] == 0.0
print(" PASS\n")
def test_height_histogram():
print("=== Test: HeightHistogram ===")
from backend.nodes.analysis import HeightHistogram
node = HeightHistogram()
# Uniform data should give a roughly flat histogram
data = np.linspace(0, 1, 1000).reshape(25, 40)
field = make_field(data=data)
counts, bin_centers = node.process(field, n_bins=10)
assert len(counts) == 10
assert len(bin_centers) == 10
assert counts.dtype == np.float64
# Total counts should equal number of pixels
assert counts.sum() == 1000
# For uniform data, each bin should have ~100 counts
assert np.std(counts) < 10, f"Histogram not flat enough: std={np.std(counts)}"
# Bin centers should span the data range
assert bin_centers[0] > 0.0
assert bin_centers[-1] < 1.0
print(" PASS\n")
def test_cross_section():
print("=== Test: CrossSection ===")
from backend.nodes.analysis import CrossSection
node = CrossSection()
# Create a field with a known horizontal gradient
N = 100
y, x = np.mgrid[0:N, 0:N] / N
data = x * 10.0 # value = 10 * x_fraction
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
# Horizontal cross section at y=0.5
(profile,) = node.process(
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
extend="none", n_samples=100,
)
assert len(profile) == 100
# Profile should be a linear ramp from ~0 to ~10
assert profile[0] < 0.5, f"Start of profile: {profile[0]}"
assert profile[-1] > 9.5, f"End of profile: {profile[-1]}"
# n_samples=0 should auto-calculate
(profile_auto,) = node.process(
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
extend="none", n_samples=0,
)
assert len(profile_auto) >= 2
# Test extend to edges — a short segment should be extended
(profile_ext,) = node.process(
field, x1=0.3, y1=0.5, x2=0.7, y2=0.5,
extend="to_edges", n_samples=100,
)
# Extended profile should start near 0 and end near 10
assert profile_ext[0] < 0.5
assert profile_ext[-1] > 9.5
# Diagonal cross section
(profile_diag,) = node.process(
field, x1=0.0, y1=0.0, x2=1.0, y2=1.0,
extend="none", n_samples=50,
)
assert len(profile_diag) == 50
print(" PASS\n")
# =========================================================================
# Grains
# =========================================================================
def test_threshold_mask():
print("=== Test: ThresholdMask ===")
from backend.nodes.grains import ThresholdMask
node = ThresholdMask()
# Clear bimodal data: left half = 0, right half = 1
data = np.zeros((64, 64))
data[:, 32:] = 1.0
field = make_field(data=data)
# Absolute threshold at 0.5
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
assert mask.dtype == np.uint8
assert mask.shape == (64, 64)
assert np.all(mask[:, :32] == 0)
assert np.all(mask[:, 32:] == 255)
# Direction "below"
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
assert np.all(mask_below[:, :32] == 255)
assert np.all(mask_below[:, 32:] == 0)
# Relative threshold at 0.5 (midpoint of range)
mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above")
assert np.all(mask_rel[:, 32:] == 255)
# Otsu should find the bimodal threshold
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
print(" PASS\n")
def test_grain_analysis():
print("=== Test: GrainAnalysis ===")
from backend.nodes.grains import GrainAnalysis
node = GrainAnalysis()
# Create a field with two distinct "grains"
N = 64
data = np.zeros((N, N))
# Grain 1: 10x10 block at top-left with height 5
data[5:15, 5:15] = 5.0
# Grain 2: 8x8 block at bottom-right with height 3
data[45:53, 45:53] = 3.0
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
# Create matching mask
mask = np.zeros((N, N), dtype=np.uint8)
mask[5:15, 5:15] = 255
mask[45:53, 45:53] = 255
table, = node.process(field, mask=mask, min_size=10)
assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
# Sort by area descending
table.sort(key=lambda r: r["area_px"], reverse=True)
assert table[0]["area_px"] == 100 # 10x10
assert table[1]["area_px"] == 64 # 8x8
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
# min_size filtering: only keep grains >= 80 px
table_filtered, = node.process(field, mask=mask, min_size=80)
assert len(table_filtered) == 1
assert table_filtered[0]["area_px"] == 100
print(" PASS\n")
# =========================================================================
# I/O
# =========================================================================
def test_load_image():
print("=== Test: LoadImage ===")
from backend.nodes.io import LoadImage
from PIL import Image
node = LoadImage()
with tempfile.TemporaryDirectory() as tmpdir:
# Test loading a grayscale PNG
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
img = Image.fromarray(arr, mode="L")
path = os.path.join(tmpdir, "test_gray.png")
img.save(path)
image, field = node.load(filename=path)
assert image.shape == (48, 64)
assert field.data.shape == (48, 64)
assert field.data.dtype == np.float64
# Test loading an RGB PNG (should average to grayscale for field)
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
path_rgb = os.path.join(tmpdir, "test_rgb.png")
img_rgb.save(path_rgb)
image_rgb, field_rgb = node.load(filename=path_rgb)
assert image_rgb.shape == (32, 32, 3)
assert field_rgb.data.shape == (32, 32)
# Test loading a .npy file
data_npy = np.random.default_rng(3).standard_normal((50, 60))
path_npy = os.path.join(tmpdir, "test.npy")
np.save(path_npy, data_npy)
image_npy, field_npy = node.load(filename=path_npy)
assert np.allclose(field_npy.data, data_npy)
print(" PASS\n")
def test_save_image():
print("=== Test: SaveImage ===")
from backend.nodes.io import SaveImage
node = SaveImage()
with tempfile.TemporaryDirectory() as tmpdir:
# Monkey-patch OUTPUT_DIR for testing
from pathlib import Path
import backend.nodes.io as io_mod
orig_dir = io_mod.OUTPUT_DIR
io_mod.OUTPUT_DIR = Path(tmpdir)
try:
arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8)
# Save as PNG
node.save(image=arr, filename_prefix="test", format="PNG")
saved = os.listdir(tmpdir)
assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}"
# Save as NPY
node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY")
saved = os.listdir(tmpdir)
assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}"
finally:
io_mod.OUTPUT_DIR = orig_dir
print(" PASS\n")
# =========================================================================
# Display (limited testing — these are output nodes with WS callbacks)
# =========================================================================
def test_preview_image():
print("=== Test: PreviewImage ===")
from backend.nodes.display import PreviewImage
node = PreviewImage()
# Set up a capture for the broadcast
captured = []
PreviewImage._broadcast_fn = lambda node_id, data_uri: captured.append(data_uri)
PreviewImage._current_node_id = "test"
# Preview with a DataField
field = make_field()
node.preview(colormap="viridis", field=field)
assert len(captured) == 1
assert captured[0].startswith("data:image/png;base64,")
# Preview with an IMAGE array
captured.clear()
arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8)
node.preview(colormap="gray", image=arr)
assert len(captured) == 1
# Clean up
PreviewImage._broadcast_fn = None
print(" PASS\n")
def test_print_table():
print("=== Test: PrintTable ===")
from backend.nodes.display import PrintTable
node = PrintTable()
captured = []
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
PrintTable._current_node_id = "test"
table = [{"quantity": "test", "value": 42.0, "unit": "m"}]
node.print_table(table=table)
assert len(captured) == 1
assert captured[0] == table
PrintTable._broadcast_table_fn = None
print(" PASS\n")
# =========================================================================
# Run all tests
# =========================================================================
if __name__ == "__main__":
# Filters
test_gaussian_filter()
test_median_filter()
test_edge_detect()
# Level
test_plane_level()
test_poly_level()
test_fix_zero()
# Analysis
test_statistics()
test_height_histogram()
test_cross_section()
# Grains
test_threshold_mask()
test_grain_analysis()
# I/O
test_load_image()
test_save_image()
# Display
test_preview_image()
test_print_table()
print("All tests passed!")