fix preview and save on native

This commit is contained in:
2026-03-24 22:52:24 -07:00
parent a60b0c15ca
commit 6959c62c8f
16 changed files with 875 additions and 202 deletions

View File

@@ -11,7 +11,7 @@ Gwyddion equivalents:
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
from backend.data_types import DataField, datafield_to_uint8, encode_preview
# ---------------------------------------------------------------------------
@@ -131,88 +131,49 @@ class LineCursors:
self, line, x1: float, y1: float, x2: float, y2: float,
x_axis=None,
) -> tuple:
import io as _io
import base64
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
y = np.asarray(line, dtype=np.float64).ravel()
n = len(y)
if x_axis is not None:
x = np.asarray(x_axis, dtype=np.float64).ravel()[:n]
else:
x = np.arange(n, dtype=np.float64)
x1 = float(np.clip(x1, 0.0, 1.0))
x2 = float(np.clip(x2, 0.0, 1.0))
# --- Render the base plot first to determine axes bounds ---
fig, ax = plt.subplots(figsize=(3.2, 2.2), dpi=100)
fig.patch.set_facecolor("#1e293b")
ax.set_facecolor("#0f172a")
ax.plot(x, y, color="#ff9800", linewidth=1.2)
ax.tick_params(colors="#94a3b8", labelsize=7)
for spine in ax.spines.values():
spine.set_color("#334155")
ax.grid(True, color="#334155", linewidth=0.3, alpha=0.5)
fig.tight_layout(pad=0.4)
xmin = float(np.min(x)) if len(x) else 0.0
xmax = float(np.max(x)) if len(x) else 1.0
# Force a draw so transforms are valid
fig.canvas.draw()
def x_frac_to_idx(frac):
if n <= 1:
return 0
if xmax == xmin:
return 0
target_x = xmin + frac * (xmax - xmin)
return int(np.argmin(np.abs(x - target_x)))
# Get axes position in figure-fraction coordinates
ax_pos = ax.get_position()
ax_l, ax_b = ax_pos.x0, ax_pos.y0
ax_w, ax_h = ax_pos.width, ax_pos.height
# x1/y1 arrive as image-fraction from the frontend drag.
# Convert image-fraction x → axes-fraction → nearest data index.
def img_x_to_idx(ix):
axes_frac = np.clip((ix - ax_l) / ax_w, 0, 1)
return int(np.clip(round(axes_frac * (n - 1)), 0, n - 1))
idx_a = img_x_to_idx(x1)
idx_b = img_x_to_idx(x2)
idx_a = x_frac_to_idx(x1)
idx_b = x_frac_to_idx(x2)
xa, ya = float(x[idx_a]), float(y[idx_a])
xb, yb = float(x[idx_b]), float(y[idx_b])
# --- Draw cursor lines and markers on the plot ---
ax.axvline(xa, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
ax.axvline(xb, color="#ffd700", linewidth=1.5, linestyle="--", alpha=0.9)
ax.plot(xa, ya, "o", color="#ffd700", markersize=6, zorder=5)
ax.plot(xb, yb, "o", color="#ffd700", markersize=6, zorder=5)
ax.annotate(
"", xy=(xb, yb), xytext=(xa, ya),
arrowprops=dict(arrowstyle="<->", color="#90caf9", lw=1.5),
)
# --- Broadcast overlay ---
if LineCursors._broadcast_overlay_fn is not None:
# Convert data-space positions back to image-fraction for markers
fig.canvas.draw()
inv = fig.transFigure.inverted()
fig_a = inv.transform(ax.transData.transform([xa, ya]))
fig_b = inv.transform(ax.transData.transform([xb, yb]))
buf = _io.BytesIO()
fig.savefig(buf, format="png", facecolor=fig.get_facecolor())
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
LineCursors._broadcast_overlay_fn(
LineCursors._current_node_id,
{
"image": image_uri,
"x1": float(fig_a[0]),
"y1": float(1.0 - fig_a[1]), # flip: image y=0 is top
"x2": float(fig_b[0]),
"y2": float(1.0 - fig_b[1]),
"kind": "line_plot",
"line": y.tolist(),
"x_axis": x.tolist(),
"x1": x1,
"x2": x2,
"y1": float(y1),
"y2": float(y2),
"a_locked": False,
"b_locked": False,
},
)
plt.close(fig)
# --- Output table ---
table = [
{"quantity": "A position", "value": xa, "unit": ""},
@@ -414,8 +375,6 @@ class CrossSection:
point_a=None, point_b=None,
) -> tuple:
from scipy.ndimage import map_coordinates
import io, base64
from matplotlib.figure import Figure
# COORD inputs override widget values
if point_a is not None:
@@ -453,14 +412,9 @@ class CrossSection:
# Broadcast overlay image with marker positions
if CrossSection._broadcast_overlay_fn is not None:
fig = Figure(figsize=(3, 3), dpi=100)
ax = fig.add_axes([0, 0, 1, 1])
ax.imshow(field.data, cmap="viridis", aspect="auto")
ax.axis("off")
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
image_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
# Use the field's native pixel grid for the overlay preview so enlarging
# the panel keeps the image as sharp as the source data allows.
image_uri = encode_preview(datafield_to_uint8(field, field.colormap))
CrossSection._broadcast_overlay_fn(
CrossSection._current_node_id,

View File

@@ -78,7 +78,7 @@ class View3D:
"required": {
"field": ("DATA_FIELD",),
"colormap": (["auto"] + list(COLORMAPS),),
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
"z_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10.0, "step": 0.05}),
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
}
}
@@ -134,7 +134,7 @@ class View3D:
"colors": colors_b64,
"z_min": zmin,
"z_max": zmax,
"z_scale": float(z_scale),
"z_scale": float(z_scale * 0.1),
"x_range": [float(field.xoff), float(field.xoff + field.xreal)],
"y_range": [float(field.yoff), float(field.yoff + field.yreal)],
}

View File

@@ -399,54 +399,86 @@ class Coordinate:
# SaveImage
# ---------------------------------------------------------------------------
@register_node(display_name="Save Image")
_MAX_SAVE_FIELDS = 8
@register_node(display_name="Save Layers")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
optional = {}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("DATA_FIELD",)
return {
"required": {
"image": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "output"}),
"format": (["PNG", "TIFF", "NPY"],),
}
"filename": ("FILE_PICKER", {"default": ""}),
"format": (["TIFF", "NPZ"],),
},
"optional": optional,
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "io"
OUTPUT_NODE = True
DESCRIPTION = "Save an image or array to the output folder."
MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more DATA_FIELD layers to a single file. "
"Connect fields to the inputs — a new slot appears as each is filled. "
"TIFF writes float32 multi-page; NPZ writes float64 named arrays. "
"Click Save to write (does not auto-run)."
)
# Injected by server.py before execution begins
_broadcast_preview = None
_broadcast_warning_fn = None
_current_node_id = None
def save(self, image: np.ndarray, filename_prefix: str = "output", format: str = "PNG"):
OUTPUT_DIR.mkdir(exist_ok=True)
def save(self, filename: str, format: str = "TIFF", **kwargs):
# Collect connected fields in order
fields = []
for i in range(_MAX_SAVE_FIELDS):
f = kwargs.get(f"field_{i}")
if f is not None:
fields.append(f)
# Find next available filename
idx = 1
while True:
name = f"{filename_prefix}_{idx:04d}"
candidate = OUTPUT_DIR / f"{name}.{format.lower()}"
if not candidate.exists():
break
idx += 1
if not fields:
raise ValueError("No fields connected — connect at least one DATA_FIELD input.")
if format == "NPY":
np.save(str(OUTPUT_DIR / f"{name}.npy"), image)
if not filename or not filename.strip():
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(filename)
# Ensure parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)
# Force correct extension
ext = ".tiff" if format == "TIFF" else ".npz"
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
if format == "TIFF":
self._save_tiff(path, fields)
else:
from PIL import Image
arr = image_to_uint8(image)
if arr.ndim == 2:
pil_img = Image.fromarray(arr, mode="L")
else:
pil_img = Image.fromarray(arr, mode="RGB")
pil_img.save(str(OUTPUT_DIR / f"{name}.{format.lower()}"))
self._save_npz(path, fields)
# Emit preview over WebSocket if callback is set
if SaveImage._broadcast_preview is not None:
arr_u8 = image_to_uint8(image)
data_uri = encode_preview(arr_u8)
SaveImage._broadcast_preview(data_uri)
self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}")
return ()
def _save_tiff(self, path: Path, fields: list[DataField]):
from PIL import Image
images = []
for f in fields:
images.append(Image.fromarray(f.data.astype(np.float32)))
images[0].save(str(path), save_all=True, append_images=images[1:])
def _save_npz(self, path: Path, fields: list[DataField]):
arrays = {}
for i, f in enumerate(fields):
arrays[f"layer_{i}"] = f.data
np.savez(str(path), **arrays)
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return ()