fix preview and save on native

This commit is contained in:
2026-03-24 22:52:24 -07:00
parent a60b0c15ca
commit 6959c62c8f
16 changed files with 875 additions and 202 deletions

View File

@@ -194,18 +194,17 @@ class ExecutionEngine:
CrossSection._broadcast_overlay_fn = on_overlay CrossSection._broadcast_overlay_fn = on_overlay
LineCursors._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay
LoadFile._broadcast_warning_fn = on_warning LoadFile._broadcast_warning_fn = on_warning
SaveImage._broadcast_preview = ( SaveImage._broadcast_warning_fn = on_warning
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
)
def _set_node_id_on_display(self, cls: type, node_id: str) -> None: def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging.""" """Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection, LineCursors from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import LoadFile from backend.nodes.io import LoadFile, SaveImage
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors, if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile): ThresholdMask, MaskMorphology, MaskInvert, MaskCombine,
LoadFile, SaveImage):
cls._current_node_id = node_id cls._current_node_id = node_id
def _auto_preview( def _auto_preview(
@@ -262,11 +261,9 @@ class ExecutionEngine:
cls: type, cls: type,
slot: int, slot: int,
result: tuple, result: tuple,
) -> str | None: ) -> dict | None:
"""Render a LINE output as a small matplotlib plot, returned as a data URI.""" """Return structured LINE preview data for responsive frontend rendering."""
import numpy as np import numpy as np
import base64
import io as _io
return_types = getattr(cls, "RETURN_TYPES", ()) return_types = getattr(cls, "RETURN_TYPES", ())
@@ -281,17 +278,22 @@ class ExecutionEngine:
return None # the first LINE already plotted both return None # the first LINE already plotted both
try: try:
import base64
import io as _io
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
y = np.asarray(y, dtype=np.float64).ravel()
if x is None:
x = np.arange(len(y), dtype=np.float64)
else:
x = np.asarray(x, dtype=np.float64).ravel()[:len(y)]
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100) fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
fig.patch.set_facecolor("#1e293b") fig.patch.set_facecolor("#1e293b")
ax.set_facecolor("#0f172a") ax.set_facecolor("#0f172a")
if x is not None: ax.plot(x, y, color="#ff9800", linewidth=1.2)
ax.plot(x, y, color="#ff9800", linewidth=1.2)
else:
ax.plot(y, color="#ff9800", linewidth=1.2)
ax.tick_params(colors="#94a3b8", labelsize=7) ax.tick_params(colors="#94a3b8", labelsize=7)
for spine in ax.spines.values(): for spine in ax.spines.values():
spine.set_color("#334155") spine.set_color("#334155")
@@ -301,8 +303,15 @@ class ExecutionEngine:
buf = _io.BytesIO() buf = _io.BytesIO()
fig.savefig(buf, format="png", facecolor=fig.get_facecolor()) fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
plt.close(fig) plt.close(fig)
b64 = base64.b64encode(buf.getvalue()).decode() fallback_image = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
return f"data:image/png;base64,{b64}"
return {
"kind": "line_plot",
"line": y.tolist(),
"x_axis": x.tolist(),
"interactive": False,
"fallback_image": fallback_image,
}
except Exception: except Exception:
return None return None

View File

@@ -47,6 +47,7 @@ def get_node_info(class_name: str) -> dict[str, Any]:
"output": list(cls.RETURN_TYPES), "output": list(cls.RETURN_TYPES),
"output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)), "output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)),
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)), "output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
"manual_trigger": bool(getattr(cls, "MANUAL_TRIGGER", False)),
"description": getattr(cls, "DESCRIPTION", ""), "description": getattr(cls, "DESCRIPTION", ""),
} }

View File

@@ -11,7 +11,7 @@ Gwyddion equivalents:
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.data_types import DataField from backend.data_types import DataField, datafield_to_uint8, encode_preview
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -131,88 +131,49 @@ class LineCursors:
self, line, x1: float, y1: float, x2: float, y2: float, self, line, x1: float, y1: float, x2: float, y2: float,
x_axis=None, x_axis=None,
) -> tuple: ) -> tuple:
import io as _io
import base64
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
y = np.asarray(line, dtype=np.float64).ravel() y = np.asarray(line, dtype=np.float64).ravel()
n = len(y) n = len(y)
if x_axis is not None: if x_axis is not None:
x = np.asarray(x_axis, dtype=np.float64).ravel()[:n] x = np.asarray(x_axis, dtype=np.float64).ravel()[:n]
else: else:
x = np.arange(n, dtype=np.float64) x = np.arange(n, dtype=np.float64)
x1 = float(np.clip(x1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 1.0))
# --- Render the base plot first to determine axes bounds --- xmin = float(np.min(x)) if len(x) else 0.0
fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100) xmax = float(np.max(x)) if len(x) else 1.0
fig.patch.set_facecolor("#1e293b")
ax.set_facecolor("#0f172a")
ax.plot(x, y, color="#ff9800", linewidth=1.2)
ax.tick_params(colors="#94a3b8", labelsize=7)
for spine in ax.spines.values():
spine.set_color("#334155")
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
fig.tight_layout(pad=0.4)
# Force a draw so transforms are valid def x_frac_to_idx(frac):
fig.canvas.draw() if n <= 1:
return 0
if xmax == xmin:
return 0
target_x = xmin + frac * (xmax - xmin)
return int(np.argmin(np.abs(x - target_x)))
# Get axes position in figure-fraction coordinates idx_a = x_frac_to_idx(x1)
ax_pos = ax.get_position() idx_b = x_frac_to_idx(x2)
ax_l, ax_b = ax_pos.x0, ax_pos.y0
ax_w, ax_h = ax_pos.width, ax_pos.height
# x1/y1 arrive as image-fraction from the frontend drag.
# Convert image-fraction x → axes-fraction → nearest data index.
def img_x_to_idx(ix):
axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1)
return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1))
idx_a = img_x_to_idx(x1)
idx_b = img_x_to_idx(x2)
xa, ya = float(x[idx_a]), float(y[idx_a]) xa, ya = float(x[idx_a]), float(y[idx_a])
xb, yb = float(x[idx_b]), float(y[idx_b]) xb, yb = float(x[idx_b]), float(y[idx_b])
# --- Draw cursor lines and markers on the plot ---
ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5)
ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5)
ax.annotate(
"", xy=(xb, yb), xytext=(xa, ya),
arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5),
)
# --- Broadcast overlay --- # --- Broadcast overlay ---
if LineCursors._broadcast_overlay_fn is not None: if LineCursors._broadcast_overlay_fn is not None:
# Convert data-space positions back to image-fraction for markers
fig.canvas.draw()
inv = fig.transFigure.inverted()
fig_a = inv.transform(ax.transData.transform([xa, ya]))
fig_b = inv.transform(ax.transData.transform([xb, yb]))
buf = _io.BytesIO()
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
LineCursors._broadcast_overlay_fn( LineCursors._broadcast_overlay_fn(
LineCursors._current_node_id, LineCursors._current_node_id,
{ {
"image": image_uri, "kind": "line_plot",
"x1": float(fig_a[0]), "line": y.tolist(),
"y1": float(1.0 - fig_a[1]), # flip: image y=0 is top "x_axis": x.tolist(),
"x2": float(fig_b[0]), "x1": x1,
"y2": float(1.0 - fig_b[1]), "x2": x2,
"y1": float(y1),
"y2": float(y2),
"a_locked": False, "a_locked": False,
"b_locked": False, "b_locked": False,
}, },
) )
plt.close(fig)
# --- Output table --- # --- Output table ---
table = [ table = [
{"quantity": "A position", "value": xa, "unit": ""}, {"quantity": "A position", "value": xa, "unit": ""},
@@ -414,8 +375,6 @@ class CrossSection:
point_a=None, point_b=None, point_a=None, point_b=None,
) -> tuple: ) -> tuple:
from scipy.ndimage import map_coordinates from scipy.ndimage import map_coordinates
import io, base64
from matplotlib.figure import Figure
# COORD inputs override widget values # COORD inputs override widget values
if point_a is not None: if point_a is not None:
@@ -453,14 +412,9 @@ class CrossSection:
# Broadcast overlay image with marker positions # Broadcast overlay image with marker positions
if CrossSection._broadcast_overlay_fn is not None: if CrossSection._broadcast_overlay_fn is not None:
fig = Figure(figsize=(3, 3), dpi=100) # Use the field's native pixel grid for the overlay preview so enlarging
ax = fig.add_axes([0, 0, 1, 1]) # the panel keeps the image as sharp as the source data allows.
ax.imshow(field.data, cmap="viridis", aspect="auto") image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
ax.axis("off")
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
CrossSection._broadcast_overlay_fn( CrossSection._broadcast_overlay_fn(
CrossSection._current_node_id, CrossSection._current_node_id,

View File

@@ -78,7 +78,7 @@ class View3D:
"required": { "required": {
"field": ("DATA_FIELD",), "field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS),), "colormap": (["auto"] + list(COLORMAPS),),
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}), "z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}), "resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
} }
} }
@@ -134,7 +134,7 @@ class View3D:
"colors": colors_b64, "colors": colors_b64,
"z_min": zmin, "z_min": zmin,
"z_max": zmax, "z_max": zmax,
"z_scale": float(z_scale), "z_scale": float(z_scale * 0.1),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)], "x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)], "y_range": [float(field.yoff), float(field.yoff + field.yreal)],
} }

View File

@@ -399,54 +399,86 @@ class Coordinate:
# SaveImage # SaveImage
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@register_node(display_name="Save Image") _MAX_SAVE_FIELDS = 8
@register_node(display_name="Save Layers")
class SaveImage: class SaveImage:
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
optional = {}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("DATA_FIELD",)
return { return {
"required": { "required": {
"image": ("IMAGE",), "filename": ("FILE_PICKER", {"default": ""}),
"filename_prefix": ("STRING", {"default": "output"}), "format": (["TIFF", "NPZ"],),
"format": (["PNG", "TIFF", "NPY"],), },
} "optional": optional,
} }
RETURN_TYPES = () RETURN_TYPES = ()
FUNCTION = "save" FUNCTION = "save"
CATEGORY = "io" CATEGORY = "io"
OUTPUT_NODE = True OUTPUT_NODE = True
DESCRIPTION = "Save an image or array to the output folder." MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more DATA_FIELD layers to a single file. "
"Connect fields to the inputs — a new slot appears as each is filled. "
"TIFF writes float32 multi-page; NPZ writes float64 named arrays. "
"Click Save to write (does not auto-run)."
)
# Injected by server.py before execution begins _broadcast_warning_fn = None
_broadcast_preview = None _current_node_id = None
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"): def save(self, filename: str, format: str = "TIFF", **kwargs):
OUTPUT_DIR.mkdir(exist_ok=True) # Collect connected fields in order
fields = []
for i in range(_MAX_SAVE_FIELDS):
f = kwargs.get(f"field_{i}")
if f is not None:
fields.append(f)
# Find next available filename if not fields:
idx = 1 raise ValueError("No fields connected — connect at least one DATA_FIELD input.")
while True:
name = f"{filename_prefix}_{idx:04d}"
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
if not candidate.exists():
break
idx += 1
if format == "NPY": if not filename or not filename.strip():
np.save(str(OUTPUT_DIR / f"{name}.npy"), image) raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(filename)
# Ensure parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)
# Force correct extension
ext = ".tiff" if format == "TIFF" else ".npz"
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
if format == "TIFF":
self._save_tiff(path, fields)
else: else:
from PIL import Image self._save_npz(path, fields)
arr = image_to_uint8(image)
if arr.ndim == 2:
pil_img = Image.fromarray(arr, mode="L")
else:
pil_img = Image.fromarray(arr, mode="RGB")
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
# Emit preview over WebSocket if callback is set self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}")
if SaveImage._broadcast_preview is not None: return ()
arr_u8 = image_to_uint8(image)
data_uri = encode_preview(arr_u8) def _save_tiff(self, path: Path, fields: list[DataField]):
SaveImage._broadcast_preview(data_uri) from PIL import Image
images = []
for f in fields:
images.append(Image.fromarray(f.data.astype(np.float32)))
images[0].save(str(path), save_all=True, append_images=images[1:])
def _save_npz(self, path: Path, fields: list[DataField]):
arrays = {}
for i, f in enumerate(fields):
arrays[f"layer_{i}"] = f.data
np.savez(str(path), **arrays)
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return () return ()

View File

@@ -41,6 +41,7 @@ FRONTEND_DIR = frontend_dir()
DIST_DIR = frontend_dist_dir() DIST_DIR = frontend_dist_dir()
INPUT_DIR = input_dir() INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir() OUTPUT_DIR = output_dir()
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -63,6 +64,18 @@ def _dumps(obj) -> str:
return json.dumps(obj, cls=_SafeEncoder) return json.dumps(obj, cls=_SafeEncoder)
def save_png_bytes(target_path: str, payload: bytes) -> Path:
path = Path(target_path).expanduser()
if not target_path.strip():
raise ValueError("Missing save path")
if path.suffix.lower() != ".png":
path = path.with_suffix(".png")
if not payload.startswith(PNG_SIGNATURE):
raise ValueError("Payload is not a valid PNG")
path.write_bytes(payload)
return path
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Application factory # Application factory
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -196,6 +209,20 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
}, },
) )
async def save_workflow_png(request: web.Request) -> web.Response:
body = await request.read()
target_path = request.query.get("path", "")
if not target_path:
raise web.HTTPBadRequest(reason="Missing path")
try:
saved_path = save_png_bytes(target_path, body)
except ValueError as exc:
raise web.HTTPBadRequest(reason=str(exc)) from exc
return web.Response(
text=_dumps({"path": str(saved_path)}),
content_type="application/json",
)
async def get_channels(request: web.Request) -> web.Response: async def get_channels(request: web.Request) -> web.Response:
"""Return available channels for a given file path.""" """Return available channels for a given file path."""
from backend.nodes.io import list_channels from backend.nodes.io import list_channels
@@ -278,6 +305,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
app.router.add_get("/browse", browse_dir) app.router.add_get("/browse", browse_dir)
app.router.add_post("/upload", upload_file) app.router.add_post("/upload", upload_file)
app.router.add_post("/download", download_file) app.router.add_post("/download", download_file)
app.router.add_post("/save-workflow-png", save_workflow_png)
app.router.add_get("/channels", get_channels) app.router.add_get("/channels", get_channels)
app.router.add_post("/prompt", submit_prompt) app.router.add_post("/prompt", submit_prompt)
app.router.add_get("/ws", websocket_handler) app.router.add_get("/ws", websocket_handler)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import base64
import logging import logging
import socket import socket
import threading import threading
@@ -45,8 +44,8 @@ class _Api:
return result[0] return result[0]
return None return None
def save_workflow_png(self, data_url: str, default_filename: str = "workflow.png") -> str | None: def choose_save_workflow_png_path(self, default_filename: str = "workflow.png") -> str | None:
"""Open a native save dialog, write the PNG bytes, and return the saved path.""" """Open a native save dialog and return the chosen PNG path (or None)."""
win = self._window_ref[0] win = self._window_ref[0]
if win is None: if win is None:
return None return None
@@ -65,12 +64,6 @@ class _Api:
path = Path(result[0] if isinstance(result, (list, tuple)) else result).expanduser() path = Path(result[0] if isinstance(result, (list, tuple)) else result).expanduser()
if path.suffix.lower() != ".png": if path.suffix.lower() != ".png":
path = path.with_suffix(".png") path = path.with_suffix(".png")
_, _, encoded = data_url.partition(",")
if not encoded:
raise ValueError("Invalid data URL payload")
path.write_bytes(base64.b64decode(encoded))
return str(path) return str(path)

View File

@@ -13,6 +13,7 @@ import FileBrowser from './FileBrowser';
import * as api from './api'; import * as api from './api';
import { toBlob } from 'html-to-image'; import { toBlob } from 'html-to-image';
import { embedWorkflow, extractWorkflow } from './pngMetadata'; import { embedWorkflow, extractWorkflow } from './pngMetadata';
import { hydrateWorkflowState } from './workflowHydration';
import { serializeWorkflowState } from './workflowSerialization'; import { serializeWorkflowState } from './workflowSerialization';
// ── Constants ───────────────────────────────────────────────────────── // ── Constants ─────────────────────────────────────────────────────────
@@ -43,15 +44,6 @@ function getOutputSlot(handleId) {
return parseInt(handleId.split('::')[1], 10); return parseInt(handleId.split('::')[1], 10);
} }
function blobToDataUrl(blob) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onloadend = () => resolve(reader.result);
reader.onerror = () => reject(reader.error || new Error('Failed to read file'));
reader.readAsDataURL(blob);
});
}
async function waitForImageElement(img) { async function waitForImageElement(img) {
if (img.complete && img.naturalWidth > 0) return; if (img.complete && img.naturalWidth > 0) return;
if (typeof img.decode === 'function') { if (typeof img.decode === 'function') {
@@ -73,6 +65,31 @@ async function waitForImageElement(img) {
}); });
} }
async function getCaptureImageDataUrl(img) {
const src = img.currentSrc || img.src;
if (!src) return null;
if (!src.startsWith('data:')) return src;
const rect = img.getBoundingClientRect();
const width = Math.max(1, Math.round(img.clientWidth || rect.width));
const height = Math.max(1, Math.round(img.clientHeight || rect.height));
const scale = Math.min(2, window.devicePixelRatio || 1);
const canvas = document.createElement('canvas');
canvas.width = Math.max(1, Math.round(width * scale));
canvas.height = Math.max(1, Math.round(height * scale));
const ctx = canvas.getContext('2d');
if (!ctx) return src;
try {
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
return canvas.toDataURL('image/png');
} catch {
return src;
}
}
function createCapturePlaceholder(el, dataUrl) { function createCapturePlaceholder(el, dataUrl) {
const rect = el.getBoundingClientRect(); const rect = el.getBoundingClientRect();
const style = window.getComputedStyle(el); const style = window.getComputedStyle(el);
@@ -101,8 +118,9 @@ async function captureViewportBlob(viewportEl, options) {
await Promise.all(images.map(waitForImageElement)); await Promise.all(images.map(waitForImageElement));
for (const img of images) { for (const img of images) {
const dataUrl = img.currentSrc || img.src; if (!img.parentNode) continue;
if (!dataUrl || !img.parentNode) continue; const dataUrl = await getCaptureImageDataUrl(img);
if (!dataUrl) continue;
const placeholder = createCapturePlaceholder(img, dataUrl); const placeholder = createCapturePlaceholder(img, dataUrl);
img.parentNode.replaceChild(placeholder, img); img.parentNode.replaceChild(placeholder, img);
restorers.push(() => { restorers.push(() => {
@@ -144,12 +162,13 @@ async function captureViewportBlob(viewportEl, options) {
// ── Graph serialisation → backend prompt format ─────────────────────── // ── Graph serialisation → backend prompt format ───────────────────────
function serializeGraph(nodes, edges) { function serializeGraph(nodes, edges, { excludeManualTrigger = false } = {}) {
const prompt = {}; const prompt = {};
for (const node of nodes) { for (const node of nodes) {
const { className, definition, widgetValues } = node.data; const { className, definition, widgetValues } = node.data;
if (!definition) continue; if (!definition) continue;
if (excludeManualTrigger && definition.manual_trigger) continue;
const inputs = {}; const inputs = {};
@@ -551,10 +570,23 @@ function Flow() {
// ── Node context value (stable) ───────────────────────────────────── // ── Node context value (stable) ─────────────────────────────────────
const onManualTrigger = useCallback((nodeId) => {
const currentNodes = reactFlow.getNodes();
const currentEdges = reactFlow.getEdges();
// Include ALL nodes (no excludeManualTrigger) so the save node is in the prompt
const prompt = serializeGraph(currentNodes, currentEdges);
if (!prompt || Object.keys(prompt).length === 0) return;
setStatus({ text: 'Saving…', level: 'info' });
api.runPrompt(prompt).catch((err) => {
setStatus({ text: 'Save failed: ' + err.message, level: 'error' });
});
}, [reactFlow]);
const contextValue = useMemo(() => ({ const contextValue = useMemo(() => ({
onWidgetChange, onWidgetChange,
openFileBrowser, openFileBrowser,
}), [onWidgetChange, openFileBrowser]); onManualTrigger,
}), [onWidgetChange, openFileBrowser, onManualTrigger]);
// ── Add node from context menu ────────────────────────────────────── // ── Add node from context menu ──────────────────────────────────────
@@ -687,13 +719,18 @@ function Flow() {
const currentNodes = reactFlow.getNodes(); const currentNodes = reactFlow.getNodes();
const currentEdges = reactFlow.getEdges(); const currentEdges = reactFlow.getEdges();
// Don't run if any node has unconnected required data inputs // Don't run if any non-manual node has unconnected required data inputs
// or any FILE_PICKER widget is empty
for (const node of currentNodes) { for (const node of currentNodes) {
const def = node.data?.definition; const def = node.data?.definition;
if (!def) continue; if (!def || def.manual_trigger) continue; // skip manual-trigger nodes
const required = def.input.required || {}; const required = def.input.required || {};
for (const [name, spec] of Object.entries(required)) { for (const [name, spec] of Object.entries(required)) {
const [type] = Array.isArray(spec) ? spec : [spec]; const [type] = Array.isArray(spec) ? spec : [spec];
if (type === 'FILE_PICKER') {
if (!node.data.widgetValues?.[name]) return; // no file selected, skip
continue;
}
if (!DATA_TYPES.has(type)) continue; if (!DATA_TYPES.has(type)) continue;
const hasEdge = currentEdges.some( const hasEdge = currentEdges.some(
(e) => e.target === node.id && getInputName(e.targetHandle) === name (e) => e.target === node.id && getInputName(e.targetHandle) === name
@@ -702,7 +739,7 @@ function Flow() {
} }
} }
const prompt = serializeGraph(currentNodes, currentEdges); const prompt = serializeGraph(currentNodes, currentEdges, { excludeManualTrigger: true });
if (!prompt || Object.keys(prompt).length === 0) return; if (!prompt || Object.keys(prompt).length === 0) return;
setStatus({ text: 'Running…', level: 'info' }); setStatus({ text: 'Running…', level: 'info' });
api.runPrompt(prompt).catch((err) => { api.runPrompt(prompt).catch((err) => {
@@ -723,25 +760,10 @@ function Flow() {
}, [setNodes, setEdges]); }, [setNodes, setEdges]);
const applyWorkflowData = useCallback((data) => { const applyWorkflowData = useCallback((data) => {
const loadedNodes = data.nodes || []; const hydrated = hydrateWorkflowState(data, nodeDefsRef.current);
const loadedEdges = data.edges || []; setNodes(hydrated.nodes);
const defs = nodeDefsRef.current; setEdges(hydrated.edges);
const hydrated = loadedNodes.map((n) => ({ nextIdRef.current = hydrated.nextNodeId;
...n,
type: n.type || 'custom',
dragHandle: n.dragHandle || '.drag-handle',
data: {
...n.data,
label: n.data?.label || n.data?.className || 'Node',
widgetValues: n.data?.widgetValues || {},
definition: defs[n.data.className] || n.data.definition,
previewImage: null, tableRows: null, meshData: null, overlay: null,
},
}));
setNodes(hydrated);
setEdges(loadedEdges);
const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
nextIdRef.current = maxId + 1;
}, [setNodes, setEdges]); }, [setNodes, setEdges]);
const getWorkflowBlob = useCallback(async () => { const getWorkflowBlob = useCallback(async () => {
@@ -778,9 +800,23 @@ function Flow() {
try { try {
const finalBlob = await getWorkflowBlob(); const finalBlob = await getWorkflowBlob();
if (window.pywebview?.api?.save_workflow_png) { if (window.pywebview?.api?.choose_save_workflow_png_path) {
const dataUrl = await blobToDataUrl(finalBlob); const requestedPath = await window.pywebview.api.choose_save_workflow_png_path('workflow.png');
const savedPath = await window.pywebview.api.save_workflow_png(dataUrl, 'workflow.png'); if (!requestedPath) {
setStatus({ text: 'Save cancelled.', level: 'info' });
return;
}
const resp = await fetch(`/save-workflow-png?path=${encodeURIComponent(requestedPath)}`, {
method: 'POST',
headers: {
'Content-Type': 'image/png',
},
body: finalBlob,
});
if (!resp.ok) {
throw new Error(await resp.text() || `Save failed (${resp.status})`);
}
const { path: savedPath } = await resp.json();
if (!savedPath) { if (!savedPath) {
setStatus({ text: 'Save cancelled.', level: 'info' }); setStatus({ text: 'Save cancelled.', level: 'info' });
return; return;

View File

@@ -1,5 +1,6 @@
import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react'; import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react';
import { Handle, Position } from '@xyflow/react'; import { Handle, Position, useStore } from '@xyflow/react';
import LinePlotOverlay from './LinePlotOverlay';
const SurfaceView = lazy(() => import('./SurfaceView')); const SurfaceView = lazy(() => import('./SurfaceView'));
const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay')); const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay'));
@@ -29,6 +30,47 @@ const CAT_COLORS = {
export const NodeContext = React.createContext(null); export const NodeContext = React.createContext(null);
class PreviewBoundary extends React.Component {
constructor(props) {
super(props);
this.state = { hasError: false };
}
static getDerivedStateFromError() {
return { hasError: true };
}
componentDidCatch(error) {
console.error('[argonode] preview render failed', error);
}
componentDidUpdate(prevProps) {
if (prevProps.resetKey !== this.props.resetKey && this.state.hasError) {
this.setState({ hasError: false });
}
}
render() {
if (!this.state.hasError) {
return this.props.children;
}
if (this.props.fallbackImage) {
return (
<div className="node-preview">
<img src={this.props.fallbackImage} alt="preview fallback" draggable={false} />
</div>
);
}
return (
<div className="node-preview" style={{ color: '#94a3b8', padding: 8 }}>
Preview unavailable.
</div>
);
}
}
// ── Draggable number input ──────────────────────────────────────────── // ── Draggable number input ────────────────────────────────────────────
function DraggableNumber({ value, step, min, max, precision, onChange }) { function DraggableNumber({ value, step, min, max, precision, onChange }) {
@@ -151,8 +193,39 @@ function CustomNode({ id, data }) {
} }
} }
// For manual-trigger nodes (Save), show progressive optional inputs:
// show field_N only if field_(N-1) is connected (or N==0).
const isProgressive = def.manual_trigger;
const connectedInputs = useStore(
useCallback(
(s) => {
if (!isProgressive) return null;
const set = new Set();
for (const e of s.edges) {
if (e.target === id) {
const parts = e.targetHandle?.split('::');
if (parts) set.add(parts[1]);
}
}
return set;
},
[id, isProgressive],
),
);
for (const [name, spec] of Object.entries(optional)) { for (const [name, spec] of Object.entries(optional)) {
const [type] = Array.isArray(spec) ? spec : [spec]; const [type] = Array.isArray(spec) ? spec : [spec];
if (isProgressive && DATA_TYPES.has(type)) {
// Progressive: show this slot only if it's the first or the previous is connected
const match = name.match(/^field_(\d+)$/);
if (match) {
const idx = parseInt(match[1], 10);
if (idx === 0 || (connectedInputs && connectedInputs.has(`field_${idx - 1}`))) {
dataInputs.push({ name, type });
}
continue;
}
}
dataInputs.push({ name, type }); dataInputs.push({ name, type });
} }
@@ -229,6 +302,19 @@ function CustomNode({ id, data }) {
</div> </div>
))} ))}
{/* Manual trigger button (Save) */}
{def.manual_trigger && (
<div className="widget-row">
<button
className="nodrag btn btn-primary"
style={{ flex: 1 }}
onClick={() => ctx.onManualTrigger?.(id)}
>
Save to Disk
</button>
</div>
)}
{/* Interactive 3D surface view */} {/* Interactive 3D surface view */}
{data.meshData && ( {data.meshData && (
<CollapsibleSection title="3D View" defaultOpen={true}> <CollapsibleSection title="3D View" defaultOpen={true}>
@@ -241,9 +327,21 @@ function CustomNode({ id, data }) {
{/* Collapsible preview image */} {/* Collapsible preview image */}
{data.previewImage && ( {data.previewImage && (
<CollapsibleSection title="Preview" defaultOpen={true}> <CollapsibleSection title="Preview" defaultOpen={true}>
<div className="node-preview"> <PreviewBoundary
<img src={data.previewImage} alt="preview" draggable={false} /> resetKey={typeof data.previewImage === 'string' ? data.previewImage : JSON.stringify({
</div> kind: data.previewImage.kind,
len: data.previewImage.line?.length,
})}
fallbackImage={typeof data.previewImage === 'object' ? data.previewImage.fallback_image : null}
>
{typeof data.previewImage === 'string' ? (
<div className="node-preview">
<img src={data.previewImage} alt="preview" draggable={false} />
</div>
) : data.previewImage.kind === 'line_plot' ? (
<LinePlotOverlay overlay={data.previewImage} interactive={false} />
) : null}
</PreviewBoundary>
</CollapsibleSection> </CollapsibleSection>
)} )}
@@ -251,17 +349,29 @@ function CustomNode({ id, data }) {
{data.overlay && hiddenWidgets.has('x1') && ( {data.overlay && hiddenWidgets.has('x1') && (
<CollapsibleSection title="Cross Section" defaultOpen={true}> <CollapsibleSection title="Cross Section" defaultOpen={true}>
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading...</div>}> <Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading...</div>}>
<CrossSectionOverlay {data.overlay.kind === 'line_plot' ? (
image={data.overlay.image} <LinePlotOverlay
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)} overlay={data.overlay}
y1={data.overlay.a_locked ? data.overlay.y1 : (data.widgetValues.y1 ?? data.overlay.y1)} x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)} x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
y2={data.overlay.b_locked ? data.overlay.y2 : (data.widgetValues.y2 ?? data.overlay.y2)} aLocked={data.overlay.a_locked}
aLocked={data.overlay.a_locked} bLocked={data.overlay.b_locked}
bLocked={data.overlay.b_locked} nodeId={id}
nodeId={id} onWidgetChange={ctx.onWidgetChange}
onWidgetChange={ctx.onWidgetChange} />
/> ) : (
<CrossSectionOverlay
image={data.overlay.image}
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
y1={data.overlay.a_locked ? data.overlay.y1 : (data.widgetValues.y1 ?? data.overlay.y1)}
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
y2={data.overlay.b_locked ? data.overlay.y2 : (data.widgetValues.y2 ?? data.overlay.y2)}
aLocked={data.overlay.a_locked}
bLocked={data.overlay.b_locked}
nodeId={id}
onWidgetChange={ctx.onWidgetChange}
/>
)}
</Suspense> </Suspense>
</CollapsibleSection> </CollapsibleSection>
)} )}

View File

@@ -0,0 +1,271 @@
import React, { useEffect, useRef, useState, useCallback } from 'react';
const ASPECT_RATIO = 3.2 / 2.2;
const MARGINS = { top: 18, right: 16, bottom: 34, left: 56 };
function clamp(v, min, max) {
return Math.max(min, Math.min(max, v));
}
function round3(v) {
return parseFloat(v.toFixed(3));
}
function trimZeros(text) {
return text.replace(/(?:\.0+|(\.\d+?)0+)$/, '$1');
}
function formatTick(value) {
const abs = Math.abs(value);
if (abs === 0) return '0';
if (abs >= 1e4 || abs < 1e-3) {
return value.toExponential(1).replace('e+', 'e');
}
if (abs >= 100) return trimZeros(value.toFixed(0));
if (abs >= 10) return trimZeros(value.toFixed(1));
if (abs >= 1) return trimZeros(value.toFixed(2));
return trimZeros(value.toFixed(3));
}
function makeTicks(min, max, count = 5) {
if (!Number.isFinite(min) || !Number.isFinite(max)) return [];
if (min === max) return [min];
const ticks = [];
for (let i = 0; i < count; i += 1) {
ticks.push(min + ((max - min) * i) / (count - 1));
}
return ticks;
}
function getExtent(values, fallbackMin = 0, fallbackMax = 1) {
if (!Array.isArray(values) || values.length === 0) {
return [fallbackMin, fallbackMax];
}
let min = Infinity;
let max = -Infinity;
for (const value of values) {
if (!Number.isFinite(value)) continue;
if (value < min) min = value;
if (value > max) max = value;
}
if (!Number.isFinite(min) || !Number.isFinite(max)) {
return [fallbackMin, fallbackMax];
}
return [min, max];
}
export default function LinePlotOverlay({
overlay,
x1,
x2,
aLocked,
bLocked,
nodeId,
onWidgetChange,
interactive = true,
}) {
const containerRef = useRef(null);
const [dragging, setDragging] = useState(null);
const [size, setSize] = useState({ width: 0, height: 0 });
useEffect(() => {
if (!containerRef.current) return undefined;
const updateSize = () => {
if (!containerRef.current) return;
setSize({
width: Math.max(1, Math.round(containerRef.current.clientWidth || 320)),
height: Math.max(1, Math.round(containerRef.current.clientHeight || (containerRef.current.clientWidth / ASPECT_RATIO) || 220)),
});
};
updateSize();
if (typeof ResizeObserver === 'function') {
const observer = new ResizeObserver((entries) => {
const entry = entries[0];
if (!entry) return;
const { width, height } = entry.contentRect;
setSize({
width: Math.max(1, Math.round(width)),
height: Math.max(1, Math.round(height)),
});
});
observer.observe(containerRef.current);
return () => observer.disconnect();
}
window.addEventListener('resize', updateSize);
return () => window.removeEventListener('resize', updateSize);
}, []);
const xValues = Array.isArray(overlay?.x_axis) && overlay.x_axis.length === overlay.line?.length
? overlay.x_axis
: overlay?.line?.map((_, i) => i) || [];
const yValues = Array.isArray(overlay?.line) ? overlay.line : [];
const width = size.width || 320;
const height = size.height || Math.round(width / ASPECT_RATIO);
const plotLeft = MARGINS.left;
const plotTop = MARGINS.top;
const plotWidth = Math.max(1, width - MARGINS.left - MARGINS.right);
const plotHeight = Math.max(1, height - MARGINS.top - MARGINS.bottom);
const [xMin, xMax] = getExtent(xValues, 0, 1);
const [yMinRaw, yMaxRaw] = getExtent(yValues, 0, 1);
const yPad = yMinRaw === yMaxRaw ? 1 : (yMaxRaw - yMinRaw) * 0.08;
const yMin = yMinRaw - yPad;
const yMax = yMaxRaw + yPad;
const scaleX = useCallback((value) => {
if (xMax === xMin) return plotLeft + plotWidth / 2;
return plotLeft + ((value - xMin) / (xMax - xMin)) * plotWidth;
}, [plotLeft, plotWidth, xMin, xMax]);
const scaleY = useCallback((value) => {
if (yMax === yMin) return plotTop + plotHeight / 2;
return plotTop + (1 - ((value - yMin) / (yMax - yMin))) * plotHeight;
}, [plotTop, plotHeight, yMin, yMax]);
const pickCursorPoint = useCallback((fraction) => {
if (!xValues.length || !yValues.length) {
return {
x: plotLeft,
y: plotTop + plotHeight / 2,
yFraction: 0.5,
};
}
const frac = clamp(fraction ?? 0.5, 0, 1);
const targetX = xMin + frac * (xMax - xMin || 1);
let idx = 0;
let best = Infinity;
for (let i = 0; i < xValues.length; i += 1) {
const dist = Math.abs(xValues[i] - targetX);
if (dist < best) {
best = dist;
idx = i;
}
}
const x = xValues[idx];
const y = yValues[idx];
const yFraction = yMax === yMin ? 0.5 : clamp((y - yMin) / (yMax - yMin), 0, 1);
return {
x: scaleX(x),
y: scaleY(y),
yFraction,
};
}, [plotLeft, plotTop, plotHeight, scaleX, scaleY, xValues, yValues, xMin, xMax, yMin, yMax]);
const cursorA = pickCursorPoint(x1 ?? overlay?.x1 ?? 0.25);
const cursorB = pickCursorPoint(x2 ?? overlay?.x2 ?? 0.75);
const path = yValues.map((y, i) => `${i === 0 ? 'M' : 'L'} ${scaleX(xValues[i])} ${scaleY(y)}`).join(' ');
const xTicks = makeTicks(xMin, xMax);
const yTicks = makeTicks(yMin, yMax);
const plotStroke = clamp(plotWidth / 240, 1.4, 2.6);
const gridStroke = clamp(plotWidth / 900, 0.6, 1.1);
const cursorStroke = clamp(plotWidth / 220, 1.4, 2.2);
const measureStroke = clamp(plotWidth / 180, 1.6, 2.8);
const markerRadius = clamp(plotWidth / 42, 5.5, 9);
const updateCursor = useCallback((point, event) => {
if (!interactive || !onWidgetChange || !nodeId) return;
if (!containerRef.current) return;
const rect = containerRef.current.getBoundingClientRect();
const xFrac = clamp((event.clientX - rect.left - plotLeft) / plotWidth, 0, 1);
const sample = pickCursorPoint(xFrac);
if (point === 'p1') {
onWidgetChange(nodeId, 'x1', round3(xFrac));
onWidgetChange(nodeId, 'y1', round3(sample.yFraction));
} else {
onWidgetChange(nodeId, 'x2', round3(xFrac));
onWidgetChange(nodeId, 'y2', round3(sample.yFraction));
}
}, [interactive, nodeId, onWidgetChange, pickCursorPoint, plotLeft, plotWidth]);
const onPointerDown = useCallback((point) => (event) => {
if (!interactive) return;
if ((point === 'p1' && aLocked) || (point === 'p2' && bLocked)) return;
event.preventDefault();
event.stopPropagation();
event.currentTarget.setPointerCapture(event.pointerId);
setDragging(point);
}, [interactive, aLocked, bLocked]);
const onPointerMove = useCallback((event) => {
if (!dragging) return;
updateCursor(dragging, event);
}, [dragging, updateCursor]);
const onPointerUp = useCallback(() => {
setDragging(null);
}, []);
return (
<div
ref={containerRef}
className="nodrag nowheel lineplot-overlay"
onPointerMove={onPointerMove}
onPointerUp={onPointerUp}
onLostPointerCapture={onPointerUp}
>
<svg width={width} height={height} viewBox={`0 0 ${width} ${height}`} className="lineplot-svg">
<rect x="0" y="0" width={width} height={height} fill="#0f172a" />
{xTicks.map((tick) => {
const x = scaleX(tick);
return (
<g key={`x-${tick}`}>
<line x1={x} y1={plotTop} x2={x} y2={plotTop + plotHeight} stroke="#334155" strokeWidth={gridStroke} opacity="0.45" />
<text x={x} y={height - 10} textAnchor="middle" fontSize="11" fill="#94a3b8">
{formatTick(tick)}
</text>
</g>
);
})}
{yTicks.map((tick) => {
const y = scaleY(tick);
return (
<g key={`y-${tick}`}>
<line x1={plotLeft} y1={y} x2={plotLeft + plotWidth} y2={y} stroke="#334155" strokeWidth={gridStroke} opacity="0.45" />
<text x={plotLeft - 10} y={y + 4} textAnchor="end" fontSize="11" fill="#94a3b8">
{formatTick(tick)}
</text>
</g>
);
})}
<rect x={plotLeft} y={plotTop} width={plotWidth} height={plotHeight} fill="none" stroke="#334155" strokeWidth={gridStroke + 0.3} />
<path d={path} fill="none" stroke="#ff9800" strokeWidth={plotStroke} strokeLinecap="round" strokeLinejoin="round" />
{interactive && (
<>
<line x1={cursorA.x} y1={plotTop} x2={cursorA.x} y2={plotTop + plotHeight} stroke="#ffd700" strokeWidth={cursorStroke} strokeDasharray="10 6" opacity="0.95" />
<line x1={cursorB.x} y1={plotTop} x2={cursorB.x} y2={plotTop + plotHeight} stroke="#ffd700" strokeWidth={cursorStroke} strokeDasharray="10 6" opacity="0.95" />
<line x1={cursorA.x} y1={cursorA.y} x2={cursorB.x} y2={cursorB.y} stroke="#90caf9" strokeWidth={measureStroke} opacity="0.95" />
<circle
cx={cursorA.x}
cy={cursorA.y}
r={markerRadius}
className={`lineplot-marker ${aLocked ? 'lineplot-marker-locked' : ''}`}
onPointerDown={onPointerDown('p1')}
/>
<circle
cx={cursorB.x}
cy={cursorB.y}
r={markerRadius}
className={`lineplot-marker ${bLocked ? 'lineplot-marker-locked' : ''}`}
onPointerDown={onPointerDown('p2')}
/>
</>
)}
</svg>
</div>
);
}

View File

@@ -367,6 +367,41 @@ html, body, #root {
opacity: 0.9; opacity: 0.9;
} }
.lineplot-overlay {
width: 100%;
aspect-ratio: 32 / 22;
background: #0f172a;
border: 1px solid #334155;
border-radius: 6px;
overflow: hidden;
user-select: none;
touch-action: none;
}
.lineplot-svg {
display: block;
width: 100%;
height: 100%;
}
.lineplot-marker {
fill: #ffd700;
stroke: #fff;
stroke-width: 2px;
cursor: grab;
filter: drop-shadow(0 0 4px rgba(0, 0, 0, 0.45));
}
.lineplot-marker:active {
cursor: grabbing;
}
.lineplot-marker-locked {
fill: #e91e63;
stroke: #e91e63;
cursor: default;
}
/* ── 3D surface view ──────────────────────────────────────────────── */ /* ── 3D surface view ──────────────────────────────────────────────── */
.surface-view-container { .surface-view-container {
width: 100%; width: 100%;

View File

@@ -0,0 +1,52 @@
function mergeDefinition(nodeData, defs) {
const savedData = nodeData || {};
const savedDefinition = savedData.definition && typeof savedData.definition === 'object'
? savedData.definition
: null;
const registryDefinition = savedData.className ? defs[savedData.className] : null;
const definition = registryDefinition || savedDefinition;
if (!definition) return null;
const output = Array.isArray(savedData.output)
? savedData.output
: (Array.isArray(savedDefinition?.output) ? savedDefinition.output : null);
const outputName = Array.isArray(savedData.output_name)
? savedData.output_name
: (Array.isArray(savedDefinition?.output_name) ? savedDefinition.output_name : null);
return {
...definition,
...(output ? { output } : {}),
...(outputName ? { output_name: outputName } : {}),
};
}
export function hydrateWorkflowState(data, defs = {}) {
const loadedNodes = Array.isArray(data?.nodes) ? data.nodes : [];
const loadedEdges = Array.isArray(data?.edges) ? data.edges : [];
const nodes = loadedNodes.map((node) => ({
...node,
type: node.type || 'custom',
dragHandle: node.dragHandle || '.drag-handle',
data: {
...node.data,
label: node.data?.label || node.data?.className || 'Node',
widgetValues: node.data?.widgetValues || {},
definition: mergeDefinition(node.data, defs),
previewImage: null,
tableRows: null,
meshData: null,
overlay: null,
},
}));
const nextNodeId = Math.max(0, ...loadedNodes.map((node) => parseInt(node.id, 10) || 0)) + 1;
return {
nodes,
edges: loadedEdges,
nextNodeId,
};
}

View File

@@ -10,6 +10,8 @@ export function serializeWorkflowState(nodes, edges) {
label: node.data?.label || node.data?.className || 'Node', label: node.data?.label || node.data?.className || 'Node',
className: node.data?.className || '', className: node.data?.className || '',
widgetValues: node.data?.widgetValues || {}, widgetValues: node.data?.widgetValues || {},
output: node.data?.definition?.output || [],
output_name: node.data?.definition?.output_name || [],
}, },
})), })),
edges: edges.map((edge) => ({ edges: edges.map((edge) => ({
@@ -18,7 +20,7 @@ export function serializeWorkflowState(nodes, edges) {
sourceHandle: edge.sourceHandle, sourceHandle: edge.sourceHandle,
target: edge.target, target: edge.target,
targetHandle: edge.targetHandle, targetHandle: edge.targetHandle,
style: edge.style, ...(edge.style ? { style: edge.style } : {}),
})), })),
}; };
} }

View File

@@ -1,6 +1,7 @@
import test from 'node:test'; import test from 'node:test';
import assert from 'node:assert/strict'; import assert from 'node:assert/strict';
import { hydrateWorkflowState } from '../src/workflowHydration.js';
import { serializeWorkflowState } from '../src/workflowSerialization.js'; import { serializeWorkflowState } from '../src/workflowSerialization.js';
test('serializeWorkflowState keeps only stable workflow fields needed for reload', () => { test('serializeWorkflowState keeps only stable workflow fields needed for reload', () => {
@@ -59,6 +60,8 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
label: 'Demo Label', label: 'Demo Label',
className: 'DemoNode', className: 'DemoNode',
widgetValues: { threshold: 0.42, mode: 'fast' }, widgetValues: { threshold: 0.42, mode: 'fast' },
output: [],
output_name: [],
}, },
}, },
{ {
@@ -70,6 +73,8 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
label: 'NoLabelNode', label: 'NoLabelNode',
className: 'NoLabelNode', className: 'NoLabelNode',
widgetValues: {}, widgetValues: {},
output: [],
output_name: [],
}, },
}, },
], ],
@@ -89,3 +94,99 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
assert.equal('previewImage' in serialized.nodes[0].data, false); assert.equal('previewImage' in serialized.nodes[0].data, false);
assert.equal('selected' in serialized.edges[0], false); assert.equal('selected' in serialized.edges[0], false);
}); });
test('hydrateWorkflowState restores saved dynamic outputs on top of current node definitions', () => {
const saved = {
version: 1,
nodes: [
{
id: '12',
position: { x: 40, y: 80 },
data: {
className: 'LoadFile',
widgetValues: { filename: 'scan.ibw', colormap: 'viridis' },
output: ['DATA_FIELD', 'DATA_FIELD'],
output_name: ['Height', 'Phase'],
previewImage: 'stale',
},
},
],
edges: [
{
id: 'e12-3',
source: '12',
sourceHandle: 'output::1::DATA_FIELD',
target: '3',
targetHandle: 'input::field::DATA_FIELD',
},
],
};
const defs = {
LoadFile: {
category: 'io',
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['viridis', 'gray'], {}] } },
output: ['DATA_FIELD'],
output_name: ['field'],
manual_trigger: false,
},
};
const hydrated = hydrateWorkflowState(saved, defs);
assert.equal(hydrated.nextNodeId, 13);
assert.deepEqual(hydrated.edges, saved.edges);
assert.equal(hydrated.nodes[0].type, 'custom');
assert.equal(hydrated.nodes[0].dragHandle, '.drag-handle');
assert.equal(hydrated.nodes[0].data.label, 'LoadFile');
assert.equal(hydrated.nodes[0].data.previewImage, null);
assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD']);
assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Height', 'Phase']);
assert.deepEqual(hydrated.nodes[0].data.definition.input, defs.LoadFile.input);
});
test('serializeWorkflowState and hydrateWorkflowState preserve reload-critical metadata for dynamic nodes', () => {
const nodes = [
{
id: '7',
position: { x: 10, y: 20 },
data: {
label: 'Load File',
className: 'LoadFile',
widgetValues: { filename: 'scan.gwy', colormap: 'gray' },
definition: {
category: 'io',
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } },
output: ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD'],
output_name: ['Topography', 'Error', 'Mask'],
},
previewImage: 'data:image/png;base64,stale',
},
},
];
const edges = [
{
id: 'e7-9',
source: '7',
sourceHandle: 'output::2::DATA_FIELD',
target: '9',
targetHandle: 'input::field::DATA_FIELD',
},
];
const defs = {
LoadFile: {
category: 'io',
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } },
output: ['DATA_FIELD'],
output_name: ['field'],
},
};
const serialized = serializeWorkflowState(nodes, edges);
const hydrated = hydrateWorkflowState(serialized, defs);
assert.deepEqual(hydrated.nodes[0].data.widgetValues, nodes[0].data.widgetValues);
assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD']);
assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Topography', 'Error', 'Mask']);
assert.deepEqual(hydrated.edges, edges);
});

View File

@@ -564,31 +564,57 @@ def test_load_file():
def test_save_image(): def test_save_image():
print("=== Test: SaveImage ===") print("=== Test: SaveImage (Save Layers) ===")
from backend.nodes.io import SaveImage from backend.nodes.io import SaveImage
node = SaveImage() node = SaveImage()
field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
# Monkey-patch OUTPUT_DIR for testing # Save single layer as TIFF
from pathlib import Path tiff_path = os.path.join(tmpdir, "out.tiff")
import backend.nodes.io as io_mod node.save(filename=tiff_path, format="TIFF", field_0=field_a)
orig_dir = io_mod.OUTPUT_DIR assert os.path.exists(tiff_path), "TIFF file not created"
io_mod.OUTPUT_DIR = Path(tmpdir) from PIL import Image
im = Image.open(tiff_path)
assert im.n_frames == 1
arr_back = np.array(im)
assert arr_back.shape == (32, 32)
# Save multi-layer as TIFF
tiff_path2 = os.path.join(tmpdir, "multi.tiff")
node.save(filename=tiff_path2, format="TIFF", field_0=field_a, field_1=field_b)
im2 = Image.open(tiff_path2)
assert im2.n_frames == 2
# Save as NPZ
npz_path = os.path.join(tmpdir, "out.npz")
node.save(filename=npz_path, format="NPZ", field_0=field_a, field_1=field_b)
assert os.path.exists(npz_path)
npz = np.load(npz_path)
assert len(npz.files) == 2
assert np.allclose(npz["layer_0"], field_a.data)
assert np.allclose(npz["layer_1"], field_b.data)
# Extension is forced to match format
wrong_ext = os.path.join(tmpdir, "output.png")
node.save(filename=wrong_ext, format="TIFF", field_0=field_a)
assert os.path.exists(os.path.join(tmpdir, "output.tiff"))
# No fields connected → error
try: try:
arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8) node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF")
assert False, "Should have raised ValueError"
except ValueError:
pass
# Save as PNG # No filename → error
node.save(image=arr, filename_prefix="test", format="PNG") try:
saved = os.listdir(tmpdir) node.save(filename="", format="TIFF", field_0=field_a)
assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}" assert False, "Should have raised ValueError"
except ValueError:
# Save as NPY pass
node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY")
saved = os.listdir(tmpdir)
assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}"
finally:
io_mod.OUTPUT_DIR = orig_dir
print(" PASS\n") print(" PASS\n")
@@ -896,8 +922,11 @@ def test_line_cursors():
# Overlay should have been broadcast # Overlay should have been broadcast
assert len(overlays) == 1 assert len(overlays) == 1
assert "image" in overlays[0] assert overlays[0]["kind"] == "line_plot"
assert overlays[0]["image"].startswith("data:image/png;base64,") assert len(overlays[0]["line"]) == len(line)
assert len(overlays[0]["x_axis"]) == len(line)
assert 0.0 <= overlays[0]["x1"] <= 1.0
assert 0.0 <= overlays[0]["x2"] <= 1.0
# With x_axis provided # With x_axis provided
x_axis = np.linspace(0, 1, 100).astype(np.float64) x_axis = np.linspace(0, 1, 100).astype(np.float64)

View File

@@ -0,0 +1,20 @@
from pathlib import Path
import pytest
from backend.server import PNG_SIGNATURE, save_png_bytes
def test_save_png_bytes_writes_exact_png_payload(tmp_path: Path):
target = tmp_path / "workflow"
payload = PNG_SIGNATURE + b"argonode-test-payload"
saved_path = save_png_bytes(str(target), payload)
assert saved_path == tmp_path / "workflow.png"
assert saved_path.read_bytes() == payload
def test_save_png_bytes_rejects_invalid_payload(tmp_path: Path):
with pytest.raises(ValueError, match="valid PNG"):
save_png_bytes(str(tmp_path / "workflow.png"), b"not-a-png")