dimensioned export (gwy, HDF5)
This commit is contained in:
128
backend/exporters/__init__.py
Normal file
128
backend/exporters/__init__.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Exporter registry.
|
||||
|
||||
Each module in this package exports a tuple of tono type names it can handle
|
||||
(`accepted_types`), a FORMATS map of format name → FormatSpec, and a `save()`
|
||||
function. This registry walks those modules and builds lookup tables the
|
||||
Save node uses to dispatch.
|
||||
|
||||
Usage::
|
||||
|
||||
from backend.exporters import get_exporter, resolve_path, type_name_for_value
|
||||
|
||||
type_name = type_name_for_value(value) # e.g. "DATA_FIELD"
|
||||
exporter, spec = get_exporter(type_name, "GWY") # raises on unknown combo
|
||||
path = resolve_path(filename, spec)
|
||||
exporter.save(path, value, "GWY")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import (
|
||||
DataField,
|
||||
DataTable,
|
||||
ImageData,
|
||||
LineData,
|
||||
MeshModel,
|
||||
RecordTable,
|
||||
)
|
||||
from backend.exporters import datafield, image, line, mesh, scalar, table
|
||||
from backend.exporters._base import Exporter, FormatSpec
|
||||
|
||||
_EXPORTER_MODULES: list[ModuleType] = [datafield, image, line, mesh, scalar, table]
|
||||
|
||||
# (type_name, format_name) → (module, FormatSpec)
|
||||
_REGISTRY: dict[tuple[str, str], tuple[ModuleType, FormatSpec]] = {}
|
||||
for _mod in _EXPORTER_MODULES:
|
||||
for _type_name in _mod.accepted_types:
|
||||
for _format_name, _spec in _mod.FORMATS.items():
|
||||
_REGISTRY[(_type_name, _format_name)] = (_mod, _spec)
|
||||
|
||||
|
||||
def get_exporter(type_name: str, format_name: str) -> tuple[ModuleType, FormatSpec]:
|
||||
"""Return the (module, FormatSpec) for a type + format combination.
|
||||
|
||||
Raises ValueError with a user-readable message when the combination is
|
||||
unknown. That message gets propagated straight to the UI status toast,
|
||||
so keep it actionable.
|
||||
"""
|
||||
entry = _REGISTRY.get((type_name, format_name))
|
||||
if entry is None:
|
||||
raise ValueError(f"Format {format_name!r} is not supported for {type_name}.")
|
||||
return entry
|
||||
|
||||
|
||||
def available_formats(type_name: str) -> list[str]:
|
||||
"""Format names available for a given tono type, in registration order."""
|
||||
return [fmt for (t, fmt) in _REGISTRY if t == type_name]
|
||||
|
||||
|
||||
def type_name_for_value(value: Any) -> str:
|
||||
"""Classify a runtime Python value into a tono type name.
|
||||
|
||||
The ordering matters: ImageData is a subclass of ndarray, and RecordTable /
|
||||
DataTable are subclasses of list, so check the more specific classes first.
|
||||
"""
|
||||
if isinstance(value, MeshModel):
|
||||
return "MESH_MODEL"
|
||||
if isinstance(value, DataField):
|
||||
return "DATA_FIELD"
|
||||
if isinstance(value, LineData):
|
||||
return "LINE"
|
||||
if isinstance(value, ImageData):
|
||||
# Annotation outputs carry context in ``.metadata``; regardless, image
|
||||
# formats are the right set.
|
||||
return "IMAGE"
|
||||
if isinstance(value, np.ndarray):
|
||||
if value.ndim == 1:
|
||||
return "LINE"
|
||||
return "IMAGE"
|
||||
if isinstance(value, RecordTable):
|
||||
return "RECORD_TABLE"
|
||||
if isinstance(value, DataTable):
|
||||
return "DATA_TABLE"
|
||||
if isinstance(value, list):
|
||||
# Plain list — treat as a data table; the table exporter handles both.
|
||||
return "DATA_TABLE"
|
||||
if isinstance(value, (int, float, np.floating, np.integer)):
|
||||
return "FLOAT"
|
||||
raise ValueError(f"Save does not support input type: {type(value).__name__}")
|
||||
|
||||
|
||||
def resolve_path(filename: str, spec: FormatSpec, default_dir: Path) -> Path:
|
||||
"""Expand *filename* into an absolute Path with the correct extension.
|
||||
|
||||
Relative names are written under *default_dir* (the session download dir);
|
||||
absolute paths are honored as-is, with parent directories created.
|
||||
"""
|
||||
raw_filename = str(filename).strip() if filename is not None else ""
|
||||
if not raw_filename:
|
||||
raise ValueError("No output filename selected — enter a file name.")
|
||||
|
||||
candidate = Path(raw_filename).expanduser()
|
||||
if candidate.is_absolute():
|
||||
candidate.parent.mkdir(parents=True, exist_ok=True)
|
||||
path = candidate
|
||||
else:
|
||||
default_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = default_dir / candidate.name
|
||||
|
||||
if path.suffix.lower() != spec.ext:
|
||||
path = path.with_suffix(spec.ext)
|
||||
return path
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Exporter",
|
||||
"FormatSpec",
|
||||
"available_formats",
|
||||
"get_exporter",
|
||||
"resolve_path",
|
||||
"type_name_for_value",
|
||||
]
|
||||
60
backend/exporters/_base.py
Normal file
60
backend/exporters/_base.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Base protocol for file exporters.
|
||||
|
||||
Each exporter module handles one tono value type (DATA_FIELD, IMAGE, LINE, …)
|
||||
and implements one or more output formats. Registration is discovered via the
|
||||
module-level attributes declared below, so adding a new exporter is a matter
|
||||
of dropping a new file in this package and importing it from __init__.
|
||||
|
||||
A single file per value type (rather than per format) keeps format choices
|
||||
that share plumbing — PNG & TIFF previews for DATA_FIELD, CSV & JSON for
|
||||
tables — co-located, which is where most of the shared logic lives.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FormatSpec:
|
||||
"""One output format supported by an exporter module."""
|
||||
|
||||
#: File extension (leading dot), e.g. ".tiff".
|
||||
ext: str
|
||||
#: True if the format preserves enough information to reload the value
|
||||
#: via the matching importer. Advertised in the UI so users can tell
|
||||
#: "save a preview" and "save for later" apart.
|
||||
round_trip: bool
|
||||
#: Short human-readable label. The enum key used in the format dropdown
|
||||
#: is the dict key in each module's FORMATS map; `label` is what we'd
|
||||
#: surface in tooltips or docs. Leave empty to fall back to the key.
|
||||
label: str = ""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Exporter(Protocol):
|
||||
"""Structural protocol satisfied by every module in backend.exporters."""
|
||||
|
||||
#: Tono type names this exporter handles. Must match the upper-case names
|
||||
#: used in node INPUT_TYPES / OUTPUTS (e.g. "DATA_FIELD", "IMAGE", "LINE").
|
||||
accepted_types: tuple[str, ...]
|
||||
|
||||
#: Format name → spec. Format names are what users pick in the Save node's
|
||||
#: format dropdown, so they should be short and recognizable.
|
||||
FORMATS: dict[str, FormatSpec]
|
||||
|
||||
def save(self, path: Path, value: Any, format_name: str, **opts: Any) -> None:
|
||||
"""Write *value* to *path* in *format_name*.
|
||||
|
||||
The caller is responsible for ensuring ``path`` has the correct
|
||||
extension (see registry.resolve_path) and that ``value`` is of a type
|
||||
listed in ``accepted_types``.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Re-exported so modules can write `from backend.exporters._base import FormatSpec`.
|
||||
__all__ = ["FormatSpec", "Exporter"]
|
||||
215
backend/exporters/datafield.py
Normal file
215
backend/exporters/datafield.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Exporter for DATA_FIELD values.
|
||||
|
||||
Format choices:
|
||||
|
||||
* **TIFF** — 8-bit RGB colormap preview. *Not* round-trippable. Useful for
|
||||
figures and sharing; opening it back gives you pixels, not physics.
|
||||
* **TIFF (data)** — float64 array with tono metadata JSON-embedded in the
|
||||
TIFF ImageDescription tag. Round-trips via the array_image importer once
|
||||
that importer learns to read the tag (see tests/node_tests/exporters.py).
|
||||
* **PNG** — 8-bit RGB colormap preview. Not round-trippable.
|
||||
* **NPZ** — raw ``data`` array only. Not round-trippable (units are dropped).
|
||||
* **GWY** — Gwyddion native format via the ``gwyfile`` package. Round-trips
|
||||
and opens directly in Gwyddion. Recommended for "save and come back later".
|
||||
* **HDF5** — generic HDF5 with one ``/data`` dataset and physical dimensions
|
||||
as dataset attrs. Round-trips via our generic ``hdf5`` importer.
|
||||
* **HDF5 (Ergo)** — Asylum Research / Ergo layout with the dataset at
|
||||
``Image/DataSet/Resolution 0/Frame 0/<title>/Image`` and a sidecar group
|
||||
``Image/DataSetInfo/Global/Channels/<title>/ImageDims`` carrying
|
||||
``DimScaling`` / ``DimUnits`` / ``DataUnits``. Round-trips via our
|
||||
``ergo_hdf5`` importer and opens in Asylum Ergo / Igor.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import DataField, datafield_to_uint8
|
||||
from backend.exporters._base import FormatSpec
|
||||
|
||||
accepted_types: tuple[str, ...] = ("DATA_FIELD",)
|
||||
|
||||
FORMATS: dict[str, FormatSpec] = {
|
||||
"TIFF": FormatSpec(ext=".tiff", round_trip=False, label="TIFF (preview)"),
|
||||
"TIFF (data)": FormatSpec(ext=".tiff", round_trip=True, label="TIFF (calibrated data)"),
|
||||
"PNG": FormatSpec(ext=".png", round_trip=False, label="PNG (preview)"),
|
||||
"NPZ": FormatSpec(ext=".npz", round_trip=False, label="NumPy (.npz)"),
|
||||
"GWY": FormatSpec(ext=".gwy", round_trip=True, label="Gwyddion (.gwy)"),
|
||||
"HDF5": FormatSpec(ext=".h5", round_trip=True, label="HDF5 (generic)"),
|
||||
"HDF5 (Ergo)": FormatSpec(ext=".h5", round_trip=True, label="HDF5 (Asylum Research / Ergo)"),
|
||||
}
|
||||
|
||||
|
||||
def save(path: Path, value: DataField, format_name: str, **_opts) -> None:
|
||||
if format_name == "TIFF":
|
||||
_save_tiff_preview(path, value)
|
||||
return
|
||||
if format_name == "TIFF (data)":
|
||||
_save_tiff_data(path, value)
|
||||
return
|
||||
if format_name == "PNG":
|
||||
_save_png_preview(path, value)
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
_save_npz(path, value)
|
||||
return
|
||||
if format_name == "GWY":
|
||||
_save_gwy(path, value)
|
||||
return
|
||||
if format_name == "HDF5":
|
||||
_save_hdf5_generic(path, value)
|
||||
return
|
||||
if format_name == "HDF5 (Ergo)":
|
||||
_save_hdf5_ergo(path, value)
|
||||
return
|
||||
raise ValueError(f"Format {format_name!r} is not supported for DATA_FIELD.")
|
||||
|
||||
|
||||
def _save_tiff_preview(path: Path, field: DataField) -> None:
|
||||
import tifffile
|
||||
tifffile.imwrite(str(path), datafield_to_uint8(field, field.colormap))
|
||||
|
||||
|
||||
def _save_tiff_data(path: Path, field: DataField) -> None:
|
||||
"""Write the raw float64 data with tono metadata in the ImageDescription tag.
|
||||
|
||||
The description is a JSON document of shape ``{"tono": {...}}`` so future
|
||||
schema extensions can coexist with other tools' TIFF metadata. Only the
|
||||
fields needed to reconstruct physical coordinates and z-scaling are
|
||||
embedded; display state (colormap, display_scale) is intentionally out of
|
||||
scope — this format is for data, not styling.
|
||||
"""
|
||||
import tifffile
|
||||
|
||||
meta = {
|
||||
"tono": {
|
||||
"version": 1,
|
||||
"xreal": float(field.xreal),
|
||||
"yreal": float(field.yreal),
|
||||
"xoff": float(field.xoff),
|
||||
"yoff": float(field.yoff),
|
||||
"si_unit_xy": str(field.si_unit_xy),
|
||||
"si_unit_z": str(field.si_unit_z),
|
||||
"domain": str(field.domain),
|
||||
"colormap": field.colormap if isinstance(field.colormap, str) else "viridis",
|
||||
}
|
||||
}
|
||||
tifffile.imwrite(
|
||||
str(path),
|
||||
np.ascontiguousarray(field.data, dtype=np.float64),
|
||||
description=json.dumps(meta, separators=(",", ":")),
|
||||
)
|
||||
|
||||
|
||||
def _save_png_preview(path: Path, field: DataField) -> None:
|
||||
from PIL import Image
|
||||
Image.fromarray(datafield_to_uint8(field, field.colormap)).save(str(path))
|
||||
|
||||
|
||||
def _save_npz(path: Path, field: DataField) -> None:
|
||||
np.savez(str(path), field=np.asarray(field.data))
|
||||
|
||||
|
||||
def _save_gwy(path: Path, field: DataField) -> None:
|
||||
"""Write a single-channel .gwy file via the gwyfile package."""
|
||||
from gwyfile.objects import GwyContainer, GwyDataField, GwySIUnit
|
||||
|
||||
# gwyfile's GwyDataField ctor expects the data array and physical extents.
|
||||
# si_unit_xy / si_unit_z accept a GwySIUnit wrapper with a .unitstr field.
|
||||
gwy_field = GwyDataField(
|
||||
np.ascontiguousarray(field.data, dtype=np.float64),
|
||||
xreal=float(field.xreal),
|
||||
yreal=float(field.yreal),
|
||||
xoff=float(field.xoff),
|
||||
yoff=float(field.yoff),
|
||||
si_unit_xy=GwySIUnit(unitstr=str(field.si_unit_xy or "")),
|
||||
si_unit_z=GwySIUnit(unitstr=str(field.si_unit_z or "")),
|
||||
)
|
||||
title = path.stem or "field"
|
||||
container = GwyContainer({
|
||||
"/0/data": gwy_field,
|
||||
"/0/data/title": title,
|
||||
})
|
||||
container.tofile(str(path))
|
||||
|
||||
|
||||
def _save_hdf5_generic(path: Path, field: DataField) -> None:
|
||||
"""Write a single dataset ``/data`` with physical dimensions as dataset attrs.
|
||||
|
||||
The layout is the mirror of :mod:`backend.importers.hdf5`: any 2-D numeric
|
||||
dataset is picked up and its attrs (``xreal``, ``yreal``, ``xoff``, ``yoff``,
|
||||
``si_unit_xy``, ``si_unit_z``) reconstruct the DataField.
|
||||
"""
|
||||
import h5py
|
||||
|
||||
arr = np.ascontiguousarray(field.data, dtype=np.float64)
|
||||
with h5py.File(str(path), "w") as f:
|
||||
ds = f.create_dataset("data", data=arr)
|
||||
ds.attrs["xreal"] = float(field.xreal)
|
||||
ds.attrs["yreal"] = float(field.yreal)
|
||||
ds.attrs["xoff"] = float(field.xoff)
|
||||
ds.attrs["yoff"] = float(field.yoff)
|
||||
ds.attrs["si_unit_xy"] = str(field.si_unit_xy or "")
|
||||
ds.attrs["si_unit_z"] = str(field.si_unit_z or "")
|
||||
|
||||
|
||||
def _save_hdf5_ergo(path: Path, field: DataField) -> None:
|
||||
"""Write an Asylum Research / Ergo-compatible HDF5 file.
|
||||
|
||||
The layout mirrors :mod:`backend.importers.ergo_hdf5`:
|
||||
|
||||
* The image dataset lives at
|
||||
``Image/DataSet/Resolution 0/Frame 0/<title>/Image`` — the second-to-last
|
||||
path component is the channel name that the importer keys off.
|
||||
* A sidecar group at
|
||||
``Image/DataSetInfo/Global/Channels/<title>/ImageDims`` carries
|
||||
``DimScaling`` (a (2, 2) array of absolute physical ranges, Y-first),
|
||||
``DimUnits`` (``[Y_unit, X_unit]``), and ``DataUnits`` (Z unit string).
|
||||
|
||||
This makes the file openable by Asylum Ergo / Igor and round-trippable
|
||||
through our ergo_hdf5 importer.
|
||||
"""
|
||||
import h5py
|
||||
|
||||
arr = np.ascontiguousarray(field.data, dtype=np.float64)
|
||||
title = path.stem or "field"
|
||||
|
||||
x_start = float(field.xoff)
|
||||
x_end = float(field.xoff) + float(field.xreal)
|
||||
y_start = float(field.yoff)
|
||||
y_end = float(field.yoff) + float(field.yreal)
|
||||
# DimScaling is stored Y-first to match the importer's expectations
|
||||
# (see ergo_hdf5.py:110-113).
|
||||
dim_scaling = np.array(
|
||||
[[y_start, y_end], [x_start, x_end]],
|
||||
dtype=np.float64,
|
||||
)
|
||||
# DimUnits is [Y_unit, X_unit]; the importer takes the X (second) entry
|
||||
# as the canonical lateral unit (see ergo_hdf5.py:129-135).
|
||||
xy_unit = str(field.si_unit_xy or "m")
|
||||
z_unit = str(field.si_unit_z or "")
|
||||
dim_units = np.array([xy_unit, xy_unit], dtype=h5py.string_dtype())
|
||||
|
||||
with h5py.File(str(path), "w") as f:
|
||||
ds = f.create_dataset(
|
||||
f"Image/DataSet/Resolution 0/Frame 0/{title}/Image",
|
||||
data=arr,
|
||||
)
|
||||
# Also write the generic attrs so non-Ergo readers still see physics.
|
||||
ds.attrs["xreal"] = float(field.xreal)
|
||||
ds.attrs["yreal"] = float(field.yreal)
|
||||
ds.attrs["xoff"] = float(field.xoff)
|
||||
ds.attrs["yoff"] = float(field.yoff)
|
||||
ds.attrs["si_unit_xy"] = xy_unit
|
||||
ds.attrs["si_unit_z"] = z_unit
|
||||
|
||||
dims_grp = f.create_group(
|
||||
f"Image/DataSetInfo/Global/Channels/{title}/ImageDims"
|
||||
)
|
||||
dims_grp.attrs["DimScaling"] = dim_scaling
|
||||
dims_grp.attrs["DimUnits"] = dim_units
|
||||
dims_grp.attrs["DataUnits"] = z_unit
|
||||
41
backend/exporters/image.py
Normal file
41
backend/exporters/image.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Exporter for IMAGE values (numpy arrays, ImageData annotation sources).
|
||||
|
||||
Images are raw pixel arrays — no physical calibration by design — so none of
|
||||
the formats here round-trip dimensions. PNG/TIFF convert to uint8 via the
|
||||
same image_to_uint8 helper the preview pipeline uses; NPZ preserves the raw
|
||||
array.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import image_to_uint8
|
||||
from backend.exporters._base import FormatSpec
|
||||
|
||||
accepted_types: tuple[str, ...] = ("IMAGE", "ANNOTATION_SOURCE")
|
||||
|
||||
FORMATS: dict[str, FormatSpec] = {
|
||||
"PNG": FormatSpec(ext=".png", round_trip=False, label="PNG"),
|
||||
"TIFF": FormatSpec(ext=".tiff", round_trip=False, label="TIFF"),
|
||||
"NPZ": FormatSpec(ext=".npz", round_trip=False, label="NumPy (.npz)"),
|
||||
}
|
||||
|
||||
|
||||
def save(path: Path, value: np.ndarray, format_name: str, **_opts) -> None:
|
||||
arr = np.asarray(value)
|
||||
if format_name == "PNG":
|
||||
from PIL import Image
|
||||
Image.fromarray(image_to_uint8(arr)).save(str(path))
|
||||
return
|
||||
if format_name == "TIFF":
|
||||
import tifffile
|
||||
tifffile.imwrite(str(path), image_to_uint8(arr))
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), image=arr)
|
||||
return
|
||||
raise ValueError(f"Format {format_name!r} is not supported for IMAGE.")
|
||||
182
backend/exporters/line.py
Normal file
182
backend/exporters/line.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Exporter for LINE values (1-D profiles as LineData or bare ndarrays).
|
||||
|
||||
PNG / TIFF render a plot image via Pillow; CSV / JSON / NPZ save the raw
|
||||
(x, y, unit) arrays. The plot renderer is self-contained (no matplotlib
|
||||
dependency) and handles SI-prefix axis labels.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import LineData, _PREFIXABLE_UNITS, _SI_PREFIXES
|
||||
from backend.exporters._base import FormatSpec
|
||||
|
||||
accepted_types: tuple[str, ...] = ("LINE",)
|
||||
|
||||
FORMATS: dict[str, FormatSpec] = {
|
||||
"PNG": FormatSpec(ext=".png", round_trip=False, label="PNG plot"),
|
||||
"TIFF": FormatSpec(ext=".tiff", round_trip=False, label="TIFF plot"),
|
||||
"CSV": FormatSpec(ext=".csv", round_trip=True, label="CSV"),
|
||||
"NPZ": FormatSpec(ext=".npz", round_trip=False, label="NumPy (.npz)"),
|
||||
"JSON": FormatSpec(ext=".json", round_trip=True, label="JSON"),
|
||||
}
|
||||
|
||||
|
||||
def save(path: Path, value, format_name: str, *, plot_title: str = "", **_opts) -> None:
|
||||
line = value if isinstance(value, LineData) else LineData(data=np.asarray(value).ravel())
|
||||
|
||||
y = np.asarray(line.data, dtype=np.float64).ravel()
|
||||
if line.x_axis is not None:
|
||||
x = np.asarray(line.x_axis, dtype=np.float64).ravel()[: len(y)]
|
||||
else:
|
||||
x = np.arange(len(y), dtype=np.float64)
|
||||
|
||||
if format_name in ("PNG", "TIFF"):
|
||||
_save_line_plot(path, x, y, line.x_unit, line.y_unit, plot_title, format_name)
|
||||
return
|
||||
if format_name == "CSV":
|
||||
with path.open("w", newline="", encoding="utf-8") as fh:
|
||||
writer = csv.writer(fh)
|
||||
writer.writerow(["x", "y", "x_unit", "y_unit"])
|
||||
for xv, yv in zip(x, y):
|
||||
writer.writerow([xv, yv, line.x_unit, line.y_unit])
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), x=x, y=y)
|
||||
return
|
||||
if format_name == "JSON":
|
||||
path.write_text(
|
||||
json.dumps({
|
||||
"x": x.tolist(),
|
||||
"y": y.tolist(),
|
||||
"x_unit": line.x_unit,
|
||||
"y_unit": line.y_unit,
|
||||
}, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return
|
||||
raise ValueError(f"Format {format_name!r} is not supported for LINE.")
|
||||
|
||||
|
||||
def _save_line_plot(
|
||||
path: Path,
|
||||
x: np.ndarray,
|
||||
y: np.ndarray,
|
||||
x_unit: str,
|
||||
y_unit: str,
|
||||
title: str,
|
||||
format_name: str,
|
||||
) -> None:
|
||||
"""Render a simple PNG/TIFF line plot with SI-prefixed axes.
|
||||
|
||||
Intentionally self-contained (Pillow only, no matplotlib) so that builds
|
||||
stay lean. Layout is fixed 1200×750 with 5×5 grid and a single blue line.
|
||||
"""
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
w, h = 1200, 750
|
||||
bg = (255, 255, 255)
|
||||
line_color = (79, 142, 247) # #4f8ef7
|
||||
grid_color = (200, 200, 200)
|
||||
text_color = (60, 60, 60)
|
||||
margin = {"left": 80, "right": 30, "top": 50, "bottom": 60}
|
||||
|
||||
img = Image.new("RGB", (w, h), bg)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("DejaVuSans.ttf", 14)
|
||||
font_small = ImageFont.truetype("DejaVuSans.ttf", 11)
|
||||
font_title = ImageFont.truetype("DejaVuSans.ttf", 16)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
font_small = font
|
||||
font_title = font
|
||||
|
||||
pw = w - margin["left"] - margin["right"]
|
||||
ph = h - margin["top"] - margin["bottom"]
|
||||
|
||||
def _si_scale(unit: str, vmin: float, vmax: float) -> tuple[float, str]:
|
||||
"""Pick the best SI prefix for an axis range. Returns (divisor, prefixed_unit)."""
|
||||
unit = (unit or "").strip()
|
||||
if not unit or unit not in _PREFIXABLE_UNITS:
|
||||
return 1.0, unit if unit else ""
|
||||
peak = max(abs(vmin), abs(vmax))
|
||||
if peak == 0:
|
||||
return 1.0, unit
|
||||
for scale, prefix in _SI_PREFIXES:
|
||||
if peak / scale >= 1.0:
|
||||
return scale, f"{prefix}{unit}"
|
||||
return _SI_PREFIXES[-1][0], f"{_SI_PREFIXES[-1][1]}{unit}"
|
||||
|
||||
xmin, xmax = float(np.nanmin(x)), float(np.nanmax(x))
|
||||
ymin, ymax = float(np.nanmin(y)), float(np.nanmax(y))
|
||||
|
||||
x_scale, x_label = _si_scale(x_unit, xmin, xmax)
|
||||
y_scale, y_label = _si_scale(y_unit, ymin, ymax)
|
||||
if not x_label:
|
||||
x_label = "x"
|
||||
if not y_label:
|
||||
y_label = "y"
|
||||
|
||||
x = x / x_scale
|
||||
y = y / y_scale
|
||||
xmin, xmax = xmin / x_scale, xmax / x_scale
|
||||
ymin, ymax = ymin / y_scale, ymax / y_scale
|
||||
|
||||
if ymax == ymin:
|
||||
ymin, ymax = ymin - 1, ymax + 1
|
||||
if xmax == xmin:
|
||||
xmax = xmin + 1
|
||||
ypad = (ymax - ymin) * 0.05
|
||||
ymin -= ypad
|
||||
ymax += ypad
|
||||
|
||||
def to_px(xv: float, yv: float) -> tuple[float, float]:
|
||||
px = margin["left"] + (xv - xmin) / (xmax - xmin) * pw
|
||||
py = margin["top"] + (1.0 - (yv - ymin) / (ymax - ymin)) * ph
|
||||
return px, py
|
||||
|
||||
for i in range(6):
|
||||
gy = ymin + (ymax - ymin) * i / 5
|
||||
_, py = to_px(xmin, gy)
|
||||
draw.line([(margin["left"], py), (margin["left"] + pw, py)], fill=grid_color, width=1)
|
||||
label = f"{gy:.4g}"
|
||||
draw.text((margin["left"] - 8, py - 6), label, fill=text_color, font=font_small, anchor="rm")
|
||||
|
||||
gx = xmin + (xmax - xmin) * i / 5
|
||||
px, _ = to_px(gx, ymin)
|
||||
draw.line([(px, margin["top"]), (px, margin["top"] + ph)], fill=grid_color, width=1)
|
||||
label = f"{gx:.4g}"
|
||||
draw.text((px, margin["top"] + ph + 6), label, fill=text_color, font=font_small, anchor="mt")
|
||||
|
||||
n = len(y)
|
||||
step = max(1, n // pw)
|
||||
xs, ys = x[::step], y[::step]
|
||||
pts = [to_px(float(xs[i]), float(ys[i])) for i in range(len(xs))]
|
||||
if len(pts) > 1:
|
||||
draw.line(pts, fill=line_color, width=2)
|
||||
|
||||
draw.rectangle(
|
||||
[margin["left"], margin["top"], margin["left"] + pw, margin["top"] + ph],
|
||||
outline=(100, 100, 100), width=1,
|
||||
)
|
||||
draw.text((margin["left"] + pw // 2, h - 10), x_label, fill=text_color, font=font, anchor="mb")
|
||||
|
||||
y_label_img = Image.new("RGBA", (200, 20), (0, 0, 0, 0))
|
||||
y_draw = ImageDraw.Draw(y_label_img)
|
||||
y_draw.text((100, 10), y_label, fill=text_color, font=font, anchor="mm")
|
||||
y_label_img = y_label_img.rotate(90, expand=True)
|
||||
img.paste(y_label_img, (2, margin["top"] + ph // 2 - y_label_img.height // 2), y_label_img)
|
||||
|
||||
if title and title.strip():
|
||||
draw.text((w // 2, 10), title.strip(), fill=text_color, font=font_title, anchor="mt")
|
||||
|
||||
ext = ".png" if format_name == "PNG" else ".tiff"
|
||||
img.save(str(path.with_suffix(ext)))
|
||||
60
backend/exporters/mesh.py
Normal file
60
backend/exporters/mesh.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Exporter for MESH_MODEL values (Wavefront OBJ, ASCII STL).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import MeshModel
|
||||
from backend.exporters._base import FormatSpec
|
||||
|
||||
accepted_types: tuple[str, ...] = ("MESH_MODEL",)
|
||||
|
||||
FORMATS: dict[str, FormatSpec] = {
|
||||
"OBJ": FormatSpec(ext=".obj", round_trip=True, label="Wavefront OBJ"),
|
||||
"STL": FormatSpec(ext=".stl", round_trip=True, label="STL (ASCII)"),
|
||||
}
|
||||
|
||||
|
||||
def save(path: Path, value: MeshModel, format_name: str, **_opts) -> None:
|
||||
if format_name == "OBJ":
|
||||
_save_obj(path, value)
|
||||
return
|
||||
if format_name == "STL":
|
||||
_save_stl(path, value)
|
||||
return
|
||||
raise ValueError(f"Format {format_name!r} is not supported for MESH_MODEL.")
|
||||
|
||||
|
||||
def _save_obj(path: Path, mesh: MeshModel) -> None:
|
||||
lines: list[str] = []
|
||||
for vertex in mesh.vertices:
|
||||
lines.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}")
|
||||
for face in mesh.faces:
|
||||
lines.append(f"f {int(face[0]) + 1} {int(face[1]) + 1} {int(face[2]) + 1}")
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def _save_stl(path: Path, mesh: MeshModel) -> None:
|
||||
def normal(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
|
||||
n = np.cross(b - a, c - a)
|
||||
length = float(np.linalg.norm(n))
|
||||
return n / length if length > 0 else np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
||||
|
||||
lines = ["solid tono"]
|
||||
vertices = np.asarray(mesh.vertices, dtype=np.float32)
|
||||
for face in np.asarray(mesh.faces, dtype=np.int32):
|
||||
a, b, c = vertices[int(face[0])], vertices[int(face[1])], vertices[int(face[2])]
|
||||
n = normal(a, b, c)
|
||||
lines.append(f" facet normal {n[0]} {n[1]} {n[2]}")
|
||||
lines.append(" outer loop")
|
||||
lines.append(f" vertex {a[0]} {a[1]} {a[2]}")
|
||||
lines.append(f" vertex {b[0]} {b[1]} {b[2]}")
|
||||
lines.append(f" vertex {c[0]} {c[1]} {c[2]}")
|
||||
lines.append(" endloop")
|
||||
lines.append(" endfacet")
|
||||
lines.append("endsolid tono")
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
28
backend/exporters/scalar.py
Normal file
28
backend/exporters/scalar.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Exporter for FLOAT scalars (also handles Python int and numpy scalar types).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from backend.exporters._base import FormatSpec
|
||||
|
||||
accepted_types: tuple[str, ...] = ("FLOAT",)
|
||||
|
||||
FORMATS: dict[str, FormatSpec] = {
|
||||
"TXT": FormatSpec(ext=".txt", round_trip=True, label="Text"),
|
||||
"JSON": FormatSpec(ext=".json", round_trip=True, label="JSON"),
|
||||
}
|
||||
|
||||
|
||||
def save(path: Path, value: float, format_name: str, **_opts) -> None:
|
||||
numeric = float(value)
|
||||
if format_name == "TXT":
|
||||
path.write_text(f"{numeric}\n", encoding="utf-8")
|
||||
return
|
||||
if format_name == "JSON":
|
||||
path.write_text(json.dumps({"value": numeric}, indent=2), encoding="utf-8")
|
||||
return
|
||||
raise ValueError(f"Format {format_name!r} is not supported for scalar values.")
|
||||
44
backend/exporters/table.py
Normal file
44
backend/exporters/table.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Exporter for RECORD_TABLE and DATA_TABLE values.
|
||||
|
||||
Both types are list-of-dict; the Save node currently accepts plain lists in
|
||||
this slot too, which is preserved here. CSV auto-derives its column set from
|
||||
the first row's keys (and any additional keys that appear later), matching
|
||||
the prior behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from backend.exporters._base import FormatSpec
|
||||
|
||||
accepted_types: tuple[str, ...] = ("RECORD_TABLE", "DATA_TABLE")
|
||||
|
||||
FORMATS: dict[str, FormatSpec] = {
|
||||
"CSV": FormatSpec(ext=".csv", round_trip=True, label="CSV"),
|
||||
"JSON": FormatSpec(ext=".json", round_trip=True, label="JSON"),
|
||||
}
|
||||
|
||||
|
||||
def save(path: Path, value: list, format_name: str, **_opts) -> None:
|
||||
rows = list(value)
|
||||
if format_name == "JSON":
|
||||
path.write_text(json.dumps(rows, indent=2), encoding="utf-8")
|
||||
return
|
||||
if format_name == "CSV":
|
||||
columns: list[str] = []
|
||||
for row in rows:
|
||||
if isinstance(row, dict):
|
||||
for key in row.keys():
|
||||
if key not in columns:
|
||||
columns.append(str(key))
|
||||
with path.open("w", newline="", encoding="utf-8") as fh:
|
||||
writer = csv.DictWriter(fh, fieldnames=columns)
|
||||
writer.writeheader()
|
||||
for row in rows:
|
||||
writer.writerow(row if isinstance(row, dict) else {"value": row})
|
||||
return
|
||||
raise ValueError(f"Format {format_name!r} is not supported for table inputs.")
|
||||
@@ -20,19 +20,42 @@ def load(path: Path) -> list[DataField]:
|
||||
|
||||
fields = []
|
||||
for ch in channels.values():
|
||||
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
|
||||
# gwyfile.objects.GwyDataField exposes .data as an already-2D ndarray
|
||||
# (no xres/yres attributes — those were removed in gwyfile 0.3+).
|
||||
data = np.asarray(ch.data, dtype=np.float64)
|
||||
if data.ndim != 2:
|
||||
# Defensive: if a future gwyfile version yields a flat buffer, the
|
||||
# dimensions live in the serialized object's xres/yres keys.
|
||||
xres = int(ch.get("xres", data.size))
|
||||
yres = int(ch.get("yres", 1))
|
||||
data = data.reshape(yres, xres)
|
||||
fields.append(DataField(
|
||||
data=data,
|
||||
xreal=float(ch.xreal),
|
||||
yreal=float(ch.yreal),
|
||||
xoff=float(getattr(ch, "xoff", 0.0)),
|
||||
yoff=float(getattr(ch, "yoff", 0.0)),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
si_unit_xy=_unit_str(getattr(ch, "si_unit_xy", None)) or "m",
|
||||
si_unit_z=_unit_str(getattr(ch, "si_unit_z", None)) or "m",
|
||||
))
|
||||
return fields
|
||||
|
||||
|
||||
def _unit_str(si_unit: object) -> str:
|
||||
"""Extract the unit string from a GwySIUnit without importing gwyfile.
|
||||
|
||||
Loaded GwySIUnit objects behave like dicts with a ``unitstr`` key.
|
||||
"""
|
||||
if si_unit is None:
|
||||
return ""
|
||||
if hasattr(si_unit, "unitstr"):
|
||||
return str(getattr(si_unit, "unitstr") or "")
|
||||
try:
|
||||
return str(si_unit["unitstr"] or "")
|
||||
except (KeyError, TypeError):
|
||||
return ""
|
||||
|
||||
|
||||
def channel_names(path: Path) -> list[str]:
|
||||
import gwyfile
|
||||
try:
|
||||
|
||||
@@ -1,26 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.execution_context import emit_warning, emit_file_download
|
||||
from backend.data_types import (
|
||||
DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8,
|
||||
_SI_PREFIXES, _PREFIXABLE_UNITS,
|
||||
from backend.exporters import (
|
||||
available_formats,
|
||||
get_exporter,
|
||||
resolve_path,
|
||||
type_name_for_value,
|
||||
)
|
||||
|
||||
DOWNLOAD_DIR = Path(tempfile.gettempdir()) / "tono-downloads"
|
||||
|
||||
|
||||
def _choices_by_source_type() -> dict[str, list[str]]:
|
||||
"""Build the format dropdown's source-type map from the exporter registry.
|
||||
|
||||
Centralising this here means adding a new exporter module (or a new format
|
||||
inside an existing one) automatically surfaces in the UI — no parallel
|
||||
list to keep in sync.
|
||||
"""
|
||||
return {
|
||||
"DATA_FIELD": available_formats("DATA_FIELD"),
|
||||
"IMAGE": available_formats("IMAGE"),
|
||||
"ANNOTATION_SOURCE": available_formats("ANNOTATION_SOURCE"),
|
||||
"LINE": available_formats("LINE"),
|
||||
"RECORD_TABLE": available_formats("RECORD_TABLE"),
|
||||
"DATA_TABLE": available_formats("DATA_TABLE"),
|
||||
"FLOAT": available_formats("FLOAT"),
|
||||
"MESH_MODEL": available_formats("MESH_MODEL"),
|
||||
}
|
||||
|
||||
|
||||
@register_node(display_name="Save")
|
||||
class Save:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
choices = _choices_by_source_type()
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("STRING", {
|
||||
@@ -41,17 +59,8 @@ class Save:
|
||||
],
|
||||
}),
|
||||
"format": ("STRING", {
|
||||
"default": "TIFF",
|
||||
"choices_by_source_type": {
|
||||
"DATA_FIELD": ["TIFF", "PNG", "NPZ"],
|
||||
"IMAGE": ["PNG", "TIFF", "NPZ"],
|
||||
"ANNOTATION_SOURCE": ["PNG", "TIFF", "NPZ"],
|
||||
"LINE": ["PNG", "TIFF", "CSV", "NPZ", "JSON"],
|
||||
"RECORD_TABLE": ["CSV", "JSON"],
|
||||
"DATA_TABLE": ["CSV", "JSON"],
|
||||
"FLOAT": ["TXT", "JSON"],
|
||||
"MESH_MODEL": ["OBJ", "STL"],
|
||||
},
|
||||
"default": choices["DATA_FIELD"][0] if choices["DATA_FIELD"] else "",
|
||||
"choices_by_source_type": choices,
|
||||
"source_type_input": "value",
|
||||
}),
|
||||
},
|
||||
@@ -71,10 +80,12 @@ class Save:
|
||||
OUTPUT_NODE = True
|
||||
MANUAL_TRIGGER = True
|
||||
DESCRIPTION = (
|
||||
"Save a single graph value to disk. Supports fields, images, lines, tables, scalars, and 3D meshes."
|
||||
"Save a single graph value to disk. Supports fields, images, lines, tables, scalars, "
|
||||
"and 3D meshes. Use 'GWY' or 'TIFF (data)' for DataFields you want to re-open later "
|
||||
"with their physical units preserved."
|
||||
)
|
||||
|
||||
KEYWORDS = ("export", "write", "download", "png", "tiff", "csv", "json", "npz", "obj", "stl")
|
||||
KEYWORDS = ("export", "write", "download", "png", "tiff", "csv", "json", "npz", "obj", "stl", "gwy")
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -83,295 +94,11 @@ class Save:
|
||||
value,
|
||||
plot_title: str = "",
|
||||
):
|
||||
path = self._resolve_save_path(filename, format)
|
||||
|
||||
if isinstance(value, MeshModel):
|
||||
self._save_mesh(path, value, format)
|
||||
elif isinstance(value, DataField):
|
||||
self._save_datafield(path, value, format)
|
||||
elif isinstance(value, np.ndarray):
|
||||
if value.ndim == 1:
|
||||
self._save_line(path, LineData(data=value), format, title=plot_title)
|
||||
else:
|
||||
self._save_image_or_array(path, value, format)
|
||||
elif isinstance(value, LineData):
|
||||
self._save_line(path, value, format, title=plot_title)
|
||||
elif isinstance(value, list):
|
||||
self._save_table(path, value, format)
|
||||
elif isinstance(value, (int, float, np.floating, np.integer)):
|
||||
self._save_scalar(path, float(value), format)
|
||||
else:
|
||||
raise ValueError(f"Save does not support input type: {type(value).__name__}")
|
||||
type_name = type_name_for_value(value)
|
||||
module, spec = get_exporter(type_name, format)
|
||||
path = resolve_path(filename, spec, DOWNLOAD_DIR)
|
||||
module.save(path, value, format, plot_title=plot_title)
|
||||
|
||||
emit_warning(f"Saved to {path.name}")
|
||||
emit_file_download(str(path))
|
||||
return ()
|
||||
|
||||
def _resolve_save_path(self, filename: str, format_name: str) -> Path:
|
||||
ext_map = {
|
||||
"PNG": ".png",
|
||||
"TIFF": ".tiff",
|
||||
"NPZ": ".npz",
|
||||
"CSV": ".csv",
|
||||
"JSON": ".json",
|
||||
"OBJ": ".obj",
|
||||
"STL": ".stl",
|
||||
"TXT": ".txt",
|
||||
}
|
||||
ext = ext_map[format_name]
|
||||
|
||||
raw_filename = str(filename).strip() if filename is not None else ""
|
||||
if not raw_filename:
|
||||
raise ValueError("No output filename selected — enter a file name.")
|
||||
|
||||
candidate = Path(raw_filename).expanduser()
|
||||
if candidate.is_absolute():
|
||||
candidate.parent.mkdir(parents=True, exist_ok=True)
|
||||
path = candidate
|
||||
else:
|
||||
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
path = DOWNLOAD_DIR / candidate.name
|
||||
|
||||
if path.suffix.lower() != ext:
|
||||
path = path.with_suffix(ext)
|
||||
return path
|
||||
|
||||
def _save_datafield(self, path: Path, field: DataField, format_name: str):
|
||||
if format_name == "TIFF":
|
||||
import tifffile
|
||||
tifffile.imwrite(str(path), datafield_to_uint8(field, field.colormap))
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), field=np.asarray(field.data))
|
||||
return
|
||||
if format_name == "PNG":
|
||||
from PIL import Image
|
||||
Image.fromarray(datafield_to_uint8(field, field.colormap)).save(str(path))
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for DATA_FIELD.")
|
||||
|
||||
def _save_image_or_array(self, path: Path, image: np.ndarray, format_name: str):
|
||||
arr = np.asarray(image)
|
||||
if format_name == "PNG":
|
||||
from PIL import Image
|
||||
Image.fromarray(image_to_uint8(arr)).save(str(path))
|
||||
return
|
||||
if format_name == "TIFF":
|
||||
import tifffile
|
||||
tifffile.imwrite(str(path), image_to_uint8(arr))
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), image=arr)
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for IMAGE.")
|
||||
|
||||
def _save_line(self, path: Path, line: LineData, format_name: str, title: str = ""):
|
||||
y = np.asarray(line.data, dtype=np.float64).ravel()
|
||||
x = np.asarray(line.x_axis, dtype=np.float64).ravel()[: len(y)] if line.x_axis is not None else np.arange(len(y), dtype=np.float64)
|
||||
if format_name in ("PNG", "TIFF"):
|
||||
self._save_line_plot(path, x, y, line.x_unit, line.y_unit, title, format_name)
|
||||
return
|
||||
if format_name == "CSV":
|
||||
with path.open("w", newline="", encoding="utf-8") as fh:
|
||||
writer = csv.writer(fh)
|
||||
writer.writerow(["x", "y", "x_unit", "y_unit"])
|
||||
for xv, yv in zip(x, y):
|
||||
writer.writerow([xv, yv, line.x_unit, line.y_unit])
|
||||
return
|
||||
if format_name == "NPZ":
|
||||
np.savez(str(path), x=x, y=y)
|
||||
return
|
||||
if format_name == "JSON":
|
||||
path.write_text(json.dumps({
|
||||
"x": x.tolist(),
|
||||
"y": y.tolist(),
|
||||
"x_unit": line.x_unit,
|
||||
"y_unit": line.y_unit,
|
||||
}, indent=2), encoding="utf-8")
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for LINE.")
|
||||
|
||||
def _save_line_plot(
|
||||
self,
|
||||
path: Path,
|
||||
x: np.ndarray,
|
||||
y: np.ndarray,
|
||||
x_unit: str,
|
||||
y_unit: str,
|
||||
title: str,
|
||||
format_name: str,
|
||||
):
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
w, h = 1200, 750
|
||||
bg = (255, 255, 255)
|
||||
line_color = (79, 142, 247) # #4f8ef7
|
||||
grid_color = (200, 200, 200)
|
||||
text_color = (60, 60, 60)
|
||||
margin = {"left": 80, "right": 30, "top": 50, "bottom": 60}
|
||||
|
||||
img = Image.new("RGB", (w, h), bg)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("DejaVuSans.ttf", 14)
|
||||
font_small = ImageFont.truetype("DejaVuSans.ttf", 11)
|
||||
font_title = ImageFont.truetype("DejaVuSans.ttf", 16)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
font_small = font
|
||||
font_title = font
|
||||
|
||||
pw = w - margin["left"] - margin["right"]
|
||||
ph = h - margin["top"] - margin["bottom"]
|
||||
|
||||
def _si_scale(unit: str, vmin: float, vmax: float) -> tuple[float, str]:
|
||||
"""Pick the best SI prefix for an axis range. Returns (divisor, prefixed_unit)."""
|
||||
unit = (unit or "").strip()
|
||||
if not unit or unit not in _PREFIXABLE_UNITS:
|
||||
return 1.0, unit if unit else ""
|
||||
peak = max(abs(vmin), abs(vmax))
|
||||
if peak == 0:
|
||||
return 1.0, unit
|
||||
for scale, prefix in _SI_PREFIXES:
|
||||
if peak / scale >= 1.0:
|
||||
return scale, f"{prefix}{unit}"
|
||||
return _SI_PREFIXES[-1][0], f"{_SI_PREFIXES[-1][1]}{unit}"
|
||||
|
||||
xmin, xmax = float(np.nanmin(x)), float(np.nanmax(x))
|
||||
ymin, ymax = float(np.nanmin(y)), float(np.nanmax(y))
|
||||
|
||||
x_scale, x_label = _si_scale(x_unit, xmin, xmax)
|
||||
y_scale, y_label = _si_scale(y_unit, ymin, ymax)
|
||||
if not x_label:
|
||||
x_label = "x"
|
||||
if not y_label:
|
||||
y_label = "y"
|
||||
|
||||
# Scale data into prefixed units
|
||||
x = x / x_scale
|
||||
y = y / y_scale
|
||||
xmin, xmax = xmin / x_scale, xmax / x_scale
|
||||
ymin, ymax = ymin / y_scale, ymax / y_scale
|
||||
|
||||
if ymax == ymin:
|
||||
ymin, ymax = ymin - 1, ymax + 1
|
||||
if xmax == xmin:
|
||||
xmax = xmin + 1
|
||||
# Add 5% padding to y range
|
||||
ypad = (ymax - ymin) * 0.05
|
||||
ymin -= ypad
|
||||
ymax += ypad
|
||||
|
||||
def to_px(xv: float, yv: float) -> tuple[float, float]:
|
||||
px = margin["left"] + (xv - xmin) / (xmax - xmin) * pw
|
||||
py = margin["top"] + (1.0 - (yv - ymin) / (ymax - ymin)) * ph
|
||||
return px, py
|
||||
|
||||
# Grid lines (5 horizontal, 5 vertical)
|
||||
for i in range(6):
|
||||
gy = ymin + (ymax - ymin) * i / 5
|
||||
_, py = to_px(xmin, gy)
|
||||
draw.line([(margin["left"], py), (margin["left"] + pw, py)], fill=grid_color, width=1)
|
||||
label = f"{gy:.4g}"
|
||||
draw.text((margin["left"] - 8, py - 6), label, fill=text_color, font=font_small, anchor="rm")
|
||||
|
||||
gx = xmin + (xmax - xmin) * i / 5
|
||||
px, _ = to_px(gx, ymin)
|
||||
draw.line([(px, margin["top"]), (px, margin["top"] + ph)], fill=grid_color, width=1)
|
||||
label = f"{gx:.4g}"
|
||||
draw.text((px, margin["top"] + ph + 6), label, fill=text_color, font=font_small, anchor="mt")
|
||||
|
||||
# Plot line
|
||||
n = len(y)
|
||||
step = max(1, n // pw)
|
||||
xs, ys = x[::step], y[::step]
|
||||
pts = [to_px(float(xs[i]), float(ys[i])) for i in range(len(xs))]
|
||||
if len(pts) > 1:
|
||||
draw.line(pts, fill=line_color, width=2)
|
||||
|
||||
# Border
|
||||
draw.rectangle(
|
||||
[margin["left"], margin["top"], margin["left"] + pw, margin["top"] + ph],
|
||||
outline=(100, 100, 100), width=1,
|
||||
)
|
||||
draw.text((margin["left"] + pw // 2, h - 10), x_label, fill=text_color, font=font, anchor="mb")
|
||||
# Vertical y label — draw rotated
|
||||
y_label_img = Image.new("RGBA", (200, 20), (0, 0, 0, 0))
|
||||
y_draw = ImageDraw.Draw(y_label_img)
|
||||
y_draw.text((100, 10), y_label, fill=text_color, font=font, anchor="mm")
|
||||
y_label_img = y_label_img.rotate(90, expand=True)
|
||||
img.paste(y_label_img, (2, margin["top"] + ph // 2 - y_label_img.height // 2), y_label_img)
|
||||
|
||||
# Title
|
||||
if title and title.strip():
|
||||
draw.text((w // 2, 10), title.strip(), fill=text_color, font=font_title, anchor="mt")
|
||||
|
||||
ext = ".png" if format_name == "PNG" else ".tiff"
|
||||
img.save(str(path.with_suffix(ext)))
|
||||
|
||||
def _save_table(self, path: Path, rows: list, format_name: str):
|
||||
if format_name == "JSON":
|
||||
path.write_text(json.dumps(rows, indent=2), encoding="utf-8")
|
||||
return
|
||||
if format_name == "CSV":
|
||||
columns: list[str] = []
|
||||
for row in rows:
|
||||
if isinstance(row, dict):
|
||||
for key in row.keys():
|
||||
if key not in columns:
|
||||
columns.append(str(key))
|
||||
with path.open("w", newline="", encoding="utf-8") as fh:
|
||||
writer = csv.DictWriter(fh, fieldnames=columns)
|
||||
writer.writeheader()
|
||||
for row in rows:
|
||||
writer.writerow(row if isinstance(row, dict) else {"value": row})
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for table inputs.")
|
||||
|
||||
def _save_scalar(self, path: Path, value: float, format_name: str):
|
||||
if format_name == "TXT":
|
||||
path.write_text(f"{value}\n", encoding="utf-8")
|
||||
return
|
||||
if format_name == "JSON":
|
||||
path.write_text(json.dumps({"value": value}, indent=2), encoding="utf-8")
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for scalar values.")
|
||||
|
||||
def _save_mesh(self, path: Path, mesh: MeshModel, format_name: str):
|
||||
if format_name == "OBJ":
|
||||
self._save_obj(path, mesh)
|
||||
return
|
||||
if format_name == "STL":
|
||||
self._save_stl(path, mesh)
|
||||
return
|
||||
raise ValueError(f"Format {format_name} is not supported for MESH_MODEL.")
|
||||
|
||||
def _save_obj(self, path: Path, mesh: MeshModel):
|
||||
lines = []
|
||||
for vertex in mesh.vertices:
|
||||
lines.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}")
|
||||
for face in mesh.faces:
|
||||
lines.append(f"f {int(face[0]) + 1} {int(face[1]) + 1} {int(face[2]) + 1}")
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
def _save_stl(self, path: Path, mesh: MeshModel):
|
||||
def normal(a, b, c):
|
||||
n = np.cross(b - a, c - a)
|
||||
length = float(np.linalg.norm(n))
|
||||
return n / length if length > 0 else np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
||||
|
||||
lines = ["solid tono"]
|
||||
vertices = np.asarray(mesh.vertices, dtype=np.float32)
|
||||
for face in np.asarray(mesh.faces, dtype=np.int32):
|
||||
a, b, c = vertices[int(face[0])], vertices[int(face[1])], vertices[int(face[2])]
|
||||
n = normal(a, b, c)
|
||||
lines.append(f" facet normal {n[0]} {n[1]} {n[2]}")
|
||||
lines.append(" outer loop")
|
||||
lines.append(f" vertex {a[0]} {a[1]} {a[2]}")
|
||||
lines.append(f" vertex {b[0]} {b[1]} {b[2]}")
|
||||
lines.append(f" vertex {c[0]} {c[1]} {c[2]}")
|
||||
lines.append(" endloop")
|
||||
lines.append(" endfacet")
|
||||
lines.append("endsolid tono")
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
300
tests/node_tests/exporters.py
Normal file
300
tests/node_tests/exporters.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Tests for the exporter registry and the round-trippable DataField formats.
|
||||
|
||||
The Save node's format-specific behavior is covered in test_save_generic
|
||||
(tests/node_tests/save.py). This module focuses on:
|
||||
|
||||
1. Registry contract — every exporter module satisfies the protocol.
|
||||
2. Dispatch — type_name_for_value classifies values correctly and
|
||||
get_exporter returns a matching module.
|
||||
3. Round-trip — GWY and TIFF (data) preserve xreal/yreal/units/data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import (
|
||||
DataField,
|
||||
DataTable,
|
||||
ImageData,
|
||||
LineData,
|
||||
MeshModel,
|
||||
RecordTable,
|
||||
)
|
||||
|
||||
|
||||
def test_exporter_registry_contract():
|
||||
"""Every registered exporter module must expose the required attributes."""
|
||||
from backend.exporters import _REGISTRY
|
||||
from backend.exporters._base import FormatSpec
|
||||
|
||||
assert _REGISTRY, "Registry must not be empty"
|
||||
seen_modules = {mod for (mod, _) in _REGISTRY.values()}
|
||||
for module in seen_modules:
|
||||
assert hasattr(module, "accepted_types")
|
||||
assert hasattr(module, "FORMATS")
|
||||
assert hasattr(module, "save")
|
||||
assert isinstance(module.accepted_types, tuple)
|
||||
assert all(isinstance(t, str) and t.isupper() for t in module.accepted_types)
|
||||
assert isinstance(module.FORMATS, dict)
|
||||
for name, spec in module.FORMATS.items():
|
||||
assert isinstance(name, str) and name
|
||||
assert isinstance(spec, FormatSpec)
|
||||
assert spec.ext.startswith(".")
|
||||
|
||||
|
||||
def test_type_name_for_value_classification():
|
||||
from backend.exporters import type_name_for_value
|
||||
|
||||
assert type_name_for_value(DataField(data=np.zeros((4, 4)))) == "DATA_FIELD"
|
||||
assert type_name_for_value(np.zeros((4, 4))) == "IMAGE"
|
||||
assert type_name_for_value(np.zeros((4, 4, 3), dtype=np.uint8)) == "IMAGE"
|
||||
assert type_name_for_value(ImageData(np.zeros((4, 4), dtype=np.uint8))) == "IMAGE"
|
||||
assert type_name_for_value(np.zeros(8)) == "LINE"
|
||||
assert type_name_for_value(LineData(data=np.zeros(8))) == "LINE"
|
||||
assert type_name_for_value(RecordTable([{"a": 1}])) == "RECORD_TABLE"
|
||||
assert type_name_for_value(DataTable([{"a": 1}])) == "DATA_TABLE"
|
||||
assert type_name_for_value(1.25) == "FLOAT"
|
||||
assert type_name_for_value(np.float64(0.5)) == "FLOAT"
|
||||
mesh = MeshModel(
|
||||
vertices=np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=np.float32),
|
||||
faces=np.array([[0, 1, 2]], dtype=np.int32),
|
||||
)
|
||||
assert type_name_for_value(mesh) == "MESH_MODEL"
|
||||
|
||||
try:
|
||||
type_name_for_value(object())
|
||||
assert False, "Expected ValueError for unsupported type"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_get_exporter_known_and_unknown():
|
||||
from backend.exporters import get_exporter
|
||||
|
||||
mod, spec = get_exporter("DATA_FIELD", "GWY")
|
||||
assert spec.ext == ".gwy"
|
||||
assert spec.round_trip is True
|
||||
|
||||
mod, spec = get_exporter("DATA_FIELD", "TIFF")
|
||||
assert spec.ext == ".tiff"
|
||||
# Legacy preview path — not round-trippable.
|
||||
assert spec.round_trip is False
|
||||
|
||||
mod, spec = get_exporter("DATA_FIELD", "TIFF (data)")
|
||||
assert spec.round_trip is True
|
||||
|
||||
try:
|
||||
get_exporter("DATA_FIELD", "DOES_NOT_EXIST")
|
||||
assert False, "Expected ValueError for unknown format"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
get_exporter("FLOAT", "GWY")
|
||||
assert False, "Expected ValueError for type/format mismatch"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_available_formats_includes_new_datafield_formats():
|
||||
from backend.exporters import available_formats
|
||||
|
||||
formats = available_formats("DATA_FIELD")
|
||||
assert "TIFF" in formats
|
||||
assert "TIFF (data)" in formats
|
||||
assert "GWY" in formats
|
||||
assert "PNG" in formats
|
||||
assert "NPZ" in formats
|
||||
assert "HDF5" in formats
|
||||
assert "HDF5 (Ergo)" in formats
|
||||
|
||||
|
||||
def test_datafield_gwy_round_trip():
|
||||
"""Writing a DataField to .gwy and reloading via the importer preserves everything."""
|
||||
from backend.importers import gwy as gwy_importer
|
||||
from backend.nodes.save import Save
|
||||
|
||||
rng = np.random.default_rng(7)
|
||||
data = rng.standard_normal((32, 48)).astype(np.float64) * 1e-9
|
||||
field = DataField(
|
||||
data=data,
|
||||
xreal=3.2e-6,
|
||||
yreal=2.4e-6,
|
||||
xoff=1.1e-7,
|
||||
yoff=-5.5e-7,
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = Path(tmpdir) / "topo"
|
||||
Save().save(filename=str(path), format="GWY", value=field)
|
||||
out_path = path.with_suffix(".gwy")
|
||||
assert out_path.exists()
|
||||
|
||||
reloaded = gwy_importer.load(out_path)
|
||||
assert len(reloaded) == 1
|
||||
rf = reloaded[0]
|
||||
assert rf.data.shape == field.data.shape
|
||||
assert np.allclose(rf.data, field.data)
|
||||
assert np.isclose(rf.xreal, field.xreal)
|
||||
assert np.isclose(rf.yreal, field.yreal)
|
||||
assert np.isclose(rf.xoff, field.xoff)
|
||||
assert np.isclose(rf.yoff, field.yoff)
|
||||
assert rf.si_unit_xy == "m"
|
||||
assert rf.si_unit_z == "m"
|
||||
|
||||
# channel_names() should return the stem we used as the title
|
||||
names = gwy_importer.channel_names(out_path)
|
||||
assert names == ["topo"]
|
||||
|
||||
|
||||
def test_datafield_tiff_data_round_trip():
|
||||
"""TIFF (data) writes float64 pixels + JSON metadata; we verify both."""
|
||||
import tifffile
|
||||
|
||||
from backend.nodes.save import Save
|
||||
|
||||
rng = np.random.default_rng(11)
|
||||
data = rng.standard_normal((24, 36)).astype(np.float64) * 1e-8
|
||||
field = DataField(
|
||||
data=data,
|
||||
xreal=5e-6,
|
||||
yreal=3e-6,
|
||||
xoff=0.0,
|
||||
yoff=0.0,
|
||||
si_unit_xy="m",
|
||||
si_unit_z="V",
|
||||
colormap="viridis",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = Path(tmpdir) / "field"
|
||||
Save().save(filename=str(path), format="TIFF (data)", value=field)
|
||||
out_path = path.with_suffix(".tiff")
|
||||
assert out_path.exists()
|
||||
|
||||
with tifffile.TiffFile(out_path) as tif:
|
||||
arr = tif.asarray()
|
||||
desc = tif.pages[0].tags["ImageDescription"].value
|
||||
|
||||
assert arr.dtype == np.float64
|
||||
assert arr.shape == field.data.shape
|
||||
assert np.allclose(arr, field.data)
|
||||
|
||||
meta = json.loads(desc)["tono"]
|
||||
assert meta["xreal"] == field.xreal
|
||||
assert meta["yreal"] == field.yreal
|
||||
assert meta["si_unit_xy"] == "m"
|
||||
assert meta["si_unit_z"] == "V"
|
||||
assert meta["domain"] == "spatial"
|
||||
|
||||
|
||||
def test_datafield_hdf5_generic_round_trip():
|
||||
"""HDF5 (generic) writes /data + attrs that our hdf5 importer reads back."""
|
||||
from backend.importers import hdf5 as hdf5_importer
|
||||
from backend.nodes.save import Save
|
||||
|
||||
rng = np.random.default_rng(23)
|
||||
data = rng.standard_normal((20, 28)).astype(np.float64) * 1e-7
|
||||
field = DataField(
|
||||
data=data,
|
||||
xreal=4.8e-6,
|
||||
yreal=3.2e-6,
|
||||
xoff=1.5e-7,
|
||||
yoff=-2.5e-7,
|
||||
si_unit_xy="m",
|
||||
si_unit_z="V",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = Path(tmpdir) / "topo"
|
||||
Save().save(filename=str(path), format="HDF5", value=field)
|
||||
out_path = path.with_suffix(".h5")
|
||||
assert out_path.exists()
|
||||
|
||||
reloaded = hdf5_importer.load(out_path)
|
||||
assert len(reloaded) == 1
|
||||
rf = reloaded[0]
|
||||
assert rf.data.shape == field.data.shape
|
||||
assert np.allclose(rf.data, field.data)
|
||||
assert np.isclose(rf.xreal, field.xreal)
|
||||
assert np.isclose(rf.yreal, field.yreal)
|
||||
assert np.isclose(rf.xoff, field.xoff)
|
||||
assert np.isclose(rf.yoff, field.yoff)
|
||||
assert rf.si_unit_xy == "m"
|
||||
assert rf.si_unit_z == "V"
|
||||
|
||||
|
||||
def test_datafield_hdf5_ergo_round_trip():
|
||||
"""HDF5 (Ergo) writes the Asylum sidecar layout and round-trips via ergo_hdf5."""
|
||||
import h5py
|
||||
|
||||
from backend.importers import ergo_hdf5 as ergo_importer
|
||||
from backend.nodes.save import Save
|
||||
|
||||
rng = np.random.default_rng(29)
|
||||
data = rng.standard_normal((16, 24)).astype(np.float64) * 1e-9
|
||||
field = DataField(
|
||||
data=data,
|
||||
xreal=2.5e-6,
|
||||
yreal=1.8e-6,
|
||||
xoff=0.5e-7,
|
||||
yoff=-1.1e-7,
|
||||
si_unit_xy="m",
|
||||
si_unit_z="N",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = Path(tmpdir) / "topo"
|
||||
Save().save(filename=str(path), format="HDF5 (Ergo)", value=field)
|
||||
out_path = path.with_suffix(".h5")
|
||||
assert out_path.exists()
|
||||
|
||||
# Sanity-check the layout: the dataset lives under
|
||||
# Image/DataSet/Resolution 0/Frame 0/<title>/Image, and the sidecar
|
||||
# group under Image/DataSetInfo/Global/Channels/<title>/ImageDims.
|
||||
with h5py.File(str(out_path), "r") as f:
|
||||
assert "Image/DataSet/Resolution 0/Frame 0/topo/Image" in f
|
||||
dims = f["Image/DataSetInfo/Global/Channels/topo/ImageDims"]
|
||||
scaling = np.asarray(dims.attrs["DimScaling"])
|
||||
assert scaling.shape == (2, 2)
|
||||
# DimScaling is Y-first: [[y_start, y_end], [x_start, x_end]]
|
||||
assert np.isclose(scaling[1, 1] - scaling[1, 0], field.xreal)
|
||||
assert np.isclose(scaling[0, 1] - scaling[0, 0], field.yreal)
|
||||
|
||||
reloaded = ergo_importer.load(out_path)
|
||||
assert len(reloaded) == 1
|
||||
rf = reloaded[0]
|
||||
assert rf.data.shape == field.data.shape
|
||||
assert np.allclose(rf.data, field.data)
|
||||
assert np.isclose(rf.xreal, field.xreal)
|
||||
assert np.isclose(rf.yreal, field.yreal)
|
||||
assert np.isclose(rf.xoff, field.xoff)
|
||||
assert np.isclose(rf.yoff, field.yoff)
|
||||
assert rf.si_unit_xy == "m"
|
||||
assert rf.si_unit_z == "N"
|
||||
|
||||
|
||||
def test_tiff_preview_is_still_rgb_uint8():
|
||||
"""The legacy TIFF format for DATA_FIELD must keep producing 8-bit RGB."""
|
||||
import tifffile
|
||||
|
||||
from backend.nodes.save import Save
|
||||
|
||||
field = DataField(
|
||||
data=np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float64),
|
||||
xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m",
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = Path(tmpdir) / "preview"
|
||||
Save().save(filename=str(path), format="TIFF", value=field)
|
||||
arr = tifffile.imread(str(path.with_suffix(".tiff")))
|
||||
assert arr.dtype == np.uint8
|
||||
assert arr.shape == (2, 2, 3)
|
||||
Reference in New Issue
Block a user