fix preview and save on native

This commit is contained in:
2026-03-24 22:52:24 -07:00
parent a60b0c15ca
commit 6959c62c8f
16 changed files with 875 additions and 202 deletions

View File

@@ -194,18 +194,17 @@ class ExecutionEngine:
CrossSection._broadcast_overlay_fn = on_overlay
LineCursors._broadcast_overlay_fn = on_overlay
LoadFile._broadcast_warning_fn = on_warning
SaveImage._broadcast_preview = (
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
)
SaveImage._broadcast_warning_fn = on_warning
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
"""Inform display nodes of their current node_id for WS tagging."""
from backend.nodes.display import PreviewImage, PrintTable, View3D
from backend.nodes.analysis import CrossSection, LineCursors
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.io import LoadFile
from backend.nodes.io import LoadFile, SaveImage
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile):
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine,
LoadFile, SaveImage):
cls._current_node_id = node_id
def _auto_preview(
@@ -262,11 +261,9 @@ class ExecutionEngine:
cls: type,
slot: int,
result: tuple,
) -> str | None:
"""Render a LINE output as a small matplotlib plot, returned as a data URI."""
) -> dict | None:
"""Return structured LINE preview data for responsive frontend rendering."""
import numpy as np
import base64
import io as _io
return_types = getattr(cls, "RETURN_TYPES", ())
@@ -281,17 +278,22 @@ class ExecutionEngine:
return None # the first LINE already plotted both
try:
import base64
import io as _io
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
y = np.asarray(y, dtype=np.float64).ravel()
if x is None:
x = np.arange(len(y), dtype=np.float64)
else:
x = np.asarray(x, dtype=np.float64).ravel()[:len(y)]
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
fig.patch.set_facecolor("#1e293b")
ax.set_facecolor("#0f172a")
if x is not None:
ax.plot(x, y, color="#ff9800", linewidth=1.2)
else:
ax.plot(y, color="#ff9800", linewidth=1.2)
ax.plot(x, y, color="#ff9800", linewidth=1.2)
ax.tick_params(colors="#94a3b8", labelsize=7)
for spine in ax.spines.values():
spine.set_color("#334155")
@@ -301,8 +303,15 @@ class ExecutionEngine:
buf = _io.BytesIO()
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
plt.close(fig)
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
fallback_image = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
return {
"kind": "line_plot",
"line": y.tolist(),
"x_axis": x.tolist(),
"interactive": False,
"fallback_image": fallback_image,
}
except Exception:
return None

View File

@@ -47,6 +47,7 @@ def get_node_info(class_name: str) -> dict[str, Any]:
"output": list(cls.RETURN_TYPES),
"output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)),
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
"manual_trigger": bool(getattr(cls, "MANUAL_TRIGGER", False)),
"description": getattr(cls, "DESCRIPTION", ""),
}

View File

@@ -11,7 +11,7 @@ Gwyddion equivalents:
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
from backend.data_types import DataField, datafield_to_uint8, encode_preview
# ---------------------------------------------------------------------------
@@ -131,88 +131,49 @@ class LineCursors:
self, line, x1: float, y1: float, x2: float, y2: float,
x_axis=None,
) -> tuple:
import io as _io
import base64
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
y = np.asarray(line, dtype=np.float64).ravel()
n = len(y)
if x_axis is not None:
x = np.asarray(x_axis, dtype=np.float64).ravel()[:n]
else:
x = np.arange(n, dtype=np.float64)
x1 = float(np.clip(x1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 1.0))
# --- Render the base plot first to determine axes bounds ---
fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100)
fig.patch.set_facecolor("#1e293b")
ax.set_facecolor("#0f172a")
ax.plot(x, y, color="#ff9800", linewidth=1.2)
ax.tick_params(colors="#94a3b8", labelsize=7)
for spine in ax.spines.values():
spine.set_color("#334155")
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
fig.tight_layout(pad=0.4)
xmin = float(np.min(x)) if len(x) else 0.0
xmax = float(np.max(x)) if len(x) else 1.0
# Force a draw so transforms are valid
fig.canvas.draw()
def x_frac_to_idx(frac):
if n <= 1:
return 0
if xmax == xmin:
return 0
target_x = xmin + frac * (xmax - xmin)
return int(np.argmin(np.abs(x - target_x)))
# Get axes position in figure-fraction coordinates
ax_pos = ax.get_position()
ax_l, ax_b = ax_pos.x0, ax_pos.y0
ax_w, ax_h = ax_pos.width, ax_pos.height
# x1/y1 arrive as image-fraction from the frontend drag.
# Convert image-fraction x → axes-fraction → nearest data index.
def img_x_to_idx(ix):
axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1)
return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1))
idx_a = img_x_to_idx(x1)
idx_b = img_x_to_idx(x2)
idx_a = x_frac_to_idx(x1)
idx_b = x_frac_to_idx(x2)
xa, ya = float(x[idx_a]), float(y[idx_a])
xb, yb = float(x[idx_b]), float(y[idx_b])
# --- Draw cursor lines and markers on the plot ---
ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5)
ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5)
ax.annotate(
"", xy=(xb, yb), xytext=(xa, ya),
arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5),
)
# --- Broadcast overlay ---
if LineCursors._broadcast_overlay_fn is not None:
# Convert data-space positions back to image-fraction for markers
fig.canvas.draw()
inv = fig.transFigure.inverted()
fig_a = inv.transform(ax.transData.transform([xa, ya]))
fig_b = inv.transform(ax.transData.transform([xb, yb]))
buf = _io.BytesIO()
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
LineCursors._broadcast_overlay_fn(
LineCursors._current_node_id,
{
"image": image_uri,
"x1": float(fig_a[0]),
"y1": float(1.0 - fig_a[1]), # flip: image y=0 is top
"x2": float(fig_b[0]),
"y2": float(1.0 - fig_b[1]),
"kind": "line_plot",
"line": y.tolist(),
"x_axis": x.tolist(),
"x1": x1,
"x2": x2,
"y1": float(y1),
"y2": float(y2),
"a_locked": False,
"b_locked": False,
},
)
plt.close(fig)
# --- Output table ---
table = [
{"quantity": "A position", "value": xa, "unit": ""},
@@ -414,8 +375,6 @@ class CrossSection:
point_a=None, point_b=None,
) -> tuple:
from scipy.ndimage import map_coordinates
import io, base64
from matplotlib.figure import Figure
# COORD inputs override widget values
if point_a is not None:
@@ -453,14 +412,9 @@ class CrossSection:
# Broadcast overlay image with marker positions
if CrossSection._broadcast_overlay_fn is not None:
fig = Figure(figsize=(3, 3), dpi=100)
ax = fig.add_axes([0, 0, 1, 1])
ax.imshow(field.data, cmap="viridis", aspect="auto")
ax.axis("off")
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
# Use the field's native pixel grid for the overlay preview so enlarging
# the panel keeps the image as sharp as the source data allows.
image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
CrossSection._broadcast_overlay_fn(
CrossSection._current_node_id,

View File

@@ -78,7 +78,7 @@ class View3D:
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS),),
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
}
}
@@ -134,7 +134,7 @@ class View3D:
"colors": colors_b64,
"z_min": zmin,
"z_max": zmax,
"z_scale": float(z_scale),
"z_scale": float(z_scale * 0.1),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}

View File

@@ -399,54 +399,86 @@ class Coordinate:
# SaveImage
# ---------------------------------------------------------------------------
@register_node(display_name="Save Image")
_MAX_SAVE_FIELDS = 8
@register_node(display_name="Save Layers")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
optional = {}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("DATA_FIELD",)
return {
"required": {
"image": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "output"}),
"format": (["PNG", "TIFF", "NPY"],),
}
"filename": ("FILE_PICKER", {"default": ""}),
"format": (["TIFF", "NPZ"],),
},
"optional": optional,
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "io"
OUTPUT_NODE = True
DESCRIPTION = "Save an image or array to the output folder."
MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more DATA_FIELD layers to a single file. "
"Connect fields to the inputs — a new slot appears as each is filled. "
"TIFF writes float32 multi-page; NPZ writes float64 named arrays. "
"Click Save to write (does not auto-run)."
)
# Injected by server.py before execution begins
_broadcast_preview = None
_broadcast_warning_fn = None
_current_node_id = None
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"):
OUTPUT_DIR.mkdir(exist_ok=True)
def save(self, filename: str, format: str = "TIFF", **kwargs):
# Collect connected fields in order
fields = []
for i in range(_MAX_SAVE_FIELDS):
f = kwargs.get(f"field_{i}")
if f is not None:
fields.append(f)
# Find next available filename
idx = 1
while True:
name = f"{filename_prefix}_{idx:04d}"
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
if not candidate.exists():
break
idx += 1
if not fields:
raise ValueError("No fields connected — connect at least one DATA_FIELD input.")
if format == "NPY":
np.save(str(OUTPUT_DIR / f"{name}.npy"), image)
if not filename or not filename.strip():
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(filename)
# Ensure parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)
# Force correct extension
ext = ".tiff" if format == "TIFF" else ".npz"
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
if format == "TIFF":
self._save_tiff(path, fields)
else:
from PIL import Image
arr = image_to_uint8(image)
if arr.ndim == 2:
pil_img = Image.fromarray(arr, mode="L")
else:
pil_img = Image.fromarray(arr, mode="RGB")
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
self._save_npz(path, fields)
# Emit preview over WebSocket if callback is set
if SaveImage._broadcast_preview is not None:
arr_u8 = image_to_uint8(image)
data_uri = encode_preview(arr_u8)
SaveImage._broadcast_preview(data_uri)
self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}")
return ()
def _save_tiff(self, path: Path, fields: list[DataField]):
from PIL import Image
images = []
for f in fields:
images.append(Image.fromarray(f.data.astype(np.float32)))
images[0].save(str(path), save_all=True, append_images=images[1:])
def _save_npz(self, path: Path, fields: list[DataField]):
arrays = {}
for i, f in enumerate(fields):
arrays[f"layer_{i}"] = f.data
np.savez(str(path), **arrays)
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return ()

View File

@@ -41,6 +41,7 @@ FRONTEND_DIR = frontend_dir()
DIST_DIR = frontend_dist_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
# ---------------------------------------------------------------------------
@@ -63,6 +64,18 @@ def _dumps(obj) -> str:
return json.dumps(obj, cls=_SafeEncoder)
def save_png_bytes(target_path: str, payload: bytes) -> Path:
path = Path(target_path).expanduser()
if not target_path.strip():
raise ValueError("Missing save path")
if path.suffix.lower() != ".png":
path = path.with_suffix(".png")
if not payload.startswith(PNG_SIGNATURE):
raise ValueError("Payload is not a valid PNG")
path.write_bytes(payload)
return path
# ---------------------------------------------------------------------------
# Application factory
# ---------------------------------------------------------------------------
@@ -196,6 +209,20 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
},
)
async def save_workflow_png(request: web.Request) -> web.Response:
body = await request.read()
target_path = request.query.get("path", "")
if not target_path:
raise web.HTTPBadRequest(reason="Missing path")
try:
saved_path = save_png_bytes(target_path, body)
except ValueError as exc:
raise web.HTTPBadRequest(reason=str(exc)) from exc
return web.Response(
text=_dumps({"path": str(saved_path)}),
content_type="application/json",
)
async def get_channels(request: web.Request) -> web.Response:
"""Return available channels for a given file path."""
from backend.nodes.io import list_channels
@@ -278,6 +305,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
app.router.add_get("/browse", browse_dir)
app.router.add_post("/upload", upload_file)
app.router.add_post("/download", download_file)
app.router.add_post("/save-workflow-png", save_workflow_png)
app.router.add_get("/channels", get_channels)
app.router.add_post("/prompt", submit_prompt)
app.router.add_get("/ws", websocket_handler)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import base64
import logging
import socket
import threading
@@ -45,8 +44,8 @@ class _Api:
return result[0]
return None
def save_workflow_png(self, data_url: str, default_filename: str = "workflow.png") -> str | None:
"""Open a native save dialog, write the PNG bytes, and return the saved path."""
def choose_save_workflow_png_path(self, default_filename: str = "workflow.png") -> str | None:
"""Open a native save dialog and return the chosen PNG path (or None)."""
win = self._window_ref[0]
if win is None:
return None
@@ -65,12 +64,6 @@ class _Api:
path = Path(result[0] if isinstance(result, (list, tuple)) else result).expanduser()
if path.suffix.lower() != ".png":
path = path.with_suffix(".png")
_, _, encoded = data_url.partition(",")
if not encoded:
raise ValueError("Invalid data URL payload")
path.write_bytes(base64.b64decode(encoded))
return str(path)

View File

@@ -13,6 +13,7 @@ import FileBrowser from './FileBrowser';
import * as api from './api';
import { toBlob } from 'html-to-image';
import { embedWorkflow, extractWorkflow } from './pngMetadata';
import { hydrateWorkflowState } from './workflowHydration';
import { serializeWorkflowState } from './workflowSerialization';
// ── Constants ─────────────────────────────────────────────────────────
@@ -43,15 +44,6 @@ function getOutputSlot(handleId) {
return parseInt(handleId.split('::')[1], 10);
}
function blobToDataUrl(blob) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onloadend = () => resolve(reader.result);
reader.onerror = () => reject(reader.error || new Error('Failed to read file'));
reader.readAsDataURL(blob);
});
}
async function waitForImageElement(img) {
if (img.complete && img.naturalWidth > 0) return;
if (typeof img.decode === 'function') {
@@ -73,6 +65,31 @@ async function waitForImageElement(img) {
});
}
async function getCaptureImageDataUrl(img) {
const src = img.currentSrc || img.src;
if (!src) return null;
if (!src.startsWith('data:')) return src;
const rect = img.getBoundingClientRect();
const width = Math.max(1, Math.round(img.clientWidth || rect.width));
const height = Math.max(1, Math.round(img.clientHeight || rect.height));
const scale = Math.min(2, window.devicePixelRatio || 1);
const canvas = document.createElement('canvas');
canvas.width = Math.max(1, Math.round(width * scale));
canvas.height = Math.max(1, Math.round(height * scale));
const ctx = canvas.getContext('2d');
if (!ctx) return src;
try {
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
return canvas.toDataURL('image/png');
} catch {
return src;
}
}
function createCapturePlaceholder(el, dataUrl) {
const rect = el.getBoundingClientRect();
const style = window.getComputedStyle(el);
@@ -101,8 +118,9 @@ async function captureViewportBlob(viewportEl, options) {
await Promise.all(images.map(waitForImageElement));
for (const img of images) {
const dataUrl = img.currentSrc || img.src;
if (!dataUrl || !img.parentNode) continue;
if (!img.parentNode) continue;
const dataUrl = await getCaptureImageDataUrl(img);
if (!dataUrl) continue;
const placeholder = createCapturePlaceholder(img, dataUrl);
img.parentNode.replaceChild(placeholder, img);
restorers.push(() => {
@@ -144,12 +162,13 @@ async function captureViewportBlob(viewportEl, options) {
// ── Graph serialisation → backend prompt format ───────────────────────
function serializeGraph(nodes, edges) {
function serializeGraph(nodes, edges, { excludeManualTrigger = false } = {}) {
const prompt = {};
for (const node of nodes) {
const { className, definition, widgetValues } = node.data;
if (!definition) continue;
if (excludeManualTrigger && definition.manual_trigger) continue;
const inputs = {};
@@ -551,10 +570,23 @@ function Flow() {
// ── Node context value (stable) ─────────────────────────────────────
const onManualTrigger = useCallback((nodeId) => {
const currentNodes = reactFlow.getNodes();
const currentEdges = reactFlow.getEdges();
// Include ALL nodes (no excludeManualTrigger) so the save node is in the prompt
const prompt = serializeGraph(currentNodes, currentEdges);
if (!prompt || Object.keys(prompt).length === 0) return;
setStatus({ text: 'Saving…', level: 'info' });
api.runPrompt(prompt).catch((err) => {
setStatus({ text: 'Save failed: ' + err.message, level: 'error' });
});
}, [reactFlow]);
const contextValue = useMemo(() => ({
onWidgetChange,
openFileBrowser,
}), [onWidgetChange, openFileBrowser]);
onManualTrigger,
}), [onWidgetChange, openFileBrowser, onManualTrigger]);
// ── Add node from context menu ──────────────────────────────────────
@@ -687,13 +719,18 @@ function Flow() {
const currentNodes = reactFlow.getNodes();
const currentEdges = reactFlow.getEdges();
// Don't run if any node has unconnected required data inputs
// Don't run if any non-manual node has unconnected required data inputs
// or any FILE_PICKER widget is empty
for (const node of currentNodes) {
const def = node.data?.definition;
if (!def) continue;
if (!def || def.manual_trigger) continue; // skip manual-trigger nodes
const required = def.input.required || {};
for (const [name, spec] of Object.entries(required)) {
const [type] = Array.isArray(spec) ? spec : [spec];
if (type === 'FILE_PICKER') {
if (!node.data.widgetValues?.[name]) return; // no file selected, skip
continue;
}
if (!DATA_TYPES.has(type)) continue;
const hasEdge = currentEdges.some(
(e) => e.target === node.id && getInputName(e.targetHandle) === name
@@ -702,7 +739,7 @@ function Flow() {
}
}
const prompt = serializeGraph(currentNodes, currentEdges);
const prompt = serializeGraph(currentNodes, currentEdges, { excludeManualTrigger: true });
if (!prompt || Object.keys(prompt).length === 0) return;
setStatus({ text: 'Running…', level: 'info' });
api.runPrompt(prompt).catch((err) => {
@@ -723,25 +760,10 @@ function Flow() {
}, [setNodes, setEdges]);
const applyWorkflowData = useCallback((data) => {
const loadedNodes = data.nodes || [];
const loadedEdges = data.edges || [];
const defs = nodeDefsRef.current;
const hydrated = loadedNodes.map((n) => ({
...n,
type: n.type || 'custom',
dragHandle: n.dragHandle || '.drag-handle',
data: {
...n.data,
label: n.data?.label || n.data?.className || 'Node',
widgetValues: n.data?.widgetValues || {},
definition: defs[n.data.className] || n.data.definition,
previewImage: null, tableRows: null, meshData: null, overlay: null,
},
}));
setNodes(hydrated);
setEdges(loadedEdges);
const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
nextIdRef.current = maxId + 1;
const hydrated = hydrateWorkflowState(data, nodeDefsRef.current);
setNodes(hydrated.nodes);
setEdges(hydrated.edges);
nextIdRef.current = hydrated.nextNodeId;
}, [setNodes, setEdges]);
const getWorkflowBlob = useCallback(async () => {
@@ -778,9 +800,23 @@ function Flow() {
try {
const finalBlob = await getWorkflowBlob();
if (window.pywebview?.api?.save_workflow_png) {
const dataUrl = await blobToDataUrl(finalBlob);
const savedPath = await window.pywebview.api.save_workflow_png(dataUrl, 'workflow.png');
if (window.pywebview?.api?.choose_save_workflow_png_path) {
const requestedPath = await window.pywebview.api.choose_save_workflow_png_path('workflow.png');
if (!requestedPath) {
setStatus({ text: 'Save cancelled.', level: 'info' });
return;
}
const resp = await fetch(`/save-workflow-png?path=${encodeURIComponent(requestedPath)}`, {
method: 'POST',
headers: {
'Content-Type': 'image/png',
},
body: finalBlob,
});
if (!resp.ok) {
throw new Error(await resp.text() || `Save failed (${resp.status})`);
}
const { path: savedPath } = await resp.json();
if (!savedPath) {
setStatus({ text: 'Save cancelled.', level: 'info' });
return;

View File

@@ -1,5 +1,6 @@
import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react';
import { Handle, Position } from '@xyflow/react';
import { Handle, Position, useStore } from '@xyflow/react';
import LinePlotOverlay from './LinePlotOverlay';
const SurfaceView = lazy(() => import('./SurfaceView'));
const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay'));
@@ -29,6 +30,47 @@ const CAT_COLORS = {
export const NodeContext = React.createContext(null);
class PreviewBoundary extends React.Component {
constructor(props) {
super(props);
this.state = { hasError: false };
}
static getDerivedStateFromError() {
return { hasError: true };
}
componentDidCatch(error) {
console.error('[argonode] preview render failed', error);
}
componentDidUpdate(prevProps) {
if (prevProps.resetKey !== this.props.resetKey && this.state.hasError) {
this.setState({ hasError: false });
}
}
render() {
if (!this.state.hasError) {
return this.props.children;
}
if (this.props.fallbackImage) {
return (
<div className="node-preview">
<img src={this.props.fallbackImage} alt="preview fallback" draggable={false} />
</div>
);
}
return (
<div className="node-preview" style={{ color: '#94a3b8', padding: 8 }}>
Preview unavailable.
</div>
);
}
}
// ── Draggable number input ────────────────────────────────────────────
function DraggableNumber({ value, step, min, max, precision, onChange }) {
@@ -151,8 +193,39 @@ function CustomNode({ id, data }) {
}
}
// For manual-trigger nodes (Save), show progressive optional inputs:
// show field_N only if field_(N-1) is connected (or N==0).
const isProgressive = def.manual_trigger;
const connectedInputs = useStore(
useCallback(
(s) => {
if (!isProgressive) return null;
const set = new Set();
for (const e of s.edges) {
if (e.target === id) {
const parts = e.targetHandle?.split('::');
if (parts) set.add(parts[1]);
}
}
return set;
},
[id, isProgressive],
),
);
for (const [name, spec] of Object.entries(optional)) {
const [type] = Array.isArray(spec) ? spec : [spec];
if (isProgressive && DATA_TYPES.has(type)) {
// Progressive: show this slot only if it's the first or the previous is connected
const match = name.match(/^field_(\d+)$/);
if (match) {
const idx = parseInt(match[1], 10);
if (idx === 0 || (connectedInputs && connectedInputs.has(`field_${idx - 1}`))) {
dataInputs.push({ name, type });
}
continue;
}
}
dataInputs.push({ name, type });
}
@@ -229,6 +302,19 @@ function CustomNode({ id, data }) {
</div>
))}
{/* Manual trigger button (Save) */}
{def.manual_trigger && (
<div className="widget-row">
<button
className="nodrag btn btn-primary"
style={{ flex: 1 }}
onClick={() => ctx.onManualTrigger?.(id)}
>
Save to Disk
</button>
</div>
)}
{/* Interactive 3D surface view */}
{data.meshData && (
<CollapsibleSection title="3D View" defaultOpen={true}>
@@ -241,9 +327,21 @@ function CustomNode({ id, data }) {
{/* Collapsible preview image */}
{data.previewImage && (
<CollapsibleSection title="Preview" defaultOpen={true}>
<div className="node-preview">
<img src={data.previewImage} alt="preview" draggable={false} />
</div>
<PreviewBoundary
resetKey={typeof data.previewImage === 'string' ? data.previewImage : JSON.stringify({
kind: data.previewImage.kind,
len: data.previewImage.line?.length,
})}
fallbackImage={typeof data.previewImage === 'object' ? data.previewImage.fallback_image : null}
>
{typeof data.previewImage === 'string' ? (
<div className="node-preview">
<img src={data.previewImage} alt="preview" draggable={false} />
</div>
) : data.previewImage.kind === 'line_plot' ? (
<LinePlotOverlay overlay={data.previewImage} interactive={false} />
) : null}
</PreviewBoundary>
</CollapsibleSection>
)}
@@ -251,17 +349,29 @@ function CustomNode({ id, data }) {
{data.overlay && hiddenWidgets.has('x1') && (
<CollapsibleSection title="Cross Section" defaultOpen={true}>
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading...</div>}>
<CrossSectionOverlay
image={data.overlay.image}
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
y1={data.overlay.a_locked ? data.overlay.y1 : (data.widgetValues.y1 ?? data.overlay.y1)}
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
y2={data.overlay.b_locked ? data.overlay.y2 : (data.widgetValues.y2 ?? data.overlay.y2)}
aLocked={data.overlay.a_locked}
bLocked={data.overlay.b_locked}
nodeId={id}
onWidgetChange={ctx.onWidgetChange}
/>
{data.overlay.kind === 'line_plot' ? (
<LinePlotOverlay
overlay={data.overlay}
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
aLocked={data.overlay.a_locked}
bLocked={data.overlay.b_locked}
nodeId={id}
onWidgetChange={ctx.onWidgetChange}
/>
) : (
<CrossSectionOverlay
image={data.overlay.image}
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
y1={data.overlay.a_locked ? data.overlay.y1 : (data.widgetValues.y1 ?? data.overlay.y1)}
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
y2={data.overlay.b_locked ? data.overlay.y2 : (data.widgetValues.y2 ?? data.overlay.y2)}
aLocked={data.overlay.a_locked}
bLocked={data.overlay.b_locked}
nodeId={id}
onWidgetChange={ctx.onWidgetChange}
/>
)}
</Suspense>
</CollapsibleSection>
)}

View File

@@ -0,0 +1,271 @@
import React, { useEffect, useRef, useState, useCallback } from 'react';
const ASPECT_RATIO = 3.2 / 2.2;
const MARGINS = { top: 18, right: 16, bottom: 34, left: 56 };
function clamp(v, min, max) {
return Math.max(min, Math.min(max, v));
}
function round3(v) {
return parseFloat(v.toFixed(3));
}
function trimZeros(text) {
return text.replace(/(?:\.0+|(\.\d+?)0+)$/, '$1');
}
function formatTick(value) {
const abs = Math.abs(value);
if (abs === 0) return '0';
if (abs >= 1e4 || abs < 1e-3) {
return value.toExponential(1).replace('e+', 'e');
}
if (abs >= 100) return trimZeros(value.toFixed(0));
if (abs >= 10) return trimZeros(value.toFixed(1));
if (abs >= 1) return trimZeros(value.toFixed(2));
return trimZeros(value.toFixed(3));
}
function makeTicks(min, max, count = 5) {
if (!Number.isFinite(min) || !Number.isFinite(max)) return [];
if (min === max) return [min];
const ticks = [];
for (let i = 0; i < count; i += 1) {
ticks.push(min + ((max - min) * i) / (count - 1));
}
return ticks;
}
function getExtent(values, fallbackMin = 0, fallbackMax = 1) {
if (!Array.isArray(values) || values.length === 0) {
return [fallbackMin, fallbackMax];
}
let min = Infinity;
let max = -Infinity;
for (const value of values) {
if (!Number.isFinite(value)) continue;
if (value < min) min = value;
if (value > max) max = value;
}
if (!Number.isFinite(min) || !Number.isFinite(max)) {
return [fallbackMin, fallbackMax];
}
return [min, max];
}
export default function LinePlotOverlay({
overlay,
x1,
x2,
aLocked,
bLocked,
nodeId,
onWidgetChange,
interactive = true,
}) {
const containerRef = useRef(null);
const [dragging, setDragging] = useState(null);
const [size, setSize] = useState({ width: 0, height: 0 });
useEffect(() => {
if (!containerRef.current) return undefined;
const updateSize = () => {
if (!containerRef.current) return;
setSize({
width: Math.max(1, Math.round(containerRef.current.clientWidth || 320)),
height: Math.max(1, Math.round(containerRef.current.clientHeight || (containerRef.current.clientWidth / ASPECT_RATIO) || 220)),
});
};
updateSize();
if (typeof ResizeObserver === 'function') {
const observer = new ResizeObserver((entries) => {
const entry = entries[0];
if (!entry) return;
const { width, height } = entry.contentRect;
setSize({
width: Math.max(1, Math.round(width)),
height: Math.max(1, Math.round(height)),
});
});
observer.observe(containerRef.current);
return () => observer.disconnect();
}
window.addEventListener('resize', updateSize);
return () => window.removeEventListener('resize', updateSize);
}, []);
const xValues = Array.isArray(overlay?.x_axis) && overlay.x_axis.length === overlay.line?.length
? overlay.x_axis
: overlay?.line?.map((_, i) => i) || [];
const yValues = Array.isArray(overlay?.line) ? overlay.line : [];
const width = size.width || 320;
const height = size.height || Math.round(width / ASPECT_RATIO);
const plotLeft = MARGINS.left;
const plotTop = MARGINS.top;
const plotWidth = Math.max(1, width - MARGINS.left - MARGINS.right);
const plotHeight = Math.max(1, height - MARGINS.top - MARGINS.bottom);
const [xMin, xMax] = getExtent(xValues, 0, 1);
const [yMinRaw, yMaxRaw] = getExtent(yValues, 0, 1);
const yPad = yMinRaw === yMaxRaw ? 1 : (yMaxRaw - yMinRaw) * 0.08;
const yMin = yMinRaw - yPad;
const yMax = yMaxRaw + yPad;
const scaleX = useCallback((value) => {
if (xMax === xMin) return plotLeft + plotWidth / 2;
return plotLeft + ((value - xMin) / (xMax - xMin)) * plotWidth;
}, [plotLeft, plotWidth, xMin, xMax]);
const scaleY = useCallback((value) => {
if (yMax === yMin) return plotTop + plotHeight / 2;
return plotTop + (1 - ((value - yMin) / (yMax - yMin))) * plotHeight;
}, [plotTop, plotHeight, yMin, yMax]);
const pickCursorPoint = useCallback((fraction) => {
if (!xValues.length || !yValues.length) {
return {
x: plotLeft,
y: plotTop + plotHeight / 2,
yFraction: 0.5,
};
}
const frac = clamp(fraction ?? 0.5, 0, 1);
const targetX = xMin + frac * (xMax - xMin || 1);
let idx = 0;
let best = Infinity;
for (let i = 0; i < xValues.length; i += 1) {
const dist = Math.abs(xValues[i] - targetX);
if (dist < best) {
best = dist;
idx = i;
}
}
const x = xValues[idx];
const y = yValues[idx];
const yFraction = yMax === yMin ? 0.5 : clamp((y - yMin) / (yMax - yMin), 0, 1);
return {
x: scaleX(x),
y: scaleY(y),
yFraction,
};
}, [plotLeft, plotTop, plotHeight, scaleX, scaleY, xValues, yValues, xMin, xMax, yMin, yMax]);
const cursorA = pickCursorPoint(x1 ?? overlay?.x1 ?? 0.25);
const cursorB = pickCursorPoint(x2 ?? overlay?.x2 ?? 0.75);
const path = yValues.map((y, i) => `${i === 0 ? 'M' : 'L'} ${scaleX(xValues[i])} ${scaleY(y)}`).join(' ');
const xTicks = makeTicks(xMin, xMax);
const yTicks = makeTicks(yMin, yMax);
const plotStroke = clamp(plotWidth / 240, 1.4, 2.6);
const gridStroke = clamp(plotWidth / 900, 0.6, 1.1);
const cursorStroke = clamp(plotWidth / 220, 1.4, 2.2);
const measureStroke = clamp(plotWidth / 180, 1.6, 2.8);
const markerRadius = clamp(plotWidth / 42, 5.5, 9);
const updateCursor = useCallback((point, event) => {
if (!interactive || !onWidgetChange || !nodeId) return;
if (!containerRef.current) return;
const rect = containerRef.current.getBoundingClientRect();
const xFrac = clamp((event.clientX - rect.left - plotLeft) / plotWidth, 0, 1);
const sample = pickCursorPoint(xFrac);
if (point === 'p1') {
onWidgetChange(nodeId, 'x1', round3(xFrac));
onWidgetChange(nodeId, 'y1', round3(sample.yFraction));
} else {
onWidgetChange(nodeId, 'x2', round3(xFrac));
onWidgetChange(nodeId, 'y2', round3(sample.yFraction));
}
}, [interactive, nodeId, onWidgetChange, pickCursorPoint, plotLeft, plotWidth]);
const onPointerDown = useCallback((point) => (event) => {
if (!interactive) return;
if ((point === 'p1' && aLocked) || (point === 'p2' && bLocked)) return;
event.preventDefault();
event.stopPropagation();
event.currentTarget.setPointerCapture(event.pointerId);
setDragging(point);
}, [interactive, aLocked, bLocked]);
const onPointerMove = useCallback((event) => {
if (!dragging) return;
updateCursor(dragging, event);
}, [dragging, updateCursor]);
const onPointerUp = useCallback(() => {
setDragging(null);
}, []);
return (
<div
ref={containerRef}
className="nodrag nowheel lineplot-overlay"
onPointerMove={onPointerMove}
onPointerUp={onPointerUp}
onLostPointerCapture={onPointerUp}
>
<svg width={width} height={height} viewBox={`0 0 ${width} ${height}`} className="lineplot-svg">
<rect x="0" y="0" width={width} height={height} fill="#0f172a" />
{xTicks.map((tick) => {
const x = scaleX(tick);
return (
<g key={`x-${tick}`}>
<line x1={x} y1={plotTop} x2={x} y2={plotTop + plotHeight} stroke="#334155" strokeWidth={gridStroke} opacity="0.45" />
<text x={x} y={height - 10} textAnchor="middle" fontSize="11" fill="#94a3b8">
{formatTick(tick)}
</text>
</g>
);
})}
{yTicks.map((tick) => {
const y = scaleY(tick);
return (
<g key={`y-${tick}`}>
<line x1={plotLeft} y1={y} x2={plotLeft + plotWidth} y2={y} stroke="#334155" strokeWidth={gridStroke} opacity="0.45" />
<text x={plotLeft - 10} y={y + 4} textAnchor="end" fontSize="11" fill="#94a3b8">
{formatTick(tick)}
</text>
</g>
);
})}
<rect x={plotLeft} y={plotTop} width={plotWidth} height={plotHeight} fill="none" stroke="#334155" strokeWidth={gridStroke + 0.3} />
<path d={path} fill="none" stroke="#ff9800" strokeWidth={plotStroke} strokeLinecap="round" strokeLinejoin="round" />
{interactive && (
<>
<line x1={cursorA.x} y1={plotTop} x2={cursorA.x} y2={plotTop + plotHeight} stroke="#ffd700" strokeWidth={cursorStroke} strokeDasharray="10 6" opacity="0.95" />
<line x1={cursorB.x} y1={plotTop} x2={cursorB.x} y2={plotTop + plotHeight} stroke="#ffd700" strokeWidth={cursorStroke} strokeDasharray="10 6" opacity="0.95" />
<line x1={cursorA.x} y1={cursorA.y} x2={cursorB.x} y2={cursorB.y} stroke="#90caf9" strokeWidth={measureStroke} opacity="0.95" />
<circle
cx={cursorA.x}
cy={cursorA.y}
r={markerRadius}
className={`lineplot-marker ${aLocked ? 'lineplot-marker-locked' : ''}`}
onPointerDown={onPointerDown('p1')}
/>
<circle
cx={cursorB.x}
cy={cursorB.y}
r={markerRadius}
className={`lineplot-marker ${bLocked ? 'lineplot-marker-locked' : ''}`}
onPointerDown={onPointerDown('p2')}
/>
</>
)}
</svg>
</div>
);
}

View File

@@ -367,6 +367,41 @@ html, body, #root {
opacity: 0.9;
}
.lineplot-overlay {
width: 100%;
aspect-ratio: 32 / 22;
background: #0f172a;
border: 1px solid #334155;
border-radius: 6px;
overflow: hidden;
user-select: none;
touch-action: none;
}
.lineplot-svg {
display: block;
width: 100%;
height: 100%;
}
.lineplot-marker {
fill: #ffd700;
stroke: #fff;
stroke-width: 2px;
cursor: grab;
filter: drop-shadow(0 0 4px rgba(0, 0, 0, 0.45));
}
.lineplot-marker:active {
cursor: grabbing;
}
.lineplot-marker-locked {
fill: #e91e63;
stroke: #e91e63;
cursor: default;
}
/* ── 3D surface view ──────────────────────────────────────────────── */
.surface-view-container {
width: 100%;

View File

@@ -0,0 +1,52 @@
function mergeDefinition(nodeData, defs) {
const savedData = nodeData || {};
const savedDefinition = savedData.definition && typeof savedData.definition === 'object'
? savedData.definition
: null;
const registryDefinition = savedData.className ? defs[savedData.className] : null;
const definition = registryDefinition || savedDefinition;
if (!definition) return null;
const output = Array.isArray(savedData.output)
? savedData.output
: (Array.isArray(savedDefinition?.output) ? savedDefinition.output : null);
const outputName = Array.isArray(savedData.output_name)
? savedData.output_name
: (Array.isArray(savedDefinition?.output_name) ? savedDefinition.output_name : null);
return {
...definition,
...(output ? { output } : {}),
...(outputName ? { output_name: outputName } : {}),
};
}
export function hydrateWorkflowState(data, defs = {}) {
const loadedNodes = Array.isArray(data?.nodes) ? data.nodes : [];
const loadedEdges = Array.isArray(data?.edges) ? data.edges : [];
const nodes = loadedNodes.map((node) => ({
...node,
type: node.type || 'custom',
dragHandle: node.dragHandle || '.drag-handle',
data: {
...node.data,
label: node.data?.label || node.data?.className || 'Node',
widgetValues: node.data?.widgetValues || {},
definition: mergeDefinition(node.data, defs),
previewImage: null,
tableRows: null,
meshData: null,
overlay: null,
},
}));
const nextNodeId = Math.max(0, ...loadedNodes.map((node) => parseInt(node.id, 10) || 0)) + 1;
return {
nodes,
edges: loadedEdges,
nextNodeId,
};
}

View File

@@ -10,6 +10,8 @@ export function serializeWorkflowState(nodes, edges) {
label: node.data?.label || node.data?.className || 'Node',
className: node.data?.className || '',
widgetValues: node.data?.widgetValues || {},
output: node.data?.definition?.output || [],
output_name: node.data?.definition?.output_name || [],
},
})),
edges: edges.map((edge) => ({
@@ -18,7 +20,7 @@ export function serializeWorkflowState(nodes, edges) {
sourceHandle: edge.sourceHandle,
target: edge.target,
targetHandle: edge.targetHandle,
style: edge.style,
...(edge.style ? { style: edge.style } : {}),
})),
};
}

View File

@@ -1,6 +1,7 @@
import test from 'node:test';
import assert from 'node:assert/strict';
import { hydrateWorkflowState } from '../src/workflowHydration.js';
import { serializeWorkflowState } from '../src/workflowSerialization.js';
test('serializeWorkflowState keeps only stable workflow fields needed for reload', () => {
@@ -59,6 +60,8 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
label: 'Demo Label',
className: 'DemoNode',
widgetValues: { threshold: 0.42, mode: 'fast' },
output: [],
output_name: [],
},
},
{
@@ -70,6 +73,8 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
label: 'NoLabelNode',
className: 'NoLabelNode',
widgetValues: {},
output: [],
output_name: [],
},
},
],
@@ -89,3 +94,99 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
assert.equal('previewImage' in serialized.nodes[0].data, false);
assert.equal('selected' in serialized.edges[0], false);
});
test('hydrateWorkflowState restores saved dynamic outputs on top of current node definitions', () => {
const saved = {
version: 1,
nodes: [
{
id: '12',
position: { x: 40, y: 80 },
data: {
className: 'LoadFile',
widgetValues: { filename: 'scan.ibw', colormap: 'viridis' },
output: ['DATA_FIELD', 'DATA_FIELD'],
output_name: ['Height', 'Phase'],
previewImage: 'stale',
},
},
],
edges: [
{
id: 'e12-3',
source: '12',
sourceHandle: 'output::1::DATA_FIELD',
target: '3',
targetHandle: 'input::field::DATA_FIELD',
},
],
};
const defs = {
LoadFile: {
category: 'io',
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['viridis', 'gray'], {}] } },
output: ['DATA_FIELD'],
output_name: ['field'],
manual_trigger: false,
},
};
const hydrated = hydrateWorkflowState(saved, defs);
assert.equal(hydrated.nextNodeId, 13);
assert.deepEqual(hydrated.edges, saved.edges);
assert.equal(hydrated.nodes[0].type, 'custom');
assert.equal(hydrated.nodes[0].dragHandle, '.drag-handle');
assert.equal(hydrated.nodes[0].data.label, 'LoadFile');
assert.equal(hydrated.nodes[0].data.previewImage, null);
assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD']);
assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Height', 'Phase']);
assert.deepEqual(hydrated.nodes[0].data.definition.input, defs.LoadFile.input);
});
test('serializeWorkflowState and hydrateWorkflowState preserve reload-critical metadata for dynamic nodes', () => {
const nodes = [
{
id: '7',
position: { x: 10, y: 20 },
data: {
label: 'Load File',
className: 'LoadFile',
widgetValues: { filename: 'scan.gwy', colormap: 'gray' },
definition: {
category: 'io',
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } },
output: ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD'],
output_name: ['Topography', 'Error', 'Mask'],
},
previewImage: 'data:image/png;base64,stale',
},
},
];
const edges = [
{
id: 'e7-9',
source: '7',
sourceHandle: 'output::2::DATA_FIELD',
target: '9',
targetHandle: 'input::field::DATA_FIELD',
},
];
const defs = {
LoadFile: {
category: 'io',
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } },
output: ['DATA_FIELD'],
output_name: ['field'],
},
};
const serialized = serializeWorkflowState(nodes, edges);
const hydrated = hydrateWorkflowState(serialized, defs);
assert.deepEqual(hydrated.nodes[0].data.widgetValues, nodes[0].data.widgetValues);
assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD']);
assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Topography', 'Error', 'Mask']);
assert.deepEqual(hydrated.edges, edges);
});

View File

@@ -564,31 +564,57 @@ def test_load_file():
def test_save_image():
print("=== Test: SaveImage ===")
print("=== Test: SaveImage (Save Layers) ===")
from backend.nodes.io import SaveImage
node = SaveImage()
field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
with tempfile.TemporaryDirectory() as tmpdir:
# Monkey-patch OUTPUT_DIR for testing
from pathlib import Path
import backend.nodes.io as io_mod
orig_dir = io_mod.OUTPUT_DIR
io_mod.OUTPUT_DIR = Path(tmpdir)
# Save single layer as TIFF
tiff_path = os.path.join(tmpdir, "out.tiff")
node.save(filename=tiff_path, format="TIFF", field_0=field_a)
assert os.path.exists(tiff_path), "TIFF file not created"
from PIL import Image
im = Image.open(tiff_path)
assert im.n_frames == 1
arr_back = np.array(im)
assert arr_back.shape == (32, 32)
# Save multi-layer as TIFF
tiff_path2 = os.path.join(tmpdir, "multi.tiff")
node.save(filename=tiff_path2, format="TIFF", field_0=field_a, field_1=field_b)
im2 = Image.open(tiff_path2)
assert im2.n_frames == 2
# Save as NPZ
npz_path = os.path.join(tmpdir, "out.npz")
node.save(filename=npz_path, format="NPZ", field_0=field_a, field_1=field_b)
assert os.path.exists(npz_path)
npz = np.load(npz_path)
assert len(npz.files) == 2
assert np.allclose(npz["layer_0"], field_a.data)
assert np.allclose(npz["layer_1"], field_b.data)
# Extension is forced to match format
wrong_ext = os.path.join(tmpdir, "output.png")
node.save(filename=wrong_ext, format="TIFF", field_0=field_a)
assert os.path.exists(os.path.join(tmpdir, "output.tiff"))
# No fields connected → error
try:
arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8)
node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF")
assert False, "Should have raised ValueError"
except ValueError:
pass
# Save as PNG
node.save(image=arr, filename_prefix="test", format="PNG")
saved = os.listdir(tmpdir)
assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}"
# Save as NPY
node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY")
saved = os.listdir(tmpdir)
assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}"
finally:
io_mod.OUTPUT_DIR = orig_dir
# No filename → error
try:
node.save(filename="", format="TIFF", field_0=field_a)
assert False, "Should have raised ValueError"
except ValueError:
pass
print(" PASS\n")
@@ -896,8 +922,11 @@ def test_line_cursors():
# Overlay should have been broadcast
assert len(overlays) == 1
assert "image" in overlays[0]
assert overlays[0]["image"].startswith("data:image/png;base64,")
assert overlays[0]["kind"] == "line_plot"
assert len(overlays[0]["line"]) == len(line)
assert len(overlays[0]["x_axis"]) == len(line)
assert 0.0 <= overlays[0]["x1"] <= 1.0
assert 0.0 <= overlays[0]["x2"] <= 1.0
# With x_axis provided
x_axis = np.linspace(0, 1, 100).astype(np.float64)

View File

@@ -0,0 +1,20 @@
from pathlib import Path
import pytest
from backend.server import PNG_SIGNATURE, save_png_bytes
def test_save_png_bytes_writes_exact_png_payload(tmp_path: Path):
target = tmp_path / "workflow"
payload = PNG_SIGNATURE + b"argonode-test-payload"
saved_path = save_png_bytes(str(target), payload)
assert saved_path == tmp_path / "workflow.png"
assert saved_path.read_bytes() == payload
def test_save_png_bytes_rejects_invalid_payload(tmp_path: Path):
with pytest.raises(ValueError, match="valid PNG"):
save_png_bytes(str(tmp_path / "workflow.png"), b"not-a-png")