diff --git a/backend/execution.py b/backend/execution.py index 4172dfb..740f4c6 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -194,18 +194,17 @@ class ExecutionEngine: CrossSection._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay LoadFile._broadcast_warning_fn = on_warning - SaveImage._broadcast_preview = ( - (lambda data_uri: on_preview("save", data_uri)) if on_preview else None - ) + SaveImage._broadcast_warning_fn = on_warning def _set_node_id_on_display(self, cls: type, node_id: str) -> None: """Inform display nodes of their current node_id for WS tagging.""" from backend.nodes.display import PreviewImage, PrintTable, View3D from backend.nodes.analysis import CrossSection, LineCursors from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine - from backend.nodes.io import LoadFile + from backend.nodes.io import LoadFile, SaveImage if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors, - ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile): + ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, + LoadFile, SaveImage): cls._current_node_id = node_id def _auto_preview( @@ -262,11 +261,9 @@ class ExecutionEngine: cls: type, slot: int, result: tuple, - ) -> str | None: - """Render a LINE output as a small matplotlib plot, returned as a data URI.""" + ) -> dict | None: + """Return structured LINE preview data for responsive frontend rendering.""" import numpy as np - import base64 - import io as _io return_types = getattr(cls, "RETURN_TYPES", ()) @@ -281,17 +278,22 @@ class ExecutionEngine: return None # the first LINE already plotted both try: + import base64 + import io as _io import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt + y = np.asarray(y, dtype=np.float64).ravel() + if x is None: + x = np.arange(len(y), dtype=np.float64) + else: + x = np.asarray(x, dtype=np.float64).ravel()[:len(y)] + fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100) fig.patch.set_facecolor("#1e293b") ax.set_facecolor("#0f172a") - if x is not None: - ax.plot(x, y, color="#ff9800", linewidth=1.2) - else: - ax.plot(y, color="#ff9800", linewidth=1.2) + ax.plot(x, y, color="#ff9800", linewidth=1.2) ax.tick_params(colors="#94a3b8", labelsize=7) for spine in ax.spines.values(): spine.set_color("#334155") @@ -301,8 +303,15 @@ class ExecutionEngine: buf = _io.BytesIO() fig.savefig(buf, format="png", facecolor=fig.get_facecolor()) plt.close(fig) - b64 = base64.b64encode(buf.getvalue()).decode() - return f"data:image/png;base64,{b64}" + fallback_image = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + + return { + "kind": "line_plot", + "line": y.tolist(), + "x_axis": x.tolist(), + "interactive": False, + "fallback_image": fallback_image, + } except Exception: return None diff --git a/backend/node_registry.py b/backend/node_registry.py index 3510544..594d2e5 100644 --- a/backend/node_registry.py +++ b/backend/node_registry.py @@ -47,6 +47,7 @@ def get_node_info(class_name: str) -> dict[str, Any]: "output": list(cls.RETURN_TYPES), "output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)), "output_node": bool(getattr(cls, "OUTPUT_NODE", False)), + "manual_trigger": bool(getattr(cls, "MANUAL_TRIGGER", False)), "description": getattr(cls, "DESCRIPTION", ""), } diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index fc9e6a1..7d9ce13 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -11,7 +11,7 @@ Gwyddion equivalents: from __future__ import annotations import numpy as np from backend.node_registry import register_node -from backend.data_types import DataField +from backend.data_types import DataField, datafield_to_uint8, encode_preview # --------------------------------------------------------------------------- @@ -131,88 +131,49 @@ class LineCursors: self, line, x1: float, y1: float, x2: float, y2: float, x_axis=None, ) -> tuple: - import io as _io - import base64 - import matplotlib - matplotlib.use("Agg") - import matplotlib.pyplot as plt - y = np.asarray(line, dtype=np.float64).ravel() n = len(y) if x_axis is not None: x = np.asarray(x_axis, dtype=np.float64).ravel()[:n] else: x = np.arange(n, dtype=np.float64) + x1 = float(np.clip(x1, 0.0, 1.0)) + x2 = float(np.clip(x2, 0.0, 1.0)) - # --- Render the base plot first to determine axes bounds --- - fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100) - fig.patch.set_facecolor("#1e293b") - ax.set_facecolor("#0f172a") - ax.plot(x, y, color="#ff9800", linewidth=1.2) - ax.tick_params(colors="#94a3b8", labelsize=7) - for spine in ax.spines.values(): - spine.set_color("#334155") - ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5) - fig.tight_layout(pad=0.4) + xmin = float(np.min(x)) if len(x) else 0.0 + xmax = float(np.max(x)) if len(x) else 1.0 - # Force a draw so transforms are valid - fig.canvas.draw() + def x_frac_to_idx(frac): + if n <= 1: + return 0 + if xmax == xmin: + return 0 + target_x = xmin + frac * (xmax - xmin) + return int(np.argmin(np.abs(x - target_x))) - # Get axes position in figure-fraction coordinates - ax_pos = ax.get_position() - ax_l, ax_b = ax_pos.x0, ax_pos.y0 - ax_w, ax_h = ax_pos.width, ax_pos.height - - # x1/y1 arrive as image-fraction from the frontend drag. - # Convert image-fraction x → axes-fraction → nearest data index. - def img_x_to_idx(ix): - axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1) - return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1)) - - idx_a = img_x_to_idx(x1) - idx_b = img_x_to_idx(x2) + idx_a = x_frac_to_idx(x1) + idx_b = x_frac_to_idx(x2) xa, ya = float(x[idx_a]), float(y[idx_a]) xb, yb = float(x[idx_b]), float(y[idx_b]) - # --- Draw cursor lines and markers on the plot --- - ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9) - ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9) - ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5) - ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5) - ax.annotate( - "", xy=(xb, yb), xytext=(xa, ya), - arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5), - ) - # --- Broadcast overlay --- if LineCursors._broadcast_overlay_fn is not None: - # Convert data-space positions back to image-fraction for markers - fig.canvas.draw() - inv = fig.transFigure.inverted() - fig_a = inv.transform(ax.transData.transform([xa, ya])) - fig_b = inv.transform(ax.transData.transform([xb, yb])) - - buf = _io.BytesIO() - fig.savefig(buf, format="png", facecolor=fig.get_facecolor()) - buf.seek(0) - image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode() - LineCursors._broadcast_overlay_fn( LineCursors._current_node_id, { - "image": image_uri, - "x1": float(fig_a[0]), - "y1": float(1.0 - fig_a[1]), # flip: image y=0 is top - "x2": float(fig_b[0]), - "y2": float(1.0 - fig_b[1]), + "kind": "line_plot", + "line": y.tolist(), + "x_axis": x.tolist(), + "x1": x1, + "x2": x2, + "y1": float(y1), + "y2": float(y2), "a_locked": False, "b_locked": False, }, ) - plt.close(fig) - # --- Output table --- table = [ {"quantity": "A position", "value": xa, "unit": ""}, @@ -414,8 +375,6 @@ class CrossSection: point_a=None, point_b=None, ) -> tuple: from scipy.ndimage import map_coordinates - import io, base64 - from matplotlib.figure import Figure # COORD inputs override widget values if point_a is not None: @@ -453,14 +412,9 @@ class CrossSection: # Broadcast overlay image with marker positions if CrossSection._broadcast_overlay_fn is not None: - fig = Figure(figsize=(3, 3), dpi=100) - ax = fig.add_axes([0, 0, 1, 1]) - ax.imshow(field.data, cmap="viridis", aspect="auto") - ax.axis("off") - buf = io.BytesIO() - fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) - buf.seek(0) - image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode() + # Use the field's native pixel grid for the overlay preview so enlarging + # the panel keeps the image as sharp as the source data allows. + image_uri = encode_preview(datafield_to_uint8(field, field.colormap)) CrossSection._broadcast_overlay_fn( CrossSection._current_node_id, diff --git a/backend/nodes/display.py b/backend/nodes/display.py index c6a1a0f..59340ff 100644 --- a/backend/nodes/display.py +++ b/backend/nodes/display.py @@ -78,7 +78,7 @@ class View3D: "required": { "field": ("DATA_FIELD",), "colormap": (["auto"] + list(COLORMAPS),), - "z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}), + "z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}), "resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}), } } @@ -134,7 +134,7 @@ class View3D: "colors": colors_b64, "z_min": zmin, "z_max": zmax, - "z_scale": float(z_scale), + "z_scale": float(z_scale * 0.1), "x_range": [float(field.xoff), float(field.xoff + field.xreal)], "y_range": [float(field.yoff), float(field.yoff + field.yreal)], } diff --git a/backend/nodes/io.py b/backend/nodes/io.py index 3d432c6..3c9de0d 100644 --- a/backend/nodes/io.py +++ b/backend/nodes/io.py @@ -399,54 +399,86 @@ class Coordinate: # SaveImage # --------------------------------------------------------------------------- -@register_node(display_name="Save Image") +_MAX_SAVE_FIELDS = 8 + +@register_node(display_name="Save Layers") class SaveImage: @classmethod def INPUT_TYPES(cls): + optional = {} + for i in range(_MAX_SAVE_FIELDS): + optional[f"field_{i}"] = ("DATA_FIELD",) return { "required": { - "image": ("IMAGE",), - "filename_prefix": ("STRING", {"default": "output"}), - "format": (["PNG", "TIFF", "NPY"],), - } + "filename": ("FILE_PICKER", {"default": ""}), + "format": (["TIFF", "NPZ"],), + }, + "optional": optional, } RETURN_TYPES = () FUNCTION = "save" CATEGORY = "io" OUTPUT_NODE = True - DESCRIPTION = "Save an image or array to the output folder." + MANUAL_TRIGGER = True + DESCRIPTION = ( + "Save one or more DATA_FIELD layers to a single file. " + "Connect fields to the inputs — a new slot appears as each is filled. " + "TIFF writes float32 multi-page; NPZ writes float64 named arrays. " + "Click Save to write (does not auto-run)." + ) - # Injected by server.py before execution begins - _broadcast_preview = None + _broadcast_warning_fn = None + _current_node_id = None - def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"): - OUTPUT_DIR.mkdir(exist_ok=True) + def save(self, filename: str, format: str = "TIFF", **kwargs): + # Collect connected fields in order + fields = [] + for i in range(_MAX_SAVE_FIELDS): + f = kwargs.get(f"field_{i}") + if f is not None: + fields.append(f) - # Find next available filename - idx = 1 - while True: - name = f"{filename_prefix}_{idx:04d}" - candidate = OUTPUT_DIR / f"{name}.{format.lower()}" - if not candidate.exists(): - break - idx += 1 + if not fields: + raise ValueError("No fields connected — connect at least one DATA_FIELD input.") - if format == "NPY": - np.save(str(OUTPUT_DIR / f"{name}.npy"), image) + if not filename or not filename.strip(): + raise ValueError("No output path selected — use Browse to pick a location.") + + path = Path(filename) + # Ensure parent directory exists + path.parent.mkdir(parents=True, exist_ok=True) + + # Force correct extension + ext = ".tiff" if format == "TIFF" else ".npz" + if path.suffix.lower() != ext: + path = path.with_suffix(ext) + + if format == "TIFF": + self._save_tiff(path, fields) else: - from PIL import Image - arr = image_to_uint8(image) - if arr.ndim == 2: - pil_img = Image.fromarray(arr, mode="L") - else: - pil_img = Image.fromarray(arr, mode="RGB") - pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}")) + self._save_npz(path, fields) - # Emit preview over WebSocket if callback is set - if SaveImage._broadcast_preview is not None: - arr_u8 = image_to_uint8(image) - data_uri = encode_preview(arr_u8) - SaveImage._broadcast_preview(data_uri) + self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}") + return () + + def _save_tiff(self, path: Path, fields: list[DataField]): + from PIL import Image + images = [] + for f in fields: + images.append(Image.fromarray(f.data.astype(np.float32))) + images[0].save(str(path), save_all=True, append_images=images[1:]) + + def _save_npz(self, path: Path, fields: list[DataField]): + arrays = {} + for i, f in enumerate(fields): + arrays[f"layer_{i}"] = f.data + np.savez(str(path), **arrays) + + def _send_warning(self, message: str): + fn = SaveImage._broadcast_warning_fn + nid = SaveImage._current_node_id + if fn and nid: + fn(nid, message) return () diff --git a/backend/server.py b/backend/server.py index 7fb50d4..cc453f4 100644 --- a/backend/server.py +++ b/backend/server.py @@ -41,6 +41,7 @@ FRONTEND_DIR = frontend_dir() DIST_DIR = frontend_dist_dir() INPUT_DIR = input_dir() OUTPUT_DIR = output_dir() +PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" # --------------------------------------------------------------------------- @@ -63,6 +64,18 @@ def _dumps(obj) -> str: return json.dumps(obj, cls=_SafeEncoder) +def save_png_bytes(target_path: str, payload: bytes) -> Path: + path = Path(target_path).expanduser() + if not target_path.strip(): + raise ValueError("Missing save path") + if path.suffix.lower() != ".png": + path = path.with_suffix(".png") + if not payload.startswith(PNG_SIGNATURE): + raise ValueError("Payload is not a valid PNG") + path.write_bytes(payload) + return path + + # --------------------------------------------------------------------------- # Application factory # --------------------------------------------------------------------------- @@ -196,6 +209,20 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: }, ) + async def save_workflow_png(request: web.Request) -> web.Response: + body = await request.read() + target_path = request.query.get("path", "") + if not target_path: + raise web.HTTPBadRequest(reason="Missing path") + try: + saved_path = save_png_bytes(target_path, body) + except ValueError as exc: + raise web.HTTPBadRequest(reason=str(exc)) from exc + return web.Response( + text=_dumps({"path": str(saved_path)}), + content_type="application/json", + ) + async def get_channels(request: web.Request) -> web.Response: """Return available channels for a given file path.""" from backend.nodes.io import list_channels @@ -278,6 +305,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application: app.router.add_get("/browse", browse_dir) app.router.add_post("/upload", upload_file) app.router.add_post("/download", download_file) + app.router.add_post("/save-workflow-png", save_workflow_png) app.router.add_get("/channels", get_channels) app.router.add_post("/prompt", submit_prompt) app.router.add_get("/ws", websocket_handler) diff --git a/desktop.py b/desktop.py index aca7983..e77c683 100644 --- a/desktop.py +++ b/desktop.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import base64 import logging import socket import threading @@ -45,8 +44,8 @@ class _Api: return result[0] return None - def save_workflow_png(self, data_url: str, default_filename: str = "workflow.png") -> str | None: - """Open a native save dialog, write the PNG bytes, and return the saved path.""" + def choose_save_workflow_png_path(self, default_filename: str = "workflow.png") -> str | None: + """Open a native save dialog and return the chosen PNG path (or None).""" win = self._window_ref[0] if win is None: return None @@ -65,12 +64,6 @@ class _Api: path = Path(result[0] if isinstance(result, (list, tuple)) else result).expanduser() if path.suffix.lower() != ".png": path = path.with_suffix(".png") - - _, _, encoded = data_url.partition(",") - if not encoded: - raise ValueError("Invalid data URL payload") - - path.write_bytes(base64.b64decode(encoded)) return str(path) diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index cbf6075..26c5cff 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -13,6 +13,7 @@ import FileBrowser from './FileBrowser'; import * as api from './api'; import { toBlob } from 'html-to-image'; import { embedWorkflow, extractWorkflow } from './pngMetadata'; +import { hydrateWorkflowState } from './workflowHydration'; import { serializeWorkflowState } from './workflowSerialization'; // ── Constants ───────────────────────────────────────────────────────── @@ -43,15 +44,6 @@ function getOutputSlot(handleId) { return parseInt(handleId.split('::')[1], 10); } -function blobToDataUrl(blob) { - return new Promise((resolve, reject) => { - const reader = new FileReader(); - reader.onloadend = () => resolve(reader.result); - reader.onerror = () => reject(reader.error || new Error('Failed to read file')); - reader.readAsDataURL(blob); - }); -} - async function waitForImageElement(img) { if (img.complete && img.naturalWidth > 0) return; if (typeof img.decode === 'function') { @@ -73,6 +65,31 @@ async function waitForImageElement(img) { }); } +async function getCaptureImageDataUrl(img) { + const src = img.currentSrc || img.src; + if (!src) return null; + if (!src.startsWith('data:')) return src; + + const rect = img.getBoundingClientRect(); + const width = Math.max(1, Math.round(img.clientWidth || rect.width)); + const height = Math.max(1, Math.round(img.clientHeight || rect.height)); + const scale = Math.min(2, window.devicePixelRatio || 1); + + const canvas = document.createElement('canvas'); + canvas.width = Math.max(1, Math.round(width * scale)); + canvas.height = Math.max(1, Math.round(height * scale)); + + const ctx = canvas.getContext('2d'); + if (!ctx) return src; + + try { + ctx.drawImage(img, 0, 0, canvas.width, canvas.height); + return canvas.toDataURL('image/png'); + } catch { + return src; + } +} + function createCapturePlaceholder(el, dataUrl) { const rect = el.getBoundingClientRect(); const style = window.getComputedStyle(el); @@ -101,8 +118,9 @@ async function captureViewportBlob(viewportEl, options) { await Promise.all(images.map(waitForImageElement)); for (const img of images) { - const dataUrl = img.currentSrc || img.src; - if (!dataUrl || !img.parentNode) continue; + if (!img.parentNode) continue; + const dataUrl = await getCaptureImageDataUrl(img); + if (!dataUrl) continue; const placeholder = createCapturePlaceholder(img, dataUrl); img.parentNode.replaceChild(placeholder, img); restorers.push(() => { @@ -144,12 +162,13 @@ async function captureViewportBlob(viewportEl, options) { // ── Graph serialisation → backend prompt format ─────────────────────── -function serializeGraph(nodes, edges) { +function serializeGraph(nodes, edges, { excludeManualTrigger = false } = {}) { const prompt = {}; for (const node of nodes) { const { className, definition, widgetValues } = node.data; if (!definition) continue; + if (excludeManualTrigger && definition.manual_trigger) continue; const inputs = {}; @@ -551,10 +570,23 @@ function Flow() { // ── Node context value (stable) ───────────────────────────────────── + const onManualTrigger = useCallback((nodeId) => { + const currentNodes = reactFlow.getNodes(); + const currentEdges = reactFlow.getEdges(); + // Include ALL nodes (no excludeManualTrigger) so the save node is in the prompt + const prompt = serializeGraph(currentNodes, currentEdges); + if (!prompt || Object.keys(prompt).length === 0) return; + setStatus({ text: 'Saving…', level: 'info' }); + api.runPrompt(prompt).catch((err) => { + setStatus({ text: 'Save failed: ' + err.message, level: 'error' }); + }); + }, [reactFlow]); + const contextValue = useMemo(() => ({ onWidgetChange, openFileBrowser, - }), [onWidgetChange, openFileBrowser]); + onManualTrigger, + }), [onWidgetChange, openFileBrowser, onManualTrigger]); // ── Add node from context menu ────────────────────────────────────── @@ -687,13 +719,18 @@ function Flow() { const currentNodes = reactFlow.getNodes(); const currentEdges = reactFlow.getEdges(); - // Don't run if any node has unconnected required data inputs + // Don't run if any non-manual node has unconnected required data inputs + // or any FILE_PICKER widget is empty for (const node of currentNodes) { const def = node.data?.definition; - if (!def) continue; + if (!def || def.manual_trigger) continue; // skip manual-trigger nodes const required = def.input.required || {}; for (const [name, spec] of Object.entries(required)) { const [type] = Array.isArray(spec) ? spec : [spec]; + if (type === 'FILE_PICKER') { + if (!node.data.widgetValues?.[name]) return; // no file selected, skip + continue; + } if (!DATA_TYPES.has(type)) continue; const hasEdge = currentEdges.some( (e) => e.target === node.id && getInputName(e.targetHandle) === name @@ -702,7 +739,7 @@ function Flow() { } } - const prompt = serializeGraph(currentNodes, currentEdges); + const prompt = serializeGraph(currentNodes, currentEdges, { excludeManualTrigger: true }); if (!prompt || Object.keys(prompt).length === 0) return; setStatus({ text: 'Running…', level: 'info' }); api.runPrompt(prompt).catch((err) => { @@ -723,25 +760,10 @@ function Flow() { }, [setNodes, setEdges]); const applyWorkflowData = useCallback((data) => { - const loadedNodes = data.nodes || []; - const loadedEdges = data.edges || []; - const defs = nodeDefsRef.current; - const hydrated = loadedNodes.map((n) => ({ - ...n, - type: n.type || 'custom', - dragHandle: n.dragHandle || '.drag-handle', - data: { - ...n.data, - label: n.data?.label || n.data?.className || 'Node', - widgetValues: n.data?.widgetValues || {}, - definition: defs[n.data.className] || n.data.definition, - previewImage: null, tableRows: null, meshData: null, overlay: null, - }, - })); - setNodes(hydrated); - setEdges(loadedEdges); - const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0)); - nextIdRef.current = maxId + 1; + const hydrated = hydrateWorkflowState(data, nodeDefsRef.current); + setNodes(hydrated.nodes); + setEdges(hydrated.edges); + nextIdRef.current = hydrated.nextNodeId; }, [setNodes, setEdges]); const getWorkflowBlob = useCallback(async () => { @@ -778,9 +800,23 @@ function Flow() { try { const finalBlob = await getWorkflowBlob(); - if (window.pywebview?.api?.save_workflow_png) { - const dataUrl = await blobToDataUrl(finalBlob); - const savedPath = await window.pywebview.api.save_workflow_png(dataUrl, 'workflow.png'); + if (window.pywebview?.api?.choose_save_workflow_png_path) { + const requestedPath = await window.pywebview.api.choose_save_workflow_png_path('workflow.png'); + if (!requestedPath) { + setStatus({ text: 'Save cancelled.', level: 'info' }); + return; + } + const resp = await fetch(`/save-workflow-png?path=${encodeURIComponent(requestedPath)}`, { + method: 'POST', + headers: { + 'Content-Type': 'image/png', + }, + body: finalBlob, + }); + if (!resp.ok) { + throw new Error(await resp.text() || `Save failed (${resp.status})`); + } + const { path: savedPath } = await resp.json(); if (!savedPath) { setStatus({ text: 'Save cancelled.', level: 'info' }); return; diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 21ab15b..97ee57a 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -1,5 +1,6 @@ import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react'; -import { Handle, Position } from '@xyflow/react'; +import { Handle, Position, useStore } from '@xyflow/react'; +import LinePlotOverlay from './LinePlotOverlay'; const SurfaceView = lazy(() => import('./SurfaceView')); const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay')); @@ -29,6 +30,47 @@ const CAT_COLORS = { export const NodeContext = React.createContext(null); +class PreviewBoundary extends React.Component { + constructor(props) { + super(props); + this.state = { hasError: false }; + } + + static getDerivedStateFromError() { + return { hasError: true }; + } + + componentDidCatch(error) { + console.error('[argonode] preview render failed', error); + } + + componentDidUpdate(prevProps) { + if (prevProps.resetKey !== this.props.resetKey && this.state.hasError) { + this.setState({ hasError: false }); + } + } + + render() { + if (!this.state.hasError) { + return this.props.children; + } + + if (this.props.fallbackImage) { + return ( +
+ preview fallback +
+ ); + } + + return ( +
+ Preview unavailable. +
+ ); + } +} + // ── Draggable number input ──────────────────────────────────────────── function DraggableNumber({ value, step, min, max, precision, onChange }) { @@ -151,8 +193,39 @@ function CustomNode({ id, data }) { } } + // For manual-trigger nodes (Save), show progressive optional inputs: + // show field_N only if field_(N-1) is connected (or N==0). + const isProgressive = def.manual_trigger; + const connectedInputs = useStore( + useCallback( + (s) => { + if (!isProgressive) return null; + const set = new Set(); + for (const e of s.edges) { + if (e.target === id) { + const parts = e.targetHandle?.split('::'); + if (parts) set.add(parts[1]); + } + } + return set; + }, + [id, isProgressive], + ), + ); + for (const [name, spec] of Object.entries(optional)) { const [type] = Array.isArray(spec) ? spec : [spec]; + if (isProgressive && DATA_TYPES.has(type)) { + // Progressive: show this slot only if it's the first or the previous is connected + const match = name.match(/^field_(\d+)$/); + if (match) { + const idx = parseInt(match[1], 10); + if (idx === 0 || (connectedInputs && connectedInputs.has(`field_${idx - 1}`))) { + dataInputs.push({ name, type }); + } + continue; + } + } dataInputs.push({ name, type }); } @@ -229,6 +302,19 @@ function CustomNode({ id, data }) { ))} + {/* Manual trigger button (Save) */} + {def.manual_trigger && ( +
+ +
+ )} + {/* Interactive 3D surface view */} {data.meshData && ( @@ -241,9 +327,21 @@ function CustomNode({ id, data }) { {/* Collapsible preview image */} {data.previewImage && ( -
- preview -
+ + {typeof data.previewImage === 'string' ? ( +
+ preview +
+ ) : data.previewImage.kind === 'line_plot' ? ( + + ) : null} +
)} @@ -251,17 +349,29 @@ function CustomNode({ id, data }) { {data.overlay && hiddenWidgets.has('x1') && ( Loading...}> - + {data.overlay.kind === 'line_plot' ? ( + + ) : ( + + )} )} diff --git a/frontend/src/LinePlotOverlay.jsx b/frontend/src/LinePlotOverlay.jsx new file mode 100644 index 0000000..9e8d499 --- /dev/null +++ b/frontend/src/LinePlotOverlay.jsx @@ -0,0 +1,271 @@ +import React, { useEffect, useRef, useState, useCallback } from 'react'; + +const ASPECT_RATIO = 3.2 / 2.2; +const MARGINS = { top: 18, right: 16, bottom: 34, left: 56 }; + +function clamp(v, min, max) { + return Math.max(min, Math.min(max, v)); +} + +function round3(v) { + return parseFloat(v.toFixed(3)); +} + +function trimZeros(text) { + return text.replace(/(?:\.0+|(\.\d+?)0+)$/, '$1'); +} + +function formatTick(value) { + const abs = Math.abs(value); + if (abs === 0) return '0'; + if (abs >= 1e4 || abs < 1e-3) { + return value.toExponential(1).replace('e+', 'e'); + } + if (abs >= 100) return trimZeros(value.toFixed(0)); + if (abs >= 10) return trimZeros(value.toFixed(1)); + if (abs >= 1) return trimZeros(value.toFixed(2)); + return trimZeros(value.toFixed(3)); +} + +function makeTicks(min, max, count = 5) { + if (!Number.isFinite(min) || !Number.isFinite(max)) return []; + if (min === max) return [min]; + const ticks = []; + for (let i = 0; i < count; i += 1) { + ticks.push(min + ((max - min) * i) / (count - 1)); + } + return ticks; +} + +function getExtent(values, fallbackMin = 0, fallbackMax = 1) { + if (!Array.isArray(values) || values.length === 0) { + return [fallbackMin, fallbackMax]; + } + + let min = Infinity; + let max = -Infinity; + for (const value of values) { + if (!Number.isFinite(value)) continue; + if (value < min) min = value; + if (value > max) max = value; + } + + if (!Number.isFinite(min) || !Number.isFinite(max)) { + return [fallbackMin, fallbackMax]; + } + return [min, max]; +} + +export default function LinePlotOverlay({ + overlay, + x1, + x2, + aLocked, + bLocked, + nodeId, + onWidgetChange, + interactive = true, +}) { + const containerRef = useRef(null); + const [dragging, setDragging] = useState(null); + const [size, setSize] = useState({ width: 0, height: 0 }); + + useEffect(() => { + if (!containerRef.current) return undefined; + const updateSize = () => { + if (!containerRef.current) return; + setSize({ + width: Math.max(1, Math.round(containerRef.current.clientWidth || 320)), + height: Math.max(1, Math.round(containerRef.current.clientHeight || (containerRef.current.clientWidth / ASPECT_RATIO) || 220)), + }); + }; + + updateSize(); + + if (typeof ResizeObserver === 'function') { + const observer = new ResizeObserver((entries) => { + const entry = entries[0]; + if (!entry) return; + const { width, height } = entry.contentRect; + setSize({ + width: Math.max(1, Math.round(width)), + height: Math.max(1, Math.round(height)), + }); + }); + observer.observe(containerRef.current); + return () => observer.disconnect(); + } + + window.addEventListener('resize', updateSize); + return () => window.removeEventListener('resize', updateSize); + }, []); + + const xValues = Array.isArray(overlay?.x_axis) && overlay.x_axis.length === overlay.line?.length + ? overlay.x_axis + : overlay?.line?.map((_, i) => i) || []; + const yValues = Array.isArray(overlay?.line) ? overlay.line : []; + + const width = size.width || 320; + const height = size.height || Math.round(width / ASPECT_RATIO); + const plotLeft = MARGINS.left; + const plotTop = MARGINS.top; + const plotWidth = Math.max(1, width - MARGINS.left - MARGINS.right); + const plotHeight = Math.max(1, height - MARGINS.top - MARGINS.bottom); + + const [xMin, xMax] = getExtent(xValues, 0, 1); + const [yMinRaw, yMaxRaw] = getExtent(yValues, 0, 1); + const yPad = yMinRaw === yMaxRaw ? 1 : (yMaxRaw - yMinRaw) * 0.08; + const yMin = yMinRaw - yPad; + const yMax = yMaxRaw + yPad; + + const scaleX = useCallback((value) => { + if (xMax === xMin) return plotLeft + plotWidth / 2; + return plotLeft + ((value - xMin) / (xMax - xMin)) * plotWidth; + }, [plotLeft, plotWidth, xMin, xMax]); + + const scaleY = useCallback((value) => { + if (yMax === yMin) return plotTop + plotHeight / 2; + return plotTop + (1 - ((value - yMin) / (yMax - yMin))) * plotHeight; + }, [plotTop, plotHeight, yMin, yMax]); + + const pickCursorPoint = useCallback((fraction) => { + if (!xValues.length || !yValues.length) { + return { + x: plotLeft, + y: plotTop + plotHeight / 2, + yFraction: 0.5, + }; + } + + const frac = clamp(fraction ?? 0.5, 0, 1); + const targetX = xMin + frac * (xMax - xMin || 1); + + let idx = 0; + let best = Infinity; + for (let i = 0; i < xValues.length; i += 1) { + const dist = Math.abs(xValues[i] - targetX); + if (dist < best) { + best = dist; + idx = i; + } + } + + const x = xValues[idx]; + const y = yValues[idx]; + const yFraction = yMax === yMin ? 0.5 : clamp((y - yMin) / (yMax - yMin), 0, 1); + return { + x: scaleX(x), + y: scaleY(y), + yFraction, + }; + }, [plotLeft, plotTop, plotHeight, scaleX, scaleY, xValues, yValues, xMin, xMax, yMin, yMax]); + + const cursorA = pickCursorPoint(x1 ?? overlay?.x1 ?? 0.25); + const cursorB = pickCursorPoint(x2 ?? overlay?.x2 ?? 0.75); + + const path = yValues.map((y, i) => `${i === 0 ? 'M' : 'L'} ${scaleX(xValues[i])} ${scaleY(y)}`).join(' '); + const xTicks = makeTicks(xMin, xMax); + const yTicks = makeTicks(yMin, yMax); + const plotStroke = clamp(plotWidth / 240, 1.4, 2.6); + const gridStroke = clamp(plotWidth / 900, 0.6, 1.1); + const cursorStroke = clamp(plotWidth / 220, 1.4, 2.2); + const measureStroke = clamp(plotWidth / 180, 1.6, 2.8); + const markerRadius = clamp(plotWidth / 42, 5.5, 9); + + const updateCursor = useCallback((point, event) => { + if (!interactive || !onWidgetChange || !nodeId) return; + if (!containerRef.current) return; + const rect = containerRef.current.getBoundingClientRect(); + const xFrac = clamp((event.clientX - rect.left - plotLeft) / plotWidth, 0, 1); + const sample = pickCursorPoint(xFrac); + if (point === 'p1') { + onWidgetChange(nodeId, 'x1', round3(xFrac)); + onWidgetChange(nodeId, 'y1', round3(sample.yFraction)); + } else { + onWidgetChange(nodeId, 'x2', round3(xFrac)); + onWidgetChange(nodeId, 'y2', round3(sample.yFraction)); + } + }, [interactive, nodeId, onWidgetChange, pickCursorPoint, plotLeft, plotWidth]); + + const onPointerDown = useCallback((point) => (event) => { + if (!interactive) return; + if ((point === 'p1' && aLocked) || (point === 'p2' && bLocked)) return; + event.preventDefault(); + event.stopPropagation(); + event.currentTarget.setPointerCapture(event.pointerId); + setDragging(point); + }, [interactive, aLocked, bLocked]); + + const onPointerMove = useCallback((event) => { + if (!dragging) return; + updateCursor(dragging, event); + }, [dragging, updateCursor]); + + const onPointerUp = useCallback(() => { + setDragging(null); + }, []); + + return ( +
+ + + + {xTicks.map((tick) => { + const x = scaleX(tick); + return ( + + + + {formatTick(tick)} + + + ); + })} + + {yTicks.map((tick) => { + const y = scaleY(tick); + return ( + + + + {formatTick(tick)} + + + ); + })} + + + + + {interactive && ( + <> + + + + + + + + )} + +
+ ); +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index 41ee658..597f4d6 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -367,6 +367,41 @@ html, body, #root { opacity: 0.9; } +.lineplot-overlay { + width: 100%; + aspect-ratio: 32 / 22; + background: #0f172a; + border: 1px solid #334155; + border-radius: 6px; + overflow: hidden; + user-select: none; + touch-action: none; +} + +.lineplot-svg { + display: block; + width: 100%; + height: 100%; +} + +.lineplot-marker { + fill: #ffd700; + stroke: #fff; + stroke-width: 2px; + cursor: grab; + filter: drop-shadow(0 0 4px rgba(0, 0, 0, 0.45)); +} + +.lineplot-marker:active { + cursor: grabbing; +} + +.lineplot-marker-locked { + fill: #e91e63; + stroke: #e91e63; + cursor: default; +} + /* ── 3D surface view ──────────────────────────────────────────────── */ .surface-view-container { width: 100%; diff --git a/frontend/src/workflowHydration.js b/frontend/src/workflowHydration.js new file mode 100644 index 0000000..342686e --- /dev/null +++ b/frontend/src/workflowHydration.js @@ -0,0 +1,52 @@ +function mergeDefinition(nodeData, defs) { + const savedData = nodeData || {}; + const savedDefinition = savedData.definition && typeof savedData.definition === 'object' + ? savedData.definition + : null; + const registryDefinition = savedData.className ? defs[savedData.className] : null; + const definition = registryDefinition || savedDefinition; + + if (!definition) return null; + + const output = Array.isArray(savedData.output) + ? savedData.output + : (Array.isArray(savedDefinition?.output) ? savedDefinition.output : null); + const outputName = Array.isArray(savedData.output_name) + ? savedData.output_name + : (Array.isArray(savedDefinition?.output_name) ? savedDefinition.output_name : null); + + return { + ...definition, + ...(output ? { output } : {}), + ...(outputName ? { output_name: outputName } : {}), + }; +} + +export function hydrateWorkflowState(data, defs = {}) { + const loadedNodes = Array.isArray(data?.nodes) ? data.nodes : []; + const loadedEdges = Array.isArray(data?.edges) ? data.edges : []; + + const nodes = loadedNodes.map((node) => ({ + ...node, + type: node.type || 'custom', + dragHandle: node.dragHandle || '.drag-handle', + data: { + ...node.data, + label: node.data?.label || node.data?.className || 'Node', + widgetValues: node.data?.widgetValues || {}, + definition: mergeDefinition(node.data, defs), + previewImage: null, + tableRows: null, + meshData: null, + overlay: null, + }, + })); + + const nextNodeId = Math.max(0, ...loadedNodes.map((node) => parseInt(node.id, 10) || 0)) + 1; + + return { + nodes, + edges: loadedEdges, + nextNodeId, + }; +} diff --git a/frontend/src/workflowSerialization.js b/frontend/src/workflowSerialization.js index ebab63f..cdac583 100644 --- a/frontend/src/workflowSerialization.js +++ b/frontend/src/workflowSerialization.js @@ -10,6 +10,8 @@ export function serializeWorkflowState(nodes, edges) { label: node.data?.label || node.data?.className || 'Node', className: node.data?.className || '', widgetValues: node.data?.widgetValues || {}, + output: node.data?.definition?.output || [], + output_name: node.data?.definition?.output_name || [], }, })), edges: edges.map((edge) => ({ @@ -18,7 +20,7 @@ export function serializeWorkflowState(nodes, edges) { sourceHandle: edge.sourceHandle, target: edge.target, targetHandle: edge.targetHandle, - style: edge.style, + ...(edge.style ? { style: edge.style } : {}), })), }; } diff --git a/frontend/tests/workflowSerialization.test.mjs b/frontend/tests/workflowSerialization.test.mjs index 8c05dda..ca60ec7 100644 --- a/frontend/tests/workflowSerialization.test.mjs +++ b/frontend/tests/workflowSerialization.test.mjs @@ -1,6 +1,7 @@ import test from 'node:test'; import assert from 'node:assert/strict'; +import { hydrateWorkflowState } from '../src/workflowHydration.js'; import { serializeWorkflowState } from '../src/workflowSerialization.js'; test('serializeWorkflowState keeps only stable workflow fields needed for reload', () => { @@ -59,6 +60,8 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload label: 'Demo Label', className: 'DemoNode', widgetValues: { threshold: 0.42, mode: 'fast' }, + output: [], + output_name: [], }, }, { @@ -70,6 +73,8 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload label: 'NoLabelNode', className: 'NoLabelNode', widgetValues: {}, + output: [], + output_name: [], }, }, ], @@ -89,3 +94,99 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload assert.equal('previewImage' in serialized.nodes[0].data, false); assert.equal('selected' in serialized.edges[0], false); }); + +test('hydrateWorkflowState restores saved dynamic outputs on top of current node definitions', () => { + const saved = { + version: 1, + nodes: [ + { + id: '12', + position: { x: 40, y: 80 }, + data: { + className: 'LoadFile', + widgetValues: { filename: 'scan.ibw', colormap: 'viridis' }, + output: ['DATA_FIELD', 'DATA_FIELD'], + output_name: ['Height', 'Phase'], + previewImage: 'stale', + }, + }, + ], + edges: [ + { + id: 'e12-3', + source: '12', + sourceHandle: 'output::1::DATA_FIELD', + target: '3', + targetHandle: 'input::field::DATA_FIELD', + }, + ], + }; + + const defs = { + LoadFile: { + category: 'io', + input: { required: { filename: ['FILE_PICKER', {}], colormap: [['viridis', 'gray'], {}] } }, + output: ['DATA_FIELD'], + output_name: ['field'], + manual_trigger: false, + }, + }; + + const hydrated = hydrateWorkflowState(saved, defs); + + assert.equal(hydrated.nextNodeId, 13); + assert.deepEqual(hydrated.edges, saved.edges); + assert.equal(hydrated.nodes[0].type, 'custom'); + assert.equal(hydrated.nodes[0].dragHandle, '.drag-handle'); + assert.equal(hydrated.nodes[0].data.label, 'LoadFile'); + assert.equal(hydrated.nodes[0].data.previewImage, null); + assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD']); + assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Height', 'Phase']); + assert.deepEqual(hydrated.nodes[0].data.definition.input, defs.LoadFile.input); +}); + +test('serializeWorkflowState and hydrateWorkflowState preserve reload-critical metadata for dynamic nodes', () => { + const nodes = [ + { + id: '7', + position: { x: 10, y: 20 }, + data: { + label: 'Load File', + className: 'LoadFile', + widgetValues: { filename: 'scan.gwy', colormap: 'gray' }, + definition: { + category: 'io', + input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } }, + output: ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD'], + output_name: ['Topography', 'Error', 'Mask'], + }, + previewImage: 'data:image/png;base64,stale', + }, + }, + ]; + const edges = [ + { + id: 'e7-9', + source: '7', + sourceHandle: 'output::2::DATA_FIELD', + target: '9', + targetHandle: 'input::field::DATA_FIELD', + }, + ]; + const defs = { + LoadFile: { + category: 'io', + input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } }, + output: ['DATA_FIELD'], + output_name: ['field'], + }, + }; + + const serialized = serializeWorkflowState(nodes, edges); + const hydrated = hydrateWorkflowState(serialized, defs); + + assert.deepEqual(hydrated.nodes[0].data.widgetValues, nodes[0].data.widgetValues); + assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD']); + assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Topography', 'Error', 'Mask']); + assert.deepEqual(hydrated.edges, edges); +}); diff --git a/tests/test_nodes.py b/tests/test_nodes.py index f81efa0..e22f062 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -564,31 +564,57 @@ def test_load_file(): def test_save_image(): - print("=== Test: SaveImage ===") + print("=== Test: SaveImage (Save Layers) ===") from backend.nodes.io import SaveImage node = SaveImage() + field_a = make_field(data=np.random.default_rng(4).random((32, 32))) + field_b = make_field(data=np.random.default_rng(5).random((32, 32))) + with tempfile.TemporaryDirectory() as tmpdir: - # Monkey-patch OUTPUT_DIR for testing - from pathlib import Path - import backend.nodes.io as io_mod - orig_dir = io_mod.OUTPUT_DIR - io_mod.OUTPUT_DIR = Path(tmpdir) + # Save single layer as TIFF + tiff_path = os.path.join(tmpdir, "out.tiff") + node.save(filename=tiff_path, format="TIFF", field_0=field_a) + assert os.path.exists(tiff_path), "TIFF file not created" + from PIL import Image + im = Image.open(tiff_path) + assert im.n_frames == 1 + arr_back = np.array(im) + assert arr_back.shape == (32, 32) + # Save multi-layer as TIFF + tiff_path2 = os.path.join(tmpdir, "multi.tiff") + node.save(filename=tiff_path2, format="TIFF", field_0=field_a, field_1=field_b) + im2 = Image.open(tiff_path2) + assert im2.n_frames == 2 + + # Save as NPZ + npz_path = os.path.join(tmpdir, "out.npz") + node.save(filename=npz_path, format="NPZ", field_0=field_a, field_1=field_b) + assert os.path.exists(npz_path) + npz = np.load(npz_path) + assert len(npz.files) == 2 + assert np.allclose(npz["layer_0"], field_a.data) + assert np.allclose(npz["layer_1"], field_b.data) + + # Extension is forced to match format + wrong_ext = os.path.join(tmpdir, "output.png") + node.save(filename=wrong_ext, format="TIFF", field_0=field_a) + assert os.path.exists(os.path.join(tmpdir, "output.tiff")) + + # No fields connected → error try: - arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8) + node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF") + assert False, "Should have raised ValueError" + except ValueError: + pass - # Save as PNG - node.save(image=arr, filename_prefix="test", format="PNG") - saved = os.listdir(tmpdir) - assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}" - - # Save as NPY - node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY") - saved = os.listdir(tmpdir) - assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}" - finally: - io_mod.OUTPUT_DIR = orig_dir + # No filename → error + try: + node.save(filename="", format="TIFF", field_0=field_a) + assert False, "Should have raised ValueError" + except ValueError: + pass print(" PASS\n") @@ -896,8 +922,11 @@ def test_line_cursors(): # Overlay should have been broadcast assert len(overlays) == 1 - assert "image" in overlays[0] - assert overlays[0]["image"].startswith("data:image/png;base64,") + assert overlays[0]["kind"] == "line_plot" + assert len(overlays[0]["line"]) == len(line) + assert len(overlays[0]["x_axis"]) == len(line) + assert 0.0 <= overlays[0]["x1"] <= 1.0 + assert 0.0 <= overlays[0]["x2"] <= 1.0 # With x_axis provided x_axis = np.linspace(0, 1, 100).astype(np.float64) diff --git a/tests/test_workflow_save.py b/tests/test_workflow_save.py new file mode 100644 index 0000000..e9796df --- /dev/null +++ b/tests/test_workflow_save.py @@ -0,0 +1,20 @@ +from pathlib import Path + +import pytest + +from backend.server import PNG_SIGNATURE, save_png_bytes + + +def test_save_png_bytes_writes_exact_png_payload(tmp_path: Path): + target = tmp_path / "workflow" + payload = PNG_SIGNATURE + b"argonode-test-payload" + + saved_path = save_png_bytes(str(target), payload) + + assert saved_path == tmp_path / "workflow.png" + assert saved_path.read_bytes() == payload + + +def test_save_png_bytes_rejects_invalid_payload(tmp_path: Path): + with pytest.raises(ValueError, match="valid PNG"): + save_png_bytes(str(tmp_path / "workflow.png"), b"not-a-png")