fix preview and save on native
This commit is contained in:
@@ -194,18 +194,17 @@ class ExecutionEngine:
|
|||||||
CrossSection._broadcast_overlay_fn = on_overlay
|
CrossSection._broadcast_overlay_fn = on_overlay
|
||||||
LineCursors._broadcast_overlay_fn = on_overlay
|
LineCursors._broadcast_overlay_fn = on_overlay
|
||||||
LoadFile._broadcast_warning_fn = on_warning
|
LoadFile._broadcast_warning_fn = on_warning
|
||||||
SaveImage._broadcast_preview = (
|
SaveImage._broadcast_warning_fn = on_warning
|
||||||
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
def _set_node_id_on_display(self, cls: type, node_id: str) -> None:
|
||||||
"""Inform display nodes of their current node_id for WS tagging."""
|
"""Inform display nodes of their current node_id for WS tagging."""
|
||||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||||
from backend.nodes.analysis import CrossSection, LineCursors
|
from backend.nodes.analysis import CrossSection, LineCursors
|
||||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||||
from backend.nodes.io import LoadFile
|
from backend.nodes.io import LoadFile, SaveImage
|
||||||
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
|
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
|
||||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile):
|
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine,
|
||||||
|
LoadFile, SaveImage):
|
||||||
cls._current_node_id = node_id
|
cls._current_node_id = node_id
|
||||||
|
|
||||||
def _auto_preview(
|
def _auto_preview(
|
||||||
@@ -262,11 +261,9 @@ class ExecutionEngine:
|
|||||||
cls: type,
|
cls: type,
|
||||||
slot: int,
|
slot: int,
|
||||||
result: tuple,
|
result: tuple,
|
||||||
) -> str | None:
|
) -> dict | None:
|
||||||
"""Render a LINE output as a small matplotlib plot, returned as a data URI."""
|
"""Return structured LINE preview data for responsive frontend rendering."""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import base64
|
|
||||||
import io as _io
|
|
||||||
|
|
||||||
return_types = getattr(cls, "RETURN_TYPES", ())
|
return_types = getattr(cls, "RETURN_TYPES", ())
|
||||||
|
|
||||||
@@ -281,17 +278,22 @@ class ExecutionEngine:
|
|||||||
return None # the first LINE already plotted both
|
return None # the first LINE already plotted both
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
import base64
|
||||||
|
import io as _io
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
y = np.asarray(y, dtype=np.float64).ravel()
|
||||||
|
if x is None:
|
||||||
|
x = np.arange(len(y), dtype=np.float64)
|
||||||
|
else:
|
||||||
|
x = np.asarray(x, dtype=np.float64).ravel()[:len(y)]
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
|
fig, ax = plt.subplots(figsize=(3.2, 1.8), dpi=100)
|
||||||
fig.patch.set_facecolor("#1e293b")
|
fig.patch.set_facecolor("#1e293b")
|
||||||
ax.set_facecolor("#0f172a")
|
ax.set_facecolor("#0f172a")
|
||||||
if x is not None:
|
ax.plot(x, y, color="#ff9800", linewidth=1.2)
|
||||||
ax.plot(x, y, color="#ff9800", linewidth=1.2)
|
|
||||||
else:
|
|
||||||
ax.plot(y, color="#ff9800", linewidth=1.2)
|
|
||||||
ax.tick_params(colors="#94a3b8", labelsize=7)
|
ax.tick_params(colors="#94a3b8", labelsize=7)
|
||||||
for spine in ax.spines.values():
|
for spine in ax.spines.values():
|
||||||
spine.set_color("#334155")
|
spine.set_color("#334155")
|
||||||
@@ -301,8 +303,15 @@ class ExecutionEngine:
|
|||||||
buf = _io.BytesIO()
|
buf = _io.BytesIO()
|
||||||
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
|
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
fallback_image = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
|
||||||
return f"data:image/png;base64,{b64}"
|
|
||||||
|
return {
|
||||||
|
"kind": "line_plot",
|
||||||
|
"line": y.tolist(),
|
||||||
|
"x_axis": x.tolist(),
|
||||||
|
"interactive": False,
|
||||||
|
"fallback_image": fallback_image,
|
||||||
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ def get_node_info(class_name: str) -> dict[str, Any]:
|
|||||||
"output": list(cls.RETURN_TYPES),
|
"output": list(cls.RETURN_TYPES),
|
||||||
"output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)),
|
"output_name": list(getattr(cls, "RETURN_NAMES", cls.RETURN_TYPES)),
|
||||||
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
|
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
|
||||||
|
"manual_trigger": bool(getattr(cls, "MANUAL_TRIGGER", False)),
|
||||||
"description": getattr(cls, "DESCRIPTION", ""),
|
"description": getattr(cls, "DESCRIPTION", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ Gwyddion equivalents:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from backend.node_registry import register_node
|
from backend.node_registry import register_node
|
||||||
from backend.data_types import DataField
|
from backend.data_types import DataField, datafield_to_uint8, encode_preview
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -131,88 +131,49 @@ class LineCursors:
|
|||||||
self, line, x1: float, y1: float, x2: float, y2: float,
|
self, line, x1: float, y1: float, x2: float, y2: float,
|
||||||
x_axis=None,
|
x_axis=None,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
import io as _io
|
|
||||||
import base64
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use("Agg")
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
y = np.asarray(line, dtype=np.float64).ravel()
|
y = np.asarray(line, dtype=np.float64).ravel()
|
||||||
n = len(y)
|
n = len(y)
|
||||||
if x_axis is not None:
|
if x_axis is not None:
|
||||||
x = np.asarray(x_axis, dtype=np.float64).ravel()[:n]
|
x = np.asarray(x_axis, dtype=np.float64).ravel()[:n]
|
||||||
else:
|
else:
|
||||||
x = np.arange(n, dtype=np.float64)
|
x = np.arange(n, dtype=np.float64)
|
||||||
|
x1 = float(np.clip(x1, 0.0, 1.0))
|
||||||
|
x2 = float(np.clip(x2, 0.0, 1.0))
|
||||||
|
|
||||||
# --- Render the base plot first to determine axes bounds ---
|
xmin = float(np.min(x)) if len(x) else 0.0
|
||||||
fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100)
|
xmax = float(np.max(x)) if len(x) else 1.0
|
||||||
fig.patch.set_facecolor("#1e293b")
|
|
||||||
ax.set_facecolor("#0f172a")
|
|
||||||
ax.plot(x, y, color="#ff9800", linewidth=1.2)
|
|
||||||
ax.tick_params(colors="#94a3b8", labelsize=7)
|
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_color("#334155")
|
|
||||||
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
|
|
||||||
fig.tight_layout(pad=0.4)
|
|
||||||
|
|
||||||
# Force a draw so transforms are valid
|
def x_frac_to_idx(frac):
|
||||||
fig.canvas.draw()
|
if n <= 1:
|
||||||
|
return 0
|
||||||
|
if xmax == xmin:
|
||||||
|
return 0
|
||||||
|
target_x = xmin + frac * (xmax - xmin)
|
||||||
|
return int(np.argmin(np.abs(x - target_x)))
|
||||||
|
|
||||||
# Get axes position in figure-fraction coordinates
|
idx_a = x_frac_to_idx(x1)
|
||||||
ax_pos = ax.get_position()
|
idx_b = x_frac_to_idx(x2)
|
||||||
ax_l, ax_b = ax_pos.x0, ax_pos.y0
|
|
||||||
ax_w, ax_h = ax_pos.width, ax_pos.height
|
|
||||||
|
|
||||||
# x1/y1 arrive as image-fraction from the frontend drag.
|
|
||||||
# Convert image-fraction x → axes-fraction → nearest data index.
|
|
||||||
def img_x_to_idx(ix):
|
|
||||||
axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1)
|
|
||||||
return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1))
|
|
||||||
|
|
||||||
idx_a = img_x_to_idx(x1)
|
|
||||||
idx_b = img_x_to_idx(x2)
|
|
||||||
|
|
||||||
xa, ya = float(x[idx_a]), float(y[idx_a])
|
xa, ya = float(x[idx_a]), float(y[idx_a])
|
||||||
xb, yb = float(x[idx_b]), float(y[idx_b])
|
xb, yb = float(x[idx_b]), float(y[idx_b])
|
||||||
|
|
||||||
# --- Draw cursor lines and markers on the plot ---
|
|
||||||
ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
|
|
||||||
ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
|
|
||||||
ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5)
|
|
||||||
ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5)
|
|
||||||
ax.annotate(
|
|
||||||
"", xy=(xb, yb), xytext=(xa, ya),
|
|
||||||
arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5),
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Broadcast overlay ---
|
# --- Broadcast overlay ---
|
||||||
if LineCursors._broadcast_overlay_fn is not None:
|
if LineCursors._broadcast_overlay_fn is not None:
|
||||||
# Convert data-space positions back to image-fraction for markers
|
|
||||||
fig.canvas.draw()
|
|
||||||
inv = fig.transFigure.inverted()
|
|
||||||
fig_a = inv.transform(ax.transData.transform([xa, ya]))
|
|
||||||
fig_b = inv.transform(ax.transData.transform([xb, yb]))
|
|
||||||
|
|
||||||
buf = _io.BytesIO()
|
|
||||||
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
|
|
||||||
buf.seek(0)
|
|
||||||
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
|
|
||||||
|
|
||||||
LineCursors._broadcast_overlay_fn(
|
LineCursors._broadcast_overlay_fn(
|
||||||
LineCursors._current_node_id,
|
LineCursors._current_node_id,
|
||||||
{
|
{
|
||||||
"image": image_uri,
|
"kind": "line_plot",
|
||||||
"x1": float(fig_a[0]),
|
"line": y.tolist(),
|
||||||
"y1": float(1.0 - fig_a[1]), # flip: image y=0 is top
|
"x_axis": x.tolist(),
|
||||||
"x2": float(fig_b[0]),
|
"x1": x1,
|
||||||
"y2": float(1.0 - fig_b[1]),
|
"x2": x2,
|
||||||
|
"y1": float(y1),
|
||||||
|
"y2": float(y2),
|
||||||
"a_locked": False,
|
"a_locked": False,
|
||||||
"b_locked": False,
|
"b_locked": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
# --- Output table ---
|
# --- Output table ---
|
||||||
table = [
|
table = [
|
||||||
{"quantity": "A position", "value": xa, "unit": ""},
|
{"quantity": "A position", "value": xa, "unit": ""},
|
||||||
@@ -414,8 +375,6 @@ class CrossSection:
|
|||||||
point_a=None, point_b=None,
|
point_a=None, point_b=None,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
from scipy.ndimage import map_coordinates
|
from scipy.ndimage import map_coordinates
|
||||||
import io, base64
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
|
|
||||||
# COORD inputs override widget values
|
# COORD inputs override widget values
|
||||||
if point_a is not None:
|
if point_a is not None:
|
||||||
@@ -453,14 +412,9 @@ class CrossSection:
|
|||||||
|
|
||||||
# Broadcast overlay image with marker positions
|
# Broadcast overlay image with marker positions
|
||||||
if CrossSection._broadcast_overlay_fn is not None:
|
if CrossSection._broadcast_overlay_fn is not None:
|
||||||
fig = Figure(figsize=(3, 3), dpi=100)
|
# Use the field's native pixel grid for the overlay preview so enlarging
|
||||||
ax = fig.add_axes([0, 0, 1, 1])
|
# the panel keeps the image as sharp as the source data allows.
|
||||||
ax.imshow(field.data, cmap="viridis", aspect="auto")
|
image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
|
||||||
ax.axis("off")
|
|
||||||
buf = io.BytesIO()
|
|
||||||
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
||||||
buf.seek(0)
|
|
||||||
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
|
|
||||||
|
|
||||||
CrossSection._broadcast_overlay_fn(
|
CrossSection._broadcast_overlay_fn(
|
||||||
CrossSection._current_node_id,
|
CrossSection._current_node_id,
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class View3D:
|
|||||||
"required": {
|
"required": {
|
||||||
"field": ("DATA_FIELD",),
|
"field": ("DATA_FIELD",),
|
||||||
"colormap": (["auto"] + list(COLORMAPS),),
|
"colormap": (["auto"] + list(COLORMAPS),),
|
||||||
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
|
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
|
||||||
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
|
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ class View3D:
|
|||||||
"colors": colors_b64,
|
"colors": colors_b64,
|
||||||
"z_min": zmin,
|
"z_min": zmin,
|
||||||
"z_max": zmax,
|
"z_max": zmax,
|
||||||
"z_scale": float(z_scale),
|
"z_scale": float(z_scale * 0.1),
|
||||||
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
|
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
|
||||||
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
|
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -399,54 +399,86 @@ class Coordinate:
|
|||||||
# SaveImage
|
# SaveImage
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@register_node(display_name="Save Image")
|
_MAX_SAVE_FIELDS = 8
|
||||||
|
|
||||||
|
@register_node(display_name="Save Layers")
|
||||||
class SaveImage:
|
class SaveImage:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
|
optional = {}
|
||||||
|
for i in range(_MAX_SAVE_FIELDS):
|
||||||
|
optional[f"field_{i}"] = ("DATA_FIELD",)
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"image": ("IMAGE",),
|
"filename": ("FILE_PICKER", {"default": ""}),
|
||||||
"filename_prefix": ("STRING", {"default": "output"}),
|
"format": (["TIFF", "NPZ"],),
|
||||||
"format": (["PNG", "TIFF", "NPY"],),
|
},
|
||||||
}
|
"optional": optional,
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
RETURN_TYPES = ()
|
||||||
FUNCTION = "save"
|
FUNCTION = "save"
|
||||||
CATEGORY = "io"
|
CATEGORY = "io"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
DESCRIPTION = "Save an image or array to the output folder."
|
MANUAL_TRIGGER = True
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Save one or more DATA_FIELD layers to a single file. "
|
||||||
|
"Connect fields to the inputs — a new slot appears as each is filled. "
|
||||||
|
"TIFF writes float32 multi-page; NPZ writes float64 named arrays. "
|
||||||
|
"Click Save to write (does not auto-run)."
|
||||||
|
)
|
||||||
|
|
||||||
# Injected by server.py before execution begins
|
_broadcast_warning_fn = None
|
||||||
_broadcast_preview = None
|
_current_node_id = None
|
||||||
|
|
||||||
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"):
|
def save(self, filename: str, format: str = "TIFF", **kwargs):
|
||||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
# Collect connected fields in order
|
||||||
|
fields = []
|
||||||
|
for i in range(_MAX_SAVE_FIELDS):
|
||||||
|
f = kwargs.get(f"field_{i}")
|
||||||
|
if f is not None:
|
||||||
|
fields.append(f)
|
||||||
|
|
||||||
# Find next available filename
|
if not fields:
|
||||||
idx = 1
|
raise ValueError("No fields connected — connect at least one DATA_FIELD input.")
|
||||||
while True:
|
|
||||||
name = f"{filename_prefix}_{idx:04d}"
|
|
||||||
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
|
|
||||||
if not candidate.exists():
|
|
||||||
break
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
if format == "NPY":
|
if not filename or not filename.strip():
|
||||||
np.save(str(OUTPUT_DIR / f"{name}.npy"), image)
|
raise ValueError("No output path selected — use Browse to pick a location.")
|
||||||
|
|
||||||
|
path = Path(filename)
|
||||||
|
# Ensure parent directory exists
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Force correct extension
|
||||||
|
ext = ".tiff" if format == "TIFF" else ".npz"
|
||||||
|
if path.suffix.lower() != ext:
|
||||||
|
path = path.with_suffix(ext)
|
||||||
|
|
||||||
|
if format == "TIFF":
|
||||||
|
self._save_tiff(path, fields)
|
||||||
else:
|
else:
|
||||||
from PIL import Image
|
self._save_npz(path, fields)
|
||||||
arr = image_to_uint8(image)
|
|
||||||
if arr.ndim == 2:
|
|
||||||
pil_img = Image.fromarray(arr, mode="L")
|
|
||||||
else:
|
|
||||||
pil_img = Image.fromarray(arr, mode="RGB")
|
|
||||||
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
|
|
||||||
|
|
||||||
# Emit preview over WebSocket if callback is set
|
self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}")
|
||||||
if SaveImage._broadcast_preview is not None:
|
return ()
|
||||||
arr_u8 = image_to_uint8(image)
|
|
||||||
data_uri = encode_preview(arr_u8)
|
def _save_tiff(self, path: Path, fields: list[DataField]):
|
||||||
SaveImage._broadcast_preview(data_uri)
|
from PIL import Image
|
||||||
|
images = []
|
||||||
|
for f in fields:
|
||||||
|
images.append(Image.fromarray(f.data.astype(np.float32)))
|
||||||
|
images[0].save(str(path), save_all=True, append_images=images[1:])
|
||||||
|
|
||||||
|
def _save_npz(self, path: Path, fields: list[DataField]):
|
||||||
|
arrays = {}
|
||||||
|
for i, f in enumerate(fields):
|
||||||
|
arrays[f"layer_{i}"] = f.data
|
||||||
|
np.savez(str(path), **arrays)
|
||||||
|
|
||||||
|
def _send_warning(self, message: str):
|
||||||
|
fn = SaveImage._broadcast_warning_fn
|
||||||
|
nid = SaveImage._current_node_id
|
||||||
|
if fn and nid:
|
||||||
|
fn(nid, message)
|
||||||
|
|
||||||
return ()
|
return ()
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ FRONTEND_DIR = frontend_dir()
|
|||||||
DIST_DIR = frontend_dist_dir()
|
DIST_DIR = frontend_dist_dir()
|
||||||
INPUT_DIR = input_dir()
|
INPUT_DIR = input_dir()
|
||||||
OUTPUT_DIR = output_dir()
|
OUTPUT_DIR = output_dir()
|
||||||
|
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -63,6 +64,18 @@ def _dumps(obj) -> str:
|
|||||||
return json.dumps(obj, cls=_SafeEncoder)
|
return json.dumps(obj, cls=_SafeEncoder)
|
||||||
|
|
||||||
|
|
||||||
|
def save_png_bytes(target_path: str, payload: bytes) -> Path:
|
||||||
|
path = Path(target_path).expanduser()
|
||||||
|
if not target_path.strip():
|
||||||
|
raise ValueError("Missing save path")
|
||||||
|
if path.suffix.lower() != ".png":
|
||||||
|
path = path.with_suffix(".png")
|
||||||
|
if not payload.startswith(PNG_SIGNATURE):
|
||||||
|
raise ValueError("Payload is not a valid PNG")
|
||||||
|
path.write_bytes(payload)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Application factory
|
# Application factory
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -196,6 +209,20 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def save_workflow_png(request: web.Request) -> web.Response:
|
||||||
|
body = await request.read()
|
||||||
|
target_path = request.query.get("path", "")
|
||||||
|
if not target_path:
|
||||||
|
raise web.HTTPBadRequest(reason="Missing path")
|
||||||
|
try:
|
||||||
|
saved_path = save_png_bytes(target_path, body)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise web.HTTPBadRequest(reason=str(exc)) from exc
|
||||||
|
return web.Response(
|
||||||
|
text=_dumps({"path": str(saved_path)}),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
async def get_channels(request: web.Request) -> web.Response:
|
async def get_channels(request: web.Request) -> web.Response:
|
||||||
"""Return available channels for a given file path."""
|
"""Return available channels for a given file path."""
|
||||||
from backend.nodes.io import list_channels
|
from backend.nodes.io import list_channels
|
||||||
@@ -278,6 +305,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
|||||||
app.router.add_get("/browse", browse_dir)
|
app.router.add_get("/browse", browse_dir)
|
||||||
app.router.add_post("/upload", upload_file)
|
app.router.add_post("/upload", upload_file)
|
||||||
app.router.add_post("/download", download_file)
|
app.router.add_post("/download", download_file)
|
||||||
|
app.router.add_post("/save-workflow-png", save_workflow_png)
|
||||||
app.router.add_get("/channels", get_channels)
|
app.router.add_get("/channels", get_channels)
|
||||||
app.router.add_post("/prompt", submit_prompt)
|
app.router.add_post("/prompt", submit_prompt)
|
||||||
app.router.add_get("/ws", websocket_handler)
|
app.router.add_get("/ws", websocket_handler)
|
||||||
|
|||||||
11
desktop.py
11
desktop.py
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import threading
|
import threading
|
||||||
@@ -45,8 +44,8 @@ class _Api:
|
|||||||
return result[0]
|
return result[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def save_workflow_png(self, data_url: str, default_filename: str = "workflow.png") -> str | None:
|
def choose_save_workflow_png_path(self, default_filename: str = "workflow.png") -> str | None:
|
||||||
"""Open a native save dialog, write the PNG bytes, and return the saved path."""
|
"""Open a native save dialog and return the chosen PNG path (or None)."""
|
||||||
win = self._window_ref[0]
|
win = self._window_ref[0]
|
||||||
if win is None:
|
if win is None:
|
||||||
return None
|
return None
|
||||||
@@ -65,12 +64,6 @@ class _Api:
|
|||||||
path = Path(result[0] if isinstance(result, (list, tuple)) else result).expanduser()
|
path = Path(result[0] if isinstance(result, (list, tuple)) else result).expanduser()
|
||||||
if path.suffix.lower() != ".png":
|
if path.suffix.lower() != ".png":
|
||||||
path = path.with_suffix(".png")
|
path = path.with_suffix(".png")
|
||||||
|
|
||||||
_, _, encoded = data_url.partition(",")
|
|
||||||
if not encoded:
|
|
||||||
raise ValueError("Invalid data URL payload")
|
|
||||||
|
|
||||||
path.write_bytes(base64.b64decode(encoded))
|
|
||||||
return str(path)
|
return str(path)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import FileBrowser from './FileBrowser';
|
|||||||
import * as api from './api';
|
import * as api from './api';
|
||||||
import { toBlob } from 'html-to-image';
|
import { toBlob } from 'html-to-image';
|
||||||
import { embedWorkflow, extractWorkflow } from './pngMetadata';
|
import { embedWorkflow, extractWorkflow } from './pngMetadata';
|
||||||
|
import { hydrateWorkflowState } from './workflowHydration';
|
||||||
import { serializeWorkflowState } from './workflowSerialization';
|
import { serializeWorkflowState } from './workflowSerialization';
|
||||||
|
|
||||||
// ── Constants ─────────────────────────────────────────────────────────
|
// ── Constants ─────────────────────────────────────────────────────────
|
||||||
@@ -43,15 +44,6 @@ function getOutputSlot(handleId) {
|
|||||||
return parseInt(handleId.split('::')[1], 10);
|
return parseInt(handleId.split('::')[1], 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
function blobToDataUrl(blob) {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
const reader = new FileReader();
|
|
||||||
reader.onloadend = () => resolve(reader.result);
|
|
||||||
reader.onerror = () => reject(reader.error || new Error('Failed to read file'));
|
|
||||||
reader.readAsDataURL(blob);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function waitForImageElement(img) {
|
async function waitForImageElement(img) {
|
||||||
if (img.complete && img.naturalWidth > 0) return;
|
if (img.complete && img.naturalWidth > 0) return;
|
||||||
if (typeof img.decode === 'function') {
|
if (typeof img.decode === 'function') {
|
||||||
@@ -73,6 +65,31 @@ async function waitForImageElement(img) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function getCaptureImageDataUrl(img) {
|
||||||
|
const src = img.currentSrc || img.src;
|
||||||
|
if (!src) return null;
|
||||||
|
if (!src.startsWith('data:')) return src;
|
||||||
|
|
||||||
|
const rect = img.getBoundingClientRect();
|
||||||
|
const width = Math.max(1, Math.round(img.clientWidth || rect.width));
|
||||||
|
const height = Math.max(1, Math.round(img.clientHeight || rect.height));
|
||||||
|
const scale = Math.min(2, window.devicePixelRatio || 1);
|
||||||
|
|
||||||
|
const canvas = document.createElement('canvas');
|
||||||
|
canvas.width = Math.max(1, Math.round(width * scale));
|
||||||
|
canvas.height = Math.max(1, Math.round(height * scale));
|
||||||
|
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
if (!ctx) return src;
|
||||||
|
|
||||||
|
try {
|
||||||
|
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
|
||||||
|
return canvas.toDataURL('image/png');
|
||||||
|
} catch {
|
||||||
|
return src;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function createCapturePlaceholder(el, dataUrl) {
|
function createCapturePlaceholder(el, dataUrl) {
|
||||||
const rect = el.getBoundingClientRect();
|
const rect = el.getBoundingClientRect();
|
||||||
const style = window.getComputedStyle(el);
|
const style = window.getComputedStyle(el);
|
||||||
@@ -101,8 +118,9 @@ async function captureViewportBlob(viewportEl, options) {
|
|||||||
await Promise.all(images.map(waitForImageElement));
|
await Promise.all(images.map(waitForImageElement));
|
||||||
|
|
||||||
for (const img of images) {
|
for (const img of images) {
|
||||||
const dataUrl = img.currentSrc || img.src;
|
if (!img.parentNode) continue;
|
||||||
if (!dataUrl || !img.parentNode) continue;
|
const dataUrl = await getCaptureImageDataUrl(img);
|
||||||
|
if (!dataUrl) continue;
|
||||||
const placeholder = createCapturePlaceholder(img, dataUrl);
|
const placeholder = createCapturePlaceholder(img, dataUrl);
|
||||||
img.parentNode.replaceChild(placeholder, img);
|
img.parentNode.replaceChild(placeholder, img);
|
||||||
restorers.push(() => {
|
restorers.push(() => {
|
||||||
@@ -144,12 +162,13 @@ async function captureViewportBlob(viewportEl, options) {
|
|||||||
|
|
||||||
// ── Graph serialisation → backend prompt format ───────────────────────
|
// ── Graph serialisation → backend prompt format ───────────────────────
|
||||||
|
|
||||||
function serializeGraph(nodes, edges) {
|
function serializeGraph(nodes, edges, { excludeManualTrigger = false } = {}) {
|
||||||
const prompt = {};
|
const prompt = {};
|
||||||
|
|
||||||
for (const node of nodes) {
|
for (const node of nodes) {
|
||||||
const { className, definition, widgetValues } = node.data;
|
const { className, definition, widgetValues } = node.data;
|
||||||
if (!definition) continue;
|
if (!definition) continue;
|
||||||
|
if (excludeManualTrigger && definition.manual_trigger) continue;
|
||||||
|
|
||||||
const inputs = {};
|
const inputs = {};
|
||||||
|
|
||||||
@@ -551,10 +570,23 @@ function Flow() {
|
|||||||
|
|
||||||
// ── Node context value (stable) ─────────────────────────────────────
|
// ── Node context value (stable) ─────────────────────────────────────
|
||||||
|
|
||||||
|
const onManualTrigger = useCallback((nodeId) => {
|
||||||
|
const currentNodes = reactFlow.getNodes();
|
||||||
|
const currentEdges = reactFlow.getEdges();
|
||||||
|
// Include ALL nodes (no excludeManualTrigger) so the save node is in the prompt
|
||||||
|
const prompt = serializeGraph(currentNodes, currentEdges);
|
||||||
|
if (!prompt || Object.keys(prompt).length === 0) return;
|
||||||
|
setStatus({ text: 'Saving…', level: 'info' });
|
||||||
|
api.runPrompt(prompt).catch((err) => {
|
||||||
|
setStatus({ text: 'Save failed: ' + err.message, level: 'error' });
|
||||||
|
});
|
||||||
|
}, [reactFlow]);
|
||||||
|
|
||||||
const contextValue = useMemo(() => ({
|
const contextValue = useMemo(() => ({
|
||||||
onWidgetChange,
|
onWidgetChange,
|
||||||
openFileBrowser,
|
openFileBrowser,
|
||||||
}), [onWidgetChange, openFileBrowser]);
|
onManualTrigger,
|
||||||
|
}), [onWidgetChange, openFileBrowser, onManualTrigger]);
|
||||||
|
|
||||||
// ── Add node from context menu ──────────────────────────────────────
|
// ── Add node from context menu ──────────────────────────────────────
|
||||||
|
|
||||||
@@ -687,13 +719,18 @@ function Flow() {
|
|||||||
const currentNodes = reactFlow.getNodes();
|
const currentNodes = reactFlow.getNodes();
|
||||||
const currentEdges = reactFlow.getEdges();
|
const currentEdges = reactFlow.getEdges();
|
||||||
|
|
||||||
// Don't run if any node has unconnected required data inputs
|
// Don't run if any non-manual node has unconnected required data inputs
|
||||||
|
// or any FILE_PICKER widget is empty
|
||||||
for (const node of currentNodes) {
|
for (const node of currentNodes) {
|
||||||
const def = node.data?.definition;
|
const def = node.data?.definition;
|
||||||
if (!def) continue;
|
if (!def || def.manual_trigger) continue; // skip manual-trigger nodes
|
||||||
const required = def.input.required || {};
|
const required = def.input.required || {};
|
||||||
for (const [name, spec] of Object.entries(required)) {
|
for (const [name, spec] of Object.entries(required)) {
|
||||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||||
|
if (type === 'FILE_PICKER') {
|
||||||
|
if (!node.data.widgetValues?.[name]) return; // no file selected, skip
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (!DATA_TYPES.has(type)) continue;
|
if (!DATA_TYPES.has(type)) continue;
|
||||||
const hasEdge = currentEdges.some(
|
const hasEdge = currentEdges.some(
|
||||||
(e) => e.target === node.id && getInputName(e.targetHandle) === name
|
(e) => e.target === node.id && getInputName(e.targetHandle) === name
|
||||||
@@ -702,7 +739,7 @@ function Flow() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const prompt = serializeGraph(currentNodes, currentEdges);
|
const prompt = serializeGraph(currentNodes, currentEdges, { excludeManualTrigger: true });
|
||||||
if (!prompt || Object.keys(prompt).length === 0) return;
|
if (!prompt || Object.keys(prompt).length === 0) return;
|
||||||
setStatus({ text: 'Running…', level: 'info' });
|
setStatus({ text: 'Running…', level: 'info' });
|
||||||
api.runPrompt(prompt).catch((err) => {
|
api.runPrompt(prompt).catch((err) => {
|
||||||
@@ -723,25 +760,10 @@ function Flow() {
|
|||||||
}, [setNodes, setEdges]);
|
}, [setNodes, setEdges]);
|
||||||
|
|
||||||
const applyWorkflowData = useCallback((data) => {
|
const applyWorkflowData = useCallback((data) => {
|
||||||
const loadedNodes = data.nodes || [];
|
const hydrated = hydrateWorkflowState(data, nodeDefsRef.current);
|
||||||
const loadedEdges = data.edges || [];
|
setNodes(hydrated.nodes);
|
||||||
const defs = nodeDefsRef.current;
|
setEdges(hydrated.edges);
|
||||||
const hydrated = loadedNodes.map((n) => ({
|
nextIdRef.current = hydrated.nextNodeId;
|
||||||
...n,
|
|
||||||
type: n.type || 'custom',
|
|
||||||
dragHandle: n.dragHandle || '.drag-handle',
|
|
||||||
data: {
|
|
||||||
...n.data,
|
|
||||||
label: n.data?.label || n.data?.className || 'Node',
|
|
||||||
widgetValues: n.data?.widgetValues || {},
|
|
||||||
definition: defs[n.data.className] || n.data.definition,
|
|
||||||
previewImage: null, tableRows: null, meshData: null, overlay: null,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
setNodes(hydrated);
|
|
||||||
setEdges(loadedEdges);
|
|
||||||
const maxId = Math.max(0, ...loadedNodes.map((n) => parseInt(n.id, 10) || 0));
|
|
||||||
nextIdRef.current = maxId + 1;
|
|
||||||
}, [setNodes, setEdges]);
|
}, [setNodes, setEdges]);
|
||||||
|
|
||||||
const getWorkflowBlob = useCallback(async () => {
|
const getWorkflowBlob = useCallback(async () => {
|
||||||
@@ -778,9 +800,23 @@ function Flow() {
|
|||||||
try {
|
try {
|
||||||
const finalBlob = await getWorkflowBlob();
|
const finalBlob = await getWorkflowBlob();
|
||||||
|
|
||||||
if (window.pywebview?.api?.save_workflow_png) {
|
if (window.pywebview?.api?.choose_save_workflow_png_path) {
|
||||||
const dataUrl = await blobToDataUrl(finalBlob);
|
const requestedPath = await window.pywebview.api.choose_save_workflow_png_path('workflow.png');
|
||||||
const savedPath = await window.pywebview.api.save_workflow_png(dataUrl, 'workflow.png');
|
if (!requestedPath) {
|
||||||
|
setStatus({ text: 'Save cancelled.', level: 'info' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const resp = await fetch(`/save-workflow-png?path=${encodeURIComponent(requestedPath)}`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'image/png',
|
||||||
|
},
|
||||||
|
body: finalBlob,
|
||||||
|
});
|
||||||
|
if (!resp.ok) {
|
||||||
|
throw new Error(await resp.text() || `Save failed (${resp.status})`);
|
||||||
|
}
|
||||||
|
const { path: savedPath } = await resp.json();
|
||||||
if (!savedPath) {
|
if (!savedPath) {
|
||||||
setStatus({ text: 'Save cancelled.', level: 'info' });
|
setStatus({ text: 'Save cancelled.', level: 'info' });
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react';
|
import React, { useContext, useRef, useCallback, useState, memo, lazy, Suspense } from 'react';
|
||||||
import { Handle, Position } from '@xyflow/react';
|
import { Handle, Position, useStore } from '@xyflow/react';
|
||||||
|
import LinePlotOverlay from './LinePlotOverlay';
|
||||||
|
|
||||||
const SurfaceView = lazy(() => import('./SurfaceView'));
|
const SurfaceView = lazy(() => import('./SurfaceView'));
|
||||||
const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay'));
|
const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay'));
|
||||||
@@ -29,6 +30,47 @@ const CAT_COLORS = {
|
|||||||
|
|
||||||
export const NodeContext = React.createContext(null);
|
export const NodeContext = React.createContext(null);
|
||||||
|
|
||||||
|
class PreviewBoundary extends React.Component {
|
||||||
|
constructor(props) {
|
||||||
|
super(props);
|
||||||
|
this.state = { hasError: false };
|
||||||
|
}
|
||||||
|
|
||||||
|
static getDerivedStateFromError() {
|
||||||
|
return { hasError: true };
|
||||||
|
}
|
||||||
|
|
||||||
|
componentDidCatch(error) {
|
||||||
|
console.error('[argonode] preview render failed', error);
|
||||||
|
}
|
||||||
|
|
||||||
|
componentDidUpdate(prevProps) {
|
||||||
|
if (prevProps.resetKey !== this.props.resetKey && this.state.hasError) {
|
||||||
|
this.setState({ hasError: false });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
render() {
|
||||||
|
if (!this.state.hasError) {
|
||||||
|
return this.props.children;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.props.fallbackImage) {
|
||||||
|
return (
|
||||||
|
<div className="node-preview">
|
||||||
|
<img src={this.props.fallbackImage} alt="preview fallback" draggable={false} />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="node-preview" style={{ color: '#94a3b8', padding: 8 }}>
|
||||||
|
Preview unavailable.
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── Draggable number input ────────────────────────────────────────────
|
// ── Draggable number input ────────────────────────────────────────────
|
||||||
|
|
||||||
function DraggableNumber({ value, step, min, max, precision, onChange }) {
|
function DraggableNumber({ value, step, min, max, precision, onChange }) {
|
||||||
@@ -151,8 +193,39 @@ function CustomNode({ id, data }) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For manual-trigger nodes (Save), show progressive optional inputs:
|
||||||
|
// show field_N only if field_(N-1) is connected (or N==0).
|
||||||
|
const isProgressive = def.manual_trigger;
|
||||||
|
const connectedInputs = useStore(
|
||||||
|
useCallback(
|
||||||
|
(s) => {
|
||||||
|
if (!isProgressive) return null;
|
||||||
|
const set = new Set();
|
||||||
|
for (const e of s.edges) {
|
||||||
|
if (e.target === id) {
|
||||||
|
const parts = e.targetHandle?.split('::');
|
||||||
|
if (parts) set.add(parts[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return set;
|
||||||
|
},
|
||||||
|
[id, isProgressive],
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
for (const [name, spec] of Object.entries(optional)) {
|
for (const [name, spec] of Object.entries(optional)) {
|
||||||
const [type] = Array.isArray(spec) ? spec : [spec];
|
const [type] = Array.isArray(spec) ? spec : [spec];
|
||||||
|
if (isProgressive && DATA_TYPES.has(type)) {
|
||||||
|
// Progressive: show this slot only if it's the first or the previous is connected
|
||||||
|
const match = name.match(/^field_(\d+)$/);
|
||||||
|
if (match) {
|
||||||
|
const idx = parseInt(match[1], 10);
|
||||||
|
if (idx === 0 || (connectedInputs && connectedInputs.has(`field_${idx - 1}`))) {
|
||||||
|
dataInputs.push({ name, type });
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
dataInputs.push({ name, type });
|
dataInputs.push({ name, type });
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,6 +302,19 @@ function CustomNode({ id, data }) {
|
|||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
|
|
||||||
|
{/* Manual trigger button (Save) */}
|
||||||
|
{def.manual_trigger && (
|
||||||
|
<div className="widget-row">
|
||||||
|
<button
|
||||||
|
className="nodrag btn btn-primary"
|
||||||
|
style={{ flex: 1 }}
|
||||||
|
onClick={() => ctx.onManualTrigger?.(id)}
|
||||||
|
>
|
||||||
|
Save to Disk
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Interactive 3D surface view */}
|
{/* Interactive 3D surface view */}
|
||||||
{data.meshData && (
|
{data.meshData && (
|
||||||
<CollapsibleSection title="3D View" defaultOpen={true}>
|
<CollapsibleSection title="3D View" defaultOpen={true}>
|
||||||
@@ -241,9 +327,21 @@ function CustomNode({ id, data }) {
|
|||||||
{/* Collapsible preview image */}
|
{/* Collapsible preview image */}
|
||||||
{data.previewImage && (
|
{data.previewImage && (
|
||||||
<CollapsibleSection title="Preview" defaultOpen={true}>
|
<CollapsibleSection title="Preview" defaultOpen={true}>
|
||||||
<div className="node-preview">
|
<PreviewBoundary
|
||||||
<img src={data.previewImage} alt="preview" draggable={false} />
|
resetKey={typeof data.previewImage === 'string' ? data.previewImage : JSON.stringify({
|
||||||
</div>
|
kind: data.previewImage.kind,
|
||||||
|
len: data.previewImage.line?.length,
|
||||||
|
})}
|
||||||
|
fallbackImage={typeof data.previewImage === 'object' ? data.previewImage.fallback_image : null}
|
||||||
|
>
|
||||||
|
{typeof data.previewImage === 'string' ? (
|
||||||
|
<div className="node-preview">
|
||||||
|
<img src={data.previewImage} alt="preview" draggable={false} />
|
||||||
|
</div>
|
||||||
|
) : data.previewImage.kind === 'line_plot' ? (
|
||||||
|
<LinePlotOverlay overlay={data.previewImage} interactive={false} />
|
||||||
|
) : null}
|
||||||
|
</PreviewBoundary>
|
||||||
</CollapsibleSection>
|
</CollapsibleSection>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
@@ -251,17 +349,29 @@ function CustomNode({ id, data }) {
|
|||||||
{data.overlay && hiddenWidgets.has('x1') && (
|
{data.overlay && hiddenWidgets.has('x1') && (
|
||||||
<CollapsibleSection title="Cross Section" defaultOpen={true}>
|
<CollapsibleSection title="Cross Section" defaultOpen={true}>
|
||||||
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading...</div>}>
|
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading...</div>}>
|
||||||
<CrossSectionOverlay
|
{data.overlay.kind === 'line_plot' ? (
|
||||||
image={data.overlay.image}
|
<LinePlotOverlay
|
||||||
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
|
overlay={data.overlay}
|
||||||
y1={data.overlay.a_locked ? data.overlay.y1 : (data.widgetValues.y1 ?? data.overlay.y1)}
|
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
|
||||||
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
|
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
|
||||||
y2={data.overlay.b_locked ? data.overlay.y2 : (data.widgetValues.y2 ?? data.overlay.y2)}
|
aLocked={data.overlay.a_locked}
|
||||||
aLocked={data.overlay.a_locked}
|
bLocked={data.overlay.b_locked}
|
||||||
bLocked={data.overlay.b_locked}
|
nodeId={id}
|
||||||
nodeId={id}
|
onWidgetChange={ctx.onWidgetChange}
|
||||||
onWidgetChange={ctx.onWidgetChange}
|
/>
|
||||||
/>
|
) : (
|
||||||
|
<CrossSectionOverlay
|
||||||
|
image={data.overlay.image}
|
||||||
|
x1={data.overlay.a_locked ? data.overlay.x1 : (data.widgetValues.x1 ?? data.overlay.x1)}
|
||||||
|
y1={data.overlay.a_locked ? data.overlay.y1 : (data.widgetValues.y1 ?? data.overlay.y1)}
|
||||||
|
x2={data.overlay.b_locked ? data.overlay.x2 : (data.widgetValues.x2 ?? data.overlay.x2)}
|
||||||
|
y2={data.overlay.b_locked ? data.overlay.y2 : (data.widgetValues.y2 ?? data.overlay.y2)}
|
||||||
|
aLocked={data.overlay.a_locked}
|
||||||
|
bLocked={data.overlay.b_locked}
|
||||||
|
nodeId={id}
|
||||||
|
onWidgetChange={ctx.onWidgetChange}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</Suspense>
|
</Suspense>
|
||||||
</CollapsibleSection>
|
</CollapsibleSection>
|
||||||
)}
|
)}
|
||||||
|
|||||||
271
frontend/src/LinePlotOverlay.jsx
Normal file
271
frontend/src/LinePlotOverlay.jsx
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
import React, { useEffect, useRef, useState, useCallback } from 'react';
|
||||||
|
|
||||||
|
const ASPECT_RATIO = 3.2 / 2.2;
|
||||||
|
const MARGINS = { top: 18, right: 16, bottom: 34, left: 56 };
|
||||||
|
|
||||||
|
function clamp(v, min, max) {
|
||||||
|
return Math.max(min, Math.min(max, v));
|
||||||
|
}
|
||||||
|
|
||||||
|
function round3(v) {
|
||||||
|
return parseFloat(v.toFixed(3));
|
||||||
|
}
|
||||||
|
|
||||||
|
function trimZeros(text) {
|
||||||
|
return text.replace(/(?:\.0+|(\.\d+?)0+)$/, '$1');
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatTick(value) {
|
||||||
|
const abs = Math.abs(value);
|
||||||
|
if (abs === 0) return '0';
|
||||||
|
if (abs >= 1e4 || abs < 1e-3) {
|
||||||
|
return value.toExponential(1).replace('e+', 'e');
|
||||||
|
}
|
||||||
|
if (abs >= 100) return trimZeros(value.toFixed(0));
|
||||||
|
if (abs >= 10) return trimZeros(value.toFixed(1));
|
||||||
|
if (abs >= 1) return trimZeros(value.toFixed(2));
|
||||||
|
return trimZeros(value.toFixed(3));
|
||||||
|
}
|
||||||
|
|
||||||
|
function makeTicks(min, max, count = 5) {
|
||||||
|
if (!Number.isFinite(min) || !Number.isFinite(max)) return [];
|
||||||
|
if (min === max) return [min];
|
||||||
|
const ticks = [];
|
||||||
|
for (let i = 0; i < count; i += 1) {
|
||||||
|
ticks.push(min + ((max - min) * i) / (count - 1));
|
||||||
|
}
|
||||||
|
return ticks;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getExtent(values, fallbackMin = 0, fallbackMax = 1) {
|
||||||
|
if (!Array.isArray(values) || values.length === 0) {
|
||||||
|
return [fallbackMin, fallbackMax];
|
||||||
|
}
|
||||||
|
|
||||||
|
let min = Infinity;
|
||||||
|
let max = -Infinity;
|
||||||
|
for (const value of values) {
|
||||||
|
if (!Number.isFinite(value)) continue;
|
||||||
|
if (value < min) min = value;
|
||||||
|
if (value > max) max = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Number.isFinite(min) || !Number.isFinite(max)) {
|
||||||
|
return [fallbackMin, fallbackMax];
|
||||||
|
}
|
||||||
|
return [min, max];
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function LinePlotOverlay({
|
||||||
|
overlay,
|
||||||
|
x1,
|
||||||
|
x2,
|
||||||
|
aLocked,
|
||||||
|
bLocked,
|
||||||
|
nodeId,
|
||||||
|
onWidgetChange,
|
||||||
|
interactive = true,
|
||||||
|
}) {
|
||||||
|
const containerRef = useRef(null);
|
||||||
|
const [dragging, setDragging] = useState(null);
|
||||||
|
const [size, setSize] = useState({ width: 0, height: 0 });
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!containerRef.current) return undefined;
|
||||||
|
const updateSize = () => {
|
||||||
|
if (!containerRef.current) return;
|
||||||
|
setSize({
|
||||||
|
width: Math.max(1, Math.round(containerRef.current.clientWidth || 320)),
|
||||||
|
height: Math.max(1, Math.round(containerRef.current.clientHeight || (containerRef.current.clientWidth / ASPECT_RATIO) || 220)),
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
updateSize();
|
||||||
|
|
||||||
|
if (typeof ResizeObserver === 'function') {
|
||||||
|
const observer = new ResizeObserver((entries) => {
|
||||||
|
const entry = entries[0];
|
||||||
|
if (!entry) return;
|
||||||
|
const { width, height } = entry.contentRect;
|
||||||
|
setSize({
|
||||||
|
width: Math.max(1, Math.round(width)),
|
||||||
|
height: Math.max(1, Math.round(height)),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
observer.observe(containerRef.current);
|
||||||
|
return () => observer.disconnect();
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener('resize', updateSize);
|
||||||
|
return () => window.removeEventListener('resize', updateSize);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const xValues = Array.isArray(overlay?.x_axis) && overlay.x_axis.length === overlay.line?.length
|
||||||
|
? overlay.x_axis
|
||||||
|
: overlay?.line?.map((_, i) => i) || [];
|
||||||
|
const yValues = Array.isArray(overlay?.line) ? overlay.line : [];
|
||||||
|
|
||||||
|
const width = size.width || 320;
|
||||||
|
const height = size.height || Math.round(width / ASPECT_RATIO);
|
||||||
|
const plotLeft = MARGINS.left;
|
||||||
|
const plotTop = MARGINS.top;
|
||||||
|
const plotWidth = Math.max(1, width - MARGINS.left - MARGINS.right);
|
||||||
|
const plotHeight = Math.max(1, height - MARGINS.top - MARGINS.bottom);
|
||||||
|
|
||||||
|
const [xMin, xMax] = getExtent(xValues, 0, 1);
|
||||||
|
const [yMinRaw, yMaxRaw] = getExtent(yValues, 0, 1);
|
||||||
|
const yPad = yMinRaw === yMaxRaw ? 1 : (yMaxRaw - yMinRaw) * 0.08;
|
||||||
|
const yMin = yMinRaw - yPad;
|
||||||
|
const yMax = yMaxRaw + yPad;
|
||||||
|
|
||||||
|
const scaleX = useCallback((value) => {
|
||||||
|
if (xMax === xMin) return plotLeft + plotWidth / 2;
|
||||||
|
return plotLeft + ((value - xMin) / (xMax - xMin)) * plotWidth;
|
||||||
|
}, [plotLeft, plotWidth, xMin, xMax]);
|
||||||
|
|
||||||
|
const scaleY = useCallback((value) => {
|
||||||
|
if (yMax === yMin) return plotTop + plotHeight / 2;
|
||||||
|
return plotTop + (1 - ((value - yMin) / (yMax - yMin))) * plotHeight;
|
||||||
|
}, [plotTop, plotHeight, yMin, yMax]);
|
||||||
|
|
||||||
|
const pickCursorPoint = useCallback((fraction) => {
|
||||||
|
if (!xValues.length || !yValues.length) {
|
||||||
|
return {
|
||||||
|
x: plotLeft,
|
||||||
|
y: plotTop + plotHeight / 2,
|
||||||
|
yFraction: 0.5,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const frac = clamp(fraction ?? 0.5, 0, 1);
|
||||||
|
const targetX = xMin + frac * (xMax - xMin || 1);
|
||||||
|
|
||||||
|
let idx = 0;
|
||||||
|
let best = Infinity;
|
||||||
|
for (let i = 0; i < xValues.length; i += 1) {
|
||||||
|
const dist = Math.abs(xValues[i] - targetX);
|
||||||
|
if (dist < best) {
|
||||||
|
best = dist;
|
||||||
|
idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const x = xValues[idx];
|
||||||
|
const y = yValues[idx];
|
||||||
|
const yFraction = yMax === yMin ? 0.5 : clamp((y - yMin) / (yMax - yMin), 0, 1);
|
||||||
|
return {
|
||||||
|
x: scaleX(x),
|
||||||
|
y: scaleY(y),
|
||||||
|
yFraction,
|
||||||
|
};
|
||||||
|
}, [plotLeft, plotTop, plotHeight, scaleX, scaleY, xValues, yValues, xMin, xMax, yMin, yMax]);
|
||||||
|
|
||||||
|
const cursorA = pickCursorPoint(x1 ?? overlay?.x1 ?? 0.25);
|
||||||
|
const cursorB = pickCursorPoint(x2 ?? overlay?.x2 ?? 0.75);
|
||||||
|
|
||||||
|
const path = yValues.map((y, i) => `${i === 0 ? 'M' : 'L'} ${scaleX(xValues[i])} ${scaleY(y)}`).join(' ');
|
||||||
|
const xTicks = makeTicks(xMin, xMax);
|
||||||
|
const yTicks = makeTicks(yMin, yMax);
|
||||||
|
const plotStroke = clamp(plotWidth / 240, 1.4, 2.6);
|
||||||
|
const gridStroke = clamp(plotWidth / 900, 0.6, 1.1);
|
||||||
|
const cursorStroke = clamp(plotWidth / 220, 1.4, 2.2);
|
||||||
|
const measureStroke = clamp(plotWidth / 180, 1.6, 2.8);
|
||||||
|
const markerRadius = clamp(plotWidth / 42, 5.5, 9);
|
||||||
|
|
||||||
|
const updateCursor = useCallback((point, event) => {
|
||||||
|
if (!interactive || !onWidgetChange || !nodeId) return;
|
||||||
|
if (!containerRef.current) return;
|
||||||
|
const rect = containerRef.current.getBoundingClientRect();
|
||||||
|
const xFrac = clamp((event.clientX - rect.left - plotLeft) / plotWidth, 0, 1);
|
||||||
|
const sample = pickCursorPoint(xFrac);
|
||||||
|
if (point === 'p1') {
|
||||||
|
onWidgetChange(nodeId, 'x1', round3(xFrac));
|
||||||
|
onWidgetChange(nodeId, 'y1', round3(sample.yFraction));
|
||||||
|
} else {
|
||||||
|
onWidgetChange(nodeId, 'x2', round3(xFrac));
|
||||||
|
onWidgetChange(nodeId, 'y2', round3(sample.yFraction));
|
||||||
|
}
|
||||||
|
}, [interactive, nodeId, onWidgetChange, pickCursorPoint, plotLeft, plotWidth]);
|
||||||
|
|
||||||
|
const onPointerDown = useCallback((point) => (event) => {
|
||||||
|
if (!interactive) return;
|
||||||
|
if ((point === 'p1' && aLocked) || (point === 'p2' && bLocked)) return;
|
||||||
|
event.preventDefault();
|
||||||
|
event.stopPropagation();
|
||||||
|
event.currentTarget.setPointerCapture(event.pointerId);
|
||||||
|
setDragging(point);
|
||||||
|
}, [interactive, aLocked, bLocked]);
|
||||||
|
|
||||||
|
const onPointerMove = useCallback((event) => {
|
||||||
|
if (!dragging) return;
|
||||||
|
updateCursor(dragging, event);
|
||||||
|
}, [dragging, updateCursor]);
|
||||||
|
|
||||||
|
const onPointerUp = useCallback(() => {
|
||||||
|
setDragging(null);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
ref={containerRef}
|
||||||
|
className="nodrag nowheel lineplot-overlay"
|
||||||
|
onPointerMove={onPointerMove}
|
||||||
|
onPointerUp={onPointerUp}
|
||||||
|
onLostPointerCapture={onPointerUp}
|
||||||
|
>
|
||||||
|
<svg width={width} height={height} viewBox={`0 0 ${width} ${height}`} className="lineplot-svg">
|
||||||
|
<rect x="0" y="0" width={width} height={height} fill="#0f172a" />
|
||||||
|
|
||||||
|
{xTicks.map((tick) => {
|
||||||
|
const x = scaleX(tick);
|
||||||
|
return (
|
||||||
|
<g key={`x-${tick}`}>
|
||||||
|
<line x1={x} y1={plotTop} x2={x} y2={plotTop + plotHeight} stroke="#334155" strokeWidth={gridStroke} opacity="0.45" />
|
||||||
|
<text x={x} y={height - 10} textAnchor="middle" fontSize="11" fill="#94a3b8">
|
||||||
|
{formatTick(tick)}
|
||||||
|
</text>
|
||||||
|
</g>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
|
||||||
|
{yTicks.map((tick) => {
|
||||||
|
const y = scaleY(tick);
|
||||||
|
return (
|
||||||
|
<g key={`y-${tick}`}>
|
||||||
|
<line x1={plotLeft} y1={y} x2={plotLeft + plotWidth} y2={y} stroke="#334155" strokeWidth={gridStroke} opacity="0.45" />
|
||||||
|
<text x={plotLeft - 10} y={y + 4} textAnchor="end" fontSize="11" fill="#94a3b8">
|
||||||
|
{formatTick(tick)}
|
||||||
|
</text>
|
||||||
|
</g>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
|
||||||
|
<rect x={plotLeft} y={plotTop} width={plotWidth} height={plotHeight} fill="none" stroke="#334155" strokeWidth={gridStroke + 0.3} />
|
||||||
|
<path d={path} fill="none" stroke="#ff9800" strokeWidth={plotStroke} strokeLinecap="round" strokeLinejoin="round" />
|
||||||
|
|
||||||
|
{interactive && (
|
||||||
|
<>
|
||||||
|
<line x1={cursorA.x} y1={plotTop} x2={cursorA.x} y2={plotTop + plotHeight} stroke="#ffd700" strokeWidth={cursorStroke} strokeDasharray="10 6" opacity="0.95" />
|
||||||
|
<line x1={cursorB.x} y1={plotTop} x2={cursorB.x} y2={plotTop + plotHeight} stroke="#ffd700" strokeWidth={cursorStroke} strokeDasharray="10 6" opacity="0.95" />
|
||||||
|
<line x1={cursorA.x} y1={cursorA.y} x2={cursorB.x} y2={cursorB.y} stroke="#90caf9" strokeWidth={measureStroke} opacity="0.95" />
|
||||||
|
|
||||||
|
<circle
|
||||||
|
cx={cursorA.x}
|
||||||
|
cy={cursorA.y}
|
||||||
|
r={markerRadius}
|
||||||
|
className={`lineplot-marker ${aLocked ? 'lineplot-marker-locked' : ''}`}
|
||||||
|
onPointerDown={onPointerDown('p1')}
|
||||||
|
/>
|
||||||
|
<circle
|
||||||
|
cx={cursorB.x}
|
||||||
|
cy={cursorB.y}
|
||||||
|
r={markerRadius}
|
||||||
|
className={`lineplot-marker ${bLocked ? 'lineplot-marker-locked' : ''}`}
|
||||||
|
onPointerDown={onPointerDown('p2')}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -367,6 +367,41 @@ html, body, #root {
|
|||||||
opacity: 0.9;
|
opacity: 0.9;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.lineplot-overlay {
|
||||||
|
width: 100%;
|
||||||
|
aspect-ratio: 32 / 22;
|
||||||
|
background: #0f172a;
|
||||||
|
border: 1px solid #334155;
|
||||||
|
border-radius: 6px;
|
||||||
|
overflow: hidden;
|
||||||
|
user-select: none;
|
||||||
|
touch-action: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lineplot-svg {
|
||||||
|
display: block;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lineplot-marker {
|
||||||
|
fill: #ffd700;
|
||||||
|
stroke: #fff;
|
||||||
|
stroke-width: 2px;
|
||||||
|
cursor: grab;
|
||||||
|
filter: drop-shadow(0 0 4px rgba(0, 0, 0, 0.45));
|
||||||
|
}
|
||||||
|
|
||||||
|
.lineplot-marker:active {
|
||||||
|
cursor: grabbing;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lineplot-marker-locked {
|
||||||
|
fill: #e91e63;
|
||||||
|
stroke: #e91e63;
|
||||||
|
cursor: default;
|
||||||
|
}
|
||||||
|
|
||||||
/* ── 3D surface view ──────────────────────────────────────────────── */
|
/* ── 3D surface view ──────────────────────────────────────────────── */
|
||||||
.surface-view-container {
|
.surface-view-container {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
|
|||||||
52
frontend/src/workflowHydration.js
Normal file
52
frontend/src/workflowHydration.js
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
function mergeDefinition(nodeData, defs) {
|
||||||
|
const savedData = nodeData || {};
|
||||||
|
const savedDefinition = savedData.definition && typeof savedData.definition === 'object'
|
||||||
|
? savedData.definition
|
||||||
|
: null;
|
||||||
|
const registryDefinition = savedData.className ? defs[savedData.className] : null;
|
||||||
|
const definition = registryDefinition || savedDefinition;
|
||||||
|
|
||||||
|
if (!definition) return null;
|
||||||
|
|
||||||
|
const output = Array.isArray(savedData.output)
|
||||||
|
? savedData.output
|
||||||
|
: (Array.isArray(savedDefinition?.output) ? savedDefinition.output : null);
|
||||||
|
const outputName = Array.isArray(savedData.output_name)
|
||||||
|
? savedData.output_name
|
||||||
|
: (Array.isArray(savedDefinition?.output_name) ? savedDefinition.output_name : null);
|
||||||
|
|
||||||
|
return {
|
||||||
|
...definition,
|
||||||
|
...(output ? { output } : {}),
|
||||||
|
...(outputName ? { output_name: outputName } : {}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function hydrateWorkflowState(data, defs = {}) {
|
||||||
|
const loadedNodes = Array.isArray(data?.nodes) ? data.nodes : [];
|
||||||
|
const loadedEdges = Array.isArray(data?.edges) ? data.edges : [];
|
||||||
|
|
||||||
|
const nodes = loadedNodes.map((node) => ({
|
||||||
|
...node,
|
||||||
|
type: node.type || 'custom',
|
||||||
|
dragHandle: node.dragHandle || '.drag-handle',
|
||||||
|
data: {
|
||||||
|
...node.data,
|
||||||
|
label: node.data?.label || node.data?.className || 'Node',
|
||||||
|
widgetValues: node.data?.widgetValues || {},
|
||||||
|
definition: mergeDefinition(node.data, defs),
|
||||||
|
previewImage: null,
|
||||||
|
tableRows: null,
|
||||||
|
meshData: null,
|
||||||
|
overlay: null,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
const nextNodeId = Math.max(0, ...loadedNodes.map((node) => parseInt(node.id, 10) || 0)) + 1;
|
||||||
|
|
||||||
|
return {
|
||||||
|
nodes,
|
||||||
|
edges: loadedEdges,
|
||||||
|
nextNodeId,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -10,6 +10,8 @@ export function serializeWorkflowState(nodes, edges) {
|
|||||||
label: node.data?.label || node.data?.className || 'Node',
|
label: node.data?.label || node.data?.className || 'Node',
|
||||||
className: node.data?.className || '',
|
className: node.data?.className || '',
|
||||||
widgetValues: node.data?.widgetValues || {},
|
widgetValues: node.data?.widgetValues || {},
|
||||||
|
output: node.data?.definition?.output || [],
|
||||||
|
output_name: node.data?.definition?.output_name || [],
|
||||||
},
|
},
|
||||||
})),
|
})),
|
||||||
edges: edges.map((edge) => ({
|
edges: edges.map((edge) => ({
|
||||||
@@ -18,7 +20,7 @@ export function serializeWorkflowState(nodes, edges) {
|
|||||||
sourceHandle: edge.sourceHandle,
|
sourceHandle: edge.sourceHandle,
|
||||||
target: edge.target,
|
target: edge.target,
|
||||||
targetHandle: edge.targetHandle,
|
targetHandle: edge.targetHandle,
|
||||||
style: edge.style,
|
...(edge.style ? { style: edge.style } : {}),
|
||||||
})),
|
})),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import test from 'node:test';
|
import test from 'node:test';
|
||||||
import assert from 'node:assert/strict';
|
import assert from 'node:assert/strict';
|
||||||
|
|
||||||
|
import { hydrateWorkflowState } from '../src/workflowHydration.js';
|
||||||
import { serializeWorkflowState } from '../src/workflowSerialization.js';
|
import { serializeWorkflowState } from '../src/workflowSerialization.js';
|
||||||
|
|
||||||
test('serializeWorkflowState keeps only stable workflow fields needed for reload', () => {
|
test('serializeWorkflowState keeps only stable workflow fields needed for reload', () => {
|
||||||
@@ -59,6 +60,8 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
|
|||||||
label: 'Demo Label',
|
label: 'Demo Label',
|
||||||
className: 'DemoNode',
|
className: 'DemoNode',
|
||||||
widgetValues: { threshold: 0.42, mode: 'fast' },
|
widgetValues: { threshold: 0.42, mode: 'fast' },
|
||||||
|
output: [],
|
||||||
|
output_name: [],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -70,6 +73,8 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
|
|||||||
label: 'NoLabelNode',
|
label: 'NoLabelNode',
|
||||||
className: 'NoLabelNode',
|
className: 'NoLabelNode',
|
||||||
widgetValues: {},
|
widgetValues: {},
|
||||||
|
output: [],
|
||||||
|
output_name: [],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -89,3 +94,99 @@ test('serializeWorkflowState keeps only stable workflow fields needed for reload
|
|||||||
assert.equal('previewImage' in serialized.nodes[0].data, false);
|
assert.equal('previewImage' in serialized.nodes[0].data, false);
|
||||||
assert.equal('selected' in serialized.edges[0], false);
|
assert.equal('selected' in serialized.edges[0], false);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('hydrateWorkflowState restores saved dynamic outputs on top of current node definitions', () => {
|
||||||
|
const saved = {
|
||||||
|
version: 1,
|
||||||
|
nodes: [
|
||||||
|
{
|
||||||
|
id: '12',
|
||||||
|
position: { x: 40, y: 80 },
|
||||||
|
data: {
|
||||||
|
className: 'LoadFile',
|
||||||
|
widgetValues: { filename: 'scan.ibw', colormap: 'viridis' },
|
||||||
|
output: ['DATA_FIELD', 'DATA_FIELD'],
|
||||||
|
output_name: ['Height', 'Phase'],
|
||||||
|
previewImage: 'stale',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
id: 'e12-3',
|
||||||
|
source: '12',
|
||||||
|
sourceHandle: 'output::1::DATA_FIELD',
|
||||||
|
target: '3',
|
||||||
|
targetHandle: 'input::field::DATA_FIELD',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
const defs = {
|
||||||
|
LoadFile: {
|
||||||
|
category: 'io',
|
||||||
|
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['viridis', 'gray'], {}] } },
|
||||||
|
output: ['DATA_FIELD'],
|
||||||
|
output_name: ['field'],
|
||||||
|
manual_trigger: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const hydrated = hydrateWorkflowState(saved, defs);
|
||||||
|
|
||||||
|
assert.equal(hydrated.nextNodeId, 13);
|
||||||
|
assert.deepEqual(hydrated.edges, saved.edges);
|
||||||
|
assert.equal(hydrated.nodes[0].type, 'custom');
|
||||||
|
assert.equal(hydrated.nodes[0].dragHandle, '.drag-handle');
|
||||||
|
assert.equal(hydrated.nodes[0].data.label, 'LoadFile');
|
||||||
|
assert.equal(hydrated.nodes[0].data.previewImage, null);
|
||||||
|
assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD']);
|
||||||
|
assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Height', 'Phase']);
|
||||||
|
assert.deepEqual(hydrated.nodes[0].data.definition.input, defs.LoadFile.input);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('serializeWorkflowState and hydrateWorkflowState preserve reload-critical metadata for dynamic nodes', () => {
|
||||||
|
const nodes = [
|
||||||
|
{
|
||||||
|
id: '7',
|
||||||
|
position: { x: 10, y: 20 },
|
||||||
|
data: {
|
||||||
|
label: 'Load File',
|
||||||
|
className: 'LoadFile',
|
||||||
|
widgetValues: { filename: 'scan.gwy', colormap: 'gray' },
|
||||||
|
definition: {
|
||||||
|
category: 'io',
|
||||||
|
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } },
|
||||||
|
output: ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD'],
|
||||||
|
output_name: ['Topography', 'Error', 'Mask'],
|
||||||
|
},
|
||||||
|
previewImage: 'data:image/png;base64,stale',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const edges = [
|
||||||
|
{
|
||||||
|
id: 'e7-9',
|
||||||
|
source: '7',
|
||||||
|
sourceHandle: 'output::2::DATA_FIELD',
|
||||||
|
target: '9',
|
||||||
|
targetHandle: 'input::field::DATA_FIELD',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const defs = {
|
||||||
|
LoadFile: {
|
||||||
|
category: 'io',
|
||||||
|
input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } },
|
||||||
|
output: ['DATA_FIELD'],
|
||||||
|
output_name: ['field'],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const serialized = serializeWorkflowState(nodes, edges);
|
||||||
|
const hydrated = hydrateWorkflowState(serialized, defs);
|
||||||
|
|
||||||
|
assert.deepEqual(hydrated.nodes[0].data.widgetValues, nodes[0].data.widgetValues);
|
||||||
|
assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD', 'DATA_FIELD']);
|
||||||
|
assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Topography', 'Error', 'Mask']);
|
||||||
|
assert.deepEqual(hydrated.edges, edges);
|
||||||
|
});
|
||||||
|
|||||||
@@ -564,31 +564,57 @@ def test_load_file():
|
|||||||
|
|
||||||
|
|
||||||
def test_save_image():
|
def test_save_image():
|
||||||
print("=== Test: SaveImage ===")
|
print("=== Test: SaveImage (Save Layers) ===")
|
||||||
from backend.nodes.io import SaveImage
|
from backend.nodes.io import SaveImage
|
||||||
node = SaveImage()
|
node = SaveImage()
|
||||||
|
|
||||||
|
field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
|
||||||
|
field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
# Monkey-patch OUTPUT_DIR for testing
|
# Save single layer as TIFF
|
||||||
from pathlib import Path
|
tiff_path = os.path.join(tmpdir, "out.tiff")
|
||||||
import backend.nodes.io as io_mod
|
node.save(filename=tiff_path, format="TIFF", field_0=field_a)
|
||||||
orig_dir = io_mod.OUTPUT_DIR
|
assert os.path.exists(tiff_path), "TIFF file not created"
|
||||||
io_mod.OUTPUT_DIR = Path(tmpdir)
|
from PIL import Image
|
||||||
|
im = Image.open(tiff_path)
|
||||||
|
assert im.n_frames == 1
|
||||||
|
arr_back = np.array(im)
|
||||||
|
assert arr_back.shape == (32, 32)
|
||||||
|
|
||||||
|
# Save multi-layer as TIFF
|
||||||
|
tiff_path2 = os.path.join(tmpdir, "multi.tiff")
|
||||||
|
node.save(filename=tiff_path2, format="TIFF", field_0=field_a, field_1=field_b)
|
||||||
|
im2 = Image.open(tiff_path2)
|
||||||
|
assert im2.n_frames == 2
|
||||||
|
|
||||||
|
# Save as NPZ
|
||||||
|
npz_path = os.path.join(tmpdir, "out.npz")
|
||||||
|
node.save(filename=npz_path, format="NPZ", field_0=field_a, field_1=field_b)
|
||||||
|
assert os.path.exists(npz_path)
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
assert len(npz.files) == 2
|
||||||
|
assert np.allclose(npz["layer_0"], field_a.data)
|
||||||
|
assert np.allclose(npz["layer_1"], field_b.data)
|
||||||
|
|
||||||
|
# Extension is forced to match format
|
||||||
|
wrong_ext = os.path.join(tmpdir, "output.png")
|
||||||
|
node.save(filename=wrong_ext, format="TIFF", field_0=field_a)
|
||||||
|
assert os.path.exists(os.path.join(tmpdir, "output.tiff"))
|
||||||
|
|
||||||
|
# No fields connected → error
|
||||||
try:
|
try:
|
||||||
arr = np.random.default_rng(4).integers(0, 256, (32, 32), dtype=np.uint8)
|
node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF")
|
||||||
|
assert False, "Should have raised ValueError"
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Save as PNG
|
# No filename → error
|
||||||
node.save(image=arr, filename_prefix="test", format="PNG")
|
try:
|
||||||
saved = os.listdir(tmpdir)
|
node.save(filename="", format="TIFF", field_0=field_a)
|
||||||
assert any(f.endswith(".png") for f in saved), f"No PNG file found in {saved}"
|
assert False, "Should have raised ValueError"
|
||||||
|
except ValueError:
|
||||||
# Save as NPY
|
pass
|
||||||
node.save(image=arr.astype(np.float64), filename_prefix="test", format="NPY")
|
|
||||||
saved = os.listdir(tmpdir)
|
|
||||||
assert any(f.endswith(".npy") for f in saved), f"No NPY file found in {saved}"
|
|
||||||
finally:
|
|
||||||
io_mod.OUTPUT_DIR = orig_dir
|
|
||||||
|
|
||||||
print(" PASS\n")
|
print(" PASS\n")
|
||||||
|
|
||||||
@@ -896,8 +922,11 @@ def test_line_cursors():
|
|||||||
|
|
||||||
# Overlay should have been broadcast
|
# Overlay should have been broadcast
|
||||||
assert len(overlays) == 1
|
assert len(overlays) == 1
|
||||||
assert "image" in overlays[0]
|
assert overlays[0]["kind"] == "line_plot"
|
||||||
assert overlays[0]["image"].startswith("data:image/png;base64,")
|
assert len(overlays[0]["line"]) == len(line)
|
||||||
|
assert len(overlays[0]["x_axis"]) == len(line)
|
||||||
|
assert 0.0 <= overlays[0]["x1"] <= 1.0
|
||||||
|
assert 0.0 <= overlays[0]["x2"] <= 1.0
|
||||||
|
|
||||||
# With x_axis provided
|
# With x_axis provided
|
||||||
x_axis = np.linspace(0, 1, 100).astype(np.float64)
|
x_axis = np.linspace(0, 1, 100).astype(np.float64)
|
||||||
|
|||||||
20
tests/test_workflow_save.py
Normal file
20
tests/test_workflow_save.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.server import PNG_SIGNATURE, save_png_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_png_bytes_writes_exact_png_payload(tmp_path: Path):
|
||||||
|
target = tmp_path / "workflow"
|
||||||
|
payload = PNG_SIGNATURE + b"argonode-test-payload"
|
||||||
|
|
||||||
|
saved_path = save_png_bytes(str(target), payload)
|
||||||
|
|
||||||
|
assert saved_path == tmp_path / "workflow.png"
|
||||||
|
assert saved_path.read_bytes() == payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_png_bytes_rejects_invalid_payload(tmp_path: Path):
|
||||||
|
with pytest.raises(ValueError, match="valid PNG"):
|
||||||
|
save_png_bytes(str(tmp_path / "workflow.png"), b"not-a-png")
|
||||||
Reference in New Issue
Block a user