From 7b309a8b230bbd2f96eda920c0cd1c2e698d215b Mon Sep 17 00:00:00 2001 From: matei jordache Date: Mon, 30 Mar 2026 20:33:28 -0700 Subject: [PATCH] hdf5 support --- backend/importers/__init__.py | 45 ++++ backend/importers/_base.py | 31 +++ backend/importers/array_image.py | 44 ++++ backend/importers/ergo_hdf5.py | 235 ++++++++++++++++++ backend/importers/gwy.py | 45 ++++ backend/importers/hdf5.py | 107 +++++++++ backend/importers/ibw.py | 106 +++++++++ backend/importers/sxm.py | 47 ++++ backend/nodes/helpers.py | 62 +---- backend/nodes/image.py | 161 +------------ demo | 2 +- pyproject.toml | 1 + tests/node_tests/image.py | 3 +- tests/node_tests/image_demo.py | 3 +- tests/node_tests/importers.py | 393 +++++++++++++++++++++++++++++++ 15 files changed, 1079 insertions(+), 206 deletions(-) create mode 100644 backend/importers/__init__.py create mode 100644 backend/importers/_base.py create mode 100644 backend/importers/array_image.py create mode 100644 backend/importers/ergo_hdf5.py create mode 100644 backend/importers/gwy.py create mode 100644 backend/importers/hdf5.py create mode 100644 backend/importers/ibw.py create mode 100644 backend/importers/sxm.py create mode 100644 tests/node_tests/importers.py diff --git a/backend/importers/__init__.py b/backend/importers/__init__.py new file mode 100644 index 0000000..bfc7a40 --- /dev/null +++ b/backend/importers/__init__.py @@ -0,0 +1,45 @@ +""" +File importer registry. + +Each module in this package exposes: + extensions frozenset[str] – lower-case extensions it handles + calibrated bool – True when physical dimensions are known + load(path) → list[DataField] – load all channels + channel_names(path) → list[str] – channel name strings (same order as load) + +Usage:: + + from backend.importers import get_importer, all_extensions + + importer = get_importer(".gwy") # returns the gwy module, or None +""" + +from __future__ import annotations + +from pathlib import Path +from types import ModuleType + +from backend.importers import array_image, ergo_hdf5, gwy, ibw, sxm + +_IMPORTERS: list[ModuleType] = [gwy, sxm, ibw, ergo_hdf5, array_image] + +# ext → importer module +_REGISTRY: dict[str, ModuleType] = {} +for _mod in _IMPORTERS: + for _ext in _mod.extensions: + _REGISTRY[_ext] = _mod + + +def get_importer(ext: str) -> ModuleType | None: + """Return the importer module for *ext* (e.g. '.gwy'), or None.""" + return _REGISTRY.get(ext.lower()) + + +def all_extensions() -> frozenset[str]: + """All file extensions supported across every registered importer.""" + return frozenset(_REGISTRY) + + +def calibrated_extensions() -> frozenset[str]: + """Extensions whose importers report physical calibration.""" + return frozenset(ext for ext, mod in _REGISTRY.items() if mod.calibrated) diff --git a/backend/importers/_base.py b/backend/importers/_base.py new file mode 100644 index 0000000..5d65964 --- /dev/null +++ b/backend/importers/_base.py @@ -0,0 +1,31 @@ +""" +Base protocol for file importers. + +Each importer handles one or more file extensions and implements: + load(path) → list[DataField] + channel_names(path) → list[str] (optional, falls back to generic names) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Protocol, runtime_checkable + +from backend.data_types import DataField + + +@runtime_checkable +class FileImporter(Protocol): + #: File extensions this importer handles, e.g. {".gwy"} + extensions: frozenset[str] + + #: True when physical dimensions are known (suppresses "uncalibrated" warning) + calibrated: bool + + def load(self, path: Path) -> list[DataField]: + """Load all channels from *path* and return them as DataField objects.""" + ... + + def channel_names(self, path: Path) -> list[str]: + """Return channel name strings in the same order as load().""" + ... diff --git a/backend/importers/array_image.py b/backend/importers/array_image.py new file mode 100644 index 0000000..58bafc1 --- /dev/null +++ b/backend/importers/array_image.py @@ -0,0 +1,44 @@ +""" +Importer for pixel images (PNG, TIFF, JPEG, BMP) and NumPy arrays (.npy, .npz). +These formats carry no physical calibration, so calibrated = False. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from backend.data_types import DataField + + +extensions = frozenset({".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp", ".npy", ".npz"}) +calibrated = False + + +def load(path: Path) -> list[DataField]: + ext = path.suffix.lower() + + if ext == ".npy": + arr = np.load(str(path)).astype(np.float64) + elif ext == ".npz": + npz = np.load(str(path)) + key = list(npz.files)[0] + arr = npz[key].astype(np.float64) + else: + from PIL import Image as PILImage + img = PILImage.open(str(path)) + arr = np.array(img) + if arr.dtype != np.uint8: + arr = arr.astype(np.float64) + + if arr.ndim == 3: + gray = np.mean(arr.astype(np.float64), axis=2) + else: + gray = arr.astype(np.float64) + + return [DataField(data=gray)] + + +def channel_names(path: Path) -> list[str]: + return ["field"] diff --git a/backend/importers/ergo_hdf5.py b/backend/importers/ergo_hdf5.py new file mode 100644 index 0000000..5f7e01c --- /dev/null +++ b/backend/importers/ergo_hdf5.py @@ -0,0 +1,235 @@ +""" +Importer for Asylum Research / Ergo HDF5 files (.h5, .hdf5, .he5). + +Asylum Research instruments store scan metadata in a sidecar group rather +than as dataset attributes. This importer reads physical dimensions from: + + Image/DataSetInfo/Global/Channels//ImageDims + DimScaling – (2,2) array: [[px_size_x, offset_x], [px_size_y, offset_y]] + DimExtents – pixel counts [xres, yres] (stored in a child group) + DimUnits – lateral unit strings + DataUnits – Z unit string + +If the sidecar group is absent (generic HDF5), standard dataset attributes +are used as a fallback: + + xreal / yreal – physical scan size in metres (fallback: 1e-6) + xoff / yoff – position offset in metres (fallback: 0) + si_unit_xy – lateral unit string (fallback: "m") + si_unit_z – value unit string (fallback: "m") + +Requires: + pip install h5py +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from backend.data_types import DataField + + +extensions = frozenset({".h5", ".hdf5", ".he5"}) +calibrated = True # we attempt to read physical metadata + + +def _iter_2d_datasets(h5file): + """Yield (name, dataset) for every 2-D numeric dataset in the file.""" + import h5py + + def _visit(name, obj): + if isinstance(obj, h5py.Dataset) and obj.ndim == 2 and np.issubdtype(obj.dtype, np.number): + results.append((name, obj)) + + results: list = [] + h5file.visititems(_visit) + return results + + +def _attr_str(attrs, key: str, default: str) -> str: + val = attrs.get(key) + if val is None: + return default + if isinstance(val, bytes): + return val.decode("utf-8", errors="replace").strip() or default + return str(val).strip() or default + + +def _attr_float(attrs, key: str, default: float) -> float: + val = attrs.get(key) + if val is None: + return default + try: + return float(val) + except (TypeError, ValueError): + return default + + +def _ar_image_dims(f, ds_name: str) -> dict | None: + """ + Look up Asylum Research ImageDims metadata for a dataset. + + AR .h5 files store scan dimensions in a sibling group rather than as + dataset attributes. Given a dataset path like: + "Image/DataSet/Resolution 0/Frame 0/Adhesion:Retrace/Image" + the channel name is the second-to-last component ("Adhesion:Retrace"), + and the metadata lives at: + "Image/DataSetInfo/Global/Channels//ImageDims" + + Returns a dict with xreal, yreal, xoff, yoff, si_unit_xy, si_unit_z, + or None if the group isn't found. + """ + import h5py + + parts = ds_name.split("/") + if len(parts) < 2: + return None + channel = parts[-2] + + dims_path = f"Image/DataSetInfo/Global/Channels/{channel}/ImageDims" + grp = f.get(dims_path) + if not isinstance(grp, h5py.Group): + return None + + scaling = grp.attrs.get("DimScaling") # shape (2, 2): [[px_x, off_x], [px_y, off_y]] + dim_units = grp.attrs.get("DimUnits") # array of unit strings, e.g. ['m', 'm'] + data_units = grp.attrs.get("DataUnits") # Z unit string, e.g. 'N' + + if scaling is None or np.asarray(scaling).shape != (2, 2): + return None + + scaling = np.asarray(scaling, dtype=np.float64) + px_x, off_x = float(scaling[0, 0]), float(scaling[0, 1]) + px_y, off_y = float(scaling[1, 0]), float(scaling[1, 1]) + + # DimExtents gives pixel counts; use to compute total physical size. + extents_grp = None + for child_name in grp: + child = grp[child_name] + if isinstance(child, h5py.Group) and "DimExtents" in child.attrs: + extents_grp = child + break + + xres, yres = 1, 1 + if extents_grp is not None: + ext = np.asarray(extents_grp.attrs["DimExtents"]) + if ext.size >= 2: + xres, yres = int(ext[0]), int(ext[1]) + + def _decode(raw, default="m") -> str: + if raw is None: + return default + if hasattr(raw, "__iter__") and not isinstance(raw, (str, bytes)): + raw = list(raw)[0] if len(raw) else default + if isinstance(raw, bytes): + return raw.decode("utf-8", errors="replace").strip() or default + return str(raw).strip() or default + + return { + "xreal": abs(px_x * xres) or 1e-6, + "yreal": abs(px_y * yres) or 1e-6, + "xoff": off_x, + "yoff": off_y, + "si_unit_xy": _decode(dim_units[0] if dim_units is not None and len(dim_units) >= 1 else None), + "si_unit_z": _decode(data_units), + } + + +def load(path: Path) -> list[DataField]: + try: + import h5py + except ImportError: + raise ImportError("Install 'h5py' to load HDF5 files: pip install h5py") + + with h5py.File(str(path), "r") as f: + datasets = _iter_2d_datasets(f) + if not datasets: + raise ValueError(f"No 2-D numeric datasets found in {path.name}") + + fields = [] + for name, ds in datasets: + data = np.asarray(ds, dtype=np.float64) + + # Try Asylum Research sidecar metadata first, then dataset attrs. + ar = _ar_image_dims(f, name) + if ar: + fields.append(DataField( + data=data, + xreal=ar["xreal"], yreal=ar["yreal"], + xoff=ar["xoff"], yoff=ar["yoff"], + si_unit_xy=ar["si_unit_xy"], + si_unit_z=ar["si_unit_z"], + )) + else: + attrs = ds.attrs + fields.append(DataField( + data=data, + xreal=_attr_float(attrs, "xreal", 1e-6), + yreal=_attr_float(attrs, "yreal", 1e-6), + xoff=_attr_float(attrs, "xoff", 0.0), + yoff=_attr_float(attrs, "yoff", 0.0), + si_unit_xy=_attr_str(attrs, "si_unit_xy", "m"), + si_unit_z=_attr_str(attrs, "si_unit_z", "m"), + )) + return fields + + +def _display_names(full_names: list[str]) -> list[str]: + """ + Derive short display names from HDF5 dataset paths. + + Rules (all comparisons case-insensitive): + 1. All thumbnail datasets are filtered out. + 2. Display name = second-to-last path component (drops the leaf like + "/image" or "/thumbnail"). + 3. "global" channels sort to the front. + 4. If two kept datasets share the same second-to-last name, the leaf is + appended to disambiguate. + + Returns a list in sorted order (not parallel to full_names). + """ + from collections import Counter + + # Filter out all thumbnail datasets. + kept: list[tuple[int, str]] = [] # (original index, full name) + for i, name in enumerate(full_names): + if name.split("/")[-1].lower() == "thumbnail": + continue + kept.append((i, name)) + + # Sort: "global" second-to-last first, then alphabetical. + def _sort_key(item: tuple[int, str]) -> tuple[int, str]: + parts = item[1].split("/") + second_last = parts[-2].lower() if len(parts) >= 2 else parts[-1].lower() + return (0 if second_last == "global" else 1, second_last) + + kept.sort(key=_sort_key) + + # Build short names (second-to-last), disambiguate clashes. + short = [ + (parts[-2] if len(parts := name.split("/")) >= 2 else parts[-1]) + for _, name in kept + ] + counts = Counter(short) + disambiguated = [ + f"{s}/{name.split('/')[-1]}" if counts[s] > 1 else s + for s, (_, name) in zip(short, kept) + ] + + return disambiguated + + +def channel_names(path: Path) -> list[str]: + try: + import h5py + except ImportError: + return [] + try: + with h5py.File(str(path), "r") as f: + datasets = _iter_2d_datasets(f) + full_names = [name for name, _ in datasets] + return [n for n in _display_names(full_names) if n is not None] + except Exception: + return [] diff --git a/backend/importers/gwy.py b/backend/importers/gwy.py new file mode 100644 index 0000000..209f9fb --- /dev/null +++ b/backend/importers/gwy.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from backend.data_types import DataField + + +extensions = frozenset({".gwy"}) +calibrated = True + + +def load(path: Path) -> list[DataField]: + import gwyfile + obj = gwyfile.load(str(path)) + channels = gwyfile.util.get_datafields(obj) + if not channels: + raise ValueError(f"No data channels found in {path.name}") + + fields = [] + for ch in channels.values(): + data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) + fields.append(DataField( + data=data, + xreal=float(ch.xreal), + yreal=float(ch.yreal), + xoff=float(getattr(ch, "xoff", 0.0)), + yoff=float(getattr(ch, "yoff", 0.0)), + si_unit_xy="m", + si_unit_z="m", + )) + return fields + + +def channel_names(path: Path) -> list[str]: + import gwyfile + try: + obj = gwyfile.load(str(path)) + channels = gwyfile.util.get_datafields(obj) + if channels: + return list(channels.keys()) + except Exception: + pass + return [] diff --git a/backend/importers/hdf5.py b/backend/importers/hdf5.py new file mode 100644 index 0000000..b670745 --- /dev/null +++ b/backend/importers/hdf5.py @@ -0,0 +1,107 @@ +""" +Generic HDF5 importer (.h5, .hdf5, .he5). + +Each 2-D dataset found in the file is returned as a DataField. Physical +dimensions are read from standard dataset attributes if present: + + xreal / yreal – physical scan size in metres (fallback: 1e-6) + xoff / yoff – position offset in metres (fallback: 0) + si_unit_xy – lateral unit string (fallback: "m") + si_unit_z – value unit string (fallback: "m") + +For Asylum Research / Ergo format files (which store scan metadata in a +sidecar group rather than as dataset attributes), use the ergo_hdf5 importer. + +Requires: + pip install h5py +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from backend.data_types import DataField + + +extensions = frozenset({".h5", ".hdf5", ".he5"}) +calibrated = True # we attempt to read physical metadata + + +def _iter_2d_datasets(h5file): + """Yield (name, dataset) for every 2-D numeric dataset in the file.""" + import h5py + + def _visit(name, obj): + if isinstance(obj, h5py.Dataset) and obj.ndim == 2 and np.issubdtype(obj.dtype, np.number): + results.append((name, obj)) + + results: list = [] + h5file.visititems(_visit) + return results + + +def _attr_str(attrs, key: str, default: str) -> str: + val = attrs.get(key) + if val is None: + return default + if isinstance(val, bytes): + return val.decode("utf-8", errors="replace").strip() or default + return str(val).strip() or default + + +def _attr_float(attrs, key: str, default: float) -> float: + val = attrs.get(key) + if val is None: + return default + try: + return float(val) + except (TypeError, ValueError): + return default + + +def load(path: Path) -> list[DataField]: + try: + import h5py + except ImportError: + raise ImportError("Install 'h5py' to load HDF5 files: pip install h5py") + + with h5py.File(str(path), "r") as f: + datasets = _iter_2d_datasets(f) + if not datasets: + raise ValueError(f"No 2-D numeric datasets found in {path.name}") + + fields = [] + for name, ds in datasets: + data = np.asarray(ds, dtype=np.float64) + attrs = ds.attrs + fields.append(DataField( + data=data, + xreal=_attr_float(attrs, "xreal", 1e-6), + yreal=_attr_float(attrs, "yreal", 1e-6), + xoff=_attr_float(attrs, "xoff", 0.0), + yoff=_attr_float(attrs, "yoff", 0.0), + si_unit_xy=_attr_str(attrs, "si_unit_xy", "m"), + si_unit_z=_attr_str(attrs, "si_unit_z", "m"), + )) + return fields + + +def channel_names(path: Path) -> list[str]: + try: + import h5py + except ImportError: + return [] + try: + with h5py.File(str(path), "r") as f: + datasets = _iter_2d_datasets(f) + # Return second-to-last component as display name, or full name for + # top-level datasets. + names = [] + for full_name, _ in datasets: + parts = full_name.split("/") + names.append(parts[-2] if len(parts) >= 2 else parts[-1]) + return names + except Exception: + return [] diff --git a/backend/importers/ibw.py b/backend/importers/ibw.py new file mode 100644 index 0000000..57177c5 --- /dev/null +++ b/backend/importers/ibw.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from backend.data_types import DataField + + +extensions = frozenset({".ibw"}) +calibrated = True + + +def _load_ibw_raw(path: Path): + import numpy as _np + if not hasattr(_np, "complex"): + setattr(_np, "complex", complex) + try: + from igor.binarywave import load as load_ibw + except ImportError: + raise ImportError("Install 'igor' to load .ibw files: pip install igor") + return load_ibw(str(path)) + + +def _decode_unit(raw_unit) -> str: + if raw_unit is None: + return "m" + if isinstance(raw_unit, bytes): + return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" + if isinstance(raw_unit, np.ndarray): + return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" + return str(raw_unit).strip() or "m" + + +def load(path: Path) -> list[DataField]: + wave = _load_ibw_raw(path) + wdata = wave["wave"] + header = wdata["wave_header"] + raw = wdata["wData"] + + n_channels = raw.shape[2] if raw.ndim >= 3 else 1 + sfA = header.get("sfA", None) + + dim_units_raw = header.get("dimUnits", None) + data_units_raw = header.get("dataUnits", None) + + if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2: + si_unit_xy = _decode_unit(dim_units_raw[0]) + elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0: + si_unit_xy = _decode_unit(dim_units_raw[0]) + else: + si_unit_xy = _decode_unit(dim_units_raw) + + si_unit_z = _decode_unit(data_units_raw) + + fields = [] + for ch_idx in range(n_channels): + if raw.ndim >= 3: + ch_data = raw[:, :, ch_idx] + elif raw.ndim == 1: + ch_data = raw.reshape(-1, 1) + else: + ch_data = raw + + data = np.flipud(ch_data.T).astype(np.float64) + yres, xres = data.shape + + if sfA is not None and len(sfA) >= 2: + xreal = abs(float(sfA[0]) * xres) or 1e-6 + yreal = abs(float(sfA[1]) * yres) or 1e-6 + else: + hsA = header.get("hsA", 0.0) + xreal = abs(float(hsA) * xres) or 1e-6 + yreal = xreal * (yres / xres) if xres else 1e-6 + + fields.append(DataField( + data=data, xreal=xreal, yreal=yreal, + si_unit_xy=si_unit_xy, si_unit_z=si_unit_z, + )) + return fields + + +def channel_names(path: Path) -> list[str]: + try: + wave = _load_ibw_raw(path) + wdata = wave["wave"] + raw = wdata["wData"] + labels = wdata.get("labels", None) + + if raw.ndim >= 3 and labels: + dim_idx = min(2, len(labels) - 1) + if dim_idx >= 0 and labels[dim_idx]: + decoded = [] + for lbl in labels[dim_idx]: + if lbl: + name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip() + if name: + decoded.append(name) + if decoded: + return decoded + + if raw.ndim >= 3 and raw.shape[2] > 1: + return [f"ch{i}" for i in range(raw.shape[2])] + except Exception: + pass + return [] diff --git a/backend/importers/sxm.py b/backend/importers/sxm.py new file mode 100644 index 0000000..cd085fb --- /dev/null +++ b/backend/importers/sxm.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from backend.data_types import DataField + + +extensions = frozenset({".sxm"}) +calibrated = True + + +def load(path: Path) -> list[DataField]: + import nanonispy as nap + sxm = nap.read.Scan(str(path)) + signals = sxm.signals + if not signals: + raise ValueError(f"No signals found in {path.name}") + + scan_range = sxm.header.get("scan_range", [1e-6, 1e-6]) + + fields = [] + for sig in signals.values(): + data = sig.get("forward", list(sig.values())[0]) + data = np.asarray(data, dtype=np.float64) + if data.ndim != 2: + data = data.reshape(data.shape[-2], data.shape[-1]) + fields.append(DataField( + data=data, + xreal=float(scan_range[0]), + yreal=float(scan_range[1]), + si_unit_xy="m", + si_unit_z="m", + )) + return fields + + +def channel_names(path: Path) -> list[str]: + import nanonispy as nap + try: + sxm = nap.read.Scan(str(path)) + if sxm.signals: + return list(sxm.signals.keys()) + except Exception: + pass + return [] diff --git a/backend/nodes/helpers.py b/backend/nodes/helpers.py index 45548c8..70d480b 100644 --- a/backend/nodes/helpers.py +++ b/backend/nodes/helpers.py @@ -527,13 +527,9 @@ OUTPUT_DIR = output_dir() _MAX_SAVE_FIELDS = 8 -_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz", - ".gwy", ".sxm", ".ibw"} +from backend.importers import all_extensions, get_importer -_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"} -_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"} -_ARRAY_EXTENSIONS = {".npy", ".npz"} -_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS +_PATH_COMPATIBLE_EXTENSIONS = all_extensions() def _resolve_path(filepath: str): @@ -554,52 +550,16 @@ def list_channels(filepath: str) -> list[dict]: if not path.exists(): return [{"name": "field", "type": "DATA_FIELD"}] - ext = path.suffix.lower() - - if ext == ".gwy": - try: - import gwyfile - obj = gwyfile.load(str(path)) - channels = gwyfile.util.get_datafields(obj) - if channels: - return [{"name": k, "type": "DATA_FIELD"} for k in channels] - except Exception: - pass - return [{"name": "field", "type": "DATA_FIELD"}] - - if ext == ".sxm": - try: - import nanonispy as nap - sxm = nap.read.Scan(str(path)) - if sxm.signals: - return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals] - except Exception: - pass - return [{"name": "field", "type": "DATA_FIELD"}] - - if ext == ".ibw": - try: - load_ibw = _import_ibw_loader() - wave = load_ibw(str(path)) - raw = wave["wave"]["wData"] - labels = wave["wave"].get("labels", None) - if raw.ndim >= 3 and labels: - dim_idx = min(2, len(labels) - 1) - if dim_idx >= 0 and labels[dim_idx]: - decoded = [] - for lbl in labels[dim_idx]: - if lbl: - name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip() - if name: - decoded.append(name) - if decoded: - return [{"name": n, "type": "DATA_FIELD"} for n in decoded] - if raw.ndim >= 3 and raw.shape[2] > 1: - return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])] - except Exception: - pass + importer = get_importer(path.suffix.lower()) + if importer is None: return [{"name": "field", "type": "DATA_FIELD"}] + try: + names = importer.channel_names(path) + if names: + return [{"name": n, "type": "DATA_FIELD"} for n in names] + except Exception: + pass return [{"name": "field", "type": "DATA_FIELD"}] @@ -624,7 +584,7 @@ def _list_demo_files() -> list[str]: return [] return sorted( f.name for f in DEMO_DIR.iterdir() - if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS + if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _PATH_COMPATIBLE_EXTENSIONS ) diff --git a/backend/nodes/image.py b/backend/nodes/image.py index a072ba7..1ff610e 100644 --- a/backend/nodes/image.py +++ b/backend/nodes/image.py @@ -1,14 +1,12 @@ from __future__ import annotations from functools import lru_cache -import numpy as np from pathlib import Path -import nanonispy as nap -import gwyfile from backend.node_registry import register_node from backend.execution_context import emit_warning from backend.data_types import COLORMAPS, DataField, resolve_colormap_input -from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS, _import_ibw_loader +from backend.nodes.helpers import _resolve_path +from backend.importers import get_importer, calibrated_extensions @register_node(display_name="Image") @@ -34,7 +32,7 @@ class Image: DESCRIPTION = ( "Load any supported file. " - "SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; " + "SPM formats (.gwy, .sxm, .ibw) and HDF5 (.h5, .hdf5) provide calibrated dimensions; " "each channel gets its own output. " "Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields." ) @@ -65,158 +63,17 @@ class Image: for field in fields: field.colormap = resolved_colormap - if ext not in _SPM_EXTENSIONS: - self._send_warning("Uncalibrated data — no physical dimensions.") + if ext not in calibrated_extensions(): + emit_warning("Uncalibrated data — no physical dimensions.") return (str(path_obj.resolve()),) + fields - def _send_warning(self, message: str): - emit_warning(message) - @staticmethod @lru_cache(maxsize=32) def _load_fields_cached(path_str: str, mtime_ns: int, size_bytes: int) -> tuple[DataField, ...]: path = Path(path_str) ext = path.suffix.lower() - if ext in _SPM_EXTENSIONS: - return tuple(Image._load_spm_all(path, ext)) - return (Image._load_image_or_array(path, ext),) - - @staticmethod - def _load_spm_all(path: Path, ext: str) -> list[DataField]: - if ext == ".gwy": - return Image._load_gwy_all(path) - elif ext == ".sxm": - return Image._load_sxm_all(path) - elif ext == ".ibw": - return Image._load_ibw_all(path) - else: - raise ValueError(f"Unsupported SPM format: {ext}") - - @staticmethod - def _load_gwy_all(path: Path) -> list[DataField]: - obj = gwyfile.load(str(path)) - channels = gwyfile.util.get_datafields(obj) - if not channels: - raise ValueError(f"No data channels found in {path.name}") - - fields = [] - for ch in channels.values(): - data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres) - fields.append(DataField( - data=data, - xreal=float(ch.xreal), - yreal=float(ch.yreal), - xoff=float(getattr(ch, "xoff", 0.0)), - yoff=float(getattr(ch, "yoff", 0.0)), - si_unit_xy="m", - si_unit_z="m", - )) - return fields - - @staticmethod - def _load_sxm_all(path: Path) -> list[DataField]: - sxm = nap.read.Scan(str(path)) - signals = sxm.signals - if not signals: - raise ValueError(f"No signals found in {path.name}") - - header = sxm.header - scan_range = header.get("scan_range", [1e-6, 1e-6]) - - fields = [] - for sig in signals.values(): - data = sig.get("forward", list(sig.values())[0]) - data = np.asarray(data, dtype=np.float64) - if data.ndim != 2: - data = data.reshape(data.shape[-2], data.shape[-1]) - fields.append(DataField( - data=data, - xreal=float(scan_range[0]), - yreal=float(scan_range[1]), - si_unit_xy="m", - si_unit_z="m", - )) - return fields - - @staticmethod - def _load_ibw_all(path: Path) -> list[DataField]: - load_ibw = _import_ibw_loader() - wave = load_ibw(str(path)) - wdata = wave["wave"] - header = wdata["wave_header"] - raw = wdata["wData"] - - n_channels = raw.shape[2] if raw.ndim >= 3 else 1 - - sfA = header.get("sfA", None) - - def _decode_unit(raw_unit): - if raw_unit is None: - return "m" - if isinstance(raw_unit, bytes): - return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" - if isinstance(raw_unit, np.ndarray): - return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m" - return str(raw_unit).strip() or "m" - - dim_units_raw = header.get("dimUnits", None) - data_units_raw = header.get("dataUnits", None) - - if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2: - si_unit_xy = _decode_unit(dim_units_raw[0]) - elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0: - si_unit_xy = _decode_unit(dim_units_raw[0]) - else: - si_unit_xy = _decode_unit(dim_units_raw) - - si_unit_z = _decode_unit(data_units_raw) - - fields = [] - for ch_idx in range(n_channels): - if raw.ndim >= 3: - ch_data = raw[:, :, ch_idx] - elif raw.ndim == 1: - ch_data = raw.reshape(-1, 1) - else: - ch_data = raw - - data = np.flipud(ch_data.T).astype(np.float64) - yres, xres = data.shape - - if sfA is not None and len(sfA) >= 2: - xreal = abs(float(sfA[0]) * xres) or 1e-6 - yreal = abs(float(sfA[1]) * yres) or 1e-6 - else: - hsA = header.get("hsA", 0.0) - xreal = abs(float(hsA) * xres) or 1e-6 - yreal = xreal * (yres / xres) if xres else 1e-6 - - fields.append(DataField( - data=data, xreal=xreal, yreal=yreal, - si_unit_xy=si_unit_xy, si_unit_z=si_unit_z, - )) - - return fields - - @staticmethod - def _load_image_or_array(path: Path, ext: str) -> DataField: - if ext == ".npy": - arr = np.load(str(path)).astype(np.float64) - elif ext == ".npz": - npz = np.load(str(path)) - key = list(npz.files)[0] - arr = npz[key].astype(np.float64) - else: - from PIL import Image as PILImage - img = PILImage.open(str(path)) - arr = np.array(img) - if arr.dtype != np.uint8: - arr = arr.astype(np.float64) - - if arr.ndim == 3: - gray = np.mean(arr.astype(np.float64), axis=2) - else: - gray = arr.astype(np.float64) - - return DataField(data=gray) + importer = get_importer(ext) + if importer is None: + raise ValueError(f"Unsupported file format: {ext}") + return tuple(importer.load(path)) diff --git a/demo b/demo index 9c82c2b..e3dcea6 160000 --- a/demo +++ b/demo @@ -1 +1 @@ -Subproject commit 9c82c2ba8cf3b8e3ba3c8a4c31744674c1affed9 +Subproject commit e3dcea633a0a9e0bc0b51982676957be7792fae8 diff --git a/pyproject.toml b/pyproject.toml index 15d7bdd..0b445e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ requires-python = ">=3.10" dependencies = [ "aiohttp>=3.9,<4", "gwyfile>=0.2", + "h5py>=3.10,<4", "igor>=0.3", "matplotlib>=3.8,<4", "nanonispy>=1.1", diff --git a/tests/node_tests/image.py b/tests/node_tests/image.py index 05004a5..13d0f1b 100644 --- a/tests/node_tests/image.py +++ b/tests/node_tests/image.py @@ -83,7 +83,8 @@ def test_load_file_cache(): path = os.path.join(tmpdir, "cached.npy") np.save(path, data) - with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: + import backend.importers.array_image as _ai + with patch.object(_ai, "load", wraps=_ai.load) as loader: _, first = node.load(filename=path) _, second = node.load(filename=path) assert loader.call_count == 1 diff --git a/tests/node_tests/image_demo.py b/tests/node_tests/image_demo.py index be9da88..be88e3d 100644 --- a/tests/node_tests/image_demo.py +++ b/tests/node_tests/image_demo.py @@ -53,7 +53,8 @@ def test_load_cache(): Image._load_fields_cached.cache_clear() with patch("backend.nodes.image_demo.DEMO_DIR", FIXTURES): - with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader: + import backend.importers.array_image as _ai + with patch.object(_ai, "load", wraps=_ai.load) as loader: _, first = node.load(name="nanoparticles.npy") _, second = node.load(name="nanoparticles.npy") assert loader.call_count == 1 diff --git a/tests/node_tests/importers.py b/tests/node_tests/importers.py new file mode 100644 index 0000000..1deb421 --- /dev/null +++ b/tests/node_tests/importers.py @@ -0,0 +1,393 @@ +""" +Tests for backend/importers/ — the importer registry and each importer module. +""" + +import os +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +from backend.data_types import DataField + +FIXTURES = Path(__file__).parent.parent / "output" + + +# ── Registry ───────────────────────────────────────────────────────────────── + +class TestRegistry: + def test_get_importer_known_extensions(self): + from backend.importers import get_importer + for ext in (".gwy", ".sxm", ".ibw", ".npy", ".npz", + ".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp", + ".h5", ".hdf5", ".he5"): + assert get_importer(ext) is not None, f"No importer registered for {ext}" + + def test_get_importer_unknown_extension(self): + from backend.importers import get_importer + assert get_importer(".xyz") is None + assert get_importer(".csv") is None + + def test_get_importer_case_insensitive(self): + from backend.importers import get_importer + assert get_importer(".NPY") is get_importer(".npy") + assert get_importer(".GWY") is get_importer(".gwy") + + def test_all_extensions_returns_frozenset(self): + from backend.importers import all_extensions + exts = all_extensions() + assert isinstance(exts, frozenset) + assert ".npy" in exts + assert ".gwy" in exts + + def test_calibrated_extensions(self): + from backend.importers import calibrated_extensions + cal = calibrated_extensions() + # SPM and HDF5 are calibrated + assert ".gwy" in cal + assert ".sxm" in cal + assert ".ibw" in cal + assert ".h5" in cal + # Images/arrays are not + assert ".png" not in cal + assert ".npy" not in cal + + def test_each_importer_has_required_interface(self): + from backend.importers import _IMPORTERS + for mod in _IMPORTERS: + assert hasattr(mod, "extensions"), f"{mod.__name__} missing extensions" + assert hasattr(mod, "calibrated"), f"{mod.__name__} missing calibrated" + assert callable(getattr(mod, "load", None)), f"{mod.__name__} missing load()" + assert callable(getattr(mod, "channel_names", None)), f"{mod.__name__} missing channel_names()" + assert isinstance(mod.extensions, frozenset) + assert isinstance(mod.calibrated, bool) + + +# ── array_image importer ────────────────────────────────────────────────────── + +class TestArrayImageImporter: + def setup_method(self): + import backend.importers.array_image as mod + self.mod = mod + + def test_npy_load(self): + with tempfile.TemporaryDirectory() as tmp: + data = np.random.default_rng(0).standard_normal((32, 48)) + path = Path(tmp) / "data.npy" + np.save(path, data) + fields = self.mod.load(path) + assert len(fields) == 1 + assert isinstance(fields[0], DataField) + assert np.allclose(fields[0].data, data) + + def test_npz_load(self): + with tempfile.TemporaryDirectory() as tmp: + data = np.random.default_rng(1).standard_normal((16, 16)) + path = Path(tmp) / "data.npz" + np.savez(path, arr=data) + fields = self.mod.load(path) + assert len(fields) == 1 + assert np.allclose(fields[0].data, data) + + def test_png_grayscale(self): + from PIL import Image as PILImage + with tempfile.TemporaryDirectory() as tmp: + arr = np.random.default_rng(2).integers(0, 256, (24, 32), dtype=np.uint8) + path = Path(tmp) / "gray.png" + PILImage.fromarray(arr).save(path) + fields = self.mod.load(path) + assert len(fields) == 1 + assert fields[0].data.shape == (24, 32) + assert fields[0].data.dtype == np.float64 + + def test_png_rgb_converted_to_grayscale(self): + from PIL import Image as PILImage + with tempfile.TemporaryDirectory() as tmp: + arr = np.random.default_rng(3).integers(0, 256, (16, 16, 3), dtype=np.uint8) + path = Path(tmp) / "rgb.png" + PILImage.fromarray(arr).save(path) + fields = self.mod.load(path) + assert fields[0].data.shape == (16, 16) + + def test_not_calibrated(self): + assert self.mod.calibrated is False + + def test_channel_names(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "x.npy" + np.save(path, np.zeros((4, 4))) + assert self.mod.channel_names(path) == ["field"] + + def test_fixture_npy(self): + path = FIXTURES / "nanoparticles.npy" + if not path.exists(): + pytest.skip("fixture not available") + fields = self.mod.load(path) + assert len(fields) == 1 + assert fields[0].data.ndim == 2 + + +# ── ibw importer ───────────────────────────────────────────────────────────── + +class TestIBWImporter: + def setup_method(self): + import backend.importers.ibw as mod + self.mod = mod + self.fixture = FIXTURES / "Bacteria.ibw" + + def test_calibrated(self): + assert self.mod.calibrated is True + + def test_extensions(self): + assert ".ibw" in self.mod.extensions + + def test_load_fixture(self): + if not self.fixture.exists(): + pytest.skip("Bacteria.ibw fixture not available") + fields = self.mod.load(self.fixture) + assert len(fields) == 4 + for f in fields: + assert isinstance(f, DataField) + assert f.data.ndim == 2 + assert f.data.dtype == np.float64 + assert f.xreal > 0 + assert f.yreal > 0 + + def test_channel_names_fixture(self): + if not self.fixture.exists(): + pytest.skip("Bacteria.ibw fixture not available") + names = self.mod.channel_names(self.fixture) + assert len(names) == 4 + assert all(isinstance(n, str) for n in names) + + +# ── hdf5 importer (generic) ────────────────────────────────────────────────── + +class TestHDF5Importer: + def setup_method(self): + pytest.importorskip("h5py") + import backend.importers.hdf5 as mod + self.mod = mod + + def test_calibrated(self): + assert self.mod.calibrated is True + + def test_extensions(self): + assert {".h5", ".hdf5", ".he5"} <= self.mod.extensions + + def test_load_single_channel(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + data = np.random.default_rng(10).standard_normal((32, 32)) + path = Path(tmp) / "test.h5" + with h5py.File(path, "w") as f: + f.create_dataset("channel", data=data) + fields = self.mod.load(path) + assert len(fields) == 1 + assert np.allclose(fields[0].data, data) + assert fields[0].data.dtype == np.float64 + + def test_load_physical_attrs(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "cal.h5" + with h5py.File(path, "w") as f: + ds = f.create_dataset("topo", data=np.zeros((16, 16))) + ds.attrs["xreal"] = 5e-6 + ds.attrs["yreal"] = 5e-6 + ds.attrs["si_unit_z"] = "V" + fields = self.mod.load(path) + assert fields[0].xreal == pytest.approx(5e-6) + assert fields[0].si_unit_z == "V" + + def test_load_fallback_attrs(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "fallback.h5" + with h5py.File(path, "w") as f: + f.create_dataset("channel", data=np.zeros((8, 8))) + fields = self.mod.load(path) + assert fields[0].xreal == pytest.approx(1e-6) + assert fields[0].si_unit_xy == "m" + + def test_empty_file_raises(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "empty.h5" + with h5py.File(path, "w") as f: + f.create_dataset("vec", data=np.zeros((10,))) + with pytest.raises(ValueError, match="No 2-D"): + self.mod.load(path) + + def test_channel_names_top_level(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "named.h5" + with h5py.File(path, "w") as f: + f.create_dataset("height", data=np.zeros((8, 8))) + f.create_dataset("phase", data=np.zeros((8, 8))) + names = self.mod.channel_names(path) + assert set(names) == {"height", "phase"} + + def test_channel_names_nested(self): + # Nested datasets return second-to-last path component. + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "nested.h5" + with h5py.File(path, "w") as f: + f.create_dataset("scan/height/data", data=np.zeros((8, 8))) + f.create_dataset("scan/phase/data", data=np.zeros((8, 8))) + names = self.mod.channel_names(path) + assert set(names) == {"height", "phase"} + + +# ── ergo_hdf5 importer (Asylum Research / Ergo format) ─────────────────────── + +class TestErgoHDF5Importer: + def setup_method(self): + pytest.importorskip("h5py") + import backend.importers.ergo_hdf5 as mod + self.mod = mod + + def _make_h5(self, tmp: Path, data: np.ndarray, name: str = "channel", + attrs: dict | None = None) -> Path: + import h5py + path = tmp / "test.h5" + with h5py.File(path, "w") as f: + ds = f.create_dataset(name, data=data) + for k, v in (attrs or {}).items(): + ds.attrs[k] = v + return path + + def test_calibrated(self): + assert self.mod.calibrated is True + + def test_extensions(self): + assert {".h5", ".hdf5", ".he5"} <= self.mod.extensions + + def test_load_single_channel(self): + with tempfile.TemporaryDirectory() as tmp: + data = np.random.default_rng(10).standard_normal((32, 32)) + path = self._make_h5(Path(tmp), data) + fields = self.mod.load(path) + assert len(fields) == 1 + assert np.allclose(fields[0].data, data) + assert fields[0].data.dtype == np.float64 + + def test_load_physical_attrs(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + data = np.zeros((16, 16)) + path = Path(tmp) / "cal.h5" + with h5py.File(path, "w") as f: + ds = f.create_dataset("topo", data=data) + ds.attrs["xreal"] = 5e-6 + ds.attrs["yreal"] = 5e-6 + ds.attrs["si_unit_xy"] = "m" + ds.attrs["si_unit_z"] = "V" + fields = self.mod.load(path) + assert fields[0].xreal == pytest.approx(5e-6) + assert fields[0].yreal == pytest.approx(5e-6) + assert fields[0].si_unit_z == "V" + + def test_load_fallback_attrs(self): + with tempfile.TemporaryDirectory() as tmp: + data = np.zeros((8, 8)) + path = self._make_h5(Path(tmp), data) + fields = self.mod.load(path) + # Default fallbacks when no AR sidecar is present + assert fields[0].xreal == pytest.approx(1e-6) + assert fields[0].si_unit_xy == "m" + + def test_load_multiple_channels(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "multi.h5" + data_a = np.ones((16, 16)) + data_b = np.zeros((16, 16)) + with h5py.File(path, "w") as f: + f.create_dataset("height", data=data_a) + f.create_dataset("phase", data=data_b) + fields = self.mod.load(path) + assert len(fields) == 2 + shapes = {f.data.shape for f in fields} + assert shapes == {(16, 16)} + + def test_ignores_non_2d_datasets(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "mixed.h5" + with h5py.File(path, "w") as f: + f.create_dataset("topo", data=np.zeros((16, 16))) + f.create_dataset("vector", data=np.zeros((10,))) # 1-D, ignored + f.create_dataset("volume", data=np.zeros((4, 4, 4))) # 3-D, ignored + fields = self.mod.load(path) + assert len(fields) == 1 + + def test_empty_file_raises(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "empty.h5" + with h5py.File(path, "w") as f: + f.create_dataset("vec", data=np.zeros((10,))) + with pytest.raises(ValueError, match="No 2-D"): + self.mod.load(path) + + def test_channel_names_top_level(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "named.h5" + with h5py.File(path, "w") as f: + f.create_dataset("height", data=np.zeros((8, 8))) + f.create_dataset("phase", data=np.zeros((8, 8))) + names = self.mod.channel_names(path) + assert set(names) == {"height", "phase"} + + def test_channel_names_strips_leaf(self): + # "/image" leaf is stripped; display name is second-to-last component. + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "nested.h5" + with h5py.File(path, "w") as f: + f.create_dataset("scan/adhesion:retrace/image", data=np.zeros((8, 8))) + f.create_dataset("scan/phase:retrace/image", data=np.zeros((8, 8))) + names = self.mod.channel_names(path) + assert set(names) == {"adhesion:retrace", "phase:retrace"} + + def test_channel_names_thumbnails_always_filtered(self): + # All thumbnail datasets are hidden, including global ones. + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "thumbs.h5" + with h5py.File(path, "w") as f: + f.create_dataset("data/adhesion/image", data=np.zeros((8, 8))) + f.create_dataset("info/channels/adhesion/thumbnail", data=np.zeros((8, 8))) + f.create_dataset("info/global/thumbnail", data=np.zeros((8, 8))) + names = self.mod.channel_names(path) + assert names == ["adhesion"] + + def test_channel_names_sorted_alphabetically(self): + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "order.h5" + with h5py.File(path, "w") as f: + f.create_dataset("data/zzz/image", data=np.zeros((8, 8))) + f.create_dataset("data/aaa/image", data=np.zeros((8, 8))) + names = self.mod.channel_names(path) + assert names == ["aaa", "zzz"] + + def test_channel_names_deduplication(self): + # Two kept datasets with the same second-to-last name get disambiguated. + import h5py + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "dup.h5" + with h5py.File(path, "w") as f: + f.create_dataset("dataset/adhesion/imageA", data=np.zeros((8, 8))) + f.create_dataset("datasetinfo/adhesion/imageB", data=np.zeros((8, 8))) + names = self.mod.channel_names(path) + assert set(names) == {"adhesion/imageA", "adhesion/imageB"} + + def test_he5_extension_registered(self): + from backend.importers import get_importer + assert get_importer(".he5") is self.mod