initial commit
2
.gitignore
vendored
@@ -1 +1,3 @@
|
|||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
frontend/node_modules/
|
||||||
|
frontend/dist/
|
||||||
87
GWYDDION_FEATURE_GAP.md
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
# Gwyddion Features Not Yet in Argonode
|
||||||
|
|
||||||
|
Reference for future implementation. Grouped by value to typical SPM workflows.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## High Value
|
||||||
|
|
||||||
|
| # | Feature | Gwyddion Source | Description |
|
||||||
|
|---|---------|---------------|-------------|
|
||||||
|
| 1 | Line Correction | linecorrect.c, linematch.c | Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts. |
|
||||||
|
| 2 | Scar Removal | scars.c | Detect and interpolate scan-line defects (horizontal streaks). |
|
||||||
|
| 3 | Facet Leveling | facet-level.c | Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces. |
|
||||||
|
| 4 | Morphological Mask Ops | mask_morph.c | Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks. |
|
||||||
|
| 5 | 1D FFT Filter | fft_filter_1d.c | Bandpass/lowpass/highpass filtering of LINE profiles. |
|
||||||
|
| 6 | 2D FFT Filter | fft_filter_2d.c | Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.). |
|
||||||
|
| 7 | Autocorrelation (ACF) | acf2d.c | 2D autocorrelation function. Reveals periodic structures and correlation lengths. |
|
||||||
|
| 8 | PSDF | psdf2d.c | Radial/2D power spectral density function. Complementary to ACF for roughness characterization. |
|
||||||
|
| 9 | Fractal Dimension | fractal.c | Multiple methods: partitioning, cube counting, triangulation, PSDF, HHCF. Quantifies surface complexity. |
|
||||||
|
| 10 | Curvature | curvature.c | Local mean/Gaussian curvature maps. Useful for feature identification. |
|
||||||
|
| 11 | Grain Distance Transform | mask_edt.c | Euclidean distance from grain boundaries. Useful for spatial distribution analysis. |
|
||||||
|
| 12 | Watershed Segmentation | grain_wshed.c | Automatic grain detection without manual threshold. More robust than simple thresholding. |
|
||||||
|
| 13 | Rotate / Flip | rotate.c, basicops.c | Basic geometric transforms (90°, arbitrary angle, mirror). |
|
||||||
|
| 14 | Crop | crop.c | Extract sub-region of a field. |
|
||||||
|
|
||||||
|
## Medium Value
|
||||||
|
|
||||||
|
| # | Feature | Gwyddion Source | Description |
|
||||||
|
|---|---------|---------------|-------------|
|
||||||
|
| 15 | Correlation / Pattern Matching | crosscor.c, maskcor.c | Find repeated features or align images via cross-correlation. |
|
||||||
|
| 16 | Slope Distribution | slope_dist.c | Angular histogram of surface slopes. Characterizes surface texture directionality. |
|
||||||
|
| 17 | Grain Filtering | grain_filter.c | Remove grains by size, height, or border contact. Refine grain masks post-detection. |
|
||||||
|
| 18 | Field Arithmetic | arithmetic.c | Add/subtract/multiply/divide two DATA_FIELDs. Useful for difference maps, normalization. |
|
||||||
|
| 19 | Spot Removal | spotremove.c | Interpolate over selected point defects (dust, spikes). |
|
||||||
|
| 20 | Tip Modeling / Deconvolution | tip_blind.c, tip_model.c | Estimate tip shape from image, deconvolve to recover true surface. |
|
||||||
|
| 21 | Radial Profile | rprofile tool | Azimuthally averaged profile from a center point. Good for circular features. |
|
||||||
|
| 22 | Wavelet Transform | dwt.c, cwt.c | Discrete/continuous wavelet analysis. Multi-scale roughness decomposition. |
|
||||||
|
| 23 | Scale / Resample | scale.c, resample.c | Resize fields with interpolation. |
|
||||||
|
| 24 | Gradient | gradient.c | Compute x/y gradient magnitude maps. |
|
||||||
|
| 25 | Custom Convolution | convolution_filter.c | User-defined kernel convolution. |
|
||||||
|
| 26 | Local Contrast Enhancement | local_contrast.c | Enhance visibility of local features in images. |
|
||||||
|
|
||||||
|
## Lower Priority
|
||||||
|
|
||||||
|
| # | Feature | Gwyddion Source | Description |
|
||||||
|
|---|---------|---------------|-------------|
|
||||||
|
| 27 | Drift Correction | drift.c | Compensate for thermal/piezo drift between scan lines. |
|
||||||
|
| 28 | Affine / Perspective Correction | correct_affine.c, correct_perspective.c | Fix geometric distortions from scanner nonlinearity. |
|
||||||
|
| 29 | MFM Analysis | mfm_*.c | Magnetic force microscopy: field calculation, shift finding. |
|
||||||
|
| 30 | Lattice Measurement | measure_lattice.c | Detect and measure periodic lattice structures from ACF/FFT. |
|
||||||
|
| 31 | Hough Transform | hough.c | Detect lines and circles in images. |
|
||||||
|
| 32 | Image Stitching / Merging | merge.c, stitch.c | Combine multiple overlapping scans into one image. |
|
||||||
|
| 33 | Facet Analysis | facet_analysis.c | Orientation distribution of surface facets (stereographic projection). |
|
||||||
|
| 34 | Shape Fitting | fit-shape.c | Fit geometric primitives: sphere, paraboloid, cylinder, etc. |
|
||||||
|
| 35 | Synthetic Surface Generation | *_synth.c (~20 modules) | Generate test surfaces: FBM, noise, lattice, waves, particles, fibers, etc. |
|
||||||
|
| 36 | Entropy | entropy.c | Information entropy of height distribution. |
|
||||||
|
| 37 | Indentation Analysis | indent_analyze.c, hertz.c | Nanoindentation curve fitting (Hertz model). |
|
||||||
|
| 38 | Deconvolution | deconvolve.c | Blind/regularized deconvolution for image restoration. |
|
||||||
|
| 39 | Canny / Harris Detection | filters.c | Corner and edge feature detection beyond basic Sobel/Prewitt. |
|
||||||
|
| 40 | Kuwahara Filter | filters.c | Edge-preserving smoothing filter. |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Already Implemented in Argonode
|
||||||
|
|
||||||
|
For reference, these Gwyddion equivalents are already covered:
|
||||||
|
|
||||||
|
| Argonode Node | Category | Gwyddion Equivalent |
|
||||||
|
|--------------|----------|-------------------|
|
||||||
|
| Load Image / Load SPM File | io | File import (gwy, sxm, ibw) |
|
||||||
|
| Save Image | io | File export |
|
||||||
|
| Coordinate | io | — |
|
||||||
|
| Plane Level | level | level.c |
|
||||||
|
| Polynomial Level | level | polylevel.c |
|
||||||
|
| Fix Zero | level | level.c (fix_zero) |
|
||||||
|
| Gaussian Filter | filters | filters.c (gaussian) |
|
||||||
|
| Median Filter | filters | filters.c (median) |
|
||||||
|
| Edge Detect | filters | edge.c (sobel, prewitt, laplacian, LoG) |
|
||||||
|
| Statistics | analysis | stats.c |
|
||||||
|
| Height Histogram | analysis | linestats.c (dh) |
|
||||||
|
| 2D FFT | analysis | fft.c |
|
||||||
|
| Cross Section | analysis | profile tool |
|
||||||
|
| Profile Roughness | analysis | roughness.c (Ra, Rq, Rsk, Rku, Rp, Rv, Rt) |
|
||||||
|
| Line Math | analysis | linestats.c |
|
||||||
|
| Threshold Mask | grains | threshold.c, otsu_threshold.c |
|
||||||
|
| Grain Analysis | grains | grain_stat.c |
|
||||||
|
| Preview / 3D View / Print Table | display | Presentation, 3D view |
|
||||||
0
backend/__init__.py
Normal file
134
backend/data_types.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
Core data types for argonode.
|
||||||
|
|
||||||
|
DataField mirrors Gwyddion's GwyDataField structure:
|
||||||
|
xres, yres – pixel dimensions
|
||||||
|
xreal, yreal – physical dimensions in metres
|
||||||
|
xoff, yoff – position offset in metres
|
||||||
|
si_unit_xy – lateral unit string (e.g. "m", "nm")
|
||||||
|
si_unit_z – height/value unit string (e.g. "m", "V", "A")
|
||||||
|
domain – "spatial" or "frequency" (set by FFT nodes)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataField:
|
||||||
|
data: np.ndarray # shape (yres, xres), dtype float64
|
||||||
|
xres: int = 0
|
||||||
|
yres: int = 0
|
||||||
|
xreal: float = 1e-6 # physical width in metres
|
||||||
|
yreal: float = 1e-6 # physical height in metres
|
||||||
|
xoff: float = 0.0
|
||||||
|
yoff: float = 0.0
|
||||||
|
si_unit_xy: str = "m"
|
||||||
|
si_unit_z: str = "m"
|
||||||
|
domain: str = "spatial" # "spatial" or "frequency"
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.data = np.asarray(self.data, dtype=np.float64)
|
||||||
|
if self.data.ndim != 2:
|
||||||
|
raise ValueError(f"DataField.data must be 2-D, got shape {self.data.shape}")
|
||||||
|
self.yres, self.xres = self.data.shape
|
||||||
|
|
||||||
|
def copy(self) -> "DataField":
|
||||||
|
"""Return a deep copy with independent data array."""
|
||||||
|
return DataField(
|
||||||
|
data=self.data.copy(),
|
||||||
|
xres=self.xres,
|
||||||
|
yres=self.yres,
|
||||||
|
xreal=self.xreal,
|
||||||
|
yreal=self.yreal,
|
||||||
|
xoff=self.xoff,
|
||||||
|
yoff=self.yoff,
|
||||||
|
si_unit_xy=self.si_unit_xy,
|
||||||
|
si_unit_z=self.si_unit_z,
|
||||||
|
domain=self.domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
def replace(self, **kwargs) -> "DataField":
|
||||||
|
"""Return a copy with selected fields replaced. data is deep-copied unless provided."""
|
||||||
|
base = {
|
||||||
|
"data": self.data.copy(),
|
||||||
|
"xres": self.xres,
|
||||||
|
"yres": self.yres,
|
||||||
|
"xreal": self.xreal,
|
||||||
|
"yreal": self.yreal,
|
||||||
|
"xoff": self.xoff,
|
||||||
|
"yoff": self.yoff,
|
||||||
|
"si_unit_xy": self.si_unit_xy,
|
||||||
|
"si_unit_z": self.si_unit_z,
|
||||||
|
"domain": self.domain,
|
||||||
|
}
|
||||||
|
base.update(kwargs)
|
||||||
|
return DataField(**base)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dx(self) -> float:
|
||||||
|
"""Physical pixel size in x (metres)."""
|
||||||
|
return self.xreal / self.xres if self.xres else 1.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dy(self) -> float:
|
||||||
|
"""Physical pixel size in y (metres)."""
|
||||||
|
return self.yreal / self.yres if self.yres else 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Utility helpers shared across nodes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap.
|
||||||
|
Returns shape (H, W, 3) uint8.
|
||||||
|
"""
|
||||||
|
import matplotlib.cm as cm
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
|
||||||
|
data = df.data
|
||||||
|
dmin, dmax = data.min(), data.max()
|
||||||
|
if dmax > dmin:
|
||||||
|
normalized = (data - dmin) / (dmax - dmin)
|
||||||
|
else:
|
||||||
|
normalized = np.zeros_like(data)
|
||||||
|
|
||||||
|
cmap = cm.get_cmap(colormap)
|
||||||
|
rgba = cmap(normalized) # (H, W, 4) float [0,1]
|
||||||
|
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||||
|
return rgb
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_uint8(image: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert an IMAGE (float or uint8, 2-D or 3-D) to uint8 (H,W,3) or (H,W) for PIL.
|
||||||
|
"""
|
||||||
|
if image.dtype == np.uint8:
|
||||||
|
return image
|
||||||
|
# float — normalize to [0, 255]
|
||||||
|
imin, imax = image.min(), image.max()
|
||||||
|
if imax > imin:
|
||||||
|
out = (image - imin) / (imax - imin) * 255.0
|
||||||
|
else:
|
||||||
|
out = np.zeros_like(image)
|
||||||
|
return out.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_preview(arr: np.ndarray) -> str:
|
||||||
|
"""
|
||||||
|
Encode a uint8 numpy array as a base64 data URI (PNG).
|
||||||
|
arr: (H, W) grayscale or (H, W, 3) RGB, uint8.
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
img = Image.fromarray(arr)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format="PNG")
|
||||||
|
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||||
|
return f"data:image/png;base64,{b64}"
|
||||||
294
backend/execution.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""
|
||||||
|
Graph execution engine for argonode.
|
||||||
|
|
||||||
|
Prompt format (same as ComfyUI):
|
||||||
|
{
|
||||||
|
"node_id": {
|
||||||
|
"class_type": "GaussianFilter",
|
||||||
|
"inputs": {
|
||||||
|
"field": ["upstream_node_id", 0], # link: [src_id, output_slot]
|
||||||
|
"sigma": 2.0 # constant widget value
|
||||||
|
}
|
||||||
|
},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
The engine:
|
||||||
|
1. Topologically sorts nodes (Kahn's algorithm).
|
||||||
|
2. Resolves input links to actual Python objects from earlier outputs.
|
||||||
|
3. Calls each node's FUNCTION method.
|
||||||
|
4. Emits progress callbacks after each node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from backend.node_registry import NODE_CLASS_MAPPINGS
|
||||||
|
|
||||||
|
|
||||||
|
def _is_link(value: Any) -> bool:
|
||||||
|
"""A value is a link if it's a [node_id_str, slot_int] pair."""
|
||||||
|
return (
|
||||||
|
isinstance(value, (list, tuple))
|
||||||
|
and len(value) == 2
|
||||||
|
and isinstance(value[0], str)
|
||||||
|
and isinstance(value[1], int)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionEngine:
|
||||||
|
"""Synchronous (blocking) graph executor. Run inside a thread pool from async code."""
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
prompt: dict[str, dict],
|
||||||
|
on_node_start: Callable[[str], None] | None = None,
|
||||||
|
on_node_done: Callable[[str], None] | None = None,
|
||||||
|
on_preview: Callable[[str, str], None] | None = None,
|
||||||
|
on_table: Callable[[str, list], None] | None = None,
|
||||||
|
on_mesh: Callable[[str, dict], None] | None = None,
|
||||||
|
on_overlay: Callable[[str, str], None] | None = None,
|
||||||
|
) -> dict[str, tuple]:
|
||||||
|
"""
|
||||||
|
Execute the workflow described by `prompt`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prompt : workflow dict (node_id → {class_type, inputs})
|
||||||
|
on_node_start : called with node_id just before a node executes
|
||||||
|
on_node_done : called with node_id just after a node executes
|
||||||
|
on_preview : called with (node_id, data_uri) when a display node runs
|
||||||
|
on_table : called with (node_id, table_list) when PrintTable runs
|
||||||
|
on_overlay : called with (node_id, data_uri) for interactive overlays
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
node_outputs : {node_id → tuple-of-outputs} for every executed node
|
||||||
|
"""
|
||||||
|
order = self._topological_sort(prompt)
|
||||||
|
node_outputs: dict[str, tuple] = {}
|
||||||
|
|
||||||
|
# Inject display callbacks before execution
|
||||||
|
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay)
|
||||||
|
|
||||||
|
for node_id in order:
|
||||||
|
node_def = prompt[node_id]
|
||||||
|
class_name = node_def["class_type"]
|
||||||
|
|
||||||
|
if class_name not in NODE_CLASS_MAPPINGS:
|
||||||
|
raise ValueError(f"Unknown node type: '{class_name}'")
|
||||||
|
|
||||||
|
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||||
|
raw_inputs = node_def.get("inputs", {})
|
||||||
|
inputs = self._resolve_inputs(raw_inputs, node_outputs)
|
||||||
|
|
||||||
|
# Let display nodes know their node_id so they can tag WS messages
|
||||||
|
self._set_node_id_on_display(cls, node_id)
|
||||||
|
|
||||||
|
if on_node_start:
|
||||||
|
on_node_start(node_id)
|
||||||
|
|
||||||
|
instance = cls()
|
||||||
|
func = getattr(instance, cls.FUNCTION)
|
||||||
|
result = func(**inputs)
|
||||||
|
|
||||||
|
# Nodes must return a tuple; coerce single values just in case
|
||||||
|
if not isinstance(result, tuple):
|
||||||
|
result = (result,)
|
||||||
|
|
||||||
|
node_outputs[node_id] = result
|
||||||
|
|
||||||
|
# Auto-preview: broadcast a thumbnail for any DATA_FIELD,
|
||||||
|
# IMAGE, or TABLE output so every node shows its result.
|
||||||
|
if on_preview or on_table:
|
||||||
|
self._auto_preview(cls, node_id, result, on_preview, on_table)
|
||||||
|
|
||||||
|
if on_node_done:
|
||||||
|
on_node_done(node_id)
|
||||||
|
|
||||||
|
return node_outputs
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _topological_sort(self, prompt: dict) -> list[str]:
|
||||||
|
"""Kahn's algorithm — returns node IDs in dependency order."""
|
||||||
|
in_degree: dict[str, int] = {nid: 0 for nid in prompt}
|
||||||
|
dependents: dict[str, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
|
for node_id, node_def in prompt.items():
|
||||||
|
for value in node_def.get("inputs", {}).values():
|
||||||
|
if _is_link(value):
|
||||||
|
src_id = value[0]
|
||||||
|
if src_id in prompt:
|
||||||
|
in_degree[node_id] += 1
|
||||||
|
dependents[src_id].append(node_id)
|
||||||
|
|
||||||
|
queue: deque[str] = deque(nid for nid, deg in in_degree.items() if deg == 0)
|
||||||
|
order: list[str] = []
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
nid = queue.popleft()
|
||||||
|
order.append(nid)
|
||||||
|
for dep in dependents[nid]:
|
||||||
|
in_degree[dep] -= 1
|
||||||
|
if in_degree[dep] == 0:
|
||||||
|
queue.append(dep)
|
||||||
|
|
||||||
|
if len(order) != len(prompt):
|
||||||
|
raise ValueError("Cycle detected in workflow graph — cannot execute.")
|
||||||
|
|
||||||
|
return order
|
||||||
|
|
||||||
|
def _resolve_inputs(
|
||||||
|
self,
|
||||||
|
raw_inputs: dict[str, Any],
|
||||||
|
node_outputs: dict[str, tuple],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Replace [src_id, slot] links with actual output values."""
|
||||||
|
resolved = {}
|
||||||
|
for key, value in raw_inputs.items():
|
||||||
|
if _is_link(value):
|
||||||
|
src_id, slot = value[0], int(value[1])
|
||||||
|
if src_id not in node_outputs:
|
||||||
|
raise KeyError(
|
||||||
|
f"Node '{src_id}' has no output yet — dependency ordering bug?"
|
||||||
|
)
|
||||||
|
outputs = node_outputs[src_id]
|
||||||
|
if slot >= len(outputs):
|
||||||
|
raise IndexError(
|
||||||
|
f"Node '{src_id}' only has {len(outputs)} outputs, "
|
||||||
|
f"but slot {slot} was requested."
|
||||||
|
)
|
||||||
|
resolved[key] = outputs[slot]
|
||||||
|
else:
|
||||||
|
resolved[key] = value
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
def _inject_display_callbacks(
|
||||||
|
self,
|
||||||
|
on_preview: Callable | None,
|
||||||
|
on_table: Callable | None,
|
||||||
|
on_mesh: Callable | None = None,
|
||||||
|
on_overlay: Callable | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Wire up broadcast callbacks on display node classes."""
|
||||||
|
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||||
|
from backend.nodes.analysis import CrossSection
|
||||||
|
from backend.nodes.io import SaveImage
|
||||||
|
|
||||||
|
PreviewImage._broadcast_fn = on_preview
|
||||||
|
View3D._broadcast_mesh_fn = on_mesh
|
||||||
|
PrintTable._broadcast_table_fn = on_table
|
||||||
|
CrossSection._broadcast_overlay_fn = on_overlay
|
||||||
|
SaveImage._broadcast_preview = (
|
||||||
|
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||||
|
"""Inform display nodes of their current node_id for WS tagging."""
|
||||||
|
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||||
|
from backend.nodes.analysis import CrossSection
|
||||||
|
if cls in (PreviewImage, PrintTable, View3D, CrossSection):
|
||||||
|
cls._current_node_id = node_id
|
||||||
|
|
||||||
|
def _auto_preview(
|
||||||
|
self,
|
||||||
|
cls: type,
|
||||||
|
node_id: str,
|
||||||
|
result: tuple,
|
||||||
|
on_preview: Callable | None,
|
||||||
|
on_table: Callable | None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
After every node executes, inspect its outputs and broadcast
|
||||||
|
a preview for the first DATA_FIELD, IMAGE, or TABLE found.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from backend.data_types import (
|
||||||
|
DataField, datafield_to_uint8, image_to_uint8, encode_preview,
|
||||||
|
)
|
||||||
|
|
||||||
|
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||||
|
|
||||||
|
for slot, type_name in enumerate(return_types):
|
||||||
|
if slot >= len(result):
|
||||||
|
break
|
||||||
|
value = result[slot]
|
||||||
|
|
||||||
|
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
|
||||||
|
arr = datafield_to_uint8(value, "viridis")
|
||||||
|
on_preview(node_id, encode_preview(arr))
|
||||||
|
return # one preview per node is enough
|
||||||
|
|
||||||
|
if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview:
|
||||||
|
arr = image_to_uint8(value)
|
||||||
|
on_preview(node_id, encode_preview(arr))
|
||||||
|
return
|
||||||
|
|
||||||
|
if type_name == "LINE" and isinstance(value, np.ndarray) and on_preview:
|
||||||
|
preview = self._render_line_preview(cls, slot, result)
|
||||||
|
if preview:
|
||||||
|
on_preview(node_id, preview)
|
||||||
|
return
|
||||||
|
|
||||||
|
if type_name == "TABLE" and isinstance(value, list) and on_table:
|
||||||
|
on_table(node_id, value)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _render_line_preview(
|
||||||
|
self,
|
||||||
|
cls: type,
|
||||||
|
slot: int,
|
||||||
|
result: tuple,
|
||||||
|
) -> str | None:
|
||||||
|
"""Render a LINE output as a small matplotlib plot, returned as a data URI."""
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
import io as _io
|
||||||
|
|
||||||
|
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||||
|
|
||||||
|
# Find the y-values (current slot) and try to find an x-axis
|
||||||
|
y = result[slot]
|
||||||
|
x = None
|
||||||
|
# If the next output is also LINE, use it as x-axis
|
||||||
|
if slot + 1 < len(return_types) and return_types[slot + 1] == "LINE":
|
||||||
|
x = result[slot + 1]
|
||||||
|
# Or if slot > 0 and previous is LINE, this slot is the x-axis — skip
|
||||||
|
if slot > 0 and return_types[slot - 1] == "LINE":
|
||||||
|
return None # the first LINE already plotted both
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
|
||||||
|
fig.patch.set_facecolor("#1e293b")
|
||||||
|
ax.set_facecolor("#0f172a")
|
||||||
|
if x is not None:
|
||||||
|
ax.plot(x, y, color="#ff9800", linewidth=1.2)
|
||||||
|
else:
|
||||||
|
ax.plot(y, color="#ff9800", linewidth=1.2)
|
||||||
|
ax.tick_params(colors="#94a3b8", labelsize=7)
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_color("#334155")
|
||||||
|
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
|
||||||
|
fig.tight_layout(pad=0.4)
|
||||||
|
|
||||||
|
buf = _io.BytesIO()
|
||||||
|
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
|
||||||
|
plt.close(fig)
|
||||||
|
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||||
|
return f"data:image/png;base64,{b64}"
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def new_prompt_id() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
47
backend/main.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
Entry point for argonode.
|
||||||
|
|
||||||
|
Run with:
|
||||||
|
python -m backend.main
|
||||||
|
or simply:
|
||||||
|
python backend/main.py
|
||||||
|
from the argonode/ directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Allow running as `python backend/main.py` from the project root
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
from backend.server import create_app
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)-8s %(name)s — %(message)s",
|
||||||
|
)
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HOST = "127.0.0.1"
|
||||||
|
PORT = 8188
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
app = create_app(loop)
|
||||||
|
|
||||||
|
log.info("=" * 60)
|
||||||
|
log.info(" Argonode — Node-based image analysis")
|
||||||
|
log.info(" Open your browser at http://%s:%d", HOST, PORT)
|
||||||
|
log.info("=" * 60)
|
||||||
|
|
||||||
|
web.run_app(app, host=HOST, port=PORT, loop=loop, access_log=None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
56
backend/node_registry.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
Node registry for argonode.
|
||||||
|
|
||||||
|
Nodes are plain Python classes decorated with @register_node.
|
||||||
|
NODE_CLASS_MAPPINGS is the single source of truth consumed by
|
||||||
|
the execution engine and the /nodes REST endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS: dict[str, type] = {}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_node(display_name: str | None = None):
|
||||||
|
"""
|
||||||
|
Class decorator that registers a node class into NODE_CLASS_MAPPINGS.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@register_node(display_name="Gaussian Filter")
|
||||||
|
class GaussianFilter:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(cls: type) -> type:
|
||||||
|
name = cls.__name__
|
||||||
|
NODE_CLASS_MAPPINGS[name] = cls
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS[name] = display_name or name
|
||||||
|
return cls
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_info(class_name: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Return a JSON-serialisable dict describing a node — consumed by GET /nodes.
|
||||||
|
Shape is compatible with what LiteGraph.js expects from the frontend.
|
||||||
|
"""
|
||||||
|
cls = NODE_CLASS_MAPPINGS[class_name]
|
||||||
|
input_types: dict = cls.INPUT_TYPES()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": class_name,
|
||||||
|
"display_name": NODE_DISPLAY_NAME_MAPPINGS.get(class_name, class_name),
|
||||||
|
"category": getattr(cls, "CATEGORY", "uncategorized"),
|
||||||
|
"input": input_types,
|
||||||
|
"input_order": {k: list(v.keys()) for k, v in input_types.items()},
|
||||||
|
"output": list(cls.RETURN_TYPES),
|
||||||
|
"output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)),
|
||||||
|
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
|
||||||
|
"description": getattr(cls, "DESCRIPTION", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_node_info() -> dict[str, dict[str, Any]]:
|
||||||
|
"""Return info dicts for every registered node."""
|
||||||
|
return {name: get_node_info(name) for name in NODE_CLASS_MAPPINGS}
|
||||||
2
backend/nodes/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Import all node modules to trigger @register_node decorators.
|
||||||
|
from . import io, filters, level, analysis, grains, display
|
||||||
471
backend/nodes/analysis.py
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
"""
|
||||||
|
Analysis nodes — statistics, histograms, FFT, cross sections.
|
||||||
|
|
||||||
|
Gwyddion equivalents:
|
||||||
|
StatisticsNode → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h)
|
||||||
|
HeightHistogram → DH (height distribution), gwy_data_field_dh
|
||||||
|
FFT2D → gwy_data_field_2dfft + gwy_data_field_2dpsdf
|
||||||
|
CrossSection → gwy_data_field_get_profile (libprocess/datafield.c)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# StatisticsNode
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Statistics")
|
||||||
|
class StatisticsNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("TABLE",)
|
||||||
|
RETURN_NAMES = ("stats",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "analysis"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Compute basic surface statistics: min, max, mean, RMS roughness, median, "
|
||||||
|
"and skewness. Equivalent to gwy_data_field_get_min/max/avg/rms."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField) -> tuple:
|
||||||
|
d = field.data
|
||||||
|
mean = float(d.mean())
|
||||||
|
rms = float(np.sqrt(np.mean((d - mean) ** 2)))
|
||||||
|
skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0
|
||||||
|
kurtosis = float(np.mean(((d - mean) / rms) ** 4)) if rms > 0 else 0.0
|
||||||
|
|
||||||
|
table = [
|
||||||
|
{"quantity": "min", "value": float(d.min()), "unit": field.si_unit_z},
|
||||||
|
{"quantity": "max", "value": float(d.max()), "unit": field.si_unit_z},
|
||||||
|
{"quantity": "mean", "value": mean, "unit": field.si_unit_z},
|
||||||
|
{"quantity": "RMS", "value": rms, "unit": field.si_unit_z},
|
||||||
|
{"quantity": "median", "value": float(np.median(d)), "unit": field.si_unit_z},
|
||||||
|
{"quantity": "skewness", "value": skewness, "unit": ""},
|
||||||
|
{"quantity": "kurtosis", "value": kurtosis, "unit": ""},
|
||||||
|
{"quantity": "range", "value": float(d.max() - d.min()), "unit": field.si_unit_z},
|
||||||
|
]
|
||||||
|
return (table,)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HeightHistogram
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Height Histogram")
|
||||||
|
class HeightHistogram:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"n_bins": ("INT", {"default": 256, "min": 10, "max": 1000, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LINE", "LINE")
|
||||||
|
RETURN_NAMES = ("counts", "bin_centers")
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "analysis"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Compute the height distribution histogram (DH). "
|
||||||
|
"Equivalent to gwy_data_field_dh."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, n_bins: int) -> tuple:
|
||||||
|
counts, bin_edges = np.histogram(field.data.ravel(), bins=int(n_bins))
|
||||||
|
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
|
||||||
|
return (counts.astype(np.float64), bin_centers)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FFT2D
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="2D FFT")
|
||||||
|
class FFT2D:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"windowing": (["hann", "hamming", "blackman", "none"],),
|
||||||
|
"level": (["mean", "plane", "none"],),
|
||||||
|
"output": (["log_magnitude", "magnitude", "phase", "psdf"],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD",)
|
||||||
|
RETURN_NAMES = ("spectrum",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "analysis"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
|
||||||
|
"Output can be log magnitude, magnitude, phase, or PSDF. "
|
||||||
|
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, windowing: str, level: str, output: str) -> tuple:
|
||||||
|
data = field.data.copy()
|
||||||
|
yres, xres = data.shape
|
||||||
|
|
||||||
|
# Level subtraction (Gwyddion-style, before windowing)
|
||||||
|
if level == "mean":
|
||||||
|
data -= data.mean()
|
||||||
|
elif level == "plane":
|
||||||
|
# Fit and subtract a plane: z = a + b*x + c*y
|
||||||
|
yy, xx = np.mgrid[0:yres, 0:xres]
|
||||||
|
xx_f = xx.ravel().astype(np.float64)
|
||||||
|
yy_f = yy.ravel().astype(np.float64)
|
||||||
|
zz_f = data.ravel()
|
||||||
|
A = np.column_stack([np.ones_like(xx_f), xx_f, yy_f])
|
||||||
|
coeffs, _, _, _ = np.linalg.lstsq(A, zz_f, rcond=None)
|
||||||
|
plane = (coeffs[0] + coeffs[1] * xx + coeffs[2] * yy)
|
||||||
|
data -= plane
|
||||||
|
|
||||||
|
# Windowing (Gwyddion uses (i+0.5)/n centred formulation)
|
||||||
|
if windowing != "none":
|
||||||
|
t_y = (np.arange(yres) + 0.5) / yres
|
||||||
|
t_x = (np.arange(xres) + 0.5) / xres
|
||||||
|
if windowing == "hann":
|
||||||
|
wy = 0.5 - 0.5 * np.cos(2 * np.pi * t_y)
|
||||||
|
wx = 0.5 - 0.5 * np.cos(2 * np.pi * t_x)
|
||||||
|
elif windowing == "hamming":
|
||||||
|
wy = 0.54 - 0.46 * np.cos(2 * np.pi * t_y)
|
||||||
|
wx = 0.54 - 0.46 * np.cos(2 * np.pi * t_x)
|
||||||
|
elif windowing == "blackman":
|
||||||
|
wy = 0.42 - 0.5 * np.cos(2 * np.pi * t_y) + 0.08 * np.cos(4 * np.pi * t_y)
|
||||||
|
wx = 0.42 - 0.5 * np.cos(2 * np.pi * t_x) + 0.08 * np.cos(4 * np.pi * t_x)
|
||||||
|
else:
|
||||||
|
wy = np.ones(yres)
|
||||||
|
wx = np.ones(xres)
|
||||||
|
data *= np.outer(wy, wx)
|
||||||
|
|
||||||
|
# 2D FFT, shifted so DC is at centre
|
||||||
|
F = np.fft.fftshift(np.fft.fft2(data))
|
||||||
|
n = xres * yres
|
||||||
|
|
||||||
|
if output == "log_magnitude":
|
||||||
|
mag = np.abs(F)
|
||||||
|
# Log scale with floor to avoid log(0)
|
||||||
|
result = np.log1p(mag)
|
||||||
|
elif output == "magnitude":
|
||||||
|
result = np.abs(F)
|
||||||
|
elif output == "phase":
|
||||||
|
result = np.angle(F)
|
||||||
|
elif output == "psdf":
|
||||||
|
# Gwyddion-equivalent PSDF: |F|^2 * dx * dy / (n * 4π²)
|
||||||
|
dx = field.xreal / xres
|
||||||
|
dy = field.yreal / yres
|
||||||
|
result = (np.abs(F) ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
|
||||||
|
else:
|
||||||
|
result = np.abs(F)
|
||||||
|
|
||||||
|
# Calibrate the output field in spatial-frequency units
|
||||||
|
if output == "psdf":
|
||||||
|
# Gwyddion uses angular frequency: 2π/dx, 2π/dy
|
||||||
|
freq_xreal = 2.0 * np.pi * xres / field.xreal
|
||||||
|
freq_yreal = 2.0 * np.pi * yres / field.yreal
|
||||||
|
z_unit = f"({field.si_unit_z})^2 m^2"
|
||||||
|
else:
|
||||||
|
freq_xreal = xres / field.xreal
|
||||||
|
freq_yreal = yres / field.yreal
|
||||||
|
z_unit = field.si_unit_z
|
||||||
|
|
||||||
|
out_field = DataField(
|
||||||
|
data=result,
|
||||||
|
xreal=freq_xreal,
|
||||||
|
yreal=freq_yreal,
|
||||||
|
si_unit_xy="1/m",
|
||||||
|
si_unit_z=z_unit,
|
||||||
|
domain="frequency",
|
||||||
|
)
|
||||||
|
return (out_field,)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CrossSection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _extend_to_edges(x1, y1, x2, y2):
|
||||||
|
"""
|
||||||
|
Extend the line through (x1,y1)-(x2,y2) to the boundaries of [0,1]x[0,1].
|
||||||
|
Returns the two intersection points (clipped to the unit square).
|
||||||
|
"""
|
||||||
|
dx = x2 - x1
|
||||||
|
dy = y2 - y1
|
||||||
|
|
||||||
|
# Collect parametric t values where line hits each boundary
|
||||||
|
t_candidates = []
|
||||||
|
if abs(dx) > 1e-12:
|
||||||
|
for bx in (0.0, 1.0):
|
||||||
|
t = (bx - x1) / dx
|
||||||
|
y_at_t = y1 + t * dy
|
||||||
|
if -1e-9 <= y_at_t <= 1.0 + 1e-9:
|
||||||
|
t_candidates.append(t)
|
||||||
|
if abs(dy) > 1e-12:
|
||||||
|
for by in (0.0, 1.0):
|
||||||
|
t = (by - y1) / dy
|
||||||
|
x_at_t = x1 + t * dx
|
||||||
|
if -1e-9 <= x_at_t <= 1.0 + 1e-9:
|
||||||
|
t_candidates.append(t)
|
||||||
|
|
||||||
|
if len(t_candidates) < 2:
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
|
||||||
|
t_min = min(t_candidates)
|
||||||
|
t_max = max(t_candidates)
|
||||||
|
|
||||||
|
return (
|
||||||
|
np.clip(x1 + t_min * dx, 0, 1),
|
||||||
|
np.clip(y1 + t_min * dy, 0, 1),
|
||||||
|
np.clip(x1 + t_max * dx, 0, 1),
|
||||||
|
np.clip(y1 + t_max * dy, 0, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Cross Section")
|
||||||
|
class CrossSection:
|
||||||
|
"""Extract a 1-D height profile along an arbitrary line across the image."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||||
|
"y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||||
|
"x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||||
|
"y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}),
|
||||||
|
"extend": (["none", "to_edges"],),
|
||||||
|
"n_samples": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"point_a": ("COORD",),
|
||||||
|
"point_b": ("COORD",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LINE",)
|
||||||
|
RETURN_NAMES = ("profile",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "analysis"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Extract a cross-section profile along a line between two points. "
|
||||||
|
"Drag the markers on the image to set the line endpoints. "
|
||||||
|
"Equivalent to gwy_data_field_get_profile."
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_overlay_fn = None
|
||||||
|
_current_node_id: str = ""
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self, field: DataField,
|
||||||
|
x1: float, y1: float, x2: float, y2: float,
|
||||||
|
extend: str, n_samples: int,
|
||||||
|
point_a=None, point_b=None,
|
||||||
|
) -> tuple:
|
||||||
|
from scipy.ndimage import map_coordinates
|
||||||
|
import io, base64
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
# COORD inputs override widget values
|
||||||
|
if point_a is not None:
|
||||||
|
x1, y1 = float(point_a[0]), float(point_a[1])
|
||||||
|
if point_b is not None:
|
||||||
|
x2, y2 = float(point_b[0]), float(point_b[1])
|
||||||
|
|
||||||
|
# Remember marker positions (before extend)
|
||||||
|
marker_x1, marker_y1 = float(x1), float(y1)
|
||||||
|
marker_x2, marker_y2 = float(x2), float(y2)
|
||||||
|
|
||||||
|
xres, yres = field.xres, field.yres
|
||||||
|
|
||||||
|
if extend == "to_edges":
|
||||||
|
x1, y1, x2, y2 = _extend_to_edges(
|
||||||
|
float(x1), float(y1), float(x2), float(y2),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert fractional [0,1] to pixel indices [0, res-1]
|
||||||
|
px1, py1 = float(x1) * (xres - 1), float(y1) * (yres - 1)
|
||||||
|
px2, py2 = float(x2) * (xres - 1), float(y2) * (yres - 1)
|
||||||
|
|
||||||
|
# Number of sample points
|
||||||
|
line_len_px = np.hypot(px2 - px1, py2 - py1)
|
||||||
|
if n_samples <= 0:
|
||||||
|
n_samples = max(2, int(np.ceil(line_len_px)))
|
||||||
|
|
||||||
|
# Sample coordinates along the line
|
||||||
|
t = np.linspace(0, 1, n_samples)
|
||||||
|
coords_y = py1 + t * (py2 - py1)
|
||||||
|
coords_x = px1 + t * (px2 - px1)
|
||||||
|
|
||||||
|
# Interpolate values along the line (cubic spline)
|
||||||
|
profile = map_coordinates(field.data, [coords_y, coords_x], order=3, mode="nearest")
|
||||||
|
|
||||||
|
# Broadcast overlay image with marker positions
|
||||||
|
if CrossSection._broadcast_overlay_fn is not None:
|
||||||
|
fig = Figure(figsize=(3, 3), dpi=100)
|
||||||
|
ax = fig.add_axes([0, 0, 1, 1])
|
||||||
|
ax.imshow(field.data, cmap="viridis", aspect="auto")
|
||||||
|
ax.axis("off")
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
||||||
|
buf.seek(0)
|
||||||
|
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
|
||||||
|
|
||||||
|
CrossSection._broadcast_overlay_fn(
|
||||||
|
CrossSection._current_node_id,
|
||||||
|
{
|
||||||
|
"image": image_uri,
|
||||||
|
"x1": marker_x1, "y1": marker_y1,
|
||||||
|
"x2": marker_x2, "y2": marker_y2,
|
||||||
|
"a_locked": point_a is not None,
|
||||||
|
"b_locked": point_b is not None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return (profile.astype(np.float64),)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LineMath — single scalar measurement from a LINE profile
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _safe_rq(d):
|
||||||
|
"""RMS of deviations from mean."""
|
||||||
|
return float(np.sqrt(np.mean(d * d)))
|
||||||
|
|
||||||
|
# Registry: name → (function(z) → float, unit_label)
|
||||||
|
# All functions receive the raw 1-D profile as float64.
|
||||||
|
LINE_OPS: dict[str, tuple] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _line_op(name, unit=""):
|
||||||
|
"""Decorator to register a LINE operation."""
|
||||||
|
def decorator(fn):
|
||||||
|
LINE_OPS[name] = (fn, unit)
|
||||||
|
return fn
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# ── Basic statistics ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@_line_op("min")
|
||||||
|
def _op_min(z):
|
||||||
|
return float(z.min())
|
||||||
|
|
||||||
|
@_line_op("max")
|
||||||
|
def _op_max(z):
|
||||||
|
return float(z.max())
|
||||||
|
|
||||||
|
@_line_op("mean")
|
||||||
|
def _op_mean(z):
|
||||||
|
return float(z.mean())
|
||||||
|
|
||||||
|
@_line_op("median")
|
||||||
|
def _op_median(z):
|
||||||
|
return float(np.median(z))
|
||||||
|
|
||||||
|
@_line_op("sum")
|
||||||
|
def _op_sum(z):
|
||||||
|
return float(z.sum())
|
||||||
|
|
||||||
|
@_line_op("range")
|
||||||
|
def _op_range(z):
|
||||||
|
return float(z.max() - z.min())
|
||||||
|
|
||||||
|
@_line_op("length", unit="pts")
|
||||||
|
def _op_length(z):
|
||||||
|
return float(len(z))
|
||||||
|
|
||||||
|
@_line_op("rms")
|
||||||
|
def _op_rms(z):
|
||||||
|
return float(np.sqrt(np.mean(z * z)))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Roughness parameters ──────────────────────────
|
||||||
|
|
||||||
|
@_line_op("Ra")
|
||||||
|
def _op_ra(z):
|
||||||
|
return float(np.mean(np.abs(z - z.mean())))
|
||||||
|
|
||||||
|
@_line_op("Rq")
|
||||||
|
def _op_rq(z):
|
||||||
|
d = z - z.mean()
|
||||||
|
return _safe_rq(d)
|
||||||
|
|
||||||
|
@_line_op("Rsk")
|
||||||
|
def _op_rsk(z):
|
||||||
|
d = z - z.mean()
|
||||||
|
rq = _safe_rq(d)
|
||||||
|
return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0
|
||||||
|
|
||||||
|
@_line_op("Rku")
|
||||||
|
def _op_rku(z):
|
||||||
|
d = z - z.mean()
|
||||||
|
rq = _safe_rq(d)
|
||||||
|
return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0
|
||||||
|
|
||||||
|
@_line_op("Rp")
|
||||||
|
def _op_rp(z):
|
||||||
|
return float((z - z.mean()).max())
|
||||||
|
|
||||||
|
@_line_op("Rv")
|
||||||
|
def _op_rv(z):
|
||||||
|
return float(-(z - z.mean()).min())
|
||||||
|
|
||||||
|
@_line_op("Rt")
|
||||||
|
def _op_rt(z):
|
||||||
|
d = z - z.mean()
|
||||||
|
return float(d.max() - d.min())
|
||||||
|
|
||||||
|
@_line_op("Dq")
|
||||||
|
def _op_dq(z):
|
||||||
|
"""RMS slope (first derivative RMS)."""
|
||||||
|
dz = np.diff(z)
|
||||||
|
return float(np.sqrt(np.mean(dz * dz)))
|
||||||
|
|
||||||
|
@_line_op("Da")
|
||||||
|
def _op_da(z):
|
||||||
|
"""Mean absolute slope."""
|
||||||
|
return float(np.mean(np.abs(np.diff(z))))
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Line Math")
|
||||||
|
class LineMath:
|
||||||
|
"""Compute a single scalar value from a LINE profile."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"line": ("LINE",),
|
||||||
|
"operation": (list(LINE_OPS.keys()),),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("TABLE",)
|
||||||
|
RETURN_NAMES = ("result",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "analysis"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Compute a single scalar measurement from a LINE profile. "
|
||||||
|
"Includes basic stats and Gwyddion-convention roughness parameters."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, line, operation: str) -> tuple:
|
||||||
|
z = np.asarray(line, dtype=np.float64).ravel()
|
||||||
|
fn, unit = LINE_OPS[operation]
|
||||||
|
value = fn(z)
|
||||||
|
table = [{"quantity": operation, "value": value, "unit": unit}]
|
||||||
|
return (table,)
|
||||||
165
backend/nodes/display.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
Display / output nodes.
|
||||||
|
|
||||||
|
Preview accepts both DATA_FIELD and IMAGE via optional inputs —
|
||||||
|
connect whichever type you have. The server injects _broadcast_fn
|
||||||
|
before execution begins.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Preview")
|
||||||
|
class PreviewImage:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "preview"
|
||||||
|
CATEGORY = "display"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
DESCRIPTION = "Display an IMAGE or DATA_FIELD as a coloured thumbnail. Connect either input."
|
||||||
|
|
||||||
|
_broadcast_fn = None
|
||||||
|
_current_node_id: str = ""
|
||||||
|
|
||||||
|
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
|
||||||
|
# Prefer field if both are connected; accept whichever is provided
|
||||||
|
if field is not None:
|
||||||
|
arr_u8 = datafield_to_uint8(field, colormap)
|
||||||
|
elif image is not None:
|
||||||
|
if image.dtype != np.uint8:
|
||||||
|
imin, imax = image.min(), image.max()
|
||||||
|
if imax > imin:
|
||||||
|
norm = (image - imin) / (imax - imin)
|
||||||
|
else:
|
||||||
|
norm = np.zeros_like(image)
|
||||||
|
arr_u8 = (norm * 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
arr_u8 = image
|
||||||
|
|
||||||
|
if arr_u8.ndim == 2 and colormap != "gray":
|
||||||
|
import matplotlib.cm as cm
|
||||||
|
cmap = cm.get_cmap(colormap)
|
||||||
|
rgba = cmap(arr_u8.astype(np.float32) / 255.0)
|
||||||
|
arr_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
raise ValueError("Connect either an IMAGE or DATA_FIELD input to Preview.")
|
||||||
|
|
||||||
|
data_uri = encode_preview(arr_u8)
|
||||||
|
|
||||||
|
if PreviewImage._broadcast_fn is not None:
|
||||||
|
PreviewImage._broadcast_fn(PreviewImage._current_node_id, data_uri)
|
||||||
|
|
||||||
|
return ()
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="3D View")
|
||||||
|
class View3D:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],),
|
||||||
|
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
|
||||||
|
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "render"
|
||||||
|
CATEGORY = "display"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Interactive 3D surface view of a DATA_FIELD. "
|
||||||
|
"Drag to rotate, scroll to zoom. z_scale exaggerates height."
|
||||||
|
)
|
||||||
|
|
||||||
|
_broadcast_mesh_fn = None
|
||||||
|
_current_node_id: str = ""
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self, field: DataField,
|
||||||
|
colormap: str, z_scale: float, resolution: int,
|
||||||
|
) -> tuple:
|
||||||
|
import matplotlib.cm as cm
|
||||||
|
import base64
|
||||||
|
|
||||||
|
data = field.data
|
||||||
|
yres, xres = data.shape
|
||||||
|
|
||||||
|
# Downsample if larger than resolution
|
||||||
|
step_y = max(1, yres // resolution)
|
||||||
|
step_x = max(1, xres // resolution)
|
||||||
|
z = data[::step_y, ::step_x].astype(np.float32)
|
||||||
|
ny, nx = z.shape
|
||||||
|
|
||||||
|
# Normalize for colormap
|
||||||
|
zmin, zmax = float(z.min()), float(z.max())
|
||||||
|
if zmax > zmin:
|
||||||
|
z_norm = (z - zmin) / (zmax - zmin)
|
||||||
|
else:
|
||||||
|
z_norm = np.zeros_like(z)
|
||||||
|
|
||||||
|
cmap = cm.get_cmap(colormap)
|
||||||
|
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
|
||||||
|
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Base64-encode arrays for efficient WS transport
|
||||||
|
z_b64 = base64.b64encode(z.tobytes()).decode()
|
||||||
|
colors_b64 = base64.b64encode(colors_u8.tobytes()).decode()
|
||||||
|
|
||||||
|
mesh_data = {
|
||||||
|
"width": nx,
|
||||||
|
"height": ny,
|
||||||
|
"z_data": z_b64,
|
||||||
|
"colors": colors_b64,
|
||||||
|
"z_min": zmin,
|
||||||
|
"z_max": zmax,
|
||||||
|
"z_scale": float(z_scale),
|
||||||
|
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
|
||||||
|
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
|
||||||
|
}
|
||||||
|
|
||||||
|
if View3D._broadcast_mesh_fn is not None:
|
||||||
|
View3D._broadcast_mesh_fn(View3D._current_node_id, mesh_data)
|
||||||
|
|
||||||
|
return ()
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Print Table")
|
||||||
|
class PrintTable:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"table": ("TABLE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "print_table"
|
||||||
|
CATEGORY = "display"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
DESCRIPTION = "Send a TABLE to the browser as a WebSocket message for display."
|
||||||
|
|
||||||
|
_broadcast_table_fn = None
|
||||||
|
_current_node_id: str = ""
|
||||||
|
|
||||||
|
def print_table(self, table: list) -> tuple:
|
||||||
|
if PrintTable._broadcast_table_fn is not None:
|
||||||
|
PrintTable._broadcast_table_fn(PrintTable._current_node_id, table)
|
||||||
|
return ()
|
||||||
115
backend/nodes/filters.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""
|
||||||
|
Filter nodes — Gwyddion-equivalent image filters.
|
||||||
|
|
||||||
|
Gwyddion equivalents:
|
||||||
|
GaussianFilter → gwy_data_field_filter_gaussian
|
||||||
|
MedianFilter → gwy_data_field_filter_median
|
||||||
|
EdgeDetect → gwy_data_field_filter_sobel / laplacian / log
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GaussianFilter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Gaussian Filter")
|
||||||
|
class GaussianFilter:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"sigma": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 50.0, "step": 0.1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD",)
|
||||||
|
RETURN_NAMES = ("filtered",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "filters"
|
||||||
|
DESCRIPTION = "Apply a Gaussian blur. Equivalent to gwy_data_field_filter_gaussian."
|
||||||
|
|
||||||
|
def process(self, field: DataField, sigma: float) -> tuple:
|
||||||
|
from scipy.ndimage import gaussian_filter
|
||||||
|
data = gaussian_filter(field.data.copy(), sigma=float(sigma))
|
||||||
|
return (field.replace(data=data),)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MedianFilter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Median Filter")
|
||||||
|
class MedianFilter:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"size": ("INT", {"default": 3, "min": 1, "max": 21, "step": 2}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD",)
|
||||||
|
RETURN_NAMES = ("filtered",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "filters"
|
||||||
|
DESCRIPTION = "Apply a median filter. Equivalent to gwy_data_field_filter_median."
|
||||||
|
|
||||||
|
def process(self, field: DataField, size: int) -> tuple:
|
||||||
|
from scipy.ndimage import median_filter
|
||||||
|
size = max(1, int(size))
|
||||||
|
data = median_filter(field.data.copy(), size=size)
|
||||||
|
return (field.replace(data=data),)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EdgeDetect
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Edge Detect")
|
||||||
|
class EdgeDetect:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"method": (["sobel", "prewitt", "laplacian", "log"],),
|
||||||
|
"sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD",)
|
||||||
|
RETURN_NAMES = ("edges",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "filters"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Detect edges using Sobel, Prewitt, Laplacian, or LoG operators. "
|
||||||
|
"Equivalent to gwy_data_field_filter_sobel / gwy_data_field_filter_laplacian."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, method: str, sigma: float) -> tuple:
|
||||||
|
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
|
||||||
|
data = field.data.copy()
|
||||||
|
|
||||||
|
if method == "sobel":
|
||||||
|
sx = sobel(data, axis=1)
|
||||||
|
sy = sobel(data, axis=0)
|
||||||
|
result = np.hypot(sx, sy)
|
||||||
|
elif method == "prewitt":
|
||||||
|
px = prewitt(data, axis=1)
|
||||||
|
py = prewitt(data, axis=0)
|
||||||
|
result = np.hypot(px, py)
|
||||||
|
elif method == "laplacian":
|
||||||
|
result = laplace(data)
|
||||||
|
elif method == "log":
|
||||||
|
result = gaussian_laplace(data, sigma=float(sigma))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown edge detection method: {method}")
|
||||||
|
|
||||||
|
return (field.replace(data=result),)
|
||||||
127
backend/nodes/grains.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
Grain/feature detection nodes.
|
||||||
|
|
||||||
|
Gwyddion equivalents:
|
||||||
|
ThresholdMask → threshold.c / otsu_threshold.c
|
||||||
|
GrainAnalysis → gwy_data_field_grains_get_values (grains-values.c)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ThresholdMask
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Threshold Mask")
|
||||||
|
class ThresholdMask:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"method": (["otsu", "absolute", "relative"],),
|
||||||
|
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
|
||||||
|
"direction": (["above", "below"],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("mask",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "grains"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Create a binary mask by thresholding data. "
|
||||||
|
"Otsu automatically finds the optimal threshold. "
|
||||||
|
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
|
||||||
|
data = field.data
|
||||||
|
|
||||||
|
if method == "otsu":
|
||||||
|
from skimage.filters import threshold_otsu
|
||||||
|
t = threshold_otsu(data)
|
||||||
|
elif method == "absolute":
|
||||||
|
t = float(threshold)
|
||||||
|
elif method == "relative":
|
||||||
|
# threshold is a fraction [0, 1] of the data range
|
||||||
|
dmin, dmax = data.min(), data.max()
|
||||||
|
t = dmin + float(threshold) * (dmax - dmin)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown threshold method: {method}")
|
||||||
|
|
||||||
|
if direction == "above":
|
||||||
|
mask = (data >= t).astype(np.uint8) * 255
|
||||||
|
else:
|
||||||
|
mask = (data < t).astype(np.uint8) * 255
|
||||||
|
|
||||||
|
return (mask,)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GrainAnalysis
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Grain Analysis")
|
||||||
|
class GrainAnalysis:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"mask": ("IMAGE",),
|
||||||
|
"min_size": ("INT", {"default": 10, "min": 1, "max": 100000, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("TABLE",)
|
||||||
|
RETURN_NAMES = ("grain_stats",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "grains"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Label connected grain regions in a binary mask and compute per-grain statistics: "
|
||||||
|
"area, equivalent diameter, mean/max height, bounding box. "
|
||||||
|
"Equivalent to gwy_data_field_grains_get_values."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:
|
||||||
|
from scipy.ndimage import label, find_objects
|
||||||
|
|
||||||
|
binary = (mask > 127).astype(np.int32)
|
||||||
|
labeled, n_grains = label(binary)
|
||||||
|
|
||||||
|
pixel_area = field.dx * field.dy # m^2 per pixel
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for grain_id in range(1, n_grains + 1):
|
||||||
|
grain_pixels = labeled == grain_id
|
||||||
|
area_px = int(grain_pixels.sum())
|
||||||
|
if area_px < min_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
area_m2 = area_px * pixel_area
|
||||||
|
equiv_diam = float(2.0 * np.sqrt(area_m2 / np.pi))
|
||||||
|
|
||||||
|
heights = field.data[grain_pixels]
|
||||||
|
mean_h = float(heights.mean())
|
||||||
|
max_h = float(heights.max())
|
||||||
|
|
||||||
|
# Bounding box
|
||||||
|
ys, xs = np.where(grain_pixels)
|
||||||
|
bbox = f"({int(xs.min())},{int(ys.min())})-({int(xs.max())},{int(ys.max())})"
|
||||||
|
|
||||||
|
rows.append({
|
||||||
|
"grain_id": grain_id,
|
||||||
|
"area_px": area_px,
|
||||||
|
"area_m2": area_m2,
|
||||||
|
"equiv_diam_m": equiv_diam,
|
||||||
|
"mean_height": mean_h,
|
||||||
|
"max_height": max_h,
|
||||||
|
"bbox": bbox,
|
||||||
|
})
|
||||||
|
|
||||||
|
return (rows,)
|
||||||
277
backend/nodes/io.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""
|
||||||
|
I/O nodes: load and save images and SPM data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField, encode_preview, image_to_uint8
|
||||||
|
|
||||||
|
# Resolved at server startup so nodes know where to look
|
||||||
|
INPUT_DIR = Path(__file__).parent.parent.parent / "input"
|
||||||
|
OUTPUT_DIR = Path(__file__).parent.parent.parent / "output"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LoadImage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Load Image")
|
||||||
|
class LoadImage:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"filename": ("FILE_PICKER", {"default": ""}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
|
||||||
|
RETURN_NAMES = ("image", "field")
|
||||||
|
FUNCTION = "load"
|
||||||
|
CATEGORY = "io"
|
||||||
|
DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD."
|
||||||
|
|
||||||
|
def load(self, filename: str):
|
||||||
|
# Accept absolute paths or filenames relative to input/
|
||||||
|
path = Path(filename)
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = INPUT_DIR / filename
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {path}")
|
||||||
|
|
||||||
|
ext = path.suffix.lower()
|
||||||
|
if ext in (".npy",):
|
||||||
|
arr = np.load(str(path)).astype(np.float64)
|
||||||
|
elif ext in (".npz",):
|
||||||
|
npz = np.load(str(path))
|
||||||
|
key = list(npz.files)[0]
|
||||||
|
arr = npz[key].astype(np.float64)
|
||||||
|
else:
|
||||||
|
from PIL import Image
|
||||||
|
img = Image.open(str(path))
|
||||||
|
arr = np.array(img)
|
||||||
|
if arr.dtype != np.uint8:
|
||||||
|
arr = arr.astype(np.float64)
|
||||||
|
|
||||||
|
# Convert to float64 grayscale for the DATA_FIELD output
|
||||||
|
if arr.ndim == 3:
|
||||||
|
gray = np.mean(arr.astype(np.float64), axis=2)
|
||||||
|
else:
|
||||||
|
gray = arr.astype(np.float64)
|
||||||
|
|
||||||
|
field = DataField(data=gray)
|
||||||
|
return (arr, field)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LoadSPM
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Load SPM File")
|
||||||
|
class LoadSPM:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"filename": ("FILE_PICKER", {"default": ""}),
|
||||||
|
"channel": ("STRING", {"default": "Z"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD",)
|
||||||
|
RETURN_NAMES = ("field",)
|
||||||
|
FUNCTION = "load"
|
||||||
|
CATEGORY = "io"
|
||||||
|
DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField."
|
||||||
|
|
||||||
|
def load(self, filename: str, channel: str = "Z"):
|
||||||
|
path = Path(filename)
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = INPUT_DIR / filename
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {path}")
|
||||||
|
|
||||||
|
ext = path.suffix.lower()
|
||||||
|
|
||||||
|
if ext == ".gwy":
|
||||||
|
return (self._load_gwy(path, channel),)
|
||||||
|
elif ext == ".sxm":
|
||||||
|
return (self._load_sxm(path, channel),)
|
||||||
|
elif ext in (".ibw",):
|
||||||
|
return (self._load_ibw(path),)
|
||||||
|
elif ext in (".npy",):
|
||||||
|
data = np.load(str(path)).astype(np.float64)
|
||||||
|
return (DataField(data=data),)
|
||||||
|
elif ext in (".npz",):
|
||||||
|
npz = np.load(str(path))
|
||||||
|
key = list(npz.files)[0]
|
||||||
|
return (DataField(data=npz[key].astype(np.float64)),)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz")
|
||||||
|
|
||||||
|
def _load_gwy(self, path: Path, channel: str) -> DataField:
|
||||||
|
try:
|
||||||
|
import gwyfile
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
|
||||||
|
|
||||||
|
obj = gwyfile.load(str(path))
|
||||||
|
channels = gwyfile.util.get_datafields(obj)
|
||||||
|
if not channels:
|
||||||
|
raise ValueError(f"No data channels found in {path.name}")
|
||||||
|
|
||||||
|
# Try requested channel name, fall back to first available
|
||||||
|
ch = None
|
||||||
|
for key, df in channels.items():
|
||||||
|
if channel.lower() in key.lower():
|
||||||
|
ch = df
|
||||||
|
break
|
||||||
|
if ch is None:
|
||||||
|
ch = next(iter(channels.values()))
|
||||||
|
|
||||||
|
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
|
||||||
|
return DataField(
|
||||||
|
data=data,
|
||||||
|
xreal=float(ch.xreal),
|
||||||
|
yreal=float(ch.yreal),
|
||||||
|
xoff=float(getattr(ch, "xoff", 0.0)),
|
||||||
|
yoff=float(getattr(ch, "yoff", 0.0)),
|
||||||
|
si_unit_xy="m",
|
||||||
|
si_unit_z="m",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_sxm(self, path: Path, channel: str) -> DataField:
|
||||||
|
try:
|
||||||
|
import nanonispy as nap
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
|
||||||
|
|
||||||
|
sxm = nap.read.Scan(str(path))
|
||||||
|
signals = sxm.signals
|
||||||
|
|
||||||
|
# Pick channel
|
||||||
|
ch_key = None
|
||||||
|
for key in signals:
|
||||||
|
if channel.upper() in key.upper():
|
||||||
|
ch_key = key
|
||||||
|
break
|
||||||
|
if ch_key is None:
|
||||||
|
ch_key = next(iter(signals))
|
||||||
|
|
||||||
|
data = signals[ch_key].get("forward", list(signals[ch_key].values())[0])
|
||||||
|
data = np.asarray(data, dtype=np.float64)
|
||||||
|
if data.ndim != 2:
|
||||||
|
data = data.reshape(data.shape[-2], data.shape[-1])
|
||||||
|
|
||||||
|
header = sxm.header
|
||||||
|
scan_range = header.get("scan_range", [1e-6, 1e-6])
|
||||||
|
return DataField(
|
||||||
|
data=data,
|
||||||
|
xreal=float(scan_range[0]),
|
||||||
|
yreal=float(scan_range[1]),
|
||||||
|
si_unit_xy="m",
|
||||||
|
si_unit_z="m",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_ibw(self, path: Path) -> DataField:
|
||||||
|
try:
|
||||||
|
import igor.igorpy as igorpy
|
||||||
|
wave = igorpy.load(str(path))
|
||||||
|
data = wave.wave["wData"].squeeze().astype(np.float64)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
|
||||||
|
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
elif data.ndim != 2:
|
||||||
|
data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1)
|
||||||
|
|
||||||
|
return DataField(data=data, si_unit_xy="m", si_unit_z="m")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Coordinate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Coordinate")
|
||||||
|
class Coordinate:
|
||||||
|
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("COORD",)
|
||||||
|
RETURN_NAMES = ("point",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "io"
|
||||||
|
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
|
||||||
|
|
||||||
|
def process(self, x: float, y: float) -> tuple:
|
||||||
|
return ((float(x), float(y)),)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SaveImage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Save Image")
|
||||||
|
class SaveImage:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "output"}),
|
||||||
|
"format": (["PNG", "TIFF", "NPY"],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
CATEGORY = "io"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
DESCRIPTION = "Save an image or array to the output folder."
|
||||||
|
|
||||||
|
# Injected by server.py before execution begins
|
||||||
|
_broadcast_preview = None
|
||||||
|
|
||||||
|
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"):
|
||||||
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Find next available filename
|
||||||
|
idx = 1
|
||||||
|
while True:
|
||||||
|
name = f"{filename_prefix}_{idx:04d}"
|
||||||
|
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
|
||||||
|
if not candidate.exists():
|
||||||
|
break
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
if format == "NPY":
|
||||||
|
np.save(str(OUTPUT_DIR / f"{name}.npy"), image)
|
||||||
|
else:
|
||||||
|
from PIL import Image
|
||||||
|
arr = image_to_uint8(image)
|
||||||
|
if arr.ndim == 2:
|
||||||
|
pil_img = Image.fromarray(arr, mode="L")
|
||||||
|
else:
|
||||||
|
pil_img = Image.fromarray(arr, mode="RGB")
|
||||||
|
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
|
||||||
|
|
||||||
|
# Emit preview over WebSocket if callback is set
|
||||||
|
if SaveImage._broadcast_preview is not None:
|
||||||
|
arr_u8 = image_to_uint8(image)
|
||||||
|
data_uri = encode_preview(arr_u8)
|
||||||
|
SaveImage._broadcast_preview(data_uri)
|
||||||
|
|
||||||
|
return ()
|
||||||
150
backend/nodes/level.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""
|
||||||
|
Leveling nodes — background removal and zero correction.
|
||||||
|
|
||||||
|
Gwyddion equivalents:
|
||||||
|
PlaneLevelField → gwy_data_field_fit_plane + gwy_data_field_plane_level
|
||||||
|
PolyLevelField → gwy_data_field_fit_polynom (via level.c polylevel module)
|
||||||
|
FixZero → fix_zero in level.c
|
||||||
|
|
||||||
|
Plane-fit algorithm follows Gwyddion's level.h definition:
|
||||||
|
z_fit = pa + pbx * x + pby * y (least-squares over all pixels)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PlaneLevelField
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Plane Level")
|
||||||
|
class PlaneLevelField:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD",)
|
||||||
|
RETURN_NAMES = ("leveled",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "level"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Fit and subtract a least-squares plane from the data. "
|
||||||
|
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField) -> tuple:
|
||||||
|
data = field.data.copy()
|
||||||
|
yres, xres = data.shape
|
||||||
|
|
||||||
|
# Normalised coordinate grids in [0, 1]
|
||||||
|
x = np.linspace(0.0, 1.0, xres)
|
||||||
|
y = np.linspace(0.0, 1.0, yres)
|
||||||
|
xx, yy = np.meshgrid(x, y)
|
||||||
|
|
||||||
|
# Design matrix: [1, x, y] shape (N, 3)
|
||||||
|
A = np.column_stack([
|
||||||
|
np.ones(xres * yres),
|
||||||
|
xx.ravel(),
|
||||||
|
yy.ravel(),
|
||||||
|
])
|
||||||
|
z = data.ravel()
|
||||||
|
|
||||||
|
# Least-squares: solve A @ [pa, pbx, pby] = z
|
||||||
|
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
|
||||||
|
pa, pbx, pby = coeffs
|
||||||
|
|
||||||
|
plane = (pa + pbx * xx + pby * yy)
|
||||||
|
return (field.replace(data=data - plane),)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PolyLevelField
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Polynomial Level")
|
||||||
|
class PolyLevelField:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"degree_x": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
|
||||||
|
"degree_y": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD")
|
||||||
|
RETURN_NAMES = ("leveled", "background")
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "level"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Fit and subtract a polynomial background of given degree in x and y. "
|
||||||
|
"Equivalent to gwy_data_field_fit_polynom."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, degree_x: int, degree_y: int) -> tuple:
|
||||||
|
data = field.data.copy()
|
||||||
|
yres, xres = data.shape
|
||||||
|
|
||||||
|
x = np.linspace(0.0, 1.0, xres)
|
||||||
|
y = np.linspace(0.0, 1.0, yres)
|
||||||
|
xx, yy = np.meshgrid(x, y)
|
||||||
|
|
||||||
|
# Build Vandermonde-style design matrix with all monomials x^i * y^j
|
||||||
|
cols = []
|
||||||
|
for i in range(degree_x + 1):
|
||||||
|
for j in range(degree_y + 1):
|
||||||
|
cols.append((xx ** i * yy ** j).ravel())
|
||||||
|
A = np.column_stack(cols)
|
||||||
|
z = data.ravel()
|
||||||
|
|
||||||
|
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
|
||||||
|
|
||||||
|
background = (A @ coeffs).reshape(yres, xres)
|
||||||
|
leveled = data - background
|
||||||
|
|
||||||
|
return (field.replace(data=leveled), field.replace(data=background))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FixZero
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@register_node(display_name="Fix Zero")
|
||||||
|
class FixZero:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"method": (["min", "mean", "median"],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("DATA_FIELD",)
|
||||||
|
RETURN_NAMES = ("zeroed",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "level"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Shift data so that the minimum (or mean/median) is zero. "
|
||||||
|
"Equivalent to fix_zero in Gwyddion's level.c."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, method: str) -> tuple:
|
||||||
|
data = field.data.copy()
|
||||||
|
if method == "min":
|
||||||
|
data -= data.min()
|
||||||
|
elif method == "mean":
|
||||||
|
data -= data.mean()
|
||||||
|
elif method == "median":
|
||||||
|
data -= np.median(data)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {method}")
|
||||||
|
return (field.replace(data=data),)
|
||||||
267
backend/server.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
aiohttp web server for argonode.
|
||||||
|
|
||||||
|
Routes
|
||||||
|
------
|
||||||
|
GET / → serve frontend/index.html
|
||||||
|
GET /static/{path} → serve frontend JS/CSS
|
||||||
|
GET /nodes → JSON dict of all registered node definitions
|
||||||
|
POST /upload → multipart file upload to input/
|
||||||
|
POST /prompt → submit a workflow; returns {prompt_id}
|
||||||
|
GET /ws → WebSocket upgrade
|
||||||
|
|
||||||
|
WebSocket message types sent to clients
|
||||||
|
----------------------------------------
|
||||||
|
{"type": "execution_start", "data": {"prompt_id": "..."}}
|
||||||
|
{"type": "executing", "data": {"node": "...", "prompt_id": "..."}}
|
||||||
|
{"type": "preview", "data": {"node_id": "...", "image": "data:..."}}
|
||||||
|
{"type": "table", "data": {"node_id": "...", "rows": [...]}}
|
||||||
|
{"type": "execution_error", "data": {"node_id": "...", "message": "..."}}
|
||||||
|
{"type": "execution_complete", "data": {"prompt_id": "..."}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from aiohttp import web, WSMsgType
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FRONTEND_DIR = Path(__file__).parent.parent / "frontend"
|
||||||
|
DIST_DIR = FRONTEND_DIR / "dist"
|
||||||
|
INPUT_DIR = Path(__file__).parent.parent / "input"
|
||||||
|
OUTPUT_DIR = Path(__file__).parent.parent / "output"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JSON helper — numpy scalars are not serialisable by default
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class _SafeEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
import numpy as np
|
||||||
|
if isinstance(obj, (np.integer,)):
|
||||||
|
return int(obj)
|
||||||
|
if isinstance(obj, (np.floating,)):
|
||||||
|
return float(obj)
|
||||||
|
if isinstance(obj, np.ndarray):
|
||||||
|
return obj.tolist()
|
||||||
|
return super().default(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _dumps(obj) -> str:
|
||||||
|
return json.dumps(obj, cls=_SafeEncoder)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Application factory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||||
|
# Import nodes to trigger registration decorators
|
||||||
|
import backend.nodes # noqa: F401
|
||||||
|
from backend.node_registry import get_all_node_info
|
||||||
|
from backend.execution import ExecutionEngine, new_prompt_id
|
||||||
|
|
||||||
|
INPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
engine = ExecutionEngine()
|
||||||
|
websockets: set[web.WebSocketResponse] = set()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# WebSocket broadcast helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def broadcast(msg: dict) -> None:
|
||||||
|
"""Schedule a broadcast to all connected WebSocket clients."""
|
||||||
|
payload = _dumps(msg)
|
||||||
|
for ws in list(websockets):
|
||||||
|
if not ws.closed:
|
||||||
|
asyncio.run_coroutine_threadsafe(ws.send_str(payload), loop)
|
||||||
|
|
||||||
|
def on_preview(node_id: str, data_uri: str) -> None:
|
||||||
|
broadcast({"type": "preview", "data": {"node_id": node_id, "image": data_uri}})
|
||||||
|
|
||||||
|
def on_table(node_id: str, rows: list) -> None:
|
||||||
|
broadcast({"type": "table", "data": {"node_id": node_id, "rows": rows}})
|
||||||
|
|
||||||
|
def on_mesh(node_id: str, mesh_data: dict) -> None:
|
||||||
|
broadcast({"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}})
|
||||||
|
|
||||||
|
def on_overlay(node_id: str, overlay_data) -> None:
|
||||||
|
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Route handlers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def index(request: web.Request) -> web.Response:
|
||||||
|
# Serve Vite build output if available, else raw frontend
|
||||||
|
if (DIST_DIR / "index.html").exists():
|
||||||
|
return web.FileResponse(DIST_DIR / "index.html")
|
||||||
|
return web.FileResponse(FRONTEND_DIR / "index.html")
|
||||||
|
|
||||||
|
async def get_nodes(request: web.Request) -> web.Response:
|
||||||
|
info = get_all_node_info()
|
||||||
|
return web.Response(
|
||||||
|
text=_dumps(info),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_files(request: web.Request) -> web.Response:
|
||||||
|
"""List files in the input/ directory for the file picker widget."""
|
||||||
|
files = sorted(
|
||||||
|
f.name for f in INPUT_DIR.iterdir()
|
||||||
|
if f.is_file() and not f.name.startswith(".")
|
||||||
|
) if INPUT_DIR.exists() else []
|
||||||
|
return web.Response(text=_dumps(files), content_type="application/json")
|
||||||
|
|
||||||
|
async def browse_dir(request: web.Request) -> web.Response:
|
||||||
|
"""
|
||||||
|
Server-side directory browser for local file picking.
|
||||||
|
GET /browse?dir=/some/path → {parent, dirs[], files[]}
|
||||||
|
"""
|
||||||
|
dir_path = request.query.get("dir", str(Path.home()))
|
||||||
|
p = Path(dir_path).expanduser().resolve()
|
||||||
|
|
||||||
|
if not p.is_dir():
|
||||||
|
raise web.HTTPBadRequest(reason=f"Not a directory: {p}")
|
||||||
|
|
||||||
|
dirs = []
|
||||||
|
files = []
|
||||||
|
try:
|
||||||
|
for entry in sorted(p.iterdir(), key=lambda e: e.name.lower()):
|
||||||
|
if entry.name.startswith("."):
|
||||||
|
continue
|
||||||
|
if entry.is_dir():
|
||||||
|
dirs.append(entry.name)
|
||||||
|
elif entry.is_file():
|
||||||
|
files.append(entry.name)
|
||||||
|
except PermissionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
text=_dumps({
|
||||||
|
"path": str(p),
|
||||||
|
"parent": str(p.parent) if p.parent != p else None,
|
||||||
|
"dirs": dirs,
|
||||||
|
"files": files,
|
||||||
|
}),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def upload_file(request: web.Request) -> web.Response:
|
||||||
|
reader = await request.multipart()
|
||||||
|
field = await reader.next()
|
||||||
|
if field is None or field.name != "file":
|
||||||
|
raise web.HTTPBadRequest(reason="Expected a 'file' field in multipart body")
|
||||||
|
|
||||||
|
filename = Path(field.filename).name # strip any path traversal
|
||||||
|
dest = INPUT_DIR / filename
|
||||||
|
with open(dest, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk(65536)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
return web.Response(text=_dumps({"filename": filename}), content_type="application/json")
|
||||||
|
|
||||||
|
async def submit_prompt(request: web.Request) -> web.Response:
|
||||||
|
body = await request.json()
|
||||||
|
prompt = body.get("prompt")
|
||||||
|
if not isinstance(prompt, dict) or not prompt:
|
||||||
|
raise web.HTTPBadRequest(reason="'prompt' must be a non-empty dict")
|
||||||
|
|
||||||
|
prompt_id = new_prompt_id()
|
||||||
|
|
||||||
|
# Run execution in a thread pool so scipy doesn't block the event loop
|
||||||
|
async def run():
|
||||||
|
broadcast({"type": "execution_start", "data": {"prompt_id": prompt_id}})
|
||||||
|
|
||||||
|
def on_start(node_id: str) -> None:
|
||||||
|
broadcast({"type": "executing", "data": {"node": node_id, "prompt_id": prompt_id}})
|
||||||
|
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: engine.execute(
|
||||||
|
prompt,
|
||||||
|
on_node_start=on_start,
|
||||||
|
on_preview=on_preview,
|
||||||
|
on_table=on_table,
|
||||||
|
on_mesh=on_mesh,
|
||||||
|
on_overlay=on_overlay,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}})
|
||||||
|
except Exception as exc:
|
||||||
|
log.exception("Execution error")
|
||||||
|
broadcast({
|
||||||
|
"type": "execution_error",
|
||||||
|
"data": {"node_id": "", "message": str(exc)},
|
||||||
|
})
|
||||||
|
|
||||||
|
asyncio.ensure_future(run())
|
||||||
|
return web.Response(
|
||||||
|
text=_dumps({"prompt_id": prompt_id}),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
websockets.add(ws)
|
||||||
|
log.info("WebSocket client connected (%d total)", len(websockets))
|
||||||
|
try:
|
||||||
|
async for msg in ws:
|
||||||
|
if msg.type == WSMsgType.TEXT:
|
||||||
|
pass # clients don't need to send anything currently
|
||||||
|
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
websockets.discard(ws)
|
||||||
|
log.info("WebSocket client disconnected (%d total)", len(websockets))
|
||||||
|
return ws
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# App assembly
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
|
||||||
|
app.router.add_get("/", index)
|
||||||
|
app.router.add_get("/nodes", get_nodes)
|
||||||
|
app.router.add_get("/files", list_files)
|
||||||
|
app.router.add_get("/browse", browse_dir)
|
||||||
|
app.router.add_post("/upload", upload_file)
|
||||||
|
app.router.add_post("/prompt", submit_prompt)
|
||||||
|
app.router.add_get("/ws", websocket_handler)
|
||||||
|
|
||||||
|
# Serve frontend static files (Vite build or raw)
|
||||||
|
if DIST_DIR.exists():
|
||||||
|
app.router.add_static("/assets", DIST_DIR / "assets")
|
||||||
|
app.router.add_static("/static", FRONTEND_DIR)
|
||||||
|
|
||||||
|
# CORS — allow any origin (local dev only)
|
||||||
|
async def _cors_middleware(app_, handler):
|
||||||
|
async def middleware(request):
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return web.Response(headers={
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type",
|
||||||
|
})
|
||||||
|
response = await handler(request)
|
||||||
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
|
return response
|
||||||
|
return middleware
|
||||||
|
|
||||||
|
app.middlewares.append(_cors_middleware)
|
||||||
|
|
||||||
|
return app
|
||||||
12
frontend/index.html
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Argonode — Image Analysis</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.jsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
1951
frontend/package-lock.json
generated
Normal file
20
frontend/package.json
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"name": "argonode-frontend",
|
||||||
|
"private": true,
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "vite build",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@xyflow/react": "^12.0.0",
|
||||||
|
"react": "^18.3.0",
|
||||||
|
"react-dom": "^18.3.0",
|
||||||
|
"three": "^0.183.2"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@vitejs/plugin-react": "^4.3.0",
|
||||||
|
"vite": "^5.4.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
601
frontend/src/App.jsx
Normal file
@@ -0,0 +1,601 @@
|
|||||||
|
import React, {
|
||||||
|
useState, useCallback, useEffect, useRef, useMemo,
|
||||||
|
} from 'react';
|
||||||
|
import {
|
||||||
|
ReactFlow, Background, Controls, MiniMap,
|
||||||
|
useNodesState, useEdgesState, addEdge, useReactFlow,
|
||||||
|
ReactFlowProvider,
|
||||||
|
} from '@xyflow/react';
|
||||||
|
import '@xyflow/react/dist/style.css';
|
||||||
|
|
||||||
|
import CustomNode, { NodeContext } from './CustomNode';
|
||||||
|
import FileBrowser from './FileBrowser';
|
||||||
|
import * as api from './api';
|
||||||
|
|
||||||
|
// ── Constants ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']);
|
||||||
|
|
||||||
|
const TYPE_COLORS = {
|
||||||
|
DATA_FIELD: '#3a7abf',
|
||||||
|
IMAGE: '#4caf50',
|
||||||
|
LINE: '#ff9800',
|
||||||
|
TABLE: '#fdd835',
|
||||||
|
COORD: '#e91e63',
|
||||||
|
};
|
||||||
|
|
||||||
|
const NODE_TYPES = { custom: CustomNode };
|
||||||
|
|
||||||
|
// ── Handle ID helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
function getHandleType(handleId) {
|
||||||
|
return handleId.split('::')[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
function getInputName(handleId) {
|
||||||
|
return handleId.split('::')[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
function getOutputSlot(handleId) {
|
||||||
|
return parseInt(handleId.split('::')[1], 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Graph serialisation → backend prompt format ───────────────────────
|
||||||
|
|
||||||
|
function serializeGraph(nodes, edges) {
|
||||||
|
const prompt = {};
|
||||||
|
|
||||||
|
for (const node of nodes) {
|
||||||
|
const { className, definition, widgetValues } = node.data;
|
||||||
|
if (!definition) continue;
|
||||||
|
|
||||||
|
const inputs = {};
|
||||||
|
|
||||||
|
// Widget (scalar) values
|
||||||
|
const required = definition.input.required || {};
|
||||||
|
for (const [name, spec] of Object.entries(required)) {
|
||||||
|
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||||
|
if (DATA_TYPES.has(type)) continue; // socket, handled via edges
|
||||||
|
if (widgetValues[name] !== undefined) {
|
||||||
|
inputs[name] = widgetValues[name];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connected (socket) inputs from edges
|
||||||
|
const incoming = edges.filter((e) => e.target === node.id);
|
||||||
|
for (const edge of incoming) {
|
||||||
|
const inputName = getInputName(edge.targetHandle);
|
||||||
|
const outputSlot = getOutputSlot(edge.sourceHandle);
|
||||||
|
inputs[inputName] = [edge.source, outputSlot];
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt[node.id] = { class_type: className, inputs };
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Context menu component ────────────────────────────────────────────
|
||||||
|
|
||||||
|
function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirection }) {
|
||||||
|
// Group by category, optionally filtering to compatible nodes
|
||||||
|
const categories = {};
|
||||||
|
for (const [className, def] of Object.entries(nodeDefs)) {
|
||||||
|
// If filtering: only show nodes with a matching input or output
|
||||||
|
if (filterType && filterDirection) {
|
||||||
|
if (filterDirection === 'source') {
|
||||||
|
// Dragged from an output — show nodes that have a matching INPUT
|
||||||
|
const req = def.input.required || {};
|
||||||
|
const opt = def.input.optional || {};
|
||||||
|
const allInputs = { ...req, ...opt };
|
||||||
|
const hasMatch = Object.values(allInputs).some((spec) => {
|
||||||
|
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||||
|
return type === filterType;
|
||||||
|
});
|
||||||
|
if (!hasMatch) continue;
|
||||||
|
} else {
|
||||||
|
// Dragged from an input — show nodes that have a matching OUTPUT
|
||||||
|
if (!def.output.includes(filterType)) continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const cat = def.category || 'uncategorized';
|
||||||
|
if (!categories[cat]) categories[cat] = [];
|
||||||
|
categories[cat].push({ className, def });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Object.keys(categories).length === 0) {
|
||||||
|
return (
|
||||||
|
<div className="context-menu" style={{ left: x, top: y }} onClick={(e) => e.stopPropagation()}>
|
||||||
|
<div className="context-item" style={{ color: '#64748b' }}>No compatible nodes</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className="context-menu"
|
||||||
|
style={{ left: x, top: y }}
|
||||||
|
onClick={(e) => e.stopPropagation()}
|
||||||
|
>
|
||||||
|
{Object.entries(categories).map(([cat, items]) => (
|
||||||
|
<div key={cat}>
|
||||||
|
<div className="context-category">{cat}</div>
|
||||||
|
{items.map(({ className, def }) => (
|
||||||
|
<div
|
||||||
|
key={className}
|
||||||
|
className="context-item"
|
||||||
|
onClick={() => { onAdd(className, def); onClose(); }}
|
||||||
|
>
|
||||||
|
{def.display_name || className}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Main flow component (needs ReactFlowProvider ancestor) ────────────
|
||||||
|
|
||||||
|
function Flow() {
|
||||||
|
const [nodes, setNodes, onNodesChange] = useNodesState([]);
|
||||||
|
const [edges, setEdges, onEdgesChange] = useEdgesState([]);
|
||||||
|
const [status, setStatus] = useState({ text: 'Connecting…', level: 'info' });
|
||||||
|
const [contextMenu, setContextMenu] = useState(null);
|
||||||
|
const [fileBrowserCb, setFileBrowserCb] = useState(null);
|
||||||
|
|
||||||
|
const nodeDefsRef = useRef({});
|
||||||
|
const nextIdRef = useRef(1);
|
||||||
|
const autoRunTimer = useRef(null);
|
||||||
|
const autoRunRef = useRef(null);
|
||||||
|
const reactFlow = useReactFlow();
|
||||||
|
|
||||||
|
// ── Load node definitions ───────────────────────────────────────────
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
api.getNodes().then((defs) => {
|
||||||
|
nodeDefsRef.current = defs;
|
||||||
|
setStatus({ text: `Loaded ${Object.keys(defs).length} nodes.`, level: 'info' });
|
||||||
|
}).catch((err) => {
|
||||||
|
setStatus({ text: 'Failed to load nodes: ' + err.message, level: 'error' });
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ── WebSocket ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const updateNodeData = useCallback((nodeId, patch) => {
|
||||||
|
setNodes((ns) => ns.map((n) =>
|
||||||
|
n.id !== nodeId ? n : { ...n, data: { ...n.data, ...patch } }
|
||||||
|
));
|
||||||
|
}, [setNodes]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
api.setMessageHandler((msg) => {
|
||||||
|
console.log('[argonode] WS:', msg.type, msg.data?.node_id || msg.data?.node || '');
|
||||||
|
switch (msg.type) {
|
||||||
|
case 'execution_start':
|
||||||
|
setStatus({ text: 'Running workflow…', level: 'info' });
|
||||||
|
break;
|
||||||
|
case 'executing':
|
||||||
|
setStatus({ text: `Executing node ${msg.data.node}…`, level: 'info' });
|
||||||
|
break;
|
||||||
|
case 'execution_complete':
|
||||||
|
setStatus({ text: 'Done.', level: 'info' });
|
||||||
|
break;
|
||||||
|
case 'execution_error':
|
||||||
|
setStatus({ text: 'Error: ' + msg.data.message, level: 'error' });
|
||||||
|
console.error('[argonode] execution error', msg.data);
|
||||||
|
break;
|
||||||
|
case 'preview':
|
||||||
|
updateNodeData(msg.data.node_id, { previewImage: msg.data.image });
|
||||||
|
break;
|
||||||
|
case 'table':
|
||||||
|
updateNodeData(msg.data.node_id, { tableRows: msg.data.rows });
|
||||||
|
break;
|
||||||
|
case 'mesh3d':
|
||||||
|
updateNodeData(msg.data.node_id, { meshData: msg.data.mesh });
|
||||||
|
break;
|
||||||
|
case 'overlay':
|
||||||
|
updateNodeData(msg.data.node_id, { overlay: msg.data.overlay });
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
api.initWS();
|
||||||
|
return () => api.closeWS();
|
||||||
|
}, [updateNodeData]);
|
||||||
|
|
||||||
|
// ── Connection handling ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
const isValidConnection = useCallback((connection) => {
|
||||||
|
const srcType = getHandleType(connection.sourceHandle);
|
||||||
|
const tgtType = getHandleType(connection.targetHandle);
|
||||||
|
return srcType === tgtType;
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const onConnect = useCallback((params) => {
|
||||||
|
const type = getHandleType(params.sourceHandle);
|
||||||
|
const color = TYPE_COLORS[type] || '#999';
|
||||||
|
|
||||||
|
setEdges((eds) => {
|
||||||
|
// Enforce single connection per input handle
|
||||||
|
const filtered = eds.filter(
|
||||||
|
(e) => !(e.target === params.target && e.targetHandle === params.targetHandle)
|
||||||
|
);
|
||||||
|
return addEdge(
|
||||||
|
{ ...params, style: { stroke: color, strokeWidth: 2 } },
|
||||||
|
filtered
|
||||||
|
);
|
||||||
|
});
|
||||||
|
scheduleAutoRun();
|
||||||
|
}, [setEdges]);
|
||||||
|
|
||||||
|
// ── Drop-on-blank: open filtered context menu ──────────────────────
|
||||||
|
|
||||||
|
const onConnectEnd = useCallback((event, connectionState) => {
|
||||||
|
// If the connection was completed (dropped on a valid handle), do nothing
|
||||||
|
if (connectionState.isValid) return;
|
||||||
|
|
||||||
|
const fromHandle = connectionState.fromHandle;
|
||||||
|
if (!fromHandle || !fromHandle.id) return;
|
||||||
|
|
||||||
|
const { clientX, clientY } = 'changedTouches' in event ? event.changedTouches[0] : event;
|
||||||
|
const handleType = getHandleType(fromHandle.id);
|
||||||
|
|
||||||
|
setContextMenu({
|
||||||
|
x: clientX,
|
||||||
|
y: clientY,
|
||||||
|
filterType: handleType,
|
||||||
|
filterDirection: fromHandle.type,
|
||||||
|
pendingNodeId: fromHandle.nodeId,
|
||||||
|
pendingHandleId: fromHandle.id,
|
||||||
|
pendingHandleType: fromHandle.type,
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ── Widget change callback ──────────────────────────────────────────
|
||||||
|
|
||||||
|
const onWidgetChange = useCallback((nodeId, name, value) => {
|
||||||
|
setNodes((ns) => ns.map((n) => {
|
||||||
|
if (n.id !== nodeId) return n;
|
||||||
|
return {
|
||||||
|
...n,
|
||||||
|
data: {
|
||||||
|
...n.data,
|
||||||
|
widgetValues: { ...n.data.widgetValues, [name]: value },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}));
|
||||||
|
scheduleAutoRun();
|
||||||
|
}, [setNodes]); // scheduleAutoRun is stable (no deps)
|
||||||
|
|
||||||
|
// ── File browser ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const openFileBrowser = useCallback((callback) => {
|
||||||
|
setFileBrowserCb(() => callback);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ── Node context value (stable) ─────────────────────────────────────
|
||||||
|
|
||||||
|
const contextValue = useMemo(() => ({
|
||||||
|
onWidgetChange,
|
||||||
|
openFileBrowser,
|
||||||
|
}), [onWidgetChange, openFileBrowser]);
|
||||||
|
|
||||||
|
// ── Add node from context menu ──────────────────────────────────────
|
||||||
|
|
||||||
|
const addNode = useCallback((className, def) => {
|
||||||
|
if (!contextMenu) return;
|
||||||
|
const position = reactFlow.screenToFlowPosition({
|
||||||
|
x: contextMenu.x,
|
||||||
|
y: contextMenu.y,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Build default widget values
|
||||||
|
const widgetValues = {};
|
||||||
|
const required = def.input.required || {};
|
||||||
|
for (const [name, spec] of Object.entries(required)) {
|
||||||
|
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
|
||||||
|
if (DATA_TYPES.has(type)) continue;
|
||||||
|
if (Array.isArray(type)) {
|
||||||
|
widgetValues[name] = type[0]; // combo default = first option
|
||||||
|
} else {
|
||||||
|
widgetValues[name] = opts?.default ?? '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const newNodeId = String(nextIdRef.current++);
|
||||||
|
const newNode = {
|
||||||
|
id: newNodeId,
|
||||||
|
type: 'custom',
|
||||||
|
position,
|
||||||
|
dragHandle: '.drag-handle',
|
||||||
|
data: {
|
||||||
|
label: def.display_name || className,
|
||||||
|
className,
|
||||||
|
definition: def,
|
||||||
|
widgetValues,
|
||||||
|
previewImage: null,
|
||||||
|
tableRows: null,
|
||||||
|
meshData: null,
|
||||||
|
overlay: null,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
setNodes((ns) => [...ns, newNode]);
|
||||||
|
|
||||||
|
// Auto-connect if this was triggered by dropping a connection on blank space
|
||||||
|
if (contextMenu.pendingHandleId) {
|
||||||
|
const filterType = contextMenu.filterType;
|
||||||
|
|
||||||
|
if (contextMenu.pendingHandleType === 'source') {
|
||||||
|
// Dragged from an output → connect to the first matching input on the new node
|
||||||
|
const allInputs = { ...(def.input.required || {}), ...(def.input.optional || {}) };
|
||||||
|
const inputName = Object.entries(allInputs).find(([, spec]) => {
|
||||||
|
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||||
|
return type === filterType;
|
||||||
|
})?.[0];
|
||||||
|
if (inputName) {
|
||||||
|
const targetHandle = `input::${inputName}::${filterType}`;
|
||||||
|
const color = TYPE_COLORS[filterType] || '#999';
|
||||||
|
setEdges((eds) => addEdge({
|
||||||
|
source: contextMenu.pendingNodeId,
|
||||||
|
sourceHandle: contextMenu.pendingHandleId,
|
||||||
|
target: newNodeId,
|
||||||
|
targetHandle,
|
||||||
|
style: { stroke: color, strokeWidth: 2 },
|
||||||
|
}, eds));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Dragged from an input → connect from the first matching output on the new node
|
||||||
|
const outputIdx = def.output.indexOf(filterType);
|
||||||
|
if (outputIdx !== -1) {
|
||||||
|
const sourceHandle = `output::${outputIdx}::${filterType}`;
|
||||||
|
const color = TYPE_COLORS[filterType] || '#999';
|
||||||
|
setEdges((eds) => addEdge({
|
||||||
|
source: newNodeId,
|
||||||
|
sourceHandle,
|
||||||
|
target: contextMenu.pendingNodeId,
|
||||||
|
targetHandle: contextMenu.pendingHandleId,
|
||||||
|
style: { stroke: color, strokeWidth: 2 },
|
||||||
|
}, eds));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setContextMenu(null);
|
||||||
|
scheduleAutoRun();
|
||||||
|
}, [contextMenu, reactFlow, setNodes, setEdges]);
|
||||||
|
|
||||||
|
// ── Toolbar actions ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const runWorkflow = useCallback(async () => {
|
||||||
|
// Read current state via functional ref to avoid stale closure
|
||||||
|
const currentNodes = reactFlow.getNodes();
|
||||||
|
const currentEdges = reactFlow.getEdges();
|
||||||
|
const prompt = serializeGraph(currentNodes, currentEdges);
|
||||||
|
|
||||||
|
if (!prompt || Object.keys(prompt).length === 0) {
|
||||||
|
setStatus({ text: 'Graph is empty — add some nodes first.', level: 'error' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setStatus({ text: 'Running…', level: 'info' });
|
||||||
|
try {
|
||||||
|
await api.runPrompt(prompt);
|
||||||
|
} catch (err) {
|
||||||
|
setStatus({ text: 'Failed: ' + err.message, level: 'error' });
|
||||||
|
}
|
||||||
|
}, [reactFlow]);
|
||||||
|
|
||||||
|
// Debounced auto-run via ref to avoid dependency chains
|
||||||
|
autoRunRef.current = () => {
|
||||||
|
const currentNodes = reactFlow.getNodes();
|
||||||
|
const currentEdges = reactFlow.getEdges();
|
||||||
|
|
||||||
|
// Don't run if any node has unconnected required data inputs
|
||||||
|
for (const node of currentNodes) {
|
||||||
|
const def = node.data?.definition;
|
||||||
|
if (!def) continue;
|
||||||
|
const required = def.input.required || {};
|
||||||
|
for (const [name, spec] of Object.entries(required)) {
|
||||||
|
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||||
|
if (!DATA_TYPES.has(type)) continue;
|
||||||
|
const hasEdge = currentEdges.some(
|
||||||
|
(e) => e.target === node.id && getInputName(e.targetHandle) === name
|
||||||
|
);
|
||||||
|
if (!hasEdge) return; // incomplete graph, skip auto-run
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = serializeGraph(currentNodes, currentEdges);
|
||||||
|
if (!prompt || Object.keys(prompt).length === 0) return;
|
||||||
|
setStatus({ text: 'Running…', level: 'info' });
|
||||||
|
api.runPrompt(prompt).catch((err) => {
|
||||||
|
setStatus({ text: 'Failed: ' + err.message, level: 'error' });
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const scheduleAutoRun = useCallback(() => {
|
||||||
|
clearTimeout(autoRunTimer.current);
|
||||||
|
autoRunTimer.current = setTimeout(() => autoRunRef.current?.(), 300);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const clearGraph = useCallback(() => {
|
||||||
|
setNodes([]);
|
||||||
|
setEdges([]);
|
||||||
|
nextIdRef.current = 1;
|
||||||
|
setStatus({ text: 'Graph cleared.', level: 'info' });
|
||||||
|
}, [setNodes, setEdges]);
|
||||||
|
|
||||||
|
const saveWorkflow = useCallback(() => {
|
||||||
|
const currentNodes = reactFlow.getNodes().map((n) => ({
|
||||||
|
...n,
|
||||||
|
data: { ...n.data, previewImage: null, tableRows: null, meshData: null, overlay: null },
|
||||||
|
}));
|
||||||
|
const data = { version: 1, nodes: currentNodes, edges: reactFlow.getEdges() };
|
||||||
|
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' });
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = URL.createObjectURL(blob);
|
||||||
|
a.download = 'workflow.json';
|
||||||
|
a.click();
|
||||||
|
}, [reactFlow]);
|
||||||
|
|
||||||
|
const loadWorkflow = useCallback(() => {
|
||||||
|
const input = document.createElement('input');
|
||||||
|
input.type = 'file';
|
||||||
|
input.accept = '.json';
|
||||||
|
input.onchange = async (e) => {
|
||||||
|
const file = e.target.files[0];
|
||||||
|
if (!file) return;
|
||||||
|
const text = await file.text();
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(text);
|
||||||
|
const loadedNodes = data.nodes || [];
|
||||||
|
const loadedEdges = data.edges || [];
|
||||||
|
|
||||||
|
// Re-populate definitions from current nodeDefs
|
||||||
|
const defs = nodeDefsRef.current;
|
||||||
|
const hydrated = loadedNodes.map((n) => ({
|
||||||
|
...n,
|
||||||
|
data: {
|
||||||
|
...n.data,
|
||||||
|
definition: defs[n.data.className] || n.data.definition,
|
||||||
|
previewImage: null,
|
||||||
|
tableRows: null,
|
||||||
|
meshData: null,
|
||||||
|
overlay: null,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
setNodes(hydrated);
|
||||||
|
setEdges(loadedEdges);
|
||||||
|
|
||||||
|
// Update ID counter to avoid collisions
|
||||||
|
const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
|
||||||
|
nextIdRef.current = maxId + 1;
|
||||||
|
|
||||||
|
setStatus({ text: 'Workflow loaded.', level: 'info' });
|
||||||
|
} catch {
|
||||||
|
setStatus({ text: 'Invalid workflow JSON.', level: 'error' });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
input.click();
|
||||||
|
}, [setNodes, setEdges]);
|
||||||
|
|
||||||
|
// ── Keyboard shortcut ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handler = (e) => {
|
||||||
|
if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') {
|
||||||
|
e.preventDefault();
|
||||||
|
runWorkflow();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
window.addEventListener('keydown', handler);
|
||||||
|
return () => window.removeEventListener('keydown', handler);
|
||||||
|
}, [runWorkflow]);
|
||||||
|
|
||||||
|
// ── Context menu ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const onPaneContextMenu = useCallback((event) => {
|
||||||
|
event.preventDefault();
|
||||||
|
setContextMenu({ x: event.clientX, y: event.clientY });
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ── Render ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
return (
|
||||||
|
<NodeContext.Provider value={contextValue}>
|
||||||
|
<div className="app-container">
|
||||||
|
{/* Toolbar */}
|
||||||
|
<div id="toolbar">
|
||||||
|
<span id="app-title">Argonode</span>
|
||||||
|
|
||||||
|
<div className="toolbar-group">
|
||||||
|
<button className="btn btn-primary" onClick={runWorkflow} title="Run workflow (Ctrl+Enter)">
|
||||||
|
▶ Run
|
||||||
|
</button>
|
||||||
|
<button className="btn" onClick={clearGraph} title="Clear graph">
|
||||||
|
✕ Clear
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="toolbar-group">
|
||||||
|
<button className="btn" onClick={saveWorkflow} title="Save workflow JSON">
|
||||||
|
⤓ Save
|
||||||
|
</button>
|
||||||
|
<button className="btn" onClick={loadWorkflow} title="Load workflow JSON">
|
||||||
|
⤒ Load
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className={`status-bar ${status.level}`}>{status.text}</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* React Flow canvas */}
|
||||||
|
<div className="flow-container" onMouseDown={(e) => {
|
||||||
|
if (!e.target.closest('.context-menu')) setContextMenu(null);
|
||||||
|
}}>
|
||||||
|
<ReactFlow
|
||||||
|
nodes={nodes}
|
||||||
|
edges={edges}
|
||||||
|
onNodesChange={onNodesChange}
|
||||||
|
onEdgesChange={onEdgesChange}
|
||||||
|
onConnect={onConnect}
|
||||||
|
onConnectEnd={onConnectEnd}
|
||||||
|
isValidConnection={isValidConnection}
|
||||||
|
nodeTypes={NODE_TYPES}
|
||||||
|
onPaneContextMenu={onPaneContextMenu}
|
||||||
|
colorMode="dark"
|
||||||
|
fitView
|
||||||
|
deleteKeyCode={['Backspace', 'Delete']}
|
||||||
|
defaultEdgeOptions={{ type: 'default' }}
|
||||||
|
>
|
||||||
|
<Background />
|
||||||
|
<Controls />
|
||||||
|
<MiniMap
|
||||||
|
nodeColor={(n) => {
|
||||||
|
const cat = n.data?.definition?.category;
|
||||||
|
const colors = {
|
||||||
|
io: '#37474f', filters: '#1a237e', level: '#1b5e20',
|
||||||
|
analysis: '#4a148c', grains: '#bf360c', display: '#212121',
|
||||||
|
};
|
||||||
|
return colors[cat] || '#333';
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</ReactFlow>
|
||||||
|
|
||||||
|
{contextMenu && (
|
||||||
|
<ContextMenu
|
||||||
|
x={contextMenu.x}
|
||||||
|
y={contextMenu.y}
|
||||||
|
nodeDefs={nodeDefsRef.current}
|
||||||
|
onAdd={addNode}
|
||||||
|
onClose={() => setContextMenu(null)}
|
||||||
|
filterType={contextMenu.filterType}
|
||||||
|
filterDirection={contextMenu.filterDirection}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* File browser modal */}
|
||||||
|
{fileBrowserCb && (
|
||||||
|
<FileBrowser
|
||||||
|
onSelect={(path) => { fileBrowserCb(path); setFileBrowserCb(null); }}
|
||||||
|
onClose={() => setFileBrowserCb(null)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</NodeContext.Provider>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── App wrapper with ReactFlowProvider ────────────────────────────────
|
||||||
|
|
||||||
|
export default function App() {
|
||||||
|
return (
|
||||||
|
<ReactFlowProvider>
|
||||||
|
<Flow />
|
||||||
|
</ReactFlowProvider>
|
||||||
|
);
|
||||||
|
}
|
||||||
86
frontend/src/CrossSectionOverlay.jsx
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import React, { useRef, useState, useCallback } from 'react';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Image preview with two endpoint markers for cross-section line control.
|
||||||
|
* Markers are draggable when unlocked (no COORD input connected),
|
||||||
|
* and fixed when locked (COORD input provides the position).
|
||||||
|
*
|
||||||
|
* Marker positions are driven by widget values (immediate React state),
|
||||||
|
* not by backend overlay coords, so they move instantly during drag.
|
||||||
|
*/
|
||||||
|
export default function CrossSectionOverlay({
|
||||||
|
image, x1, y1, x2, y2,
|
||||||
|
aLocked, bLocked,
|
||||||
|
nodeId, onWidgetChange,
|
||||||
|
}) {
|
||||||
|
const containerRef = useRef(null);
|
||||||
|
const [dragging, setDragging] = useState(null); // 'p1' or 'p2'
|
||||||
|
|
||||||
|
const getCoords = useCallback((e) => {
|
||||||
|
const rect = containerRef.current.getBoundingClientRect();
|
||||||
|
return {
|
||||||
|
fx: Math.max(0, Math.min(1, (e.clientX - rect.left) / rect.width)),
|
||||||
|
fy: Math.max(0, Math.min(1, (e.clientY - rect.top) / rect.height)),
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const onPointerDown = useCallback((point) => (e) => {
|
||||||
|
if (point === 'p1' && aLocked) return;
|
||||||
|
if (point === 'p2' && bLocked) return;
|
||||||
|
e.stopPropagation();
|
||||||
|
e.preventDefault();
|
||||||
|
e.target.setPointerCapture(e.pointerId);
|
||||||
|
setDragging(point);
|
||||||
|
}, [aLocked, bLocked]);
|
||||||
|
|
||||||
|
const onPointerMove = useCallback((e) => {
|
||||||
|
if (!dragging || !containerRef.current) return;
|
||||||
|
const { fx, fy } = getCoords(e);
|
||||||
|
const vx = parseFloat(fx.toFixed(3));
|
||||||
|
const vy = parseFloat(fy.toFixed(3));
|
||||||
|
if (dragging === 'p1') {
|
||||||
|
onWidgetChange(nodeId, 'x1', vx);
|
||||||
|
onWidgetChange(nodeId, 'y1', vy);
|
||||||
|
} else {
|
||||||
|
onWidgetChange(nodeId, 'x2', vx);
|
||||||
|
onWidgetChange(nodeId, 'y2', vy);
|
||||||
|
}
|
||||||
|
}, [dragging, nodeId, onWidgetChange, getCoords]);
|
||||||
|
|
||||||
|
const onPointerUp = useCallback(() => {
|
||||||
|
setDragging(null);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
ref={containerRef}
|
||||||
|
className="nodrag nowheel cs-overlay"
|
||||||
|
onPointerMove={onPointerMove}
|
||||||
|
onPointerUp={onPointerUp}
|
||||||
|
onLostPointerCapture={onPointerUp}
|
||||||
|
>
|
||||||
|
<img src={image} alt="field" draggable={false} className="cs-image" />
|
||||||
|
|
||||||
|
{/* Line connecting the two markers */}
|
||||||
|
<svg className="cs-svg">
|
||||||
|
<line
|
||||||
|
x1={`${x1 * 100}%`} y1={`${y1 * 100}%`}
|
||||||
|
x2={`${x2 * 100}%`} y2={`${y2 * 100}%`}
|
||||||
|
stroke="#ffd700" strokeWidth="2" strokeDasharray="6 3"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
|
||||||
|
{/* Endpoint markers — locked markers get a different style */}
|
||||||
|
<div
|
||||||
|
className={`cs-marker ${aLocked ? 'cs-marker-locked' : ''}`}
|
||||||
|
style={{ left: `${x1 * 100}%`, top: `${y1 * 100}%` }}
|
||||||
|
onPointerDown={onPointerDown('p1')}
|
||||||
|
/>
|
||||||
|
<div
|
||||||
|
className={`cs-marker ${bLocked ? 'cs-marker-locked' : ''}`}
|
||||||
|
style={{ left: `${x2 * 100}%`, top: `${y2 * 100}%` }}
|
||||||
|
onPointerDown={onPointerDown('p2')}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
396
frontend/src/CustomNode.jsx
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react';
|
||||||
|
import { Handle, Position } from '@xyflow/react';
|
||||||
|
|
||||||
|
const SurfaceView = lazy(() => import('./SurfaceView'));
|
||||||
|
const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay'));
|
||||||
|
|
||||||
|
// ── Constants ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const DATA_TYPES = new Set(['DATA_FIELD', 'IMAGE', 'LINE', 'TABLE', 'COORD']);
|
||||||
|
|
||||||
|
const TYPE_COLORS = {
|
||||||
|
DATA_FIELD: '#3a7abf',
|
||||||
|
IMAGE: '#4caf50',
|
||||||
|
LINE: '#ff9800',
|
||||||
|
TABLE: '#fdd835',
|
||||||
|
COORD: '#e91e63',
|
||||||
|
};
|
||||||
|
|
||||||
|
const CAT_COLORS = {
|
||||||
|
io: '#37474f',
|
||||||
|
filters: '#1a237e',
|
||||||
|
level: '#1b5e20',
|
||||||
|
analysis: '#4a148c',
|
||||||
|
grains: '#bf360c',
|
||||||
|
display: '#212121',
|
||||||
|
};
|
||||||
|
|
||||||
|
// ── Context (provided by App) ─────────────────────────────────────────
|
||||||
|
|
||||||
|
export const NodeContext = React.createContext(null);
|
||||||
|
|
||||||
|
// ── Draggable number input ────────────────────────────────────────────
|
||||||
|
|
||||||
|
function DraggableNumber({ value, step, min, max, precision, onChange }) {
|
||||||
|
const [editing, setEditing] = useState(false);
|
||||||
|
const [editText, setEditText] = useState('');
|
||||||
|
const dragState = useRef(null);
|
||||||
|
const elRef = useRef(null);
|
||||||
|
|
||||||
|
const display = precision != null ? Number(value).toFixed(precision) : String(value);
|
||||||
|
|
||||||
|
const clamp = useCallback((v) => {
|
||||||
|
if (min != null && v < min) v = min;
|
||||||
|
if (max != null && v > max) v = max;
|
||||||
|
return v;
|
||||||
|
}, [min, max]);
|
||||||
|
|
||||||
|
const onPointerDown = useCallback((e) => {
|
||||||
|
if (editing) return;
|
||||||
|
e.preventDefault();
|
||||||
|
dragState.current = { startX: e.clientX, startVal: Number(value) };
|
||||||
|
elRef.current?.setPointerCapture(e.pointerId);
|
||||||
|
}, [editing, value]);
|
||||||
|
|
||||||
|
const onPointerMove = useCallback((e) => {
|
||||||
|
if (!dragState.current) return;
|
||||||
|
const dx = e.clientX - dragState.current.startX;
|
||||||
|
const delta = dx * (step || 0.01);
|
||||||
|
const raw = dragState.current.startVal + delta;
|
||||||
|
const rounded = precision != null
|
||||||
|
? parseFloat(raw.toFixed(precision))
|
||||||
|
: Math.round(raw);
|
||||||
|
onChange(clamp(rounded));
|
||||||
|
}, [step, precision, clamp, onChange]);
|
||||||
|
|
||||||
|
const onPointerUp = useCallback((e) => {
|
||||||
|
if (!dragState.current) return;
|
||||||
|
const dx = Math.abs(e.clientX - dragState.current.startX);
|
||||||
|
dragState.current = null;
|
||||||
|
// If barely moved, enter text-edit mode
|
||||||
|
if (dx < 3) {
|
||||||
|
setEditText(display);
|
||||||
|
setEditing(true);
|
||||||
|
}
|
||||||
|
}, [display]);
|
||||||
|
|
||||||
|
const commitEdit = useCallback(() => {
|
||||||
|
setEditing(false);
|
||||||
|
const parsed = parseFloat(editText);
|
||||||
|
if (!isNaN(parsed)) onChange(clamp(precision != null ? parseFloat(parsed.toFixed(precision)) : Math.round(parsed)));
|
||||||
|
}, [editText, precision, clamp, onChange]);
|
||||||
|
|
||||||
|
if (editing) {
|
||||||
|
return (
|
||||||
|
<input
|
||||||
|
className="nodrag drag-number-edit"
|
||||||
|
type="text"
|
||||||
|
autoFocus
|
||||||
|
value={editText}
|
||||||
|
onChange={(e) => setEditText(e.target.value)}
|
||||||
|
onBlur={commitEdit}
|
||||||
|
onKeyDown={(e) => { if (e.key === 'Enter') commitEdit(); if (e.key === 'Escape') setEditing(false); }}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
ref={elRef}
|
||||||
|
className="nodrag drag-number"
|
||||||
|
onPointerDown={onPointerDown}
|
||||||
|
onPointerMove={onPointerMove}
|
||||||
|
onPointerUp={onPointerUp}
|
||||||
|
>
|
||||||
|
<span className="drag-number-val">{display}</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Collapsible section ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
function CollapsibleSection({ title, defaultOpen, children }) {
|
||||||
|
const [open, setOpen] = useState(defaultOpen);
|
||||||
|
return (
|
||||||
|
<div className="collapsible">
|
||||||
|
<button
|
||||||
|
className="nodrag collapsible-toggle"
|
||||||
|
onClick={() => setOpen((o) => !o)}
|
||||||
|
>
|
||||||
|
<span className="collapsible-arrow">{open ? '▾' : '▸'}</span>
|
||||||
|
{title}
|
||||||
|
</button>
|
||||||
|
{open && children}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── CustomNode component ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
function CustomNode({ id, data }) {
|
||||||
|
const ctx = useContext(NodeContext);
|
||||||
|
const def = data.definition;
|
||||||
|
|
||||||
|
// Parse inputs into data handles and widgets
|
||||||
|
const required = def.input.required || {};
|
||||||
|
const optional = def.input.optional || {};
|
||||||
|
|
||||||
|
const dataInputs = [];
|
||||||
|
const widgets = [];
|
||||||
|
|
||||||
|
const hiddenWidgets = new Set();
|
||||||
|
|
||||||
|
for (const [name, spec] of Object.entries(required)) {
|
||||||
|
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
|
||||||
|
if (DATA_TYPES.has(type)) {
|
||||||
|
dataInputs.push({ name, type });
|
||||||
|
} else if (opts?.hidden) {
|
||||||
|
hiddenWidgets.add(name);
|
||||||
|
} else {
|
||||||
|
widgets.push({ name, type, opts: opts || {} });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [name, spec] of Object.entries(optional)) {
|
||||||
|
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||||
|
dataInputs.push({ name, type });
|
||||||
|
}
|
||||||
|
|
||||||
|
const outputs = def.output.map((type, i) => ({
|
||||||
|
name: def.output_name[i] || type,
|
||||||
|
type,
|
||||||
|
slot: i,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const catColor = CAT_COLORS[def.category] || '#333';
|
||||||
|
const maxIORows = Math.max(dataInputs.length, outputs.length);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="custom-node">
|
||||||
|
{/* Title */}
|
||||||
|
<div className="node-title drag-handle" style={{ background: catColor }}>
|
||||||
|
{data.label}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="node-body">
|
||||||
|
{/* I/O rows — pair inputs[i] with outputs[i] */}
|
||||||
|
{Array.from({ length: maxIORows }, (_, i) => {
|
||||||
|
const inp = dataInputs[i];
|
||||||
|
const out = outputs[i];
|
||||||
|
return (
|
||||||
|
<div className="io-row" key={`io-${i}`}>
|
||||||
|
<div className="io-left">
|
||||||
|
{inp && (
|
||||||
|
<>
|
||||||
|
<Handle
|
||||||
|
type="target"
|
||||||
|
position={Position.Left}
|
||||||
|
id={`input::${inp.name}::${inp.type}`}
|
||||||
|
className="typed-handle"
|
||||||
|
style={{ background: TYPE_COLORS[inp.type] || '#999' }}
|
||||||
|
/>
|
||||||
|
<span className="io-label">{inp.name}</span>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="io-right">
|
||||||
|
{out && (
|
||||||
|
<>
|
||||||
|
<span className="io-label">{out.name}</span>
|
||||||
|
<Handle
|
||||||
|
type="source"
|
||||||
|
position={Position.Right}
|
||||||
|
id={`output::${out.slot}::${out.type}`}
|
||||||
|
className="typed-handle"
|
||||||
|
style={{ background: TYPE_COLORS[out.type] || '#999' }}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
|
||||||
|
{/* Widget rows */}
|
||||||
|
{widgets.map((w) => (
|
||||||
|
<div className="widget-row" key={w.name}>
|
||||||
|
<WidgetControl
|
||||||
|
widget={w}
|
||||||
|
nodeId={id}
|
||||||
|
value={data.widgetValues[w.name]}
|
||||||
|
onChange={ctx.onWidgetChange}
|
||||||
|
openFileBrowser={ctx.openFileBrowser}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
|
{/* Interactive 3D surface view */}
|
||||||
|
{data.meshData && (
|
||||||
|
<CollapsibleSection title="3D View" defaultOpen={true}>
|
||||||
|
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading 3D...</div>}>
|
||||||
|
<SurfaceView meshData={data.meshData} />
|
||||||
|
</Suspense>
|
||||||
|
</CollapsibleSection>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Collapsible preview image */}
|
||||||
|
{data.previewImage && (
|
||||||
|
<CollapsibleSection title="Preview" defaultOpen={true}>
|
||||||
|
<div className="node-preview">
|
||||||
|
<img src={data.previewImage} alt="preview" draggable={false} />
|
||||||
|
</div>
|
||||||
|
</CollapsibleSection>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Interactive cross-section overlay */}
|
||||||
|
{data.overlay && hiddenWidgets.has('x1') && (
|
||||||
|
<CollapsibleSection title="Cross Section" defaultOpen={true}>
|
||||||
|
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading...</div>}>
|
||||||
|
<CrossSectionOverlay
|
||||||
|
image={data.overlay.image}
|
||||||
|
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
|
||||||
|
y1={data.overlay.a_locked ? data.overlay.y1 : (data.widgetValues.y1 ?? data.overlay.y1)}
|
||||||
|
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
|
||||||
|
y2={data.overlay.b_locked ? data.overlay.y2 : (data.widgetValues.y2 ?? data.overlay.y2)}
|
||||||
|
aLocked={data.overlay.a_locked}
|
||||||
|
bLocked={data.overlay.b_locked}
|
||||||
|
nodeId={id}
|
||||||
|
onWidgetChange={ctx.onWidgetChange}
|
||||||
|
/>
|
||||||
|
</Suspense>
|
||||||
|
</CollapsibleSection>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Collapsible table data */}
|
||||||
|
{data.tableRows && data.tableRows.length > 0 && (
|
||||||
|
<CollapsibleSection title="Table" defaultOpen={true}>
|
||||||
|
<div className="node-table">
|
||||||
|
{data.tableRows.map((row, i) => {
|
||||||
|
let line;
|
||||||
|
if (row.quantity !== undefined) {
|
||||||
|
const val = typeof row.value === 'number' ? row.value.toExponential(3) : row.value;
|
||||||
|
line = `${row.quantity}: ${val} ${row.unit || ''}`;
|
||||||
|
} else {
|
||||||
|
line = Object.entries(row)
|
||||||
|
.slice(0, 3)
|
||||||
|
.map(([k, v]) => `${k}: ${typeof v === 'number' ? v.toExponential(2) : v}`)
|
||||||
|
.join(' ');
|
||||||
|
}
|
||||||
|
return <div key={i} className="table-line">{line}</div>;
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</CollapsibleSection>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Widget renderer ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
function WidgetControl({ widget, nodeId, value, onChange, openFileBrowser }) {
|
||||||
|
const { name, type, opts } = widget;
|
||||||
|
const val = value ?? opts?.default ?? '';
|
||||||
|
|
||||||
|
// Combo / enum — type itself is the array of options
|
||||||
|
if (Array.isArray(type)) {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<label>{name}</label>
|
||||||
|
<select
|
||||||
|
className="nodrag"
|
||||||
|
value={val || type[0]}
|
||||||
|
onChange={(e) => onChange(nodeId, name, e.target.value)}
|
||||||
|
>
|
||||||
|
{type.map((opt) => (
|
||||||
|
<option key={opt} value={opt}>{opt}</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'FILE_PICKER') {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<label>{name}</label>
|
||||||
|
<div className="file-picker-row">
|
||||||
|
<input
|
||||||
|
className="nodrag"
|
||||||
|
type="text"
|
||||||
|
value={val}
|
||||||
|
onChange={(e) => onChange(nodeId, name, e.target.value)}
|
||||||
|
placeholder="Select file…"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
className="nodrag browse-btn"
|
||||||
|
onClick={() => openFileBrowser((path) => onChange(nodeId, name, path))}
|
||||||
|
>
|
||||||
|
Browse
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'FLOAT') {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<label>{name}</label>
|
||||||
|
<DraggableNumber
|
||||||
|
value={val || 0}
|
||||||
|
step={opts?.step ?? 0.01}
|
||||||
|
min={opts?.min}
|
||||||
|
max={opts?.max}
|
||||||
|
precision={4}
|
||||||
|
onChange={(v) => onChange(nodeId, name, v)}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'INT') {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<label>{name}</label>
|
||||||
|
<DraggableNumber
|
||||||
|
value={val || 0}
|
||||||
|
step={opts?.step ?? 1}
|
||||||
|
min={opts?.min}
|
||||||
|
max={opts?.max}
|
||||||
|
precision={0}
|
||||||
|
onChange={(v) => onChange(nodeId, name, v)}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'BOOLEAN') {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<label>{name}</label>
|
||||||
|
<input
|
||||||
|
className="nodrag"
|
||||||
|
type="checkbox"
|
||||||
|
checked={!!val}
|
||||||
|
onChange={(e) => onChange(nodeId, name, e.target.checked)}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// STRING and anything else
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<label>{name}</label>
|
||||||
|
<input
|
||||||
|
className="nodrag"
|
||||||
|
type="text"
|
||||||
|
value={val}
|
||||||
|
onChange={(e) => onChange(nodeId, name, e.target.value)}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(CustomNode);
|
||||||
94
frontend/src/FileBrowser.jsx
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import React, { useState, useEffect, useCallback } from 'react';
|
||||||
|
import * as api from './api';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Server-side file browser modal.
|
||||||
|
*
|
||||||
|
* Props:
|
||||||
|
* onSelect(absolutePath) — called when user picks a file
|
||||||
|
* onClose() — called when user dismisses the dialog
|
||||||
|
*/
|
||||||
|
export default function FileBrowser({ onSelect, onClose }) {
|
||||||
|
const [path, setPath] = useState('');
|
||||||
|
const [parent, setParent] = useState(null);
|
||||||
|
const [dirs, setDirs] = useState([]);
|
||||||
|
const [files, setFiles] = useState([]);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [error, setError] = useState(null);
|
||||||
|
|
||||||
|
const navigate = useCallback(async (dir) => {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const data = await api.browse(dir);
|
||||||
|
setPath(data.path);
|
||||||
|
setParent(data.parent);
|
||||||
|
setDirs(data.dirs);
|
||||||
|
setFiles(data.files);
|
||||||
|
} catch (err) {
|
||||||
|
setError(err.message);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Start at home directory on mount
|
||||||
|
useEffect(() => {
|
||||||
|
navigate(null);
|
||||||
|
}, [navigate]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="fb-backdrop" onClick={(e) => { if (e.target === e.currentTarget) onClose(); }}>
|
||||||
|
<div className="fb-dialog">
|
||||||
|
{/* Header */}
|
||||||
|
<div className="fb-header">
|
||||||
|
<span className="fb-path">{path}</span>
|
||||||
|
<button className="fb-close" onClick={onClose}>✕</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* File list */}
|
||||||
|
<div className="fb-list">
|
||||||
|
{loading && <div className="fb-loading">Loading…</div>}
|
||||||
|
{error && <div className="fb-loading">Error: {error}</div>}
|
||||||
|
|
||||||
|
{!loading && !error && (
|
||||||
|
<>
|
||||||
|
{/* Parent directory */}
|
||||||
|
{parent && (
|
||||||
|
<div className="fb-entry fb-dir" onClick={() => navigate(parent)}>
|
||||||
|
⬆ ..
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Directories */}
|
||||||
|
{dirs.map((d) => (
|
||||||
|
<div
|
||||||
|
key={d}
|
||||||
|
className="fb-entry fb-dir"
|
||||||
|
onClick={() => navigate(path + '/' + d)}
|
||||||
|
>
|
||||||
|
📁 {d}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
|
{/* Files */}
|
||||||
|
{files.map((f) => (
|
||||||
|
<div
|
||||||
|
key={f}
|
||||||
|
className="fb-entry fb-file"
|
||||||
|
onClick={() => { onSelect(path + '/' + f); onClose(); }}
|
||||||
|
>
|
||||||
|
{f}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
|
{dirs.length === 0 && files.length === 0 && (
|
||||||
|
<div className="fb-loading">Empty directory</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
183
frontend/src/SurfaceView.jsx
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
import React, { useRef, useEffect, useCallback } from 'react';
|
||||||
|
import * as THREE from 'three';
|
||||||
|
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interactive 3D surface viewer using Three.js.
|
||||||
|
* Props:
|
||||||
|
* meshData: { width, height, z_data (b64 float32), colors (b64 uint8 RGB),
|
||||||
|
* z_min, z_max, z_scale, x_range, y_range }
|
||||||
|
*/
|
||||||
|
export default function SurfaceView({ meshData }) {
|
||||||
|
const containerRef = useRef(null);
|
||||||
|
const threeRef = useRef(null); // { renderer, scene, camera, controls, mesh }
|
||||||
|
|
||||||
|
// Decode base64 to typed arrays
|
||||||
|
const decode = useCallback((b64, ArrayType) => {
|
||||||
|
const bin = atob(b64);
|
||||||
|
const bytes = new Uint8Array(bin.length);
|
||||||
|
for (let i = 0; i < bin.length; i++) bytes[i] = bin.charCodeAt(i);
|
||||||
|
return new ArrayType(bytes.buffer);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Initialize Three.js scene once
|
||||||
|
useEffect(() => {
|
||||||
|
const container = containerRef.current;
|
||||||
|
if (!container || threeRef.current) return;
|
||||||
|
|
||||||
|
const width = container.clientWidth;
|
||||||
|
const height = width; // 1:1 aspect
|
||||||
|
|
||||||
|
const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: false });
|
||||||
|
renderer.setSize(width, height);
|
||||||
|
renderer.setPixelRatio(window.devicePixelRatio);
|
||||||
|
renderer.setClearColor(0x0f172a);
|
||||||
|
container.appendChild(renderer.domElement);
|
||||||
|
|
||||||
|
const scene = new THREE.Scene();
|
||||||
|
|
||||||
|
const camera = new THREE.PerspectiveCamera(45, 1, 0.01, 1000);
|
||||||
|
camera.position.set(1.2, 0.8, 1.2);
|
||||||
|
|
||||||
|
const controls = new OrbitControls(camera, renderer.domElement);
|
||||||
|
controls.enableDamping = true;
|
||||||
|
controls.dampingFactor = 0.1;
|
||||||
|
controls.minDistance = 0.3;
|
||||||
|
controls.maxDistance = 10;
|
||||||
|
|
||||||
|
// Lighting
|
||||||
|
const ambient = new THREE.AmbientLight(0xffffff, 0.4);
|
||||||
|
scene.add(ambient);
|
||||||
|
const dir = new THREE.DirectionalLight(0xffffff, 0.8);
|
||||||
|
dir.position.set(1, 2, 1.5);
|
||||||
|
scene.add(dir);
|
||||||
|
const dir2 = new THREE.DirectionalLight(0xffffff, 0.3);
|
||||||
|
dir2.position.set(-1, 0.5, -1);
|
||||||
|
scene.add(dir2);
|
||||||
|
|
||||||
|
// Animation loop
|
||||||
|
let animId;
|
||||||
|
const animate = () => {
|
||||||
|
animId = requestAnimationFrame(animate);
|
||||||
|
controls.update();
|
||||||
|
renderer.render(scene, camera);
|
||||||
|
};
|
||||||
|
animate();
|
||||||
|
|
||||||
|
threeRef.current = { renderer, scene, camera, controls, mesh: null, animId };
|
||||||
|
|
||||||
|
// Resize observer to maintain 1:1 aspect when node width changes
|
||||||
|
const ro = new ResizeObserver((entries) => {
|
||||||
|
const entry = entries[0];
|
||||||
|
if (!entry || !threeRef.current) return;
|
||||||
|
const w = entry.contentRect.width;
|
||||||
|
if (w < 1) return;
|
||||||
|
const { renderer: r, camera: c } = threeRef.current;
|
||||||
|
r.setSize(w, w);
|
||||||
|
c.aspect = 1;
|
||||||
|
c.updateProjectionMatrix();
|
||||||
|
});
|
||||||
|
ro.observe(container);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
ro.disconnect();
|
||||||
|
cancelAnimationFrame(animId);
|
||||||
|
controls.dispose();
|
||||||
|
renderer.dispose();
|
||||||
|
if (container.contains(renderer.domElement)) {
|
||||||
|
container.removeChild(renderer.domElement);
|
||||||
|
}
|
||||||
|
threeRef.current = null;
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Update mesh when data changes
|
||||||
|
useEffect(() => {
|
||||||
|
if (!threeRef.current || !meshData) return;
|
||||||
|
|
||||||
|
const { scene, camera, controls } = threeRef.current;
|
||||||
|
const { width: nx, height: ny, z_data, colors, z_min, z_max, z_scale, x_range, y_range } = meshData;
|
||||||
|
|
||||||
|
// Decode arrays
|
||||||
|
const zArr = decode(z_data, Float32Array);
|
||||||
|
const colArr = decode(colors, Uint8Array);
|
||||||
|
|
||||||
|
// Remove old mesh
|
||||||
|
if (threeRef.current.mesh) {
|
||||||
|
scene.remove(threeRef.current.mesh);
|
||||||
|
threeRef.current.mesh.geometry.dispose();
|
||||||
|
threeRef.current.mesh.material.dispose();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build geometry
|
||||||
|
const geom = new THREE.BufferGeometry();
|
||||||
|
const positions = new Float32Array(nx * ny * 3);
|
||||||
|
const colorAttr = new Float32Array(nx * ny * 3);
|
||||||
|
|
||||||
|
// Normalize coordinates to roughly [-0.5, 0.5] for good camera framing
|
||||||
|
const zRange = z_max - z_min || 1;
|
||||||
|
|
||||||
|
for (let iy = 0; iy < ny; iy++) {
|
||||||
|
for (let ix = 0; ix < nx; ix++) {
|
||||||
|
const idx = iy * nx + ix;
|
||||||
|
const px = ix / (nx - 1) - 0.5; // [-0.5, 0.5]
|
||||||
|
const py = iy / (ny - 1) - 0.5;
|
||||||
|
const pz = ((zArr[idx] - z_min) / zRange - 0.5) * z_scale;
|
||||||
|
|
||||||
|
positions[idx * 3] = px;
|
||||||
|
positions[idx * 3 + 1] = pz; // height on Y axis
|
||||||
|
positions[idx * 3 + 2] = py;
|
||||||
|
|
||||||
|
colorAttr[idx * 3] = colArr[idx * 3] / 255;
|
||||||
|
colorAttr[idx * 3 + 1] = colArr[idx * 3 + 1] / 255;
|
||||||
|
colorAttr[idx * 3 + 2] = colArr[idx * 3 + 2] / 255;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
geom.setAttribute('position', new THREE.BufferAttribute(positions, 3));
|
||||||
|
geom.setAttribute('color', new THREE.BufferAttribute(colorAttr, 3));
|
||||||
|
|
||||||
|
// Build index (triangles from grid)
|
||||||
|
const indices = [];
|
||||||
|
for (let iy = 0; iy < ny - 1; iy++) {
|
||||||
|
for (let ix = 0; ix < nx - 1; ix++) {
|
||||||
|
const a = iy * nx + ix;
|
||||||
|
const b = a + 1;
|
||||||
|
const c = a + nx;
|
||||||
|
const d = c + 1;
|
||||||
|
indices.push(a, c, b);
|
||||||
|
indices.push(b, c, d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
geom.setIndex(indices);
|
||||||
|
geom.computeVertexNormals();
|
||||||
|
|
||||||
|
const mat = new THREE.MeshPhongMaterial({
|
||||||
|
vertexColors: true,
|
||||||
|
side: THREE.DoubleSide,
|
||||||
|
shininess: 30,
|
||||||
|
flatShading: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
const mesh = new THREE.Mesh(geom, mat);
|
||||||
|
scene.add(mesh);
|
||||||
|
threeRef.current.mesh = mesh;
|
||||||
|
|
||||||
|
// Reset camera target to center of mesh
|
||||||
|
controls.target.set(0, 0, 0);
|
||||||
|
controls.update();
|
||||||
|
}, [meshData, decode]);
|
||||||
|
|
||||||
|
// Prevent scroll events from propagating to React Flow
|
||||||
|
const onWheel = useCallback((e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
ref={containerRef}
|
||||||
|
className="nodrag nowheel surface-view-container"
|
||||||
|
onWheelCapture={onWheel}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
93
frontend/src/api.js
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
/**
|
||||||
|
* api.js — REST + WebSocket client for argonode backend.
|
||||||
|
*
|
||||||
|
* Uses relative URLs so the Vite dev proxy (port 5173 → 8188)
|
||||||
|
* and production same-origin serving both work transparently.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// ── REST helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
export async function getNodes() {
|
||||||
|
const r = await fetch('/nodes');
|
||||||
|
if (!r.ok) throw new Error(`GET /nodes failed: ${r.status}`);
|
||||||
|
return r.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getFiles() {
|
||||||
|
const r = await fetch('/files');
|
||||||
|
if (!r.ok) return [];
|
||||||
|
return r.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function browse(dir) {
|
||||||
|
const url = dir ? `/browse?dir=${encodeURIComponent(dir)}` : '/browse';
|
||||||
|
const r = await fetch(url);
|
||||||
|
if (!r.ok) throw new Error(`Browse failed: ${r.status}`);
|
||||||
|
return r.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function uploadFile(file) {
|
||||||
|
const fd = new FormData();
|
||||||
|
fd.append('file', file);
|
||||||
|
const r = await fetch('/upload', { method: 'POST', body: fd });
|
||||||
|
if (!r.ok) throw new Error(`Upload failed: ${r.status}`);
|
||||||
|
return r.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function runPrompt(prompt) {
|
||||||
|
const r = await fetch('/prompt', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ prompt }),
|
||||||
|
});
|
||||||
|
if (!r.ok) {
|
||||||
|
const text = await r.text();
|
||||||
|
throw new Error(`POST /prompt failed (${r.status}): ${text}`);
|
||||||
|
}
|
||||||
|
return r.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── WebSocket ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
let _ws = null;
|
||||||
|
let _handler = null;
|
||||||
|
let _reconnectTimer = null;
|
||||||
|
|
||||||
|
export function setMessageHandler(fn) {
|
||||||
|
_handler = fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function initWS() {
|
||||||
|
if (_ws && _ws.readyState < 2) return; // already open or connecting
|
||||||
|
|
||||||
|
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||||
|
_ws = new WebSocket(`${protocol}//${window.location.host}/ws`);
|
||||||
|
|
||||||
|
_ws.onopen = () => {
|
||||||
|
console.log('[argonode] WebSocket connected');
|
||||||
|
};
|
||||||
|
|
||||||
|
_ws.onclose = () => {
|
||||||
|
console.log('[argonode] WebSocket closed, reconnecting in 3s…');
|
||||||
|
clearTimeout(_reconnectTimer);
|
||||||
|
_reconnectTimer = setTimeout(() => initWS(), 3000);
|
||||||
|
};
|
||||||
|
|
||||||
|
_ws.onerror = (e) => {
|
||||||
|
console.error('[argonode] WebSocket error', e);
|
||||||
|
};
|
||||||
|
|
||||||
|
_ws.onmessage = (e) => {
|
||||||
|
try {
|
||||||
|
const msg = JSON.parse(e.data);
|
||||||
|
if (_handler) _handler(msg);
|
||||||
|
} catch {
|
||||||
|
// ignore malformed messages
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function closeWS() {
|
||||||
|
clearTimeout(_reconnectTimer);
|
||||||
|
if (_ws) _ws.close();
|
||||||
|
}
|
||||||
6
frontend/src/main.jsx
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { createRoot } from 'react-dom/client';
|
||||||
|
import App from './App';
|
||||||
|
import './styles.css';
|
||||||
|
|
||||||
|
createRoot(document.getElementById('root')).render(<App />);
|
||||||
509
frontend/src/styles.css
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
/* ── Reset & base ──────────────────────────────────────────────────── */
|
||||||
|
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
||||||
|
|
||||||
|
html, body, #root {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
background: #1a1a2e;
|
||||||
|
color: #e0e0e0;
|
||||||
|
font-family: "Inter", "Segoe UI", system-ui, sans-serif;
|
||||||
|
font-size: 13px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.app-container {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Toolbar ───────────────────────────────────────────────────────── */
|
||||||
|
#toolbar {
|
||||||
|
height: 44px;
|
||||||
|
background: #16213e;
|
||||||
|
border-bottom: 1px solid #0f3460;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
padding: 0 12px;
|
||||||
|
gap: 10px;
|
||||||
|
z-index: 100;
|
||||||
|
user-select: none;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#app-title {
|
||||||
|
font-size: 15px;
|
||||||
|
font-weight: 700;
|
||||||
|
letter-spacing: 0.5px;
|
||||||
|
color: #e94560;
|
||||||
|
margin-right: 8px;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.toolbar-group {
|
||||||
|
display: flex;
|
||||||
|
gap: 6px;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Buttons ───────────────────────────────────────────────────────── */
|
||||||
|
.btn {
|
||||||
|
padding: 5px 12px;
|
||||||
|
border: 1px solid #0f3460;
|
||||||
|
border-radius: 5px;
|
||||||
|
background: #0f3460;
|
||||||
|
color: #e0e0e0;
|
||||||
|
font-size: 12px;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background 0.15s, border-color 0.15s;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
.btn:hover {
|
||||||
|
background: #1a4a8a;
|
||||||
|
border-color: #3a7abf;
|
||||||
|
}
|
||||||
|
.btn:active {
|
||||||
|
background: #0a2040;
|
||||||
|
}
|
||||||
|
.btn-primary {
|
||||||
|
background: #e94560;
|
||||||
|
border-color: #e94560;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
.btn-primary:hover {
|
||||||
|
background: #ff6b81;
|
||||||
|
border-color: #ff6b81;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Status bar ────────────────────────────────────────────────────── */
|
||||||
|
.status-bar {
|
||||||
|
margin-left: auto;
|
||||||
|
padding: 4px 10px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 11px;
|
||||||
|
max-width: 400px;
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
flex-shrink: 1;
|
||||||
|
}
|
||||||
|
.status-bar.info { color: #90caf9; }
|
||||||
|
.status-bar.error { color: #ef9a9a; background: rgba(183,28,28,0.2); }
|
||||||
|
|
||||||
|
/* ── React Flow container ──────────────────────────────────────────── */
|
||||||
|
.flow-container {
|
||||||
|
flex: 1;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── React Flow dark overrides ─────────────────────────────────────── */
|
||||||
|
.react-flow {
|
||||||
|
background: #0d1117 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Custom node ───────────────────────────────────────────────────── */
|
||||||
|
.custom-node {
|
||||||
|
background: #1e293b;
|
||||||
|
border: 1px solid #334155;
|
||||||
|
border-radius: 6px;
|
||||||
|
font-size: 11px;
|
||||||
|
color: #e0e0e0;
|
||||||
|
width: 200px;
|
||||||
|
min-width: 160px;
|
||||||
|
resize: horizontal;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Let React Flow node wrapper fit to the custom-node's size */
|
||||||
|
.react-flow__node-custom {
|
||||||
|
width: auto !important;
|
||||||
|
height: auto !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Title bar is the drag handle for moving the node */
|
||||||
|
.drag-handle {
|
||||||
|
cursor: grab;
|
||||||
|
}
|
||||||
|
.drag-handle:active {
|
||||||
|
cursor: grabbing;
|
||||||
|
}
|
||||||
|
|
||||||
|
.custom-node.selected {
|
||||||
|
border-color: #90caf9;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-title {
|
||||||
|
padding: 5px 10px;
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 12px;
|
||||||
|
color: white;
|
||||||
|
border-radius: 5px 5px 0 0;
|
||||||
|
border-bottom: 1px solid rgba(0, 0, 0, 0.3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-body {
|
||||||
|
padding: 4px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── I/O rows ──────────────────────────────────────────────────────── */
|
||||||
|
.io-row {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
padding: 3px 12px;
|
||||||
|
min-height: 22px;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.io-left, .io-right {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.io-label {
|
||||||
|
font-size: 10px;
|
||||||
|
color: #94a3b8;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Handles ───────────────────────────────────────────────────────── */
|
||||||
|
.typed-handle {
|
||||||
|
width: 10px !important;
|
||||||
|
height: 10px !important;
|
||||||
|
border: 2px solid #1e293b !important;
|
||||||
|
border-radius: 50% !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Widget rows ───────────────────────────────────────────────────── */
|
||||||
|
.widget-row {
|
||||||
|
padding: 3px 10px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.widget-row label {
|
||||||
|
font-size: 10px;
|
||||||
|
color: #64748b;
|
||||||
|
min-width: 40px;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.widget-row input[type="text"],
|
||||||
|
.widget-row input[type="number"],
|
||||||
|
.widget-row select {
|
||||||
|
background: #0f172a;
|
||||||
|
color: #e0e0e0;
|
||||||
|
border: 1px solid #334155;
|
||||||
|
border-radius: 3px;
|
||||||
|
padding: 2px 5px;
|
||||||
|
font-size: 11px;
|
||||||
|
flex: 1;
|
||||||
|
min-width: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.widget-row input[type="checkbox"] {
|
||||||
|
accent-color: #3a7abf;
|
||||||
|
}
|
||||||
|
|
||||||
|
.widget-row input:focus,
|
||||||
|
.widget-row select:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #3a7abf;
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-picker-row {
|
||||||
|
display: flex;
|
||||||
|
gap: 4px;
|
||||||
|
flex: 1;
|
||||||
|
min-width: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-picker-row input {
|
||||||
|
flex: 1;
|
||||||
|
min-width: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Draggable number ──────────────────────────────────────────────── */
|
||||||
|
.drag-number {
|
||||||
|
flex: 1;
|
||||||
|
min-width: 0;
|
||||||
|
background: #0f172a;
|
||||||
|
border: 1px solid #334155;
|
||||||
|
border-radius: 3px;
|
||||||
|
padding: 2px 6px;
|
||||||
|
cursor: ew-resize;
|
||||||
|
user-select: none;
|
||||||
|
text-align: center;
|
||||||
|
font-size: 11px;
|
||||||
|
color: #e0e0e0;
|
||||||
|
touch-action: none;
|
||||||
|
}
|
||||||
|
.drag-number:hover {
|
||||||
|
border-color: #3a7abf;
|
||||||
|
}
|
||||||
|
.drag-number-val {
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
.drag-number-edit {
|
||||||
|
flex: 1;
|
||||||
|
min-width: 0;
|
||||||
|
background: #0f172a;
|
||||||
|
border: 1px solid #3a7abf;
|
||||||
|
border-radius: 3px;
|
||||||
|
padding: 2px 5px;
|
||||||
|
font-size: 11px;
|
||||||
|
color: #e0e0e0;
|
||||||
|
text-align: center;
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.browse-btn {
|
||||||
|
background: #0f3460;
|
||||||
|
color: #e0e0e0;
|
||||||
|
border: 1px solid #334155;
|
||||||
|
border-radius: 3px;
|
||||||
|
padding: 2px 6px;
|
||||||
|
font-size: 10px;
|
||||||
|
cursor: pointer;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
.browse-btn:hover {
|
||||||
|
background: #1a4a8a;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Collapsible section ───────────────────────────────────────────── */
|
||||||
|
.collapsible {
|
||||||
|
border-top: 1px solid #334155;
|
||||||
|
margin-top: 4px;
|
||||||
|
}
|
||||||
|
.collapsible-toggle {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 4px;
|
||||||
|
width: 100%;
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
color: #64748b;
|
||||||
|
font-size: 10px;
|
||||||
|
padding: 3px 10px;
|
||||||
|
cursor: pointer;
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
.collapsible-toggle:hover {
|
||||||
|
color: #94a3b8;
|
||||||
|
}
|
||||||
|
.collapsible-arrow {
|
||||||
|
font-size: 9px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Node preview ──────────────────────────────────────────────────── */
|
||||||
|
.node-preview {
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.node-preview img {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 100%;
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Cross-section overlay ────────────────────────────────────────── */
|
||||||
|
.cs-overlay {
|
||||||
|
position: relative;
|
||||||
|
user-select: none;
|
||||||
|
touch-action: none;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
.cs-image {
|
||||||
|
width: 100%;
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
.cs-svg {
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
.cs-marker {
|
||||||
|
position: absolute;
|
||||||
|
width: 14px;
|
||||||
|
height: 14px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background: #ffd700;
|
||||||
|
border: 2px solid #fff;
|
||||||
|
transform: translate(-50%, -50%);
|
||||||
|
cursor: grab;
|
||||||
|
box-shadow: 0 0 4px rgba(0,0,0,0.6);
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
.cs-marker:active:not(.cs-marker-locked) {
|
||||||
|
cursor: grabbing;
|
||||||
|
background: #ffeb3b;
|
||||||
|
transform: translate(-50%, -50%) scale(1.2);
|
||||||
|
}
|
||||||
|
.cs-marker-locked {
|
||||||
|
background: #e91e63;
|
||||||
|
border-color: #e91e63;
|
||||||
|
cursor: default;
|
||||||
|
opacity: 0.9;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── 3D surface view ──────────────────────────────────────────────── */
|
||||||
|
.surface-view-container {
|
||||||
|
width: 100%;
|
||||||
|
aspect-ratio: 1 / 1;
|
||||||
|
cursor: grab;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
.surface-view-container:active {
|
||||||
|
cursor: grabbing;
|
||||||
|
}
|
||||||
|
.surface-view-container canvas {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Node table ────────────────────────────────────────────────────── */
|
||||||
|
.node-table {
|
||||||
|
padding: 4px 10px;
|
||||||
|
font-family: "SF Mono", "Fira Code", monospace;
|
||||||
|
font-size: 10px;
|
||||||
|
color: #cbd5e1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.table-line {
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Node resize handles ───────────────────────────────────────────── */
|
||||||
|
.node-resize-line {
|
||||||
|
border-color: #90caf9 !important;
|
||||||
|
}
|
||||||
|
.node-resize-handle {
|
||||||
|
background: #90caf9 !important;
|
||||||
|
width: 8px !important;
|
||||||
|
height: 8px !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Context menu ──────────────────────────────────────────────────── */
|
||||||
|
.context-menu {
|
||||||
|
position: fixed;
|
||||||
|
z-index: 1000;
|
||||||
|
background: #16213e;
|
||||||
|
border: 1px solid #0f3460;
|
||||||
|
border-radius: 6px;
|
||||||
|
min-width: 180px;
|
||||||
|
max-height: 60vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.5);
|
||||||
|
padding: 4px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.context-category {
|
||||||
|
padding: 6px 12px 3px;
|
||||||
|
font-size: 10px;
|
||||||
|
font-weight: 700;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.5px;
|
||||||
|
color: #64748b;
|
||||||
|
border-top: 1px solid #0f3460;
|
||||||
|
}
|
||||||
|
.context-category:first-child {
|
||||||
|
border-top: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.context-item {
|
||||||
|
padding: 5px 20px;
|
||||||
|
font-size: 12px;
|
||||||
|
cursor: pointer;
|
||||||
|
color: #e0e0e0;
|
||||||
|
}
|
||||||
|
.context-item:hover {
|
||||||
|
background: #0f3460;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── File browser dialog ──────────────────────────────────────────── */
|
||||||
|
.fb-backdrop {
|
||||||
|
position: fixed;
|
||||||
|
inset: 0;
|
||||||
|
background: rgba(0, 0, 0, 0.6);
|
||||||
|
z-index: 2000;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
.fb-dialog {
|
||||||
|
background: #16213e;
|
||||||
|
border: 1px solid #0f3460;
|
||||||
|
border-radius: 8px;
|
||||||
|
width: 520px;
|
||||||
|
max-height: 70vh;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5);
|
||||||
|
}
|
||||||
|
.fb-header {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
padding: 10px 14px;
|
||||||
|
border-bottom: 1px solid #0f3460;
|
||||||
|
}
|
||||||
|
.fb-path {
|
||||||
|
font-size: 12px;
|
||||||
|
color: #90caf9;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
white-space: nowrap;
|
||||||
|
flex: 1;
|
||||||
|
margin-right: 10px;
|
||||||
|
}
|
||||||
|
.fb-close {
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
color: #e0e0e0;
|
||||||
|
font-size: 16px;
|
||||||
|
cursor: pointer;
|
||||||
|
padding: 2px 6px;
|
||||||
|
}
|
||||||
|
.fb-close:hover { color: #e94560; }
|
||||||
|
.fb-list {
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 6px 0;
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
.fb-entry {
|
||||||
|
padding: 6px 14px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 13px;
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
}
|
||||||
|
.fb-entry:hover { background: #0f3460; }
|
||||||
|
.fb-dir { color: #90caf9; }
|
||||||
|
.fb-file { color: #e0e0e0; }
|
||||||
|
.fb-loading {
|
||||||
|
padding: 16px;
|
||||||
|
text-align: center;
|
||||||
|
color: #607d8b;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Scrollbar styling ─────────────────────────────────────────────── */
|
||||||
|
::-webkit-scrollbar { width: 6px; height: 6px; }
|
||||||
|
::-webkit-scrollbar-track { background: #1a1a2e; }
|
||||||
|
::-webkit-scrollbar-thumb { background: #0f3460; border-radius: 3px; }
|
||||||
|
::-webkit-scrollbar-thumb:hover { background: #3a7abf; }
|
||||||
|
|
||||||
|
/* ── React Flow MiniMap ────────────────────────────────────────────── */
|
||||||
|
.react-flow__minimap {
|
||||||
|
background: #16213e !important;
|
||||||
|
border: 1px solid #0f3460 !important;
|
||||||
|
border-radius: 4px !important;
|
||||||
|
}
|
||||||
23
frontend/vite.config.js
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import { defineConfig } from 'vite';
|
||||||
|
import react from '@vitejs/plugin-react';
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
server: {
|
||||||
|
port: 5173,
|
||||||
|
proxy: {
|
||||||
|
'/nodes': 'http://127.0.0.1:8188',
|
||||||
|
'/files': 'http://127.0.0.1:8188',
|
||||||
|
'/browse': 'http://127.0.0.1:8188',
|
||||||
|
'/upload': 'http://127.0.0.1:8188',
|
||||||
|
'/prompt': 'http://127.0.0.1:8188',
|
||||||
|
'/ws': {
|
||||||
|
target: 'http://127.0.0.1:8188',
|
||||||
|
ws: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
build: {
|
||||||
|
outDir: 'dist',
|
||||||
|
},
|
||||||
|
});
|
||||||
0
tests/__init__.py
Normal file
BIN
tests/output/01_sines_input.png
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
tests/output/01_sines_log_magnitude.png
Normal file
|
After Width: | Height: | Size: 981 B |
BIN
tests/output/01_sines_magnitude.png
Normal file
|
After Width: | Height: | Size: 1.0 KiB |
BIN
tests/output/01_sines_psdf.png
Normal file
|
After Width: | Height: | Size: 1.0 KiB |
BIN
tests/output/02_surface_fft_level_mean.png
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
tests/output/02_surface_fft_level_none.png
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
tests/output/02_surface_fft_level_plane.png
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
tests/output/02_surface_input.png
Normal file
|
After Width: | Height: | Size: 70 KiB |
BIN
tests/output/03_checker_fft.png
Normal file
|
After Width: | Height: | Size: 8.1 KiB |
BIN
tests/output/03_checker_input.png
Normal file
|
After Width: | Height: | Size: 1.0 KiB |
BIN
tests/output/04_rings_fft.png
Normal file
|
After Width: | Height: | Size: 23 KiB |
BIN
tests/output/04_rings_input.png
Normal file
|
After Width: | Height: | Size: 144 KiB |
BIN
tests/output/05_window_blackman.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
tests/output/05_window_hamming.png
Normal file
|
After Width: | Height: | Size: 2.4 KiB |
BIN
tests/output/05_window_hann.png
Normal file
|
After Width: | Height: | Size: 2.0 KiB |
BIN
tests/output/05_window_input.png
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
tests/output/05_window_none.png
Normal file
|
After Width: | Height: | Size: 1.7 KiB |
258
tests/test_fft.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""
|
||||||
|
Test the FFT2D node against known inputs and Gwyddion-equivalent results.
|
||||||
|
|
||||||
|
Run from project root:
|
||||||
|
python -m tests.test_fft
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
from backend.data_types import DataField
|
||||||
|
from backend.nodes.analysis import FFT2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_field(data, xreal=1e-6, yreal=1e-6):
|
||||||
|
"""Create a DataField from a 2D array."""
|
||||||
|
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
|
||||||
|
|
||||||
|
|
||||||
|
def test_dc_removal():
|
||||||
|
"""A constant image should produce near-zero FFT after mean subtraction."""
|
||||||
|
print("=== Test: DC removal ===")
|
||||||
|
data = np.ones((64, 64)) * 42.0
|
||||||
|
field = make_field(data)
|
||||||
|
node = FFT2D()
|
||||||
|
|
||||||
|
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||||
|
peak = result.data.max()
|
||||||
|
print(f" Peak magnitude after mean subtraction of constant image: {peak:.2e}")
|
||||||
|
assert peak < 1e-10, f"Expected ~0, got {peak}"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_frequency():
|
||||||
|
"""A pure sine wave should produce two peaks at the known frequency."""
|
||||||
|
print("=== Test: Single frequency detection ===")
|
||||||
|
N = 128
|
||||||
|
xreal = 1e-6 # 1 micron
|
||||||
|
freq_cycles = 10 # 10 cycles across the image
|
||||||
|
|
||||||
|
x = np.linspace(0, 1, N, endpoint=False)
|
||||||
|
data = np.sin(2 * np.pi * freq_cycles * x)[np.newaxis, :] * np.ones((N, 1))
|
||||||
|
field = make_field(data, xreal=xreal, yreal=xreal)
|
||||||
|
|
||||||
|
node = FFT2D()
|
||||||
|
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||||
|
|
||||||
|
# The peak should be at column offset = freq_cycles from center
|
||||||
|
mag = result.data
|
||||||
|
cy, cx = N // 2, N // 2 # center (DC)
|
||||||
|
|
||||||
|
# Find the peak location (excluding DC which should be ~0 after mean sub)
|
||||||
|
mag_copy = mag.copy()
|
||||||
|
mag_copy[cy, cx] = 0
|
||||||
|
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
|
||||||
|
peak_col_offset = abs(peak_idx[1] - cx)
|
||||||
|
|
||||||
|
print(f" Image: {N}x{N}, {freq_cycles} horizontal cycles")
|
||||||
|
print(f" Expected peak at column offset {freq_cycles} from center")
|
||||||
|
print(f" Found peak at {peak_idx} (offset {peak_col_offset})")
|
||||||
|
print(f" DC value: {mag[cy, cx]:.2e}")
|
||||||
|
print(f" Peak value: {mag[peak_idx]:.2e}")
|
||||||
|
assert peak_col_offset == freq_cycles, f"Expected offset {freq_cycles}, got {peak_col_offset}"
|
||||||
|
assert peak_idx[0] == cy, f"Expected peak on center row, got row {peak_idx[0]}"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_2d_frequency():
|
||||||
|
"""A 2D sine should produce peaks at the correct (kx, ky) position."""
|
||||||
|
print("=== Test: 2D frequency detection ===")
|
||||||
|
N = 128
|
||||||
|
fx, fy = 8, 5 # cycles in x and y
|
||||||
|
|
||||||
|
y, x = np.mgrid[0:N, 0:N] / N
|
||||||
|
data = np.sin(2 * np.pi * (fx * x + fy * y))
|
||||||
|
field = make_field(data)
|
||||||
|
|
||||||
|
node = FFT2D()
|
||||||
|
result, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||||
|
mag = result.data
|
||||||
|
|
||||||
|
cy, cx = N // 2, N // 2
|
||||||
|
mag_copy = mag.copy()
|
||||||
|
mag_copy[cy, cx] = 0
|
||||||
|
peak_idx = np.unravel_index(np.argmax(mag_copy), mag.shape)
|
||||||
|
|
||||||
|
dx = abs(peak_idx[1] - cx)
|
||||||
|
dy = abs(peak_idx[0] - cy)
|
||||||
|
|
||||||
|
print(f" Input: sin(2π({fx}x + {fy}y))")
|
||||||
|
print(f" Expected peak offset: ({fy}, {fx}) from center")
|
||||||
|
print(f" Found peak at {peak_idx} (offset dy={dy}, dx={dx})")
|
||||||
|
assert dx == fx and dy == fy, f"Expected ({fy},{fx}), got ({dy},{dx})"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_psdf_normalization():
|
||||||
|
"""
|
||||||
|
PSDF of white noise should integrate to the variance.
|
||||||
|
|
||||||
|
Parseval's theorem: sum of PSDF * dk_x * dk_y ≈ variance of the signal.
|
||||||
|
"""
|
||||||
|
print("=== Test: PSDF normalization (Parseval) ===")
|
||||||
|
N = 256
|
||||||
|
xreal = 1e-6
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
data = rng.standard_normal((N, N))
|
||||||
|
variance = data.var()
|
||||||
|
|
||||||
|
field = make_field(data, xreal=xreal, yreal=xreal)
|
||||||
|
node = FFT2D()
|
||||||
|
|
||||||
|
result, = node.process(field, windowing="none", level="none", output="psdf")
|
||||||
|
psdf = result.data
|
||||||
|
|
||||||
|
# Integrate: sum of PSDF * dk_x * dk_y
|
||||||
|
# Our output field has xreal = 2π*N/xreal (angular freq range)
|
||||||
|
dk_x = result.xreal / N
|
||||||
|
dk_y = result.yreal / N
|
||||||
|
integral = psdf.sum() * dk_x * dk_y
|
||||||
|
|
||||||
|
ratio = integral / variance
|
||||||
|
print(f" Signal variance: {variance:.6f}")
|
||||||
|
print(f" PSDF integral: {integral:.6f}")
|
||||||
|
print(f" Ratio (should be ~1.0): {ratio:.4f}")
|
||||||
|
# Allow 20% tolerance for finite-size effects
|
||||||
|
assert 0.8 < ratio < 1.2, f"Parseval's theorem violated: ratio = {ratio}"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_windowing_reduces_leakage():
|
||||||
|
"""Windowing should reduce spectral leakage from a non-integer frequency."""
|
||||||
|
print("=== Test: Windowing reduces leakage ===")
|
||||||
|
N = 128
|
||||||
|
freq = 10.5 # non-integer → spectral leakage without windowing
|
||||||
|
|
||||||
|
x = np.linspace(0, 1, N, endpoint=False)
|
||||||
|
data = np.sin(2 * np.pi * freq * x)[np.newaxis, :] * np.ones((N, 1))
|
||||||
|
field = make_field(data)
|
||||||
|
|
||||||
|
node = FFT2D()
|
||||||
|
|
||||||
|
# Without windowing
|
||||||
|
r_none, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||||
|
mag_none = r_none.data[N // 2, :] # center row
|
||||||
|
|
||||||
|
# With Hann windowing
|
||||||
|
r_hann, = node.process(field, windowing="hann", level="mean", output="magnitude")
|
||||||
|
mag_hann = r_hann.data[N // 2, :]
|
||||||
|
|
||||||
|
# Measure leakage: ratio of energy far from peak vs total
|
||||||
|
peak_col = np.argmax(mag_none)
|
||||||
|
far_mask = np.ones(N, dtype=bool)
|
||||||
|
far_mask[max(0, peak_col - 3):peak_col + 4] = False
|
||||||
|
# Also mask the symmetric peak
|
||||||
|
sym_col = N - peak_col
|
||||||
|
far_mask[max(0, sym_col - 3):sym_col + 4] = False
|
||||||
|
|
||||||
|
leakage_none = mag_none[far_mask].sum() / mag_none.sum()
|
||||||
|
leakage_hann = mag_hann[far_mask].sum() / mag_hann.sum()
|
||||||
|
|
||||||
|
print(f" Non-integer frequency: {freq}")
|
||||||
|
print(f" Leakage without windowing: {leakage_none:.4f}")
|
||||||
|
print(f" Leakage with Hann window: {leakage_hann:.4f}")
|
||||||
|
assert leakage_hann < leakage_none, "Hann window should reduce leakage"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_plane_subtraction():
|
||||||
|
"""Plane subtraction should remove linear gradients."""
|
||||||
|
print("=== Test: Plane subtraction ===")
|
||||||
|
N = 64
|
||||||
|
y, x = np.mgrid[0:N, 0:N] / N
|
||||||
|
# Tilted plane + sine wave
|
||||||
|
data = 100 * x + 50 * y + np.sin(2 * np.pi * 8 * x)
|
||||||
|
field = make_field(data)
|
||||||
|
|
||||||
|
node = FFT2D()
|
||||||
|
|
||||||
|
# Without leveling — huge DC and low-freq energy
|
||||||
|
r_none, = node.process(field, windowing="none", level="none", output="magnitude")
|
||||||
|
dc_none = r_none.data[N // 2, N // 2]
|
||||||
|
|
||||||
|
# With mean subtraction — DC removed but gradient leaks
|
||||||
|
r_mean, = node.process(field, windowing="none", level="mean", output="magnitude")
|
||||||
|
dc_mean = r_mean.data[N // 2, N // 2]
|
||||||
|
|
||||||
|
# With plane subtraction — gradient removed
|
||||||
|
r_plane, = node.process(field, windowing="none", level="plane", output="magnitude")
|
||||||
|
dc_plane = r_plane.data[N // 2, N // 2]
|
||||||
|
|
||||||
|
# With plane subtraction, check the low-freq energy near DC is reduced
|
||||||
|
# (plane subtraction removes gradients that leak into low frequencies)
|
||||||
|
r = 3 # radius around DC to check
|
||||||
|
cy, cx = N // 2, N // 2
|
||||||
|
lowfreq_none = r_none.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
|
||||||
|
lowfreq_plane = r_plane.data[cy-r:cy+r+1, cx-r:cx+r+1].sum()
|
||||||
|
|
||||||
|
print(f" DC magnitude (no leveling): {dc_none:.2e}")
|
||||||
|
print(f" DC magnitude (mean subtract): {dc_mean:.2e}")
|
||||||
|
print(f" DC magnitude (plane subtract): {dc_plane:.2e}")
|
||||||
|
print(f" Low-freq energy (no level): {lowfreq_none:.2e}")
|
||||||
|
print(f" Low-freq energy (plane sub): {lowfreq_plane:.2e}")
|
||||||
|
assert dc_mean < dc_none, "Mean subtraction should reduce DC"
|
||||||
|
assert lowfreq_plane < lowfreq_none * 0.01, "Plane subtraction should reduce low-freq energy"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_square():
|
||||||
|
"""FFT should work on non-square, non-power-of-2 images."""
|
||||||
|
print("=== Test: Non-square image ===")
|
||||||
|
data = np.random.default_rng(99).standard_normal((100, 150))
|
||||||
|
field = make_field(data, xreal=1.5e-6, yreal=1.0e-6)
|
||||||
|
node = FFT2D()
|
||||||
|
|
||||||
|
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
||||||
|
assert result.data.shape == (100, 150), f"Shape mismatch: {result.data.shape}"
|
||||||
|
assert np.all(np.isfinite(result.data)), "Non-finite values in output"
|
||||||
|
print(f" Shape: {result.data.shape}")
|
||||||
|
print(f" Output range: [{result.data.min():.4f}, {result.data.max():.4f}]")
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_magnitude_visual_range():
|
||||||
|
"""Log magnitude should produce a reasonable dynamic range for display."""
|
||||||
|
print("=== Test: Log magnitude visual range ===")
|
||||||
|
N = 128
|
||||||
|
x = np.linspace(0, 1, N, endpoint=False)
|
||||||
|
# Multi-frequency test image
|
||||||
|
y, x = np.mgrid[0:N, 0:N] / N
|
||||||
|
data = (np.sin(2 * np.pi * 5 * x) +
|
||||||
|
0.5 * np.sin(2 * np.pi * 15 * x + 2 * np.pi * 10 * y) +
|
||||||
|
0.1 * np.random.default_rng(7).standard_normal((N, N)))
|
||||||
|
field = make_field(data)
|
||||||
|
|
||||||
|
node = FFT2D()
|
||||||
|
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
||||||
|
|
||||||
|
vmin, vmax = result.data.min(), result.data.max()
|
||||||
|
dynamic_range = vmax - vmin if vmin > 0 else vmax / max(abs(vmin), 1e-30)
|
||||||
|
|
||||||
|
print(f" Log magnitude range: [{vmin:.4f}, {vmax:.4f}]")
|
||||||
|
print(f" Dynamic range: {dynamic_range:.2f}")
|
||||||
|
assert vmax > vmin, "Log magnitude should have nonzero range"
|
||||||
|
assert np.all(np.isfinite(result.data)), "Non-finite values in log magnitude"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_dc_removal()
|
||||||
|
test_single_frequency()
|
||||||
|
test_2d_frequency()
|
||||||
|
test_psdf_normalization()
|
||||||
|
test_windowing_reduces_leakage()
|
||||||
|
test_plane_subtraction()
|
||||||
|
test_non_square()
|
||||||
|
test_log_magnitude_visual_range()
|
||||||
|
print("All tests passed!")
|
||||||
97
tests/test_fft_visual.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
Generate test images and their FFT outputs for visual comparison with Gwyddion.
|
||||||
|
Saves PNG files to tests/output/.
|
||||||
|
|
||||||
|
Run: .venv/bin/python -m tests.test_fft_visual
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
from backend.data_types import DataField, datafield_to_uint8, encode_preview
|
||||||
|
from backend.nodes.analysis import FFT2D
|
||||||
|
|
||||||
|
OUT_DIR = os.path.join(os.path.dirname(__file__), "output")
|
||||||
|
os.makedirs(OUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def save_field(field, name, colormap="viridis"):
|
||||||
|
"""Save a DataField as a PNG for visual inspection."""
|
||||||
|
from PIL import Image
|
||||||
|
arr = datafield_to_uint8(field, colormap)
|
||||||
|
img = Image.fromarray(arr)
|
||||||
|
path = os.path.join(OUT_DIR, f"{name}.png")
|
||||||
|
img.save(path)
|
||||||
|
print(f" Saved {path} (range: [{field.data.min():.4g}, {field.data.max():.4g}])")
|
||||||
|
|
||||||
|
|
||||||
|
def make_field(data, xreal=1e-6, yreal=1e-6):
|
||||||
|
return DataField(data=data, xreal=xreal, yreal=yreal)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
node = FFT2D()
|
||||||
|
N = 256
|
||||||
|
|
||||||
|
# --- Test 1: Multi-frequency sine waves ---
|
||||||
|
print("Test 1: Multi-frequency sine waves")
|
||||||
|
y, x = np.mgrid[0:N, 0:N] / N
|
||||||
|
data = (np.sin(2 * np.pi * 10 * x)
|
||||||
|
+ 0.7 * np.sin(2 * np.pi * 25 * y)
|
||||||
|
+ 0.3 * np.sin(2 * np.pi * (15 * x + 8 * y)))
|
||||||
|
field = make_field(data)
|
||||||
|
save_field(field, "01_sines_input")
|
||||||
|
|
||||||
|
for output_mode in ["log_magnitude", "magnitude", "psdf"]:
|
||||||
|
result, = node.process(field, windowing="hann", level="mean", output=output_mode)
|
||||||
|
save_field(result, f"01_sines_{output_mode}")
|
||||||
|
|
||||||
|
# --- Test 2: Real-world-like surface with noise + tilt ---
|
||||||
|
print("\nTest 2: Tilted surface with features")
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
data = (50 * x + 30 * y # tilt
|
||||||
|
+ np.sin(2 * np.pi * 20 * x) # periodic feature
|
||||||
|
+ 0.5 * rng.standard_normal((N, N))) # noise
|
||||||
|
field = make_field(data)
|
||||||
|
save_field(field, "02_surface_input")
|
||||||
|
|
||||||
|
for level_mode in ["none", "mean", "plane"]:
|
||||||
|
result, = node.process(field, windowing="hann", level=level_mode, output="log_magnitude")
|
||||||
|
save_field(result, f"02_surface_fft_level_{level_mode}")
|
||||||
|
|
||||||
|
# --- Test 3: Checkerboard pattern ---
|
||||||
|
print("\nTest 3: Checkerboard")
|
||||||
|
freq = 16
|
||||||
|
data = np.sign(np.sin(2 * np.pi * freq * x) * np.sin(2 * np.pi * freq * y))
|
||||||
|
field = make_field(data)
|
||||||
|
save_field(field, "03_checker_input")
|
||||||
|
|
||||||
|
result, = node.process(field, windowing="none", level="mean", output="log_magnitude")
|
||||||
|
save_field(result, "03_checker_fft")
|
||||||
|
|
||||||
|
# --- Test 4: Concentric rings (radial frequency) ---
|
||||||
|
print("\nTest 4: Concentric rings")
|
||||||
|
r = np.sqrt((x - 0.5)**2 + (y - 0.5)**2)
|
||||||
|
data = np.sin(2 * np.pi * 30 * r)
|
||||||
|
field = make_field(data)
|
||||||
|
save_field(field, "04_rings_input")
|
||||||
|
|
||||||
|
result, = node.process(field, windowing="hann", level="mean", output="log_magnitude")
|
||||||
|
save_field(result, "04_rings_fft")
|
||||||
|
|
||||||
|
# --- Test 5: Compare windowing effects ---
|
||||||
|
print("\nTest 5: Windowing comparison")
|
||||||
|
data = np.sin(2 * np.pi * 10.5 * x) + 0.5 * np.sin(2 * np.pi * 30.3 * y)
|
||||||
|
field = make_field(data)
|
||||||
|
save_field(field, "05_window_input")
|
||||||
|
|
||||||
|
for win in ["none", "hann", "hamming", "blackman"]:
|
||||||
|
result, = node.process(field, windowing=win, level="mean", output="log_magnitude")
|
||||||
|
save_field(result, f"05_window_{win}")
|
||||||
|
|
||||||
|
print(f"\nAll outputs saved to {OUT_DIR}/")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
488
tests/test_nodes.py
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
"""
|
||||||
|
Tests for all argonode backend nodes (excluding FFT2D which has its own test file).
|
||||||
|
|
||||||
|
Run from project root:
|
||||||
|
.venv/bin/python -m tests.test_nodes
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
from backend.data_types import DataField
|
||||||
|
|
||||||
|
|
||||||
|
def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
|
||||||
|
"""Create a DataField, optionally from given data or a random field."""
|
||||||
|
if data is None:
|
||||||
|
data = np.random.default_rng(42).standard_normal(shape)
|
||||||
|
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Filters
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_gaussian_filter():
|
||||||
|
print("=== Test: GaussianFilter ===")
|
||||||
|
from backend.nodes.filters import GaussianFilter
|
||||||
|
node = GaussianFilter()
|
||||||
|
field = make_field()
|
||||||
|
|
||||||
|
result, = node.process(field, sigma=2.0)
|
||||||
|
assert result.data.shape == field.data.shape
|
||||||
|
assert result.xreal == field.xreal
|
||||||
|
assert result.si_unit_z == field.si_unit_z
|
||||||
|
# Gaussian blur should reduce variance
|
||||||
|
assert result.data.std() < field.data.std()
|
||||||
|
# With very small sigma, output should be nearly unchanged
|
||||||
|
result_tiny, = node.process(field, sigma=0.01)
|
||||||
|
assert np.allclose(result_tiny.data, field.data, atol=1e-6)
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_median_filter():
|
||||||
|
print("=== Test: MedianFilter ===")
|
||||||
|
from backend.nodes.filters import MedianFilter
|
||||||
|
node = MedianFilter()
|
||||||
|
|
||||||
|
# Median filter should remove salt-and-pepper noise
|
||||||
|
data = np.zeros((64, 64))
|
||||||
|
rng = np.random.default_rng(7)
|
||||||
|
noise_idx = rng.choice(64 * 64, size=100, replace=False)
|
||||||
|
data.ravel()[noise_idx] = 1.0
|
||||||
|
field = make_field(data=data)
|
||||||
|
|
||||||
|
result, = node.process(field, size=3)
|
||||||
|
assert result.data.shape == field.data.shape
|
||||||
|
# Should remove most impulse noise
|
||||||
|
assert result.data.sum() < field.data.sum()
|
||||||
|
# Size=1 should be identity
|
||||||
|
result_1, = node.process(field, size=1)
|
||||||
|
assert np.array_equal(result_1.data, field.data)
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_detect():
|
||||||
|
print("=== Test: EdgeDetect ===")
|
||||||
|
from backend.nodes.filters import EdgeDetect
|
||||||
|
node = EdgeDetect()
|
||||||
|
|
||||||
|
# Create an image with a sharp vertical edge
|
||||||
|
data = np.zeros((64, 64))
|
||||||
|
data[:, 32:] = 1.0
|
||||||
|
field = make_field(data=data)
|
||||||
|
|
||||||
|
for method in ["sobel", "prewitt", "laplacian", "log"]:
|
||||||
|
result, = node.process(field, method=method, sigma=1.0)
|
||||||
|
assert result.data.shape == field.data.shape
|
||||||
|
# Edge response should be strongest near column 32
|
||||||
|
col_energy = np.abs(result.data).sum(axis=0)
|
||||||
|
peak_col = np.argmax(col_energy)
|
||||||
|
assert abs(peak_col - 32) <= 2, f"{method}: peak at col {peak_col}, expected ~32"
|
||||||
|
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Level
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_plane_level():
|
||||||
|
print("=== Test: PlaneLevelField ===")
|
||||||
|
from backend.nodes.level import PlaneLevelField
|
||||||
|
node = PlaneLevelField()
|
||||||
|
|
||||||
|
# Create a tilted plane + small signal
|
||||||
|
N = 64
|
||||||
|
y, x = np.mgrid[0:N, 0:N] / N
|
||||||
|
signal = np.sin(2 * np.pi * 5 * x)
|
||||||
|
data = 100 * x + 50 * y + signal
|
||||||
|
field = make_field(data=data)
|
||||||
|
|
||||||
|
result, = node.process(field)
|
||||||
|
assert result.data.shape == field.data.shape
|
||||||
|
# After plane leveling, mean should be near zero
|
||||||
|
assert abs(result.data.mean()) < 1e-10
|
||||||
|
# The signal should remain (correlation with original sine)
|
||||||
|
corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1]
|
||||||
|
assert corr > 0.98, f"Signal correlation after leveling: {corr}"
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_poly_level():
|
||||||
|
print("=== Test: PolyLevelField ===")
|
||||||
|
from backend.nodes.level import PolyLevelField
|
||||||
|
node = PolyLevelField()
|
||||||
|
|
||||||
|
N = 64
|
||||||
|
y, x = np.mgrid[0:N, 0:N] / N
|
||||||
|
# Quadratic background + signal
|
||||||
|
background = 50 * x**2 + 30 * y**2 + 10 * x * y
|
||||||
|
signal = np.sin(2 * np.pi * 8 * x)
|
||||||
|
data = background + signal
|
||||||
|
field = make_field(data=data)
|
||||||
|
|
||||||
|
leveled, bg = node.process(field, degree_x=2, degree_y=2)
|
||||||
|
assert leveled.data.shape == field.data.shape
|
||||||
|
assert bg.data.shape == field.data.shape
|
||||||
|
# leveled + bg should reconstruct original
|
||||||
|
assert np.allclose(leveled.data + bg.data, field.data, atol=1e-10)
|
||||||
|
# Signal should be preserved after leveling
|
||||||
|
corr = np.corrcoef(leveled.data.ravel(), signal.ravel())[0, 1]
|
||||||
|
assert corr > 0.95, f"Signal correlation after poly leveling: {corr}"
|
||||||
|
# Degree 0 should just subtract the mean
|
||||||
|
leveled_0, bg_0 = node.process(field, degree_x=0, degree_y=0)
|
||||||
|
assert abs(leveled_0.data.mean()) < 1e-10
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_fix_zero():
|
||||||
|
print("=== Test: FixZero ===")
|
||||||
|
from backend.nodes.level import FixZero
|
||||||
|
node = FixZero()
|
||||||
|
field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64))
|
||||||
|
|
||||||
|
result_min, = node.process(field, method="min")
|
||||||
|
assert result_min.data.min() == 0.0
|
||||||
|
assert result_min.data.max() == 30.0
|
||||||
|
|
||||||
|
result_mean, = node.process(field, method="mean")
|
||||||
|
assert abs(result_mean.data.mean()) < 1e-10
|
||||||
|
|
||||||
|
result_median, = node.process(field, method="median")
|
||||||
|
assert abs(np.median(result_median.data)) < 1e-10
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Analysis (non-FFT)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_statistics():
|
||||||
|
print("=== Test: StatisticsNode ===")
|
||||||
|
from backend.nodes.analysis import StatisticsNode
|
||||||
|
node = StatisticsNode()
|
||||||
|
|
||||||
|
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
|
||||||
|
field = make_field(data=data)
|
||||||
|
|
||||||
|
table, = node.process(field)
|
||||||
|
stats = {row["quantity"]: row["value"] for row in table}
|
||||||
|
|
||||||
|
assert stats["min"] == 1.0
|
||||||
|
assert stats["max"] == 4.0
|
||||||
|
assert stats["mean"] == 2.5
|
||||||
|
assert stats["median"] == 2.5
|
||||||
|
assert stats["range"] == 3.0
|
||||||
|
# RMS = sqrt(mean((x - mean)^2))
|
||||||
|
expected_rms = np.sqrt(np.mean((data - 2.5) ** 2))
|
||||||
|
assert abs(stats["RMS"] - expected_rms) < 1e-10
|
||||||
|
|
||||||
|
# Constant data should have RMS=0, skewness=0, kurtosis=0
|
||||||
|
const_field = make_field(data=np.ones((4, 4)) * 5.0)
|
||||||
|
table_const, = node.process(const_field)
|
||||||
|
const_stats = {row["quantity"]: row["value"] for row in table_const}
|
||||||
|
assert const_stats["RMS"] == 0.0
|
||||||
|
assert const_stats["skewness"] == 0.0
|
||||||
|
assert const_stats["kurtosis"] == 0.0
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_height_histogram():
|
||||||
|
print("=== Test: HeightHistogram ===")
|
||||||
|
from backend.nodes.analysis import HeightHistogram
|
||||||
|
node = HeightHistogram()
|
||||||
|
|
||||||
|
# Uniform data should give a roughly flat histogram
|
||||||
|
data = np.linspace(0, 1, 1000).reshape(25, 40)
|
||||||
|
field = make_field(data=data)
|
||||||
|
|
||||||
|
counts, bin_centers = node.process(field, n_bins=10)
|
||||||
|
assert len(counts) == 10
|
||||||
|
assert len(bin_centers) == 10
|
||||||
|
assert counts.dtype == np.float64
|
||||||
|
# Total counts should equal number of pixels
|
||||||
|
assert counts.sum() == 1000
|
||||||
|
# For uniform data, each bin should have ~100 counts
|
||||||
|
assert np.std(counts) < 10, f"Histogram not flat enough: std={np.std(counts)}"
|
||||||
|
# Bin centers should span the data range
|
||||||
|
assert bin_centers[0] > 0.0
|
||||||
|
assert bin_centers[-1] < 1.0
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_section():
|
||||||
|
print("=== Test: CrossSection ===")
|
||||||
|
from backend.nodes.analysis import CrossSection
|
||||||
|
node = CrossSection()
|
||||||
|
|
||||||
|
# Create a field with a known horizontal gradient
|
||||||
|
N = 100
|
||||||
|
y, x = np.mgrid[0:N, 0:N] / N
|
||||||
|
data = x * 10.0 # value = 10 * x_fraction
|
||||||
|
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
||||||
|
|
||||||
|
# Horizontal cross section at y=0.5
|
||||||
|
(profile,) = node.process(
|
||||||
|
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
|
||||||
|
extend="none", n_samples=100,
|
||||||
|
)
|
||||||
|
assert len(profile) == 100
|
||||||
|
# Profile should be a linear ramp from ~0 to ~10
|
||||||
|
assert profile[0] < 0.5, f"Start of profile: {profile[0]}"
|
||||||
|
assert profile[-1] > 9.5, f"End of profile: {profile[-1]}"
|
||||||
|
|
||||||
|
# n_samples=0 should auto-calculate
|
||||||
|
(profile_auto,) = node.process(
|
||||||
|
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
|
||||||
|
extend="none", n_samples=0,
|
||||||
|
)
|
||||||
|
assert len(profile_auto) >= 2
|
||||||
|
|
||||||
|
# Test extend to edges — a short segment should be extended
|
||||||
|
(profile_ext,) = node.process(
|
||||||
|
field, x1=0.3, y1=0.5, x2=0.7, y2=0.5,
|
||||||
|
extend="to_edges", n_samples=100,
|
||||||
|
)
|
||||||
|
# Extended profile should start near 0 and end near 10
|
||||||
|
assert profile_ext[0] < 0.5
|
||||||
|
assert profile_ext[-1] > 9.5
|
||||||
|
|
||||||
|
# Diagonal cross section
|
||||||
|
(profile_diag,) = node.process(
|
||||||
|
field, x1=0.0, y1=0.0, x2=1.0, y2=1.0,
|
||||||
|
extend="none", n_samples=50,
|
||||||
|
)
|
||||||
|
assert len(profile_diag) == 50
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Grains
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_threshold_mask():
|
||||||
|
print("=== Test: ThresholdMask ===")
|
||||||
|
from backend.nodes.grains import ThresholdMask
|
||||||
|
node = ThresholdMask()
|
||||||
|
|
||||||
|
# Clear bimodal data: left half = 0, right half = 1
|
||||||
|
data = np.zeros((64, 64))
|
||||||
|
data[:, 32:] = 1.0
|
||||||
|
field = make_field(data=data)
|
||||||
|
|
||||||
|
# Absolute threshold at 0.5
|
||||||
|
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
|
||||||
|
assert mask.dtype == np.uint8
|
||||||
|
assert mask.shape == (64, 64)
|
||||||
|
assert np.all(mask[:, :32] == 0)
|
||||||
|
assert np.all(mask[:, 32:] == 255)
|
||||||
|
|
||||||
|
# Direction "below"
|
||||||
|
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
|
||||||
|
assert np.all(mask_below[:, :32] == 255)
|
||||||
|
assert np.all(mask_below[:, 32:] == 0)
|
||||||
|
|
||||||
|
# Relative threshold at 0.5 (midpoint of range)
|
||||||
|
mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above")
|
||||||
|
assert np.all(mask_rel[:, 32:] == 255)
|
||||||
|
|
||||||
|
# Otsu should find the bimodal threshold
|
||||||
|
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||||
|
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_grain_analysis():
|
||||||
|
print("=== Test: GrainAnalysis ===")
|
||||||
|
from backend.nodes.grains import GrainAnalysis
|
||||||
|
node = GrainAnalysis()
|
||||||
|
|
||||||
|
# Create a field with two distinct "grains"
|
||||||
|
N = 64
|
||||||
|
data = np.zeros((N, N))
|
||||||
|
# Grain 1: 10x10 block at top-left with height 5
|
||||||
|
data[5:15, 5:15] = 5.0
|
||||||
|
# Grain 2: 8x8 block at bottom-right with height 3
|
||||||
|
data[45:53, 45:53] = 3.0
|
||||||
|
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
||||||
|
|
||||||
|
# Create matching mask
|
||||||
|
mask = np.zeros((N, N), dtype=np.uint8)
|
||||||
|
mask[5:15, 5:15] = 255
|
||||||
|
mask[45:53, 45:53] = 255
|
||||||
|
|
||||||
|
table, = node.process(field, mask=mask, min_size=10)
|
||||||
|
assert len(table) == 2, f"Expected 2 grains, got {len(table)}"
|
||||||
|
|
||||||
|
# Sort by area descending
|
||||||
|
table.sort(key=lambda r: r["area_px"], reverse=True)
|
||||||
|
assert table[0]["area_px"] == 100 # 10x10
|
||||||
|
assert table[1]["area_px"] == 64 # 8x8
|
||||||
|
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
|
||||||
|
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
|
||||||
|
|
||||||
|
# min_size filtering: only keep grains >= 80 px
|
||||||
|
table_filtered, = node.process(field, mask=mask, min_size=80)
|
||||||
|
assert len(table_filtered) == 1
|
||||||
|
assert table_filtered[0]["area_px"] == 100
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# I/O
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_load_image():
|
||||||
|
print("=== Test: LoadImage ===")
|
||||||
|
from backend.nodes.io import LoadImage
|
||||||
|
from PIL import Image
|
||||||
|
node = LoadImage()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Test loading a grayscale PNG
|
||||||
|
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
|
||||||
|
img = Image.fromarray(arr, mode="L")
|
||||||
|
path = os.path.join(tmpdir, "test_gray.png")
|
||||||
|
img.save(path)
|
||||||
|
|
||||||
|
image, field = node.load(filename=path)
|
||||||
|
assert image.shape == (48, 64)
|
||||||
|
assert field.data.shape == (48, 64)
|
||||||
|
assert field.data.dtype == np.float64
|
||||||
|
|
||||||
|
# Test loading an RGB PNG (should average to grayscale for field)
|
||||||
|
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
|
||||||
|
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
|
||||||
|
path_rgb = os.path.join(tmpdir, "test_rgb.png")
|
||||||
|
img_rgb.save(path_rgb)
|
||||||
|
|
||||||
|
image_rgb, field_rgb = node.load(filename=path_rgb)
|
||||||
|
assert image_rgb.shape == (32, 32, 3)
|
||||||
|
assert field_rgb.data.shape == (32, 32)
|
||||||
|
|
||||||
|
# Test loading a .npy file
|
||||||
|
data_npy = np.random.default_rng(3).standard_normal((50, 60))
|
||||||
|
path_npy = os.path.join(tmpdir, "test.npy")
|
||||||
|
np.save(path_npy, data_npy)
|
||||||
|
|
||||||
|
image_npy, field_npy = node.load(filename=path_npy)
|
||||||
|
assert np.allclose(field_npy.data, data_npy)
|
||||||
|
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image():
|
||||||
|
print("=== Test: SaveImage ===")
|
||||||
|
from backend.nodes.io import SaveImage
|
||||||
|
node = SaveImage()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Monkey-patch OUTPUT_DIR for testing
|
||||||
|
from pathlib import Path
|
||||||
|
import backend.nodes.io as io_mod
|
||||||
|
orig_dir = io_mod.OUTPUT_DIR
|
||||||
|
io_mod.OUTPUT_DIR = Path(tmpdir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Save as PNG
|
||||||
|
node.save(image=arr, filename_prefix="test", format="PNG")
|
||||||
|
saved = os.listdir(tmpdir)
|
||||||
|
assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}"
|
||||||
|
|
||||||
|
# Save as NPY
|
||||||
|
node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY")
|
||||||
|
saved = os.listdir(tmpdir)
|
||||||
|
assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}"
|
||||||
|
finally:
|
||||||
|
io_mod.OUTPUT_DIR = orig_dir
|
||||||
|
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Display (limited testing — these are output nodes with WS callbacks)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def test_preview_image():
|
||||||
|
print("=== Test: PreviewImage ===")
|
||||||
|
from backend.nodes.display import PreviewImage
|
||||||
|
node = PreviewImage()
|
||||||
|
|
||||||
|
# Set up a capture for the broadcast
|
||||||
|
captured = []
|
||||||
|
PreviewImage._broadcast_fn = lambda node_id, data_uri: captured.append(data_uri)
|
||||||
|
PreviewImage._current_node_id = "test"
|
||||||
|
|
||||||
|
# Preview with a DataField
|
||||||
|
field = make_field()
|
||||||
|
node.preview(colormap="viridis", field=field)
|
||||||
|
assert len(captured) == 1
|
||||||
|
assert captured[0].startswith("data:image/png;base64,")
|
||||||
|
|
||||||
|
# Preview with an IMAGE array
|
||||||
|
captured.clear()
|
||||||
|
arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8)
|
||||||
|
node.preview(colormap="gray", image=arr)
|
||||||
|
assert len(captured) == 1
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
PreviewImage._broadcast_fn = None
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_table():
|
||||||
|
print("=== Test: PrintTable ===")
|
||||||
|
from backend.nodes.display import PrintTable
|
||||||
|
node = PrintTable()
|
||||||
|
|
||||||
|
captured = []
|
||||||
|
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
|
||||||
|
PrintTable._current_node_id = "test"
|
||||||
|
|
||||||
|
table = [{"quantity": "test", "value": 42.0, "unit": "m"}]
|
||||||
|
node.print_table(table=table)
|
||||||
|
assert len(captured) == 1
|
||||||
|
assert captured[0] == table
|
||||||
|
|
||||||
|
PrintTable._broadcast_table_fn = None
|
||||||
|
print(" PASS\n")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Run all tests
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Filters
|
||||||
|
test_gaussian_filter()
|
||||||
|
test_median_filter()
|
||||||
|
test_edge_detect()
|
||||||
|
|
||||||
|
# Level
|
||||||
|
test_plane_level()
|
||||||
|
test_poly_level()
|
||||||
|
test_fix_zero()
|
||||||
|
|
||||||
|
# Analysis
|
||||||
|
test_statistics()
|
||||||
|
test_height_histogram()
|
||||||
|
test_cross_section()
|
||||||
|
|
||||||
|
# Grains
|
||||||
|
test_threshold_mask()
|
||||||
|
test_grain_analysis()
|
||||||
|
|
||||||
|
# I/O
|
||||||
|
test_load_image()
|
||||||
|
test_save_image()
|
||||||
|
|
||||||
|
# Display
|
||||||
|
test_preview_image()
|
||||||
|
test_print_table()
|
||||||
|
|
||||||
|
print("All tests passed!")
|
||||||