From 08aff81f02002ca455ddd013faac8af010f7051e Mon Sep 17 00:00:00 2001 From: matei jordache Date: Sun, 5 Apr 2026 13:28:26 -0700 Subject: [PATCH] dimensioned export (gwy, HDF5) --- backend/exporters/__init__.py | 128 ++++++++++++ backend/exporters/_base.py | 60 ++++++ backend/exporters/datafield.py | 215 ++++++++++++++++++++ backend/exporters/image.py | 41 ++++ backend/exporters/line.py | 182 +++++++++++++++++ backend/exporters/mesh.py | 60 ++++++ backend/exporters/scalar.py | 28 +++ backend/exporters/table.py | 44 +++++ backend/importers/gwy.py | 29 ++- backend/nodes/save.py | 347 ++++----------------------------- tests/node_tests/exporters.py | 300 ++++++++++++++++++++++++++++ 11 files changed, 1121 insertions(+), 313 deletions(-) create mode 100644 backend/exporters/__init__.py create mode 100644 backend/exporters/_base.py create mode 100644 backend/exporters/datafield.py create mode 100644 backend/exporters/image.py create mode 100644 backend/exporters/line.py create mode 100644 backend/exporters/mesh.py create mode 100644 backend/exporters/scalar.py create mode 100644 backend/exporters/table.py create mode 100644 tests/node_tests/exporters.py diff --git a/backend/exporters/__init__.py b/backend/exporters/__init__.py new file mode 100644 index 0000000..21d634b --- /dev/null +++ b/backend/exporters/__init__.py @@ -0,0 +1,128 @@ +""" +Exporter registry. + +Each module in this package exports a tuple of tono type names it can handle +(`accepted_types`), a FORMATS map of format name → FormatSpec, and a `save()` +function. This registry walks those modules and builds lookup tables the +Save node uses to dispatch. + +Usage:: + + from backend.exporters import get_exporter, resolve_path, type_name_for_value + + type_name = type_name_for_value(value) # e.g. "DATA_FIELD" + exporter, spec = get_exporter(type_name, "GWY") # raises on unknown combo + path = resolve_path(filename, spec) + exporter.save(path, value, "GWY") +""" + +from __future__ import annotations + +from pathlib import Path +from types import ModuleType +from typing import Any + +import numpy as np + +from backend.data_types import ( + DataField, + DataTable, + ImageData, + LineData, + MeshModel, + RecordTable, +) +from backend.exporters import datafield, image, line, mesh, scalar, table +from backend.exporters._base import Exporter, FormatSpec + +_EXPORTER_MODULES: list[ModuleType] = [datafield, image, line, mesh, scalar, table] + +# (type_name, format_name) → (module, FormatSpec) +_REGISTRY: dict[tuple[str, str], tuple[ModuleType, FormatSpec]] = {} +for _mod in _EXPORTER_MODULES: + for _type_name in _mod.accepted_types: + for _format_name, _spec in _mod.FORMATS.items(): + _REGISTRY[(_type_name, _format_name)] = (_mod, _spec) + + +def get_exporter(type_name: str, format_name: str) -> tuple[ModuleType, FormatSpec]: + """Return the (module, FormatSpec) for a type + format combination. + + Raises ValueError with a user-readable message when the combination is + unknown. That message gets propagated straight to the UI status toast, + so keep it actionable. + """ + entry = _REGISTRY.get((type_name, format_name)) + if entry is None: + raise ValueError(f"Format {format_name!r} is not supported for {type_name}.") + return entry + + +def available_formats(type_name: str) -> list[str]: + """Format names available for a given tono type, in registration order.""" + return [fmt for (t, fmt) in _REGISTRY if t == type_name] + + +def type_name_for_value(value: Any) -> str: + """Classify a runtime Python value into a tono type name. + + The ordering matters: ImageData is a subclass of ndarray, and RecordTable / + DataTable are subclasses of list, so check the more specific classes first. + """ + if isinstance(value, MeshModel): + return "MESH_MODEL" + if isinstance(value, DataField): + return "DATA_FIELD" + if isinstance(value, LineData): + return "LINE" + if isinstance(value, ImageData): + # Annotation outputs carry context in ``.metadata``; regardless, image + # formats are the right set. + return "IMAGE" + if isinstance(value, np.ndarray): + if value.ndim == 1: + return "LINE" + return "IMAGE" + if isinstance(value, RecordTable): + return "RECORD_TABLE" + if isinstance(value, DataTable): + return "DATA_TABLE" + if isinstance(value, list): + # Plain list — treat as a data table; the table exporter handles both. + return "DATA_TABLE" + if isinstance(value, (int, float, np.floating, np.integer)): + return "FLOAT" + raise ValueError(f"Save does not support input type: {type(value).__name__}") + + +def resolve_path(filename: str, spec: FormatSpec, default_dir: Path) -> Path: + """Expand *filename* into an absolute Path with the correct extension. + + Relative names are written under *default_dir* (the session download dir); + absolute paths are honored as-is, with parent directories created. + """ + raw_filename = str(filename).strip() if filename is not None else "" + if not raw_filename: + raise ValueError("No output filename selected — enter a file name.") + + candidate = Path(raw_filename).expanduser() + if candidate.is_absolute(): + candidate.parent.mkdir(parents=True, exist_ok=True) + path = candidate + else: + default_dir.mkdir(parents=True, exist_ok=True) + path = default_dir / candidate.name + + if path.suffix.lower() != spec.ext: + path = path.with_suffix(spec.ext) + return path + + +__all__ = [ + "Exporter", + "FormatSpec", + "available_formats", + "get_exporter", + "resolve_path", + "type_name_for_value", +] diff --git a/backend/exporters/_base.py b/backend/exporters/_base.py new file mode 100644 index 0000000..bf0aa76 --- /dev/null +++ b/backend/exporters/_base.py @@ -0,0 +1,60 @@ +""" +Base protocol for file exporters. + +Each exporter module handles one tono value type (DATA_FIELD, IMAGE, LINE, …) +and implements one or more output formats. Registration is discovered via the +module-level attributes declared below, so adding a new exporter is a matter +of dropping a new file in this package and importing it from __init__. + +A single file per value type (rather than per format) keeps format choices +that share plumbing — PNG & TIFF previews for DATA_FIELD, CSV & JSON for +tables — co-located, which is where most of the shared logic lives. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Protocol, runtime_checkable + + +@dataclass(frozen=True) +class FormatSpec: + """One output format supported by an exporter module.""" + + #: File extension (leading dot), e.g. ".tiff". + ext: str + #: True if the format preserves enough information to reload the value + #: via the matching importer. Advertised in the UI so users can tell + #: "save a preview" and "save for later" apart. + round_trip: bool + #: Short human-readable label. The enum key used in the format dropdown + #: is the dict key in each module's FORMATS map; `label` is what we'd + #: surface in tooltips or docs. Leave empty to fall back to the key. + label: str = "" + + +@runtime_checkable +class Exporter(Protocol): + """Structural protocol satisfied by every module in backend.exporters.""" + + #: Tono type names this exporter handles. Must match the upper-case names + #: used in node INPUT_TYPES / OUTPUTS (e.g. "DATA_FIELD", "IMAGE", "LINE"). + accepted_types: tuple[str, ...] + + #: Format name → spec. Format names are what users pick in the Save node's + #: format dropdown, so they should be short and recognizable. + FORMATS: dict[str, FormatSpec] + + def save(self, path: Path, value: Any, format_name: str, **opts: Any) -> None: + """Write *value* to *path* in *format_name*. + + The caller is responsible for ensuring ``path`` has the correct + extension (see registry.resolve_path) and that ``value`` is of a type + listed in ``accepted_types``. + """ + ... + + +# Re-exported so modules can write `from backend.exporters._base import FormatSpec`. +__all__ = ["FormatSpec", "Exporter"] diff --git a/backend/exporters/datafield.py b/backend/exporters/datafield.py new file mode 100644 index 0000000..23d515d --- /dev/null +++ b/backend/exporters/datafield.py @@ -0,0 +1,215 @@ +""" +Exporter for DATA_FIELD values. + +Format choices: + +* **TIFF** — 8-bit RGB colormap preview. *Not* round-trippable. Useful for + figures and sharing; opening it back gives you pixels, not physics. +* **TIFF (data)** — float64 array with tono metadata JSON-embedded in the + TIFF ImageDescription tag. Round-trips via the array_image importer once + that importer learns to read the tag (see tests/node_tests/exporters.py). +* **PNG** — 8-bit RGB colormap preview. Not round-trippable. +* **NPZ** — raw ``data`` array only. Not round-trippable (units are dropped). +* **GWY** — Gwyddion native format via the ``gwyfile`` package. Round-trips + and opens directly in Gwyddion. Recommended for "save and come back later". +* **HDF5** — generic HDF5 with one ``/data`` dataset and physical dimensions + as dataset attrs. Round-trips via our generic ``hdf5`` importer. +* **HDF5 (Ergo)** — Asylum Research / Ergo layout with the dataset at + ``Image/DataSet/Resolution 0/Frame 0//Image`` and a sidecar group + ``Image/DataSetInfo/Global/Channels/<title>/ImageDims`` carrying + ``DimScaling`` / ``DimUnits`` / ``DataUnits``. Round-trips via our + ``ergo_hdf5`` importer and opens in Asylum Ergo / Igor. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import numpy as np + +from backend.data_types import DataField, datafield_to_uint8 +from backend.exporters._base import FormatSpec + +accepted_types: tuple[str, ...] = ("DATA_FIELD",) + +FORMATS: dict[str, FormatSpec] = { + "TIFF": FormatSpec(ext=".tiff", round_trip=False, label="TIFF (preview)"), + "TIFF (data)": FormatSpec(ext=".tiff", round_trip=True, label="TIFF (calibrated data)"), + "PNG": FormatSpec(ext=".png", round_trip=False, label="PNG (preview)"), + "NPZ": FormatSpec(ext=".npz", round_trip=False, label="NumPy (.npz)"), + "GWY": FormatSpec(ext=".gwy", round_trip=True, label="Gwyddion (.gwy)"), + "HDF5": FormatSpec(ext=".h5", round_trip=True, label="HDF5 (generic)"), + "HDF5 (Ergo)": FormatSpec(ext=".h5", round_trip=True, label="HDF5 (Asylum Research / Ergo)"), +} + + +def save(path: Path, value: DataField, format_name: str, **_opts) -> None: + if format_name == "TIFF": + _save_tiff_preview(path, value) + return + if format_name == "TIFF (data)": + _save_tiff_data(path, value) + return + if format_name == "PNG": + _save_png_preview(path, value) + return + if format_name == "NPZ": + _save_npz(path, value) + return + if format_name == "GWY": + _save_gwy(path, value) + return + if format_name == "HDF5": + _save_hdf5_generic(path, value) + return + if format_name == "HDF5 (Ergo)": + _save_hdf5_ergo(path, value) + return + raise ValueError(f"Format {format_name!r} is not supported for DATA_FIELD.") + + +def _save_tiff_preview(path: Path, field: DataField) -> None: + import tifffile + tifffile.imwrite(str(path), datafield_to_uint8(field, field.colormap)) + + +def _save_tiff_data(path: Path, field: DataField) -> None: + """Write the raw float64 data with tono metadata in the ImageDescription tag. + + The description is a JSON document of shape ``{"tono": {...}}`` so future + schema extensions can coexist with other tools' TIFF metadata. Only the + fields needed to reconstruct physical coordinates and z-scaling are + embedded; display state (colormap, display_scale) is intentionally out of + scope — this format is for data, not styling. + """ + import tifffile + + meta = { + "tono": { + "version": 1, + "xreal": float(field.xreal), + "yreal": float(field.yreal), + "xoff": float(field.xoff), + "yoff": float(field.yoff), + "si_unit_xy": str(field.si_unit_xy), + "si_unit_z": str(field.si_unit_z), + "domain": str(field.domain), + "colormap": field.colormap if isinstance(field.colormap, str) else "viridis", + } + } + tifffile.imwrite( + str(path), + np.ascontiguousarray(field.data, dtype=np.float64), + description=json.dumps(meta, separators=(",", ":")), + ) + + +def _save_png_preview(path: Path, field: DataField) -> None: + from PIL import Image + Image.fromarray(datafield_to_uint8(field, field.colormap)).save(str(path)) + + +def _save_npz(path: Path, field: DataField) -> None: + np.savez(str(path), field=np.asarray(field.data)) + + +def _save_gwy(path: Path, field: DataField) -> None: + """Write a single-channel .gwy file via the gwyfile package.""" + from gwyfile.objects import GwyContainer, GwyDataField, GwySIUnit + + # gwyfile's GwyDataField ctor expects the data array and physical extents. + # si_unit_xy / si_unit_z accept a GwySIUnit wrapper with a .unitstr field. + gwy_field = GwyDataField( + np.ascontiguousarray(field.data, dtype=np.float64), + xreal=float(field.xreal), + yreal=float(field.yreal), + xoff=float(field.xoff), + yoff=float(field.yoff), + si_unit_xy=GwySIUnit(unitstr=str(field.si_unit_xy or "")), + si_unit_z=GwySIUnit(unitstr=str(field.si_unit_z or "")), + ) + title = path.stem or "field" + container = GwyContainer({ + "/0/data": gwy_field, + "/0/data/title": title, + }) + container.tofile(str(path)) + + +def _save_hdf5_generic(path: Path, field: DataField) -> None: + """Write a single dataset ``/data`` with physical dimensions as dataset attrs. + + The layout is the mirror of :mod:`backend.importers.hdf5`: any 2-D numeric + dataset is picked up and its attrs (``xreal``, ``yreal``, ``xoff``, ``yoff``, + ``si_unit_xy``, ``si_unit_z``) reconstruct the DataField. + """ + import h5py + + arr = np.ascontiguousarray(field.data, dtype=np.float64) + with h5py.File(str(path), "w") as f: + ds = f.create_dataset("data", data=arr) + ds.attrs["xreal"] = float(field.xreal) + ds.attrs["yreal"] = float(field.yreal) + ds.attrs["xoff"] = float(field.xoff) + ds.attrs["yoff"] = float(field.yoff) + ds.attrs["si_unit_xy"] = str(field.si_unit_xy or "") + ds.attrs["si_unit_z"] = str(field.si_unit_z or "") + + +def _save_hdf5_ergo(path: Path, field: DataField) -> None: + """Write an Asylum Research / Ergo-compatible HDF5 file. + + The layout mirrors :mod:`backend.importers.ergo_hdf5`: + + * The image dataset lives at + ``Image/DataSet/Resolution 0/Frame 0/<title>/Image`` — the second-to-last + path component is the channel name that the importer keys off. + * A sidecar group at + ``Image/DataSetInfo/Global/Channels/<title>/ImageDims`` carries + ``DimScaling`` (a (2, 2) array of absolute physical ranges, Y-first), + ``DimUnits`` (``[Y_unit, X_unit]``), and ``DataUnits`` (Z unit string). + + This makes the file openable by Asylum Ergo / Igor and round-trippable + through our ergo_hdf5 importer. + """ + import h5py + + arr = np.ascontiguousarray(field.data, dtype=np.float64) + title = path.stem or "field" + + x_start = float(field.xoff) + x_end = float(field.xoff) + float(field.xreal) + y_start = float(field.yoff) + y_end = float(field.yoff) + float(field.yreal) + # DimScaling is stored Y-first to match the importer's expectations + # (see ergo_hdf5.py:110-113). + dim_scaling = np.array( + [[y_start, y_end], [x_start, x_end]], + dtype=np.float64, + ) + # DimUnits is [Y_unit, X_unit]; the importer takes the X (second) entry + # as the canonical lateral unit (see ergo_hdf5.py:129-135). + xy_unit = str(field.si_unit_xy or "m") + z_unit = str(field.si_unit_z or "") + dim_units = np.array([xy_unit, xy_unit], dtype=h5py.string_dtype()) + + with h5py.File(str(path), "w") as f: + ds = f.create_dataset( + f"Image/DataSet/Resolution 0/Frame 0/{title}/Image", + data=arr, + ) + # Also write the generic attrs so non-Ergo readers still see physics. + ds.attrs["xreal"] = float(field.xreal) + ds.attrs["yreal"] = float(field.yreal) + ds.attrs["xoff"] = float(field.xoff) + ds.attrs["yoff"] = float(field.yoff) + ds.attrs["si_unit_xy"] = xy_unit + ds.attrs["si_unit_z"] = z_unit + + dims_grp = f.create_group( + f"Image/DataSetInfo/Global/Channels/{title}/ImageDims" + ) + dims_grp.attrs["DimScaling"] = dim_scaling + dims_grp.attrs["DimUnits"] = dim_units + dims_grp.attrs["DataUnits"] = z_unit diff --git a/backend/exporters/image.py b/backend/exporters/image.py new file mode 100644 index 0000000..2557ece --- /dev/null +++ b/backend/exporters/image.py @@ -0,0 +1,41 @@ +""" +Exporter for IMAGE values (numpy arrays, ImageData annotation sources). + +Images are raw pixel arrays — no physical calibration by design — so none of +the formats here round-trip dimensions. PNG/TIFF convert to uint8 via the +same image_to_uint8 helper the preview pipeline uses; NPZ preserves the raw +array. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from backend.data_types import image_to_uint8 +from backend.exporters._base import FormatSpec + +accepted_types: tuple[str, ...] = ("IMAGE", "ANNOTATION_SOURCE") + +FORMATS: dict[str, FormatSpec] = { + "PNG": FormatSpec(ext=".png", round_trip=False, label="PNG"), + "TIFF": FormatSpec(ext=".tiff", round_trip=False, label="TIFF"), + "NPZ": FormatSpec(ext=".npz", round_trip=False, label="NumPy (.npz)"), +} + + +def save(path: Path, value: np.ndarray, format_name: str, **_opts) -> None: + arr = np.asarray(value) + if format_name == "PNG": + from PIL import Image + Image.fromarray(image_to_uint8(arr)).save(str(path)) + return + if format_name == "TIFF": + import tifffile + tifffile.imwrite(str(path), image_to_uint8(arr)) + return + if format_name == "NPZ": + np.savez(str(path), image=arr) + return + raise ValueError(f"Format {format_name!r} is not supported for IMAGE.") diff --git a/backend/exporters/line.py b/backend/exporters/line.py new file mode 100644 index 0000000..496c768 --- /dev/null +++ b/backend/exporters/line.py @@ -0,0 +1,182 @@ +""" +Exporter for LINE values (1-D profiles as LineData or bare ndarrays). + +PNG / TIFF render a plot image via Pillow; CSV / JSON / NPZ save the raw +(x, y, unit) arrays. The plot renderer is self-contained (no matplotlib +dependency) and handles SI-prefix axis labels. +""" + +from __future__ import annotations + +import csv +import json +from pathlib import Path + +import numpy as np + +from backend.data_types import LineData, _PREFIXABLE_UNITS, _SI_PREFIXES +from backend.exporters._base import FormatSpec + +accepted_types: tuple[str, ...] = ("LINE",) + +FORMATS: dict[str, FormatSpec] = { + "PNG": FormatSpec(ext=".png", round_trip=False, label="PNG plot"), + "TIFF": FormatSpec(ext=".tiff", round_trip=False, label="TIFF plot"), + "CSV": FormatSpec(ext=".csv", round_trip=True, label="CSV"), + "NPZ": FormatSpec(ext=".npz", round_trip=False, label="NumPy (.npz)"), + "JSON": FormatSpec(ext=".json", round_trip=True, label="JSON"), +} + + +def save(path: Path, value, format_name: str, *, plot_title: str = "", **_opts) -> None: + line = value if isinstance(value, LineData) else LineData(data=np.asarray(value).ravel()) + + y = np.asarray(line.data, dtype=np.float64).ravel() + if line.x_axis is not None: + x = np.asarray(line.x_axis, dtype=np.float64).ravel()[: len(y)] + else: + x = np.arange(len(y), dtype=np.float64) + + if format_name in ("PNG", "TIFF"): + _save_line_plot(path, x, y, line.x_unit, line.y_unit, plot_title, format_name) + return + if format_name == "CSV": + with path.open("w", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + writer.writerow(["x", "y", "x_unit", "y_unit"]) + for xv, yv in zip(x, y): + writer.writerow([xv, yv, line.x_unit, line.y_unit]) + return + if format_name == "NPZ": + np.savez(str(path), x=x, y=y) + return + if format_name == "JSON": + path.write_text( + json.dumps({ + "x": x.tolist(), + "y": y.tolist(), + "x_unit": line.x_unit, + "y_unit": line.y_unit, + }, indent=2), + encoding="utf-8", + ) + return + raise ValueError(f"Format {format_name!r} is not supported for LINE.") + + +def _save_line_plot( + path: Path, + x: np.ndarray, + y: np.ndarray, + x_unit: str, + y_unit: str, + title: str, + format_name: str, +) -> None: + """Render a simple PNG/TIFF line plot with SI-prefixed axes. + + Intentionally self-contained (Pillow only, no matplotlib) so that builds + stay lean. Layout is fixed 1200×750 with 5×5 grid and a single blue line. + """ + from PIL import Image, ImageDraw, ImageFont + + w, h = 1200, 750 + bg = (255, 255, 255) + line_color = (79, 142, 247) # #4f8ef7 + grid_color = (200, 200, 200) + text_color = (60, 60, 60) + margin = {"left": 80, "right": 30, "top": 50, "bottom": 60} + + img = Image.new("RGB", (w, h), bg) + draw = ImageDraw.Draw(img) + + try: + font = ImageFont.truetype("DejaVuSans.ttf", 14) + font_small = ImageFont.truetype("DejaVuSans.ttf", 11) + font_title = ImageFont.truetype("DejaVuSans.ttf", 16) + except (OSError, IOError): + font = ImageFont.load_default() + font_small = font + font_title = font + + pw = w - margin["left"] - margin["right"] + ph = h - margin["top"] - margin["bottom"] + + def _si_scale(unit: str, vmin: float, vmax: float) -> tuple[float, str]: + """Pick the best SI prefix for an axis range. Returns (divisor, prefixed_unit).""" + unit = (unit or "").strip() + if not unit or unit not in _PREFIXABLE_UNITS: + return 1.0, unit if unit else "" + peak = max(abs(vmin), abs(vmax)) + if peak == 0: + return 1.0, unit + for scale, prefix in _SI_PREFIXES: + if peak / scale >= 1.0: + return scale, f"{prefix}{unit}" + return _SI_PREFIXES[-1][0], f"{_SI_PREFIXES[-1][1]}{unit}" + + xmin, xmax = float(np.nanmin(x)), float(np.nanmax(x)) + ymin, ymax = float(np.nanmin(y)), float(np.nanmax(y)) + + x_scale, x_label = _si_scale(x_unit, xmin, xmax) + y_scale, y_label = _si_scale(y_unit, ymin, ymax) + if not x_label: + x_label = "x" + if not y_label: + y_label = "y" + + x = x / x_scale + y = y / y_scale + xmin, xmax = xmin / x_scale, xmax / x_scale + ymin, ymax = ymin / y_scale, ymax / y_scale + + if ymax == ymin: + ymin, ymax = ymin - 1, ymax + 1 + if xmax == xmin: + xmax = xmin + 1 + ypad = (ymax - ymin) * 0.05 + ymin -= ypad + ymax += ypad + + def to_px(xv: float, yv: float) -> tuple[float, float]: + px = margin["left"] + (xv - xmin) / (xmax - xmin) * pw + py = margin["top"] + (1.0 - (yv - ymin) / (ymax - ymin)) * ph + return px, py + + for i in range(6): + gy = ymin + (ymax - ymin) * i / 5 + _, py = to_px(xmin, gy) + draw.line([(margin["left"], py), (margin["left"] + pw, py)], fill=grid_color, width=1) + label = f"{gy:.4g}" + draw.text((margin["left"] - 8, py - 6), label, fill=text_color, font=font_small, anchor="rm") + + gx = xmin + (xmax - xmin) * i / 5 + px, _ = to_px(gx, ymin) + draw.line([(px, margin["top"]), (px, margin["top"] + ph)], fill=grid_color, width=1) + label = f"{gx:.4g}" + draw.text((px, margin["top"] + ph + 6), label, fill=text_color, font=font_small, anchor="mt") + + n = len(y) + step = max(1, n // pw) + xs, ys = x[::step], y[::step] + pts = [to_px(float(xs[i]), float(ys[i])) for i in range(len(xs))] + if len(pts) > 1: + draw.line(pts, fill=line_color, width=2) + + draw.rectangle( + [margin["left"], margin["top"], margin["left"] + pw, margin["top"] + ph], + outline=(100, 100, 100), width=1, + ) + draw.text((margin["left"] + pw // 2, h - 10), x_label, fill=text_color, font=font, anchor="mb") + + y_label_img = Image.new("RGBA", (200, 20), (0, 0, 0, 0)) + y_draw = ImageDraw.Draw(y_label_img) + y_draw.text((100, 10), y_label, fill=text_color, font=font, anchor="mm") + y_label_img = y_label_img.rotate(90, expand=True) + img.paste(y_label_img, (2, margin["top"] + ph // 2 - y_label_img.height // 2), y_label_img) + + if title and title.strip(): + draw.text((w // 2, 10), title.strip(), fill=text_color, font=font_title, anchor="mt") + + ext = ".png" if format_name == "PNG" else ".tiff" + img.save(str(path.with_suffix(ext))) diff --git a/backend/exporters/mesh.py b/backend/exporters/mesh.py new file mode 100644 index 0000000..cc271d4 --- /dev/null +++ b/backend/exporters/mesh.py @@ -0,0 +1,60 @@ +""" +Exporter for MESH_MODEL values (Wavefront OBJ, ASCII STL). +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from backend.data_types import MeshModel +from backend.exporters._base import FormatSpec + +accepted_types: tuple[str, ...] = ("MESH_MODEL",) + +FORMATS: dict[str, FormatSpec] = { + "OBJ": FormatSpec(ext=".obj", round_trip=True, label="Wavefront OBJ"), + "STL": FormatSpec(ext=".stl", round_trip=True, label="STL (ASCII)"), +} + + +def save(path: Path, value: MeshModel, format_name: str, **_opts) -> None: + if format_name == "OBJ": + _save_obj(path, value) + return + if format_name == "STL": + _save_stl(path, value) + return + raise ValueError(f"Format {format_name!r} is not supported for MESH_MODEL.") + + +def _save_obj(path: Path, mesh: MeshModel) -> None: + lines: list[str] = [] + for vertex in mesh.vertices: + lines.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}") + for face in mesh.faces: + lines.append(f"f {int(face[0]) + 1} {int(face[1]) + 1} {int(face[2]) + 1}") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _save_stl(path: Path, mesh: MeshModel) -> None: + def normal(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray: + n = np.cross(b - a, c - a) + length = float(np.linalg.norm(n)) + return n / length if length > 0 else np.array([0.0, 1.0, 0.0], dtype=np.float32) + + lines = ["solid tono"] + vertices = np.asarray(mesh.vertices, dtype=np.float32) + for face in np.asarray(mesh.faces, dtype=np.int32): + a, b, c = vertices[int(face[0])], vertices[int(face[1])], vertices[int(face[2])] + n = normal(a, b, c) + lines.append(f" facet normal {n[0]} {n[1]} {n[2]}") + lines.append(" outer loop") + lines.append(f" vertex {a[0]} {a[1]} {a[2]}") + lines.append(f" vertex {b[0]} {b[1]} {b[2]}") + lines.append(f" vertex {c[0]} {c[1]} {c[2]}") + lines.append(" endloop") + lines.append(" endfacet") + lines.append("endsolid tono") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") diff --git a/backend/exporters/scalar.py b/backend/exporters/scalar.py new file mode 100644 index 0000000..23af0b0 --- /dev/null +++ b/backend/exporters/scalar.py @@ -0,0 +1,28 @@ +""" +Exporter for FLOAT scalars (also handles Python int and numpy scalar types). +""" + +from __future__ import annotations + +import json +from pathlib import Path + +from backend.exporters._base import FormatSpec + +accepted_types: tuple[str, ...] = ("FLOAT",) + +FORMATS: dict[str, FormatSpec] = { + "TXT": FormatSpec(ext=".txt", round_trip=True, label="Text"), + "JSON": FormatSpec(ext=".json", round_trip=True, label="JSON"), +} + + +def save(path: Path, value: float, format_name: str, **_opts) -> None: + numeric = float(value) + if format_name == "TXT": + path.write_text(f"{numeric}\n", encoding="utf-8") + return + if format_name == "JSON": + path.write_text(json.dumps({"value": numeric}, indent=2), encoding="utf-8") + return + raise ValueError(f"Format {format_name!r} is not supported for scalar values.") diff --git a/backend/exporters/table.py b/backend/exporters/table.py new file mode 100644 index 0000000..040f2e0 --- /dev/null +++ b/backend/exporters/table.py @@ -0,0 +1,44 @@ +""" +Exporter for RECORD_TABLE and DATA_TABLE values. + +Both types are list-of-dict; the Save node currently accepts plain lists in +this slot too, which is preserved here. CSV auto-derives its column set from +the first row's keys (and any additional keys that appear later), matching +the prior behavior. +""" + +from __future__ import annotations + +import csv +import json +from pathlib import Path + +from backend.exporters._base import FormatSpec + +accepted_types: tuple[str, ...] = ("RECORD_TABLE", "DATA_TABLE") + +FORMATS: dict[str, FormatSpec] = { + "CSV": FormatSpec(ext=".csv", round_trip=True, label="CSV"), + "JSON": FormatSpec(ext=".json", round_trip=True, label="JSON"), +} + + +def save(path: Path, value: list, format_name: str, **_opts) -> None: + rows = list(value) + if format_name == "JSON": + path.write_text(json.dumps(rows, indent=2), encoding="utf-8") + return + if format_name == "CSV": + columns: list[str] = [] + for row in rows: + if isinstance(row, dict): + for key in row.keys(): + if key not in columns: + columns.append(str(key)) + with path.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=columns) + writer.writeheader() + for row in rows: + writer.writerow(row if isinstance(row, dict) else {"value": row}) + return + raise ValueError(f"Format {format_name!r} is not supported for table inputs.") diff --git a/backend/importers/gwy.py b/backend/importers/gwy.py index 209f9fb..f223dc6 100644 --- a/backend/importers/gwy.py +++ b/backend/importers/gwy.py @@ -20,19 +20,42 @@ def load(path: Path) -> list[DataField]: fields = [] for ch in channels.values(): - data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) + # gwyfile.objects.GwyDataField exposes .data as an already-2D ndarray + # (no xres/yres attributes — those were removed in gwyfile 0.3+). + data = np.asarray(ch.data, dtype=np.float64) + if data.ndim != 2: + # Defensive: if a future gwyfile version yields a flat buffer, the + # dimensions live in the serialized object's xres/yres keys. + xres = int(ch.get("xres", data.size)) + yres = int(ch.get("yres", 1)) + data = data.reshape(yres, xres) fields.append(DataField( data=data, xreal=float(ch.xreal), yreal=float(ch.yreal), xoff=float(getattr(ch, "xoff", 0.0)), yoff=float(getattr(ch, "yoff", 0.0)), - si_unit_xy="m", - si_unit_z="m", + si_unit_xy=_unit_str(getattr(ch, "si_unit_xy", None)) or "m", + si_unit_z=_unit_str(getattr(ch, "si_unit_z", None)) or "m", )) return fields +def _unit_str(si_unit: object) -> str: + """Extract the unit string from a GwySIUnit without importing gwyfile. + + Loaded GwySIUnit objects behave like dicts with a ``unitstr`` key. + """ + if si_unit is None: + return "" + if hasattr(si_unit, "unitstr"): + return str(getattr(si_unit, "unitstr") or "") + try: + return str(si_unit["unitstr"] or "") + except (KeyError, TypeError): + return "" + + def channel_names(path: Path) -> list[str]: import gwyfile try: diff --git a/backend/nodes/save.py b/backend/nodes/save.py index 2b64119..e885976 100644 --- a/backend/nodes/save.py +++ b/backend/nodes/save.py @@ -1,26 +1,44 @@ from __future__ import annotations -import csv -import json -from pathlib import Path - -import numpy as np - import tempfile +from pathlib import Path from backend.node_registry import register_node from backend.execution_context import emit_warning, emit_file_download -from backend.data_types import ( - DataField, LineData, MeshModel, datafield_to_uint8, image_to_uint8, - _SI_PREFIXES, _PREFIXABLE_UNITS, +from backend.exporters import ( + available_formats, + get_exporter, + resolve_path, + type_name_for_value, ) DOWNLOAD_DIR = Path(tempfile.gettempdir()) / "tono-downloads" + +def _choices_by_source_type() -> dict[str, list[str]]: + """Build the format dropdown's source-type map from the exporter registry. + + Centralising this here means adding a new exporter module (or a new format + inside an existing one) automatically surfaces in the UI — no parallel + list to keep in sync. + """ + return { + "DATA_FIELD": available_formats("DATA_FIELD"), + "IMAGE": available_formats("IMAGE"), + "ANNOTATION_SOURCE": available_formats("ANNOTATION_SOURCE"), + "LINE": available_formats("LINE"), + "RECORD_TABLE": available_formats("RECORD_TABLE"), + "DATA_TABLE": available_formats("DATA_TABLE"), + "FLOAT": available_formats("FLOAT"), + "MESH_MODEL": available_formats("MESH_MODEL"), + } + + @register_node(display_name="Save") class Save: @classmethod def INPUT_TYPES(cls): + choices = _choices_by_source_type() return { "required": { "filename": ("STRING", { @@ -41,17 +59,8 @@ class Save: ], }), "format": ("STRING", { - "default": "TIFF", - "choices_by_source_type": { - "DATA_FIELD": ["TIFF", "PNG", "NPZ"], - "IMAGE": ["PNG", "TIFF", "NPZ"], - "ANNOTATION_SOURCE": ["PNG", "TIFF", "NPZ"], - "LINE": ["PNG", "TIFF", "CSV", "NPZ", "JSON"], - "RECORD_TABLE": ["CSV", "JSON"], - "DATA_TABLE": ["CSV", "JSON"], - "FLOAT": ["TXT", "JSON"], - "MESH_MODEL": ["OBJ", "STL"], - }, + "default": choices["DATA_FIELD"][0] if choices["DATA_FIELD"] else "", + "choices_by_source_type": choices, "source_type_input": "value", }), }, @@ -71,10 +80,12 @@ class Save: OUTPUT_NODE = True MANUAL_TRIGGER = True DESCRIPTION = ( - "Save a single graph value to disk. Supports fields, images, lines, tables, scalars, and 3D meshes." + "Save a single graph value to disk. Supports fields, images, lines, tables, scalars, " + "and 3D meshes. Use 'GWY' or 'TIFF (data)' for DataFields you want to re-open later " + "with their physical units preserved." ) - KEYWORDS = ("export", "write", "download", "png", "tiff", "csv", "json", "npz", "obj", "stl") + KEYWORDS = ("export", "write", "download", "png", "tiff", "csv", "json", "npz", "obj", "stl", "gwy") def save( self, @@ -83,295 +94,11 @@ class Save: value, plot_title: str = "", ): - path = self._resolve_save_path(filename, format) - - if isinstance(value, MeshModel): - self._save_mesh(path, value, format) - elif isinstance(value, DataField): - self._save_datafield(path, value, format) - elif isinstance(value, np.ndarray): - if value.ndim == 1: - self._save_line(path, LineData(data=value), format, title=plot_title) - else: - self._save_image_or_array(path, value, format) - elif isinstance(value, LineData): - self._save_line(path, value, format, title=plot_title) - elif isinstance(value, list): - self._save_table(path, value, format) - elif isinstance(value, (int, float, np.floating, np.integer)): - self._save_scalar(path, float(value), format) - else: - raise ValueError(f"Save does not support input type: {type(value).__name__}") + type_name = type_name_for_value(value) + module, spec = get_exporter(type_name, format) + path = resolve_path(filename, spec, DOWNLOAD_DIR) + module.save(path, value, format, plot_title=plot_title) emit_warning(f"Saved to {path.name}") emit_file_download(str(path)) return () - - def _resolve_save_path(self, filename: str, format_name: str) -> Path: - ext_map = { - "PNG": ".png", - "TIFF": ".tiff", - "NPZ": ".npz", - "CSV": ".csv", - "JSON": ".json", - "OBJ": ".obj", - "STL": ".stl", - "TXT": ".txt", - } - ext = ext_map[format_name] - - raw_filename = str(filename).strip() if filename is not None else "" - if not raw_filename: - raise ValueError("No output filename selected — enter a file name.") - - candidate = Path(raw_filename).expanduser() - if candidate.is_absolute(): - candidate.parent.mkdir(parents=True, exist_ok=True) - path = candidate - else: - DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True) - path = DOWNLOAD_DIR / candidate.name - - if path.suffix.lower() != ext: - path = path.with_suffix(ext) - return path - - def _save_datafield(self, path: Path, field: DataField, format_name: str): - if format_name == "TIFF": - import tifffile - tifffile.imwrite(str(path), datafield_to_uint8(field, field.colormap)) - return - if format_name == "NPZ": - np.savez(str(path), field=np.asarray(field.data)) - return - if format_name == "PNG": - from PIL import Image - Image.fromarray(datafield_to_uint8(field, field.colormap)).save(str(path)) - return - raise ValueError(f"Format {format_name} is not supported for DATA_FIELD.") - - def _save_image_or_array(self, path: Path, image: np.ndarray, format_name: str): - arr = np.asarray(image) - if format_name == "PNG": - from PIL import Image - Image.fromarray(image_to_uint8(arr)).save(str(path)) - return - if format_name == "TIFF": - import tifffile - tifffile.imwrite(str(path), image_to_uint8(arr)) - return - if format_name == "NPZ": - np.savez(str(path), image=arr) - return - raise ValueError(f"Format {format_name} is not supported for IMAGE.") - - def _save_line(self, path: Path, line: LineData, format_name: str, title: str = ""): - y = np.asarray(line.data, dtype=np.float64).ravel() - x = np.asarray(line.x_axis, dtype=np.float64).ravel()[: len(y)] if line.x_axis is not None else np.arange(len(y), dtype=np.float64) - if format_name in ("PNG", "TIFF"): - self._save_line_plot(path, x, y, line.x_unit, line.y_unit, title, format_name) - return - if format_name == "CSV": - with path.open("w", newline="", encoding="utf-8") as fh: - writer = csv.writer(fh) - writer.writerow(["x", "y", "x_unit", "y_unit"]) - for xv, yv in zip(x, y): - writer.writerow([xv, yv, line.x_unit, line.y_unit]) - return - if format_name == "NPZ": - np.savez(str(path), x=x, y=y) - return - if format_name == "JSON": - path.write_text(json.dumps({ - "x": x.tolist(), - "y": y.tolist(), - "x_unit": line.x_unit, - "y_unit": line.y_unit, - }, indent=2), encoding="utf-8") - return - raise ValueError(f"Format {format_name} is not supported for LINE.") - - def _save_line_plot( - self, - path: Path, - x: np.ndarray, - y: np.ndarray, - x_unit: str, - y_unit: str, - title: str, - format_name: str, - ): - from PIL import Image, ImageDraw, ImageFont - - w, h = 1200, 750 - bg = (255, 255, 255) - line_color = (79, 142, 247) # #4f8ef7 - grid_color = (200, 200, 200) - text_color = (60, 60, 60) - margin = {"left": 80, "right": 30, "top": 50, "bottom": 60} - - img = Image.new("RGB", (w, h), bg) - draw = ImageDraw.Draw(img) - - try: - font = ImageFont.truetype("DejaVuSans.ttf", 14) - font_small = ImageFont.truetype("DejaVuSans.ttf", 11) - font_title = ImageFont.truetype("DejaVuSans.ttf", 16) - except (OSError, IOError): - font = ImageFont.load_default() - font_small = font - font_title = font - - pw = w - margin["left"] - margin["right"] - ph = h - margin["top"] - margin["bottom"] - - def _si_scale(unit: str, vmin: float, vmax: float) -> tuple[float, str]: - """Pick the best SI prefix for an axis range. Returns (divisor, prefixed_unit).""" - unit = (unit or "").strip() - if not unit or unit not in _PREFIXABLE_UNITS: - return 1.0, unit if unit else "" - peak = max(abs(vmin), abs(vmax)) - if peak == 0: - return 1.0, unit - for scale, prefix in _SI_PREFIXES: - if peak / scale >= 1.0: - return scale, f"{prefix}{unit}" - return _SI_PREFIXES[-1][0], f"{_SI_PREFIXES[-1][1]}{unit}" - - xmin, xmax = float(np.nanmin(x)), float(np.nanmax(x)) - ymin, ymax = float(np.nanmin(y)), float(np.nanmax(y)) - - x_scale, x_label = _si_scale(x_unit, xmin, xmax) - y_scale, y_label = _si_scale(y_unit, ymin, ymax) - if not x_label: - x_label = "x" - if not y_label: - y_label = "y" - - # Scale data into prefixed units - x = x / x_scale - y = y / y_scale - xmin, xmax = xmin / x_scale, xmax / x_scale - ymin, ymax = ymin / y_scale, ymax / y_scale - - if ymax == ymin: - ymin, ymax = ymin - 1, ymax + 1 - if xmax == xmin: - xmax = xmin + 1 - # Add 5% padding to y range - ypad = (ymax - ymin) * 0.05 - ymin -= ypad - ymax += ypad - - def to_px(xv: float, yv: float) -> tuple[float, float]: - px = margin["left"] + (xv - xmin) / (xmax - xmin) * pw - py = margin["top"] + (1.0 - (yv - ymin) / (ymax - ymin)) * ph - return px, py - - # Grid lines (5 horizontal, 5 vertical) - for i in range(6): - gy = ymin + (ymax - ymin) * i / 5 - _, py = to_px(xmin, gy) - draw.line([(margin["left"], py), (margin["left"] + pw, py)], fill=grid_color, width=1) - label = f"{gy:.4g}" - draw.text((margin["left"] - 8, py - 6), label, fill=text_color, font=font_small, anchor="rm") - - gx = xmin + (xmax - xmin) * i / 5 - px, _ = to_px(gx, ymin) - draw.line([(px, margin["top"]), (px, margin["top"] + ph)], fill=grid_color, width=1) - label = f"{gx:.4g}" - draw.text((px, margin["top"] + ph + 6), label, fill=text_color, font=font_small, anchor="mt") - - # Plot line - n = len(y) - step = max(1, n // pw) - xs, ys = x[::step], y[::step] - pts = [to_px(float(xs[i]), float(ys[i])) for i in range(len(xs))] - if len(pts) > 1: - draw.line(pts, fill=line_color, width=2) - - # Border - draw.rectangle( - [margin["left"], margin["top"], margin["left"] + pw, margin["top"] + ph], - outline=(100, 100, 100), width=1, - ) - draw.text((margin["left"] + pw // 2, h - 10), x_label, fill=text_color, font=font, anchor="mb") - # Vertical y label — draw rotated - y_label_img = Image.new("RGBA", (200, 20), (0, 0, 0, 0)) - y_draw = ImageDraw.Draw(y_label_img) - y_draw.text((100, 10), y_label, fill=text_color, font=font, anchor="mm") - y_label_img = y_label_img.rotate(90, expand=True) - img.paste(y_label_img, (2, margin["top"] + ph // 2 - y_label_img.height // 2), y_label_img) - - # Title - if title and title.strip(): - draw.text((w // 2, 10), title.strip(), fill=text_color, font=font_title, anchor="mt") - - ext = ".png" if format_name == "PNG" else ".tiff" - img.save(str(path.with_suffix(ext))) - - def _save_table(self, path: Path, rows: list, format_name: str): - if format_name == "JSON": - path.write_text(json.dumps(rows, indent=2), encoding="utf-8") - return - if format_name == "CSV": - columns: list[str] = [] - for row in rows: - if isinstance(row, dict): - for key in row.keys(): - if key not in columns: - columns.append(str(key)) - with path.open("w", newline="", encoding="utf-8") as fh: - writer = csv.DictWriter(fh, fieldnames=columns) - writer.writeheader() - for row in rows: - writer.writerow(row if isinstance(row, dict) else {"value": row}) - return - raise ValueError(f"Format {format_name} is not supported for table inputs.") - - def _save_scalar(self, path: Path, value: float, format_name: str): - if format_name == "TXT": - path.write_text(f"{value}\n", encoding="utf-8") - return - if format_name == "JSON": - path.write_text(json.dumps({"value": value}, indent=2), encoding="utf-8") - return - raise ValueError(f"Format {format_name} is not supported for scalar values.") - - def _save_mesh(self, path: Path, mesh: MeshModel, format_name: str): - if format_name == "OBJ": - self._save_obj(path, mesh) - return - if format_name == "STL": - self._save_stl(path, mesh) - return - raise ValueError(f"Format {format_name} is not supported for MESH_MODEL.") - - def _save_obj(self, path: Path, mesh: MeshModel): - lines = [] - for vertex in mesh.vertices: - lines.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}") - for face in mesh.faces: - lines.append(f"f {int(face[0]) + 1} {int(face[1]) + 1} {int(face[2]) + 1}") - path.write_text("\n".join(lines) + "\n", encoding="utf-8") - - def _save_stl(self, path: Path, mesh: MeshModel): - def normal(a, b, c): - n = np.cross(b - a, c - a) - length = float(np.linalg.norm(n)) - return n / length if length > 0 else np.array([0.0, 1.0, 0.0], dtype=np.float32) - - lines = ["solid tono"] - vertices = np.asarray(mesh.vertices, dtype=np.float32) - for face in np.asarray(mesh.faces, dtype=np.int32): - a, b, c = vertices[int(face[0])], vertices[int(face[1])], vertices[int(face[2])] - n = normal(a, b, c) - lines.append(f" facet normal {n[0]} {n[1]} {n[2]}") - lines.append(" outer loop") - lines.append(f" vertex {a[0]} {a[1]} {a[2]}") - lines.append(f" vertex {b[0]} {b[1]} {b[2]}") - lines.append(f" vertex {c[0]} {c[1]} {c[2]}") - lines.append(" endloop") - lines.append(" endfacet") - lines.append("endsolid tono") - path.write_text("\n".join(lines) + "\n", encoding="utf-8") diff --git a/tests/node_tests/exporters.py b/tests/node_tests/exporters.py new file mode 100644 index 0000000..1f56704 --- /dev/null +++ b/tests/node_tests/exporters.py @@ -0,0 +1,300 @@ +""" +Tests for the exporter registry and the round-trippable DataField formats. + +The Save node's format-specific behavior is covered in test_save_generic +(tests/node_tests/save.py). This module focuses on: + + 1. Registry contract — every exporter module satisfies the protocol. + 2. Dispatch — type_name_for_value classifies values correctly and + get_exporter returns a matching module. + 3. Round-trip — GWY and TIFF (data) preserve xreal/yreal/units/data. +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path + +import numpy as np + +from backend.data_types import ( + DataField, + DataTable, + ImageData, + LineData, + MeshModel, + RecordTable, +) + + +def test_exporter_registry_contract(): + """Every registered exporter module must expose the required attributes.""" + from backend.exporters import _REGISTRY + from backend.exporters._base import FormatSpec + + assert _REGISTRY, "Registry must not be empty" + seen_modules = {mod for (mod, _) in _REGISTRY.values()} + for module in seen_modules: + assert hasattr(module, "accepted_types") + assert hasattr(module, "FORMATS") + assert hasattr(module, "save") + assert isinstance(module.accepted_types, tuple) + assert all(isinstance(t, str) and t.isupper() for t in module.accepted_types) + assert isinstance(module.FORMATS, dict) + for name, spec in module.FORMATS.items(): + assert isinstance(name, str) and name + assert isinstance(spec, FormatSpec) + assert spec.ext.startswith(".") + + +def test_type_name_for_value_classification(): + from backend.exporters import type_name_for_value + + assert type_name_for_value(DataField(data=np.zeros((4, 4)))) == "DATA_FIELD" + assert type_name_for_value(np.zeros((4, 4))) == "IMAGE" + assert type_name_for_value(np.zeros((4, 4, 3), dtype=np.uint8)) == "IMAGE" + assert type_name_for_value(ImageData(np.zeros((4, 4), dtype=np.uint8))) == "IMAGE" + assert type_name_for_value(np.zeros(8)) == "LINE" + assert type_name_for_value(LineData(data=np.zeros(8))) == "LINE" + assert type_name_for_value(RecordTable([{"a": 1}])) == "RECORD_TABLE" + assert type_name_for_value(DataTable([{"a": 1}])) == "DATA_TABLE" + assert type_name_for_value(1.25) == "FLOAT" + assert type_name_for_value(np.float64(0.5)) == "FLOAT" + mesh = MeshModel( + vertices=np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=np.float32), + faces=np.array([[0, 1, 2]], dtype=np.int32), + ) + assert type_name_for_value(mesh) == "MESH_MODEL" + + try: + type_name_for_value(object()) + assert False, "Expected ValueError for unsupported type" + except ValueError: + pass + + +def test_get_exporter_known_and_unknown(): + from backend.exporters import get_exporter + + mod, spec = get_exporter("DATA_FIELD", "GWY") + assert spec.ext == ".gwy" + assert spec.round_trip is True + + mod, spec = get_exporter("DATA_FIELD", "TIFF") + assert spec.ext == ".tiff" + # Legacy preview path — not round-trippable. + assert spec.round_trip is False + + mod, spec = get_exporter("DATA_FIELD", "TIFF (data)") + assert spec.round_trip is True + + try: + get_exporter("DATA_FIELD", "DOES_NOT_EXIST") + assert False, "Expected ValueError for unknown format" + except ValueError: + pass + + try: + get_exporter("FLOAT", "GWY") + assert False, "Expected ValueError for type/format mismatch" + except ValueError: + pass + + +def test_available_formats_includes_new_datafield_formats(): + from backend.exporters import available_formats + + formats = available_formats("DATA_FIELD") + assert "TIFF" in formats + assert "TIFF (data)" in formats + assert "GWY" in formats + assert "PNG" in formats + assert "NPZ" in formats + assert "HDF5" in formats + assert "HDF5 (Ergo)" in formats + + +def test_datafield_gwy_round_trip(): + """Writing a DataField to .gwy and reloading via the importer preserves everything.""" + from backend.importers import gwy as gwy_importer + from backend.nodes.save import Save + + rng = np.random.default_rng(7) + data = rng.standard_normal((32, 48)).astype(np.float64) * 1e-9 + field = DataField( + data=data, + xreal=3.2e-6, + yreal=2.4e-6, + xoff=1.1e-7, + yoff=-5.5e-7, + si_unit_xy="m", + si_unit_z="m", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "topo" + Save().save(filename=str(path), format="GWY", value=field) + out_path = path.with_suffix(".gwy") + assert out_path.exists() + + reloaded = gwy_importer.load(out_path) + assert len(reloaded) == 1 + rf = reloaded[0] + assert rf.data.shape == field.data.shape + assert np.allclose(rf.data, field.data) + assert np.isclose(rf.xreal, field.xreal) + assert np.isclose(rf.yreal, field.yreal) + assert np.isclose(rf.xoff, field.xoff) + assert np.isclose(rf.yoff, field.yoff) + assert rf.si_unit_xy == "m" + assert rf.si_unit_z == "m" + + # channel_names() should return the stem we used as the title + names = gwy_importer.channel_names(out_path) + assert names == ["topo"] + + +def test_datafield_tiff_data_round_trip(): + """TIFF (data) writes float64 pixels + JSON metadata; we verify both.""" + import tifffile + + from backend.nodes.save import Save + + rng = np.random.default_rng(11) + data = rng.standard_normal((24, 36)).astype(np.float64) * 1e-8 + field = DataField( + data=data, + xreal=5e-6, + yreal=3e-6, + xoff=0.0, + yoff=0.0, + si_unit_xy="m", + si_unit_z="V", + colormap="viridis", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "field" + Save().save(filename=str(path), format="TIFF (data)", value=field) + out_path = path.with_suffix(".tiff") + assert out_path.exists() + + with tifffile.TiffFile(out_path) as tif: + arr = tif.asarray() + desc = tif.pages[0].tags["ImageDescription"].value + + assert arr.dtype == np.float64 + assert arr.shape == field.data.shape + assert np.allclose(arr, field.data) + + meta = json.loads(desc)["tono"] + assert meta["xreal"] == field.xreal + assert meta["yreal"] == field.yreal + assert meta["si_unit_xy"] == "m" + assert meta["si_unit_z"] == "V" + assert meta["domain"] == "spatial" + + +def test_datafield_hdf5_generic_round_trip(): + """HDF5 (generic) writes /data + attrs that our hdf5 importer reads back.""" + from backend.importers import hdf5 as hdf5_importer + from backend.nodes.save import Save + + rng = np.random.default_rng(23) + data = rng.standard_normal((20, 28)).astype(np.float64) * 1e-7 + field = DataField( + data=data, + xreal=4.8e-6, + yreal=3.2e-6, + xoff=1.5e-7, + yoff=-2.5e-7, + si_unit_xy="m", + si_unit_z="V", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "topo" + Save().save(filename=str(path), format="HDF5", value=field) + out_path = path.with_suffix(".h5") + assert out_path.exists() + + reloaded = hdf5_importer.load(out_path) + assert len(reloaded) == 1 + rf = reloaded[0] + assert rf.data.shape == field.data.shape + assert np.allclose(rf.data, field.data) + assert np.isclose(rf.xreal, field.xreal) + assert np.isclose(rf.yreal, field.yreal) + assert np.isclose(rf.xoff, field.xoff) + assert np.isclose(rf.yoff, field.yoff) + assert rf.si_unit_xy == "m" + assert rf.si_unit_z == "V" + + +def test_datafield_hdf5_ergo_round_trip(): + """HDF5 (Ergo) writes the Asylum sidecar layout and round-trips via ergo_hdf5.""" + import h5py + + from backend.importers import ergo_hdf5 as ergo_importer + from backend.nodes.save import Save + + rng = np.random.default_rng(29) + data = rng.standard_normal((16, 24)).astype(np.float64) * 1e-9 + field = DataField( + data=data, + xreal=2.5e-6, + yreal=1.8e-6, + xoff=0.5e-7, + yoff=-1.1e-7, + si_unit_xy="m", + si_unit_z="N", + ) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "topo" + Save().save(filename=str(path), format="HDF5 (Ergo)", value=field) + out_path = path.with_suffix(".h5") + assert out_path.exists() + + # Sanity-check the layout: the dataset lives under + # Image/DataSet/Resolution 0/Frame 0/<title>/Image, and the sidecar + # group under Image/DataSetInfo/Global/Channels/<title>/ImageDims. + with h5py.File(str(out_path), "r") as f: + assert "Image/DataSet/Resolution 0/Frame 0/topo/Image" in f + dims = f["Image/DataSetInfo/Global/Channels/topo/ImageDims"] + scaling = np.asarray(dims.attrs["DimScaling"]) + assert scaling.shape == (2, 2) + # DimScaling is Y-first: [[y_start, y_end], [x_start, x_end]] + assert np.isclose(scaling[1, 1] - scaling[1, 0], field.xreal) + assert np.isclose(scaling[0, 1] - scaling[0, 0], field.yreal) + + reloaded = ergo_importer.load(out_path) + assert len(reloaded) == 1 + rf = reloaded[0] + assert rf.data.shape == field.data.shape + assert np.allclose(rf.data, field.data) + assert np.isclose(rf.xreal, field.xreal) + assert np.isclose(rf.yreal, field.yreal) + assert np.isclose(rf.xoff, field.xoff) + assert np.isclose(rf.yoff, field.yoff) + assert rf.si_unit_xy == "m" + assert rf.si_unit_z == "N" + + +def test_tiff_preview_is_still_rgb_uint8(): + """The legacy TIFF format for DATA_FIELD must keep producing 8-bit RGB.""" + import tifffile + + from backend.nodes.save import Save + + field = DataField( + data=np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float64), + xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m", + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "preview" + Save().save(filename=str(path), format="TIFF", value=field) + arr = tifffile.imread(str(path.with_suffix(".tiff"))) + assert arr.dtype == np.uint8 + assert arr.shape == (2, 2, 3)