dimensioned export (gwy, HDF5)
This commit is contained in:
@@ -1,26 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.execution_context import emit_warning, emit_file_download
|
||||
from backend.data_types import (
|
||||
DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8,
|
||||
_SI_PREFIXES, _PREFIXABLE_UNITS,
|
||||
from backend.exporters import (
|
||||
available_formats,
|
||||
get_exporter,
|
||||
resolve_path,
|
||||
type_name_for_value,
|
||||
)
|
||||
|
||||
DOWNLOAD_DIR = Path(tempfile.gettempdir()) / "tono-downloads"
|
||||
|
||||
|
||||
def _choices_by_source_type() -> dict[str, list[str]]:
|
||||
"""Build the format dropdown's source-type map from the exporter registry.
|
||||
|
||||
Centralising this here means adding a new exporter module (or a new format
|
||||
inside an existing one) automatically surfaces in the UI — no parallel
|
||||
list to keep in sync.
|
||||
"""
|
||||
return {
|
||||
"DATA_FIELD": available_formats("DATA_FIELD"),
|
||||
"IMAGE": available_formats("IMAGE"),
|
||||
"ANNOTATION_SOURCE": available_formats("ANNOTATION_SOURCE"),
|
||||
"LINE": available_formats("LINE"),
|
||||
"RECORD_TABLE": available_formats("RECORD_TABLE"),
|
||||
"DATA_TABLE": available_formats("DATA_TABLE"),
|
||||
"FLOAT": available_formats("FLOAT"),
|
||||
"MESH_MODEL": available_formats("MESH_MODEL"),
|
||||
}
|
||||
|
||||
|
||||
@register_node(display_name="Save")
|
||||
class Save:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
choices = _choices_by_source_type()
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("STRING", {
|
||||
@@ -41,17 +59,8 @@ class Save:
|
||||
],
|
||||
}),
|
||||
"format": ("STRING", {
|
||||
"default": "TIFF",
|
||||
"choices_by_source_type": {
|
||||
"DATA_FIELD": ["TIFF", "PNG", "NPZ"],
|
||||
"IMAGE": ["PNG", "TIFF", "NPZ"],
|
||||
"ANNOTATION_SOURCE": ["PNG", "TIFF", "NPZ"],
|
||||
"LINE": ["PNG", "TIFF", "CSV", "NPZ", "JSON"],
|
||||
"RECORD_TABLE": ["CSV", "JSON"],
|
||||
"DATA_TABLE": ["CSV", "JSON"],
|
||||
"FLOAT": ["TXT", "JSON"],
|
||||
"MESH_MODEL": ["OBJ", "STL"],
|
||||
},
|
||||
"default": choices["DATA_FIELD"][0] if choices["DATA_FIELD"] else "",
|
||||
"choices_by_source_type": choices,
|
||||
"source_type_input": "value",
|
||||
}),
|
||||
},
|
||||
@@ -71,10 +80,12 @@ class Save:
|
||||
OUTPUT_NODE = True
|
||||
MANUAL_TRIGGER = True
|
||||
DESCRIPTION = (
|
||||
"Save a single graph value to disk. Supports fields, images, lines, tables, scalars, and 3D meshes."
|
||||
"Save a single graph value to disk. Supports fields, images, lines, tables, scalars, "
|
||||
"and 3D meshes. Use 'GWY' or 'TIFF (data)' for DataFields you want to re-open later "
|
||||
"with their physical units preserved."
|
||||
)
|
||||
|
||||
KEYWORDS = ("export", "write", "download", "png", "tiff", "csv", "json", "npz", "obj", "stl")
|
||||
KEYWORDS = ("export", "write", "download", "png", "tiff", "csv", "json", "npz", "obj", "stl", "gwy")
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -83,295 +94,11 @@ class Save:
|
||||
value,
|
||||
plot_title: str = "",
|
||||
):
|
||||
path = self._resolve_save_path(filename, format)
|
||||
|
||||
if isinstance(value, MeshModel):
|
||||
self._save_mesh(path, value, format)
|
||||
elif isinstance(value, DataField):
|
||||
self._save_datafield(path, value, format)
|
||||
elif isinstance(value, np.ndarray):
|
||||
if value.ndim == 1:
|
||||
self._save_line(path, LineData(data=value), format, title=plot_title)
|
||||
else:
|
||||
self._save_image_or_array(path, value, format)
|
||||
elif isinstance(value, LineData):
|
||||
self._save_line(path, value, format, title=plot_title)
|
||||
elif isinstance(value, list):
|
||||
self._save_table(path, value, format)
|
||||
elif isinstance(value, (int, float, np.floating, np.integer)):
|
||||
self._save_scalar(path, float(value), format)
|
||||
else:
|
||||
raise ValueError(f"Save does not support input type: {type(value).__name__}")
|
||||
type_name = type_name_for_value(value)
|
||||
module, spec = get_exporter(type_name, format)
|
||||
path = resolve_path(filename, spec, DOWNLOAD_DIR)
|
||||
module.save(path, value, format, plot_title=plot_title)
|
||||
|
||||
emit_warning(f"Saved to {path.name}")
|
||||
emit_file_download(str(path))
|
||||
return ()
|
||||
|
||||
def _resolve_save_path(self, filename: str, format_name: str) -> Path:
|
||||
ext_map = {
|
||||
"PNG": ".png",
|
||||
"TIFF": ".tiff",
|
||||
"NPZ": ".npz",
|
||||
"CSV": ".csv",
|
||||
"JSON": ".json",
|
||||
"OBJ": ".obj",
|
||||
"STL": ".stl",
|
||||
"TXT": ".txt",
|
||||
}
|
||||
ext = ext_map[format_name]
|
||||
|
||||
raw_filename = str(filename).strip() if filename is not None else ""
|
||||
if not raw_filename:
|
||||
raise ValueError("No output filename selected — enter a file name.")
|
||||
|
||||
candidate = Path(raw_filename).expanduser()
|
||||
if candidate.is_absolute():
|
||||
candidate.parent.mkdir(parents=True, exist_ok=True)
|
||||
path = candidate
|
||||
else:
|
||||
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
path = DOWNLOAD_DIR / candidate.name
|
||||
|
||||
if path.suffix.lower() != ext:
|
||||
path = path.with_suffix(ext)
|
||||
return path
|
||||
|
||||
def _save_datafield(self, path: Path, field: DataField, format_name: str):
|
||||
if format_name == "TIFF":
|
||||
import tifffile
|
||||
tifffile.imwrite(str(path), datafield_to_uint8(field, field.colormap))
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), field=np.asarray(field.data))
|
||||
return
|
||||
if format_name == "PNG":
|
||||
from PIL import Image
|
||||
Image.fromarray(datafield_to_uint8(field, field.colormap)).save(str(path))
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for DATA_FIELD.")
|
||||
|
||||
def _save_image_or_array(self, path: Path, image: np.ndarray, format_name: str):
|
||||
arr = np.asarray(image)
|
||||
if format_name == "PNG":
|
||||
from PIL import Image
|
||||
Image.fromarray(image_to_uint8(arr)).save(str(path))
|
||||
return
|
||||
if format_name == "TIFF":
|
||||
import tifffile
|
||||
tifffile.imwrite(str(path), image_to_uint8(arr))
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), image=arr)
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for IMAGE.")
|
||||
|
||||
def _save_line(self, path: Path, line: LineData, format_name: str, title: str = ""):
|
||||
y = np.asarray(line.data, dtype=np.float64).ravel()
|
||||
x = np.asarray(line.x_axis, dtype=np.float64).ravel()[: len(y)] if line.x_axis is not None else np.arange(len(y), dtype=np.float64)
|
||||
if format_name in ("PNG", "TIFF"):
|
||||
self._save_line_plot(path, x, y, line.x_unit, line.y_unit, title, format_name)
|
||||
return
|
||||
if format_name == "CSV":
|
||||
with path.open("w", newline="", encoding="utf-8") as fh:
|
||||
writer = csv.writer(fh)
|
||||
writer.writerow(["x", "y", "x_unit", "y_unit"])
|
||||
for xv, yv in zip(x, y):
|
||||
writer.writerow([xv, yv, line.x_unit, line.y_unit])
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), x=x, y=y)
|
||||
return
|
||||
if format_name == "JSON":
|
||||
path.write_text(json.dumps({
|
||||
"x": x.tolist(),
|
||||
"y": y.tolist(),
|
||||
"x_unit": line.x_unit,
|
||||
"y_unit": line.y_unit,
|
||||
}, indent=2), encoding="utf-8")
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for LINE.")
|
||||
|
||||
def _save_line_plot(
|
||||
self,
|
||||
path: Path,
|
||||
x: np.ndarray,
|
||||
y: np.ndarray,
|
||||
x_unit: str,
|
||||
y_unit: str,
|
||||
title: str,
|
||||
format_name: str,
|
||||
):
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
w, h = 1200, 750
|
||||
bg = (255, 255, 255)
|
||||
line_color = (79, 142, 247) # #4f8ef7
|
||||
grid_color = (200, 200, 200)
|
||||
text_color = (60, 60, 60)
|
||||
margin = {"left": 80, "right": 30, "top": 50, "bottom": 60}
|
||||
|
||||
img = Image.new("RGB", (w, h), bg)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("DejaVuSans.ttf", 14)
|
||||
font_small = ImageFont.truetype("DejaVuSans.ttf", 11)
|
||||
font_title = ImageFont.truetype("DejaVuSans.ttf", 16)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
font_small = font
|
||||
font_title = font
|
||||
|
||||
pw = w - margin["left"] - margin["right"]
|
||||
ph = h - margin["top"] - margin["bottom"]
|
||||
|
||||
def _si_scale(unit: str, vmin: float, vmax: float) -> tuple[float, str]:
|
||||
"""Pick the best SI prefix for an axis range. Returns (divisor, prefixed_unit)."""
|
||||
unit = (unit or "").strip()
|
||||
if not unit or unit not in _PREFIXABLE_UNITS:
|
||||
return 1.0, unit if unit else ""
|
||||
peak = max(abs(vmin), abs(vmax))
|
||||
if peak == 0:
|
||||
return 1.0, unit
|
||||
for scale, prefix in _SI_PREFIXES:
|
||||
if peak / scale >= 1.0:
|
||||
return scale, f"{prefix}{unit}"
|
||||
return _SI_PREFIXES[-1][0], f"{_SI_PREFIXES[-1][1]}{unit}"
|
||||
|
||||
xmin, xmax = float(np.nanmin(x)), float(np.nanmax(x))
|
||||
ymin, ymax = float(np.nanmin(y)), float(np.nanmax(y))
|
||||
|
||||
x_scale, x_label = _si_scale(x_unit, xmin, xmax)
|
||||
y_scale, y_label = _si_scale(y_unit, ymin, ymax)
|
||||
if not x_label:
|
||||
x_label = "x"
|
||||
if not y_label:
|
||||
y_label = "y"
|
||||
|
||||
# Scale data into prefixed units
|
||||
x = x / x_scale
|
||||
y = y / y_scale
|
||||
xmin, xmax = xmin / x_scale, xmax / x_scale
|
||||
ymin, ymax = ymin / y_scale, ymax / y_scale
|
||||
|
||||
if ymax == ymin:
|
||||
ymin, ymax = ymin - 1, ymax + 1
|
||||
if xmax == xmin:
|
||||
xmax = xmin + 1
|
||||
# Add 5% padding to y range
|
||||
ypad = (ymax - ymin) * 0.05
|
||||
ymin -= ypad
|
||||
ymax += ypad
|
||||
|
||||
def to_px(xv: float, yv: float) -> tuple[float, float]:
|
||||
px = margin["left"] + (xv - xmin) / (xmax - xmin) * pw
|
||||
py = margin["top"] + (1.0 - (yv - ymin) / (ymax - ymin)) * ph
|
||||
return px, py
|
||||
|
||||
# Grid lines (5 horizontal, 5 vertical)
|
||||
for i in range(6):
|
||||
gy = ymin + (ymax - ymin) * i / 5
|
||||
_, py = to_px(xmin, gy)
|
||||
draw.line([(margin["left"], py), (margin["left"] + pw, py)], fill=grid_color, width=1)
|
||||
label = f"{gy:.4g}"
|
||||
draw.text((margin["left"] - 8, py - 6), label, fill=text_color, font=font_small, anchor="rm")
|
||||
|
||||
gx = xmin + (xmax - xmin) * i / 5
|
||||
px, _ = to_px(gx, ymin)
|
||||
draw.line([(px, margin["top"]), (px, margin["top"] + ph)], fill=grid_color, width=1)
|
||||
label = f"{gx:.4g}"
|
||||
draw.text((px, margin["top"] + ph + 6), label, fill=text_color, font=font_small, anchor="mt")
|
||||
|
||||
# Plot line
|
||||
n = len(y)
|
||||
step = max(1, n // pw)
|
||||
xs, ys = x[::step], y[::step]
|
||||
pts = [to_px(float(xs[i]), float(ys[i])) for i in range(len(xs))]
|
||||
if len(pts) > 1:
|
||||
draw.line(pts, fill=line_color, width=2)
|
||||
|
||||
# Border
|
||||
draw.rectangle(
|
||||
[margin["left"], margin["top"], margin["left"] + pw, margin["top"] + ph],
|
||||
outline=(100, 100, 100), width=1,
|
||||
)
|
||||
draw.text((margin["left"] + pw // 2, h - 10), x_label, fill=text_color, font=font, anchor="mb")
|
||||
# Vertical y label — draw rotated
|
||||
y_label_img = Image.new("RGBA", (200, 20), (0, 0, 0, 0))
|
||||
y_draw = ImageDraw.Draw(y_label_img)
|
||||
y_draw.text((100, 10), y_label, fill=text_color, font=font, anchor="mm")
|
||||
y_label_img = y_label_img.rotate(90, expand=True)
|
||||
img.paste(y_label_img, (2, margin["top"] + ph // 2 - y_label_img.height // 2), y_label_img)
|
||||
|
||||
# Title
|
||||
if title and title.strip():
|
||||
draw.text((w // 2, 10), title.strip(), fill=text_color, font=font_title, anchor="mt")
|
||||
|
||||
ext = ".png" if format_name == "PNG" else ".tiff"
|
||||
img.save(str(path.with_suffix(ext)))
|
||||
|
||||
def _save_table(self, path: Path, rows: list, format_name: str):
|
||||
if format_name == "JSON":
|
||||
path.write_text(json.dumps(rows, indent=2), encoding="utf-8")
|
||||
return
|
||||
if format_name == "CSV":
|
||||
columns: list[str] = []
|
||||
for row in rows:
|
||||
if isinstance(row, dict):
|
||||
for key in row.keys():
|
||||
if key not in columns:
|
||||
columns.append(str(key))
|
||||
with path.open("w", newline="", encoding="utf-8") as fh:
|
||||
writer = csv.DictWriter(fh, fieldnames=columns)
|
||||
writer.writeheader()
|
||||
for row in rows:
|
||||
writer.writerow(row if isinstance(row, dict) else {"value": row})
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for table inputs.")
|
||||
|
||||
def _save_scalar(self, path: Path, value: float, format_name: str):
|
||||
if format_name == "TXT":
|
||||
path.write_text(f"{value}\n", encoding="utf-8")
|
||||
return
|
||||
if format_name == "JSON":
|
||||
path.write_text(json.dumps({"value": value}, indent=2), encoding="utf-8")
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for scalar values.")
|
||||
|
||||
def _save_mesh(self, path: Path, mesh: MeshModel, format_name: str):
|
||||
if format_name == "OBJ":
|
||||
self._save_obj(path, mesh)
|
||||
return
|
||||
if format_name == "STL":
|
||||
self._save_stl(path, mesh)
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for MESH_MODEL.")
|
||||
|
||||
def _save_obj(self, path: Path, mesh: MeshModel):
|
||||
lines = []
|
||||
for vertex in mesh.vertices:
|
||||
lines.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}")
|
||||
for face in mesh.faces:
|
||||
lines.append(f"f {int(face[0]) + 1} {int(face[1]) + 1} {int(face[2]) + 1}")
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
def _save_stl(self, path: Path, mesh: MeshModel):
|
||||
def normal(a, b, c):
|
||||
n = np.cross(b - a, c - a)
|
||||
length = float(np.linalg.norm(n))
|
||||
return n / length if length > 0 else np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
||||
|
||||
lines = ["solid tono"]
|
||||
vertices = np.asarray(mesh.vertices, dtype=np.float32)
|
||||
for face in np.asarray(mesh.faces, dtype=np.int32):
|
||||
a, b, c = vertices[int(face[0])], vertices[int(face[1])], vertices[int(face[2])]
|
||||
n = normal(a, b, c)
|
||||
lines.append(f" facet normal {n[0]} {n[1]} {n[2]}")
|
||||
lines.append(" outer loop")
|
||||
lines.append(f" vertex {a[0]} {a[1]} {a[2]}")
|
||||
lines.append(f" vertex {b[0]} {b[1]} {b[2]}")
|
||||
lines.append(f" vertex {c[0]} {c[1]} {c[2]}")
|
||||
lines.append(" endloop")
|
||||
lines.append(" endfacet")
|
||||
lines.append("endsolid tono")
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
Reference in New Issue
Block a user