hdf5 support

This commit is contained in:
2026-03-30 20:33:28 -07:00
parent 53e43e8761
commit 7b309a8b23
15 changed files with 1079 additions and 206 deletions

View File

@@ -0,0 +1,45 @@
"""
File importer registry.
Each module in this package exposes:
extensions frozenset[str] lower-case extensions it handles
calibrated bool True when physical dimensions are known
load(path) → list[DataField] load all channels
channel_names(path) → list[str] channel name strings (same order as load)
Usage::
from backend.importers import get_importer, all_extensions
importer = get_importer(".gwy") # returns the gwy module, or None
"""
from __future__ import annotations
from pathlib import Path
from types import ModuleType
from backend.importers import array_image, ergo_hdf5, gwy, ibw, sxm
_IMPORTERS: list[ModuleType] = [gwy, sxm, ibw, ergo_hdf5, array_image]
# ext → importer module
_REGISTRY: dict[str, ModuleType] = {}
for _mod in _IMPORTERS:
for _ext in _mod.extensions:
_REGISTRY[_ext] = _mod
def get_importer(ext: str) -> ModuleType | None:
"""Return the importer module for *ext* (e.g. '.gwy'), or None."""
return _REGISTRY.get(ext.lower())
def all_extensions() -> frozenset[str]:
"""All file extensions supported across every registered importer."""
return frozenset(_REGISTRY)
def calibrated_extensions() -> frozenset[str]:
"""Extensions whose importers report physical calibration."""
return frozenset(ext for ext, mod in _REGISTRY.items() if mod.calibrated)

View File

@@ -0,0 +1,31 @@
"""
Base protocol for file importers.
Each importer handles one or more file extensions and implements:
load(path) → list[DataField]
channel_names(path) → list[str] (optional, falls back to generic names)
"""
from __future__ import annotations
from pathlib import Path
from typing import Protocol, runtime_checkable
from backend.data_types import DataField
@runtime_checkable
class FileImporter(Protocol):
#: File extensions this importer handles, e.g. {".gwy"}
extensions: frozenset[str]
#: True when physical dimensions are known (suppresses "uncalibrated" warning)
calibrated: bool
def load(self, path: Path) -> list[DataField]:
"""Load all channels from *path* and return them as DataField objects."""
...
def channel_names(self, path: Path) -> list[str]:
"""Return channel name strings in the same order as load()."""
...

View File

@@ -0,0 +1,44 @@
"""
Importer for pixel images (PNG, TIFF, JPEG, BMP) and NumPy arrays (.npy, .npz).
These formats carry no physical calibration, so calibrated = False.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
from backend.data_types import DataField
extensions = frozenset({".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp", ".npy", ".npz"})
calibrated = False
def load(path: Path) -> list[DataField]:
ext = path.suffix.lower()
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image as PILImage
img = PILImage.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
return [DataField(data=gray)]
def channel_names(path: Path) -> list[str]:
return ["field"]

View File

@@ -0,0 +1,235 @@
"""
Importer for Asylum Research / Ergo HDF5 files (.h5, .hdf5, .he5).
Asylum Research instruments store scan metadata in a sidecar group rather
than as dataset attributes. This importer reads physical dimensions from:
Image/DataSetInfo/Global/Channels/<channel>/ImageDims
DimScaling (2,2) array: [[px_size_x, offset_x], [px_size_y, offset_y]]
DimExtents pixel counts [xres, yres] (stored in a child group)
DimUnits lateral unit strings
DataUnits Z unit string
If the sidecar group is absent (generic HDF5), standard dataset attributes
are used as a fallback:
xreal / yreal physical scan size in metres (fallback: 1e-6)
xoff / yoff position offset in metres (fallback: 0)
si_unit_xy lateral unit string (fallback: "m")
si_unit_z value unit string (fallback: "m")
Requires:
pip install h5py
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
from backend.data_types import DataField
extensions = frozenset({".h5", ".hdf5", ".he5"})
calibrated = True # we attempt to read physical metadata
def _iter_2d_datasets(h5file):
"""Yield (name, dataset) for every 2-D numeric dataset in the file."""
import h5py
def _visit(name, obj):
if isinstance(obj, h5py.Dataset) and obj.ndim == 2 and np.issubdtype(obj.dtype, np.number):
results.append((name, obj))
results: list = []
h5file.visititems(_visit)
return results
def _attr_str(attrs, key: str, default: str) -> str:
val = attrs.get(key)
if val is None:
return default
if isinstance(val, bytes):
return val.decode("utf-8", errors="replace").strip() or default
return str(val).strip() or default
def _attr_float(attrs, key: str, default: float) -> float:
val = attrs.get(key)
if val is None:
return default
try:
return float(val)
except (TypeError, ValueError):
return default
def _ar_image_dims(f, ds_name: str) -> dict | None:
"""
Look up Asylum Research ImageDims metadata for a dataset.
AR .h5 files store scan dimensions in a sibling group rather than as
dataset attributes. Given a dataset path like:
"Image/DataSet/Resolution 0/Frame 0/Adhesion:Retrace/Image"
the channel name is the second-to-last component ("Adhesion:Retrace"),
and the metadata lives at:
"Image/DataSetInfo/Global/Channels/<channel>/ImageDims"
Returns a dict with xreal, yreal, xoff, yoff, si_unit_xy, si_unit_z,
or None if the group isn't found.
"""
import h5py
parts = ds_name.split("/")
if len(parts) < 2:
return None
channel = parts[-2]
dims_path = f"Image/DataSetInfo/Global/Channels/{channel}/ImageDims"
grp = f.get(dims_path)
if not isinstance(grp, h5py.Group):
return None
scaling = grp.attrs.get("DimScaling") # shape (2, 2): [[px_x, off_x], [px_y, off_y]]
dim_units = grp.attrs.get("DimUnits") # array of unit strings, e.g. ['m', 'm']
data_units = grp.attrs.get("DataUnits") # Z unit string, e.g. 'N'
if scaling is None or np.asarray(scaling).shape != (2, 2):
return None
scaling = np.asarray(scaling, dtype=np.float64)
px_x, off_x = float(scaling[0, 0]), float(scaling[0, 1])
px_y, off_y = float(scaling[1, 0]), float(scaling[1, 1])
# DimExtents gives pixel counts; use to compute total physical size.
extents_grp = None
for child_name in grp:
child = grp[child_name]
if isinstance(child, h5py.Group) and "DimExtents" in child.attrs:
extents_grp = child
break
xres, yres = 1, 1
if extents_grp is not None:
ext = np.asarray(extents_grp.attrs["DimExtents"])
if ext.size >= 2:
xres, yres = int(ext[0]), int(ext[1])
def _decode(raw, default="m") -> str:
if raw is None:
return default
if hasattr(raw, "__iter__") and not isinstance(raw, (str, bytes)):
raw = list(raw)[0] if len(raw) else default
if isinstance(raw, bytes):
return raw.decode("utf-8", errors="replace").strip() or default
return str(raw).strip() or default
return {
"xreal": abs(px_x * xres) or 1e-6,
"yreal": abs(px_y * yres) or 1e-6,
"xoff": off_x,
"yoff": off_y,
"si_unit_xy": _decode(dim_units[0] if dim_units is not None and len(dim_units) >= 1 else None),
"si_unit_z": _decode(data_units),
}
def load(path: Path) -> list[DataField]:
try:
import h5py
except ImportError:
raise ImportError("Install 'h5py' to load HDF5 files: pip install h5py")
with h5py.File(str(path), "r") as f:
datasets = _iter_2d_datasets(f)
if not datasets:
raise ValueError(f"No 2-D numeric datasets found in {path.name}")
fields = []
for name, ds in datasets:
data = np.asarray(ds, dtype=np.float64)
# Try Asylum Research sidecar metadata first, then dataset attrs.
ar = _ar_image_dims(f, name)
if ar:
fields.append(DataField(
data=data,
xreal=ar["xreal"], yreal=ar["yreal"],
xoff=ar["xoff"], yoff=ar["yoff"],
si_unit_xy=ar["si_unit_xy"],
si_unit_z=ar["si_unit_z"],
))
else:
attrs = ds.attrs
fields.append(DataField(
data=data,
xreal=_attr_float(attrs, "xreal", 1e-6),
yreal=_attr_float(attrs, "yreal", 1e-6),
xoff=_attr_float(attrs, "xoff", 0.0),
yoff=_attr_float(attrs, "yoff", 0.0),
si_unit_xy=_attr_str(attrs, "si_unit_xy", "m"),
si_unit_z=_attr_str(attrs, "si_unit_z", "m"),
))
return fields
def _display_names(full_names: list[str]) -> list[str]:
"""
Derive short display names from HDF5 dataset paths.
Rules (all comparisons case-insensitive):
1. All thumbnail datasets are filtered out.
2. Display name = second-to-last path component (drops the leaf like
"/image" or "/thumbnail").
3. "global" channels sort to the front.
4. If two kept datasets share the same second-to-last name, the leaf is
appended to disambiguate.
Returns a list in sorted order (not parallel to full_names).
"""
from collections import Counter
# Filter out all thumbnail datasets.
kept: list[tuple[int, str]] = [] # (original index, full name)
for i, name in enumerate(full_names):
if name.split("/")[-1].lower() == "thumbnail":
continue
kept.append((i, name))
# Sort: "global" second-to-last first, then alphabetical.
def _sort_key(item: tuple[int, str]) -> tuple[int, str]:
parts = item[1].split("/")
second_last = parts[-2].lower() if len(parts) >= 2 else parts[-1].lower()
return (0 if second_last == "global" else 1, second_last)
kept.sort(key=_sort_key)
# Build short names (second-to-last), disambiguate clashes.
short = [
(parts[-2] if len(parts := name.split("/")) >= 2 else parts[-1])
for _, name in kept
]
counts = Counter(short)
disambiguated = [
f"{s}/{name.split('/')[-1]}" if counts[s] > 1 else s
for s, (_, name) in zip(short, kept)
]
return disambiguated
def channel_names(path: Path) -> list[str]:
try:
import h5py
except ImportError:
return []
try:
with h5py.File(str(path), "r") as f:
datasets = _iter_2d_datasets(f)
full_names = [name for name, _ in datasets]
return [n for n in _display_names(full_names) if n is not None]
except Exception:
return []

45
backend/importers/gwy.py Normal file
View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from pathlib import Path
import numpy as np
from backend.data_types import DataField
extensions = frozenset({".gwy"})
calibrated = True
def load(path: Path) -> list[DataField]:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
fields = []
for ch in channels.values():
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
fields.append(DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
))
return fields
def channel_names(path: Path) -> list[str]:
import gwyfile
try:
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return list(channels.keys())
except Exception:
pass
return []

107
backend/importers/hdf5.py Normal file
View File

@@ -0,0 +1,107 @@
"""
Generic HDF5 importer (.h5, .hdf5, .he5).
Each 2-D dataset found in the file is returned as a DataField. Physical
dimensions are read from standard dataset attributes if present:
xreal / yreal physical scan size in metres (fallback: 1e-6)
xoff / yoff position offset in metres (fallback: 0)
si_unit_xy lateral unit string (fallback: "m")
si_unit_z value unit string (fallback: "m")
For Asylum Research / Ergo format files (which store scan metadata in a
sidecar group rather than as dataset attributes), use the ergo_hdf5 importer.
Requires:
pip install h5py
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
from backend.data_types import DataField
extensions = frozenset({".h5", ".hdf5", ".he5"})
calibrated = True # we attempt to read physical metadata
def _iter_2d_datasets(h5file):
"""Yield (name, dataset) for every 2-D numeric dataset in the file."""
import h5py
def _visit(name, obj):
if isinstance(obj, h5py.Dataset) and obj.ndim == 2 and np.issubdtype(obj.dtype, np.number):
results.append((name, obj))
results: list = []
h5file.visititems(_visit)
return results
def _attr_str(attrs, key: str, default: str) -> str:
val = attrs.get(key)
if val is None:
return default
if isinstance(val, bytes):
return val.decode("utf-8", errors="replace").strip() or default
return str(val).strip() or default
def _attr_float(attrs, key: str, default: float) -> float:
val = attrs.get(key)
if val is None:
return default
try:
return float(val)
except (TypeError, ValueError):
return default
def load(path: Path) -> list[DataField]:
try:
import h5py
except ImportError:
raise ImportError("Install 'h5py' to load HDF5 files: pip install h5py")
with h5py.File(str(path), "r") as f:
datasets = _iter_2d_datasets(f)
if not datasets:
raise ValueError(f"No 2-D numeric datasets found in {path.name}")
fields = []
for name, ds in datasets:
data = np.asarray(ds, dtype=np.float64)
attrs = ds.attrs
fields.append(DataField(
data=data,
xreal=_attr_float(attrs, "xreal", 1e-6),
yreal=_attr_float(attrs, "yreal", 1e-6),
xoff=_attr_float(attrs, "xoff", 0.0),
yoff=_attr_float(attrs, "yoff", 0.0),
si_unit_xy=_attr_str(attrs, "si_unit_xy", "m"),
si_unit_z=_attr_str(attrs, "si_unit_z", "m"),
))
return fields
def channel_names(path: Path) -> list[str]:
try:
import h5py
except ImportError:
return []
try:
with h5py.File(str(path), "r") as f:
datasets = _iter_2d_datasets(f)
# Return second-to-last component as display name, or full name for
# top-level datasets.
names = []
for full_name, _ in datasets:
parts = full_name.split("/")
names.append(parts[-2] if len(parts) >= 2 else parts[-1])
return names
except Exception:
return []

106
backend/importers/ibw.py Normal file
View File

@@ -0,0 +1,106 @@
from __future__ import annotations
from pathlib import Path
import numpy as np
from backend.data_types import DataField
extensions = frozenset({".ibw"})
calibrated = True
def _load_ibw_raw(path: Path):
import numpy as _np
if not hasattr(_np, "complex"):
setattr(_np, "complex", complex)
try:
from igor.binarywave import load as load_ibw
except ImportError:
raise ImportError("Install 'igor' to load .ibw files: pip install igor")
return load_ibw(str(path))
def _decode_unit(raw_unit) -> str:
if raw_unit is None:
return "m"
if isinstance(raw_unit, bytes):
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
if isinstance(raw_unit, np.ndarray):
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
return str(raw_unit).strip() or "m"
def load(path: Path) -> list[DataField]:
wave = _load_ibw_raw(path)
wdata = wave["wave"]
header = wdata["wave_header"]
raw = wdata["wData"]
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
sfA = header.get("sfA", None)
dim_units_raw = header.get("dimUnits", None)
data_units_raw = header.get("dataUnits", None)
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
si_unit_xy = _decode_unit(dim_units_raw[0])
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
si_unit_xy = _decode_unit(dim_units_raw[0])
else:
si_unit_xy = _decode_unit(dim_units_raw)
si_unit_z = _decode_unit(data_units_raw)
fields = []
for ch_idx in range(n_channels):
if raw.ndim >= 3:
ch_data = raw[:, :, ch_idx]
elif raw.ndim == 1:
ch_data = raw.reshape(-1, 1)
else:
ch_data = raw
data = np.flipud(ch_data.T).astype(np.float64)
yres, xres = data.shape
if sfA is not None and len(sfA) >= 2:
xreal = abs(float(sfA[0]) * xres) or 1e-6
yreal = abs(float(sfA[1]) * yres) or 1e-6
else:
hsA = header.get("hsA", 0.0)
xreal = abs(float(hsA) * xres) or 1e-6
yreal = xreal * (yres / xres) if xres else 1e-6
fields.append(DataField(
data=data, xreal=xreal, yreal=yreal,
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
))
return fields
def channel_names(path: Path) -> list[str]:
try:
wave = _load_ibw_raw(path)
wdata = wave["wave"]
raw = wdata["wData"]
labels = wdata.get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return decoded
if raw.ndim >= 3 and raw.shape[2] > 1:
return [f"ch{i}" for i in range(raw.shape[2])]
except Exception:
pass
return []

47
backend/importers/sxm.py Normal file
View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from pathlib import Path
import numpy as np
from backend.data_types import DataField
extensions = frozenset({".sxm"})
calibrated = True
def load(path: Path) -> list[DataField]:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
signals = sxm.signals
if not signals:
raise ValueError(f"No signals found in {path.name}")
scan_range = sxm.header.get("scan_range", [1e-6, 1e-6])
fields = []
for sig in signals.values():
data = sig.get("forward", list(sig.values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
fields.append(DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
))
return fields
def channel_names(path: Path) -> list[str]:
import nanonispy as nap
try:
sxm = nap.read.Scan(str(path))
if sxm.signals:
return list(sxm.signals.keys())
except Exception:
pass
return []

View File

@@ -527,13 +527,9 @@ OUTPUT_DIR = output_dir()
_MAX_SAVE_FIELDS = 8
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
from backend.importers import all_extensions, get_importer
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS
_PATH_COMPATIBLE_EXTENSIONS = all_extensions()
def _resolve_path(filepath: str):
@@ -554,52 +550,16 @@ def list_channels(filepath: str) -> list[dict]:
if not path.exists():
return [{"name": "field", "type": "DATA_FIELD"}]
ext = path.suffix.lower()
if ext == ".gwy":
try:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".sxm":
try:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
if sxm.signals:
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".ibw":
try:
load_ibw = _import_ibw_loader()
wave = load_ibw(str(path))
raw = wave["wave"]["wData"]
labels = wave["wave"].get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
if raw.ndim >= 3 and raw.shape[2] > 1:
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
except Exception:
pass
importer = get_importer(path.suffix.lower())
if importer is None:
return [{"name": "field", "type": "DATA_FIELD"}]
try:
names = importer.channel_names(path)
if names:
return [{"name": n, "type": "DATA_FIELD"} for n in names]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
@@ -624,7 +584,7 @@ def _list_demo_files() -> list[str]:
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _PATH_COMPATIBLE_EXTENSIONS
)

View File

@@ -1,14 +1,12 @@
from __future__ import annotations
from functools import lru_cache
import numpy as np
from pathlib import Path
import nanonispy as nap
import gwyfile
from backend.node_registry import register_node
from backend.execution_context import emit_warning
from backend.data_types import COLORMAPS, DataField, resolve_colormap_input
from backend.nodes.helpers import _resolve_path, _SPM_EXTENSIONS, _import_ibw_loader
from backend.nodes.helpers import _resolve_path
from backend.importers import get_importer, calibrated_extensions
@register_node(display_name="Image")
@@ -34,7 +32,7 @@ class Image:
DESCRIPTION = (
"Load any supported file. "
"SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; "
"SPM formats (.gwy, .sxm, .ibw) and HDF5 (.h5, .hdf5) provide calibrated dimensions; "
"each channel gets its own output. "
"Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields."
)
@@ -65,158 +63,17 @@ class Image:
for field in fields:
field.colormap = resolved_colormap
if ext not in _SPM_EXTENSIONS:
self._send_warning("Uncalibrated data — no physical dimensions.")
if ext not in calibrated_extensions():
emit_warning("Uncalibrated data — no physical dimensions.")
return (str(path_obj.resolve()),) + fields
def _send_warning(self, message: str):
emit_warning(message)
@staticmethod
@lru_cache(maxsize=32)
def _load_fields_cached(path_str: str, mtime_ns: int, size_bytes: int) -> tuple[DataField, ...]:
path = Path(path_str)
ext = path.suffix.lower()
if ext in _SPM_EXTENSIONS:
return tuple(Image._load_spm_all(path, ext))
return (Image._load_image_or_array(path, ext),)
@staticmethod
def _load_spm_all(path: Path, ext: str) -> list[DataField]:
if ext == ".gwy":
return Image._load_gwy_all(path)
elif ext == ".sxm":
return Image._load_sxm_all(path)
elif ext == ".ibw":
return Image._load_ibw_all(path)
else:
raise ValueError(f"Unsupported SPM format: {ext}")
@staticmethod
def _load_gwy_all(path: Path) -> list[DataField]:
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
fields = []
for ch in channels.values():
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
fields.append(DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
))
return fields
@staticmethod
def _load_sxm_all(path: Path) -> list[DataField]:
sxm = nap.read.Scan(str(path))
signals = sxm.signals
if not signals:
raise ValueError(f"No signals found in {path.name}")
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
fields = []
for sig in signals.values():
data = sig.get("forward", list(sig.values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
fields.append(DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
))
return fields
@staticmethod
def _load_ibw_all(path: Path) -> list[DataField]:
load_ibw = _import_ibw_loader()
wave = load_ibw(str(path))
wdata = wave["wave"]
header = wdata["wave_header"]
raw = wdata["wData"]
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
sfA = header.get("sfA", None)
def _decode_unit(raw_unit):
if raw_unit is None:
return "m"
if isinstance(raw_unit, bytes):
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
if isinstance(raw_unit, np.ndarray):
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
return str(raw_unit).strip() or "m"
dim_units_raw = header.get("dimUnits", None)
data_units_raw = header.get("dataUnits", None)
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
si_unit_xy = _decode_unit(dim_units_raw[0])
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
si_unit_xy = _decode_unit(dim_units_raw[0])
else:
si_unit_xy = _decode_unit(dim_units_raw)
si_unit_z = _decode_unit(data_units_raw)
fields = []
for ch_idx in range(n_channels):
if raw.ndim >= 3:
ch_data = raw[:, :, ch_idx]
elif raw.ndim == 1:
ch_data = raw.reshape(-1, 1)
else:
ch_data = raw
data = np.flipud(ch_data.T).astype(np.float64)
yres, xres = data.shape
if sfA is not None and len(sfA) >= 2:
xreal = abs(float(sfA[0]) * xres) or 1e-6
yreal = abs(float(sfA[1]) * yres) or 1e-6
else:
hsA = header.get("hsA", 0.0)
xreal = abs(float(hsA) * xres) or 1e-6
yreal = xreal * (yres / xres) if xres else 1e-6
fields.append(DataField(
data=data, xreal=xreal, yreal=yreal,
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
))
return fields
@staticmethod
def _load_image_or_array(path: Path, ext: str) -> DataField:
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image as PILImage
img = PILImage.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
return DataField(data=gray)
importer = get_importer(ext)
if importer is None:
raise ValueError(f"Unsupported file format: {ext}")
return tuple(importer.load(path))

2
demo

Submodule demo updated: 9c82c2ba8c...e3dcea633a

View File

@@ -11,6 +11,7 @@ requires-python = ">=3.10"
dependencies = [
"aiohttp>=3.9,<4",
"gwyfile>=0.2",
"h5py>=3.10,<4",
"igor>=0.3",
"matplotlib>=3.8,<4",
"nanonispy>=1.1",

View File

@@ -83,7 +83,8 @@ def test_load_file_cache():
path = os.path.join(tmpdir, "cached.npy")
np.save(path, data)
with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader:
import backend.importers.array_image as _ai
with patch.object(_ai, "load", wraps=_ai.load) as loader:
_, first = node.load(filename=path)
_, second = node.load(filename=path)
assert loader.call_count == 1

View File

@@ -53,7 +53,8 @@ def test_load_cache():
Image._load_fields_cached.cache_clear()
with patch("backend.nodes.image_demo.DEMO_DIR", FIXTURES):
with patch.object(Image, "_load_image_or_array", wraps=Image._load_image_or_array) as loader:
import backend.importers.array_image as _ai
with patch.object(_ai, "load", wraps=_ai.load) as loader:
_, first = node.load(name="nanoparticles.npy")
_, second = node.load(name="nanoparticles.npy")
assert loader.call_count == 1

View File

@@ -0,0 +1,393 @@
"""
Tests for backend/importers/ — the importer registry and each importer module.
"""
import os
import tempfile
from pathlib import Path
import numpy as np
import pytest
from backend.data_types import DataField
FIXTURES = Path(__file__).parent.parent / "output"
# ── Registry ─────────────────────────────────────────────────────────────────
class TestRegistry:
def test_get_importer_known_extensions(self):
from backend.importers import get_importer
for ext in (".gwy", ".sxm", ".ibw", ".npy", ".npz",
".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp",
".h5", ".hdf5", ".he5"):
assert get_importer(ext) is not None, f"No importer registered for {ext}"
def test_get_importer_unknown_extension(self):
from backend.importers import get_importer
assert get_importer(".xyz") is None
assert get_importer(".csv") is None
def test_get_importer_case_insensitive(self):
from backend.importers import get_importer
assert get_importer(".NPY") is get_importer(".npy")
assert get_importer(".GWY") is get_importer(".gwy")
def test_all_extensions_returns_frozenset(self):
from backend.importers import all_extensions
exts = all_extensions()
assert isinstance(exts, frozenset)
assert ".npy" in exts
assert ".gwy" in exts
def test_calibrated_extensions(self):
from backend.importers import calibrated_extensions
cal = calibrated_extensions()
# SPM and HDF5 are calibrated
assert ".gwy" in cal
assert ".sxm" in cal
assert ".ibw" in cal
assert ".h5" in cal
# Images/arrays are not
assert ".png" not in cal
assert ".npy" not in cal
def test_each_importer_has_required_interface(self):
from backend.importers import _IMPORTERS
for mod in _IMPORTERS:
assert hasattr(mod, "extensions"), f"{mod.__name__} missing extensions"
assert hasattr(mod, "calibrated"), f"{mod.__name__} missing calibrated"
assert callable(getattr(mod, "load", None)), f"{mod.__name__} missing load()"
assert callable(getattr(mod, "channel_names", None)), f"{mod.__name__} missing channel_names()"
assert isinstance(mod.extensions, frozenset)
assert isinstance(mod.calibrated, bool)
# ── array_image importer ──────────────────────────────────────────────────────
class TestArrayImageImporter:
def setup_method(self):
import backend.importers.array_image as mod
self.mod = mod
def test_npy_load(self):
with tempfile.TemporaryDirectory() as tmp:
data = np.random.default_rng(0).standard_normal((32, 48))
path = Path(tmp) / "data.npy"
np.save(path, data)
fields = self.mod.load(path)
assert len(fields) == 1
assert isinstance(fields[0], DataField)
assert np.allclose(fields[0].data, data)
def test_npz_load(self):
with tempfile.TemporaryDirectory() as tmp:
data = np.random.default_rng(1).standard_normal((16, 16))
path = Path(tmp) / "data.npz"
np.savez(path, arr=data)
fields = self.mod.load(path)
assert len(fields) == 1
assert np.allclose(fields[0].data, data)
def test_png_grayscale(self):
from PIL import Image as PILImage
with tempfile.TemporaryDirectory() as tmp:
arr = np.random.default_rng(2).integers(0, 256, (24, 32), dtype=np.uint8)
path = Path(tmp) / "gray.png"
PILImage.fromarray(arr).save(path)
fields = self.mod.load(path)
assert len(fields) == 1
assert fields[0].data.shape == (24, 32)
assert fields[0].data.dtype == np.float64
def test_png_rgb_converted_to_grayscale(self):
from PIL import Image as PILImage
with tempfile.TemporaryDirectory() as tmp:
arr = np.random.default_rng(3).integers(0, 256, (16, 16, 3), dtype=np.uint8)
path = Path(tmp) / "rgb.png"
PILImage.fromarray(arr).save(path)
fields = self.mod.load(path)
assert fields[0].data.shape == (16, 16)
def test_not_calibrated(self):
assert self.mod.calibrated is False
def test_channel_names(self):
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "x.npy"
np.save(path, np.zeros((4, 4)))
assert self.mod.channel_names(path) == ["field"]
def test_fixture_npy(self):
path = FIXTURES / "nanoparticles.npy"
if not path.exists():
pytest.skip("fixture not available")
fields = self.mod.load(path)
assert len(fields) == 1
assert fields[0].data.ndim == 2
# ── ibw importer ─────────────────────────────────────────────────────────────
class TestIBWImporter:
def setup_method(self):
import backend.importers.ibw as mod
self.mod = mod
self.fixture = FIXTURES / "Bacteria.ibw"
def test_calibrated(self):
assert self.mod.calibrated is True
def test_extensions(self):
assert ".ibw" in self.mod.extensions
def test_load_fixture(self):
if not self.fixture.exists():
pytest.skip("Bacteria.ibw fixture not available")
fields = self.mod.load(self.fixture)
assert len(fields) == 4
for f in fields:
assert isinstance(f, DataField)
assert f.data.ndim == 2
assert f.data.dtype == np.float64
assert f.xreal > 0
assert f.yreal > 0
def test_channel_names_fixture(self):
if not self.fixture.exists():
pytest.skip("Bacteria.ibw fixture not available")
names = self.mod.channel_names(self.fixture)
assert len(names) == 4
assert all(isinstance(n, str) for n in names)
# ── hdf5 importer (generic) ──────────────────────────────────────────────────
class TestHDF5Importer:
def setup_method(self):
pytest.importorskip("h5py")
import backend.importers.hdf5 as mod
self.mod = mod
def test_calibrated(self):
assert self.mod.calibrated is True
def test_extensions(self):
assert {".h5", ".hdf5", ".he5"} <= self.mod.extensions
def test_load_single_channel(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
data = np.random.default_rng(10).standard_normal((32, 32))
path = Path(tmp) / "test.h5"
with h5py.File(path, "w") as f:
f.create_dataset("channel", data=data)
fields = self.mod.load(path)
assert len(fields) == 1
assert np.allclose(fields[0].data, data)
assert fields[0].data.dtype == np.float64
def test_load_physical_attrs(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "cal.h5"
with h5py.File(path, "w") as f:
ds = f.create_dataset("topo", data=np.zeros((16, 16)))
ds.attrs["xreal"] = 5e-6
ds.attrs["yreal"] = 5e-6
ds.attrs["si_unit_z"] = "V"
fields = self.mod.load(path)
assert fields[0].xreal == pytest.approx(5e-6)
assert fields[0].si_unit_z == "V"
def test_load_fallback_attrs(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "fallback.h5"
with h5py.File(path, "w") as f:
f.create_dataset("channel", data=np.zeros((8, 8)))
fields = self.mod.load(path)
assert fields[0].xreal == pytest.approx(1e-6)
assert fields[0].si_unit_xy == "m"
def test_empty_file_raises(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "empty.h5"
with h5py.File(path, "w") as f:
f.create_dataset("vec", data=np.zeros((10,)))
with pytest.raises(ValueError, match="No 2-D"):
self.mod.load(path)
def test_channel_names_top_level(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "named.h5"
with h5py.File(path, "w") as f:
f.create_dataset("height", data=np.zeros((8, 8)))
f.create_dataset("phase", data=np.zeros((8, 8)))
names = self.mod.channel_names(path)
assert set(names) == {"height", "phase"}
def test_channel_names_nested(self):
# Nested datasets return second-to-last path component.
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "nested.h5"
with h5py.File(path, "w") as f:
f.create_dataset("scan/height/data", data=np.zeros((8, 8)))
f.create_dataset("scan/phase/data", data=np.zeros((8, 8)))
names = self.mod.channel_names(path)
assert set(names) == {"height", "phase"}
# ── ergo_hdf5 importer (Asylum Research / Ergo format) ───────────────────────
class TestErgoHDF5Importer:
def setup_method(self):
pytest.importorskip("h5py")
import backend.importers.ergo_hdf5 as mod
self.mod = mod
def _make_h5(self, tmp: Path, data: np.ndarray, name: str = "channel",
attrs: dict | None = None) -> Path:
import h5py
path = tmp / "test.h5"
with h5py.File(path, "w") as f:
ds = f.create_dataset(name, data=data)
for k, v in (attrs or {}).items():
ds.attrs[k] = v
return path
def test_calibrated(self):
assert self.mod.calibrated is True
def test_extensions(self):
assert {".h5", ".hdf5", ".he5"} <= self.mod.extensions
def test_load_single_channel(self):
with tempfile.TemporaryDirectory() as tmp:
data = np.random.default_rng(10).standard_normal((32, 32))
path = self._make_h5(Path(tmp), data)
fields = self.mod.load(path)
assert len(fields) == 1
assert np.allclose(fields[0].data, data)
assert fields[0].data.dtype == np.float64
def test_load_physical_attrs(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
data = np.zeros((16, 16))
path = Path(tmp) / "cal.h5"
with h5py.File(path, "w") as f:
ds = f.create_dataset("topo", data=data)
ds.attrs["xreal"] = 5e-6
ds.attrs["yreal"] = 5e-6
ds.attrs["si_unit_xy"] = "m"
ds.attrs["si_unit_z"] = "V"
fields = self.mod.load(path)
assert fields[0].xreal == pytest.approx(5e-6)
assert fields[0].yreal == pytest.approx(5e-6)
assert fields[0].si_unit_z == "V"
def test_load_fallback_attrs(self):
with tempfile.TemporaryDirectory() as tmp:
data = np.zeros((8, 8))
path = self._make_h5(Path(tmp), data)
fields = self.mod.load(path)
# Default fallbacks when no AR sidecar is present
assert fields[0].xreal == pytest.approx(1e-6)
assert fields[0].si_unit_xy == "m"
def test_load_multiple_channels(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "multi.h5"
data_a = np.ones((16, 16))
data_b = np.zeros((16, 16))
with h5py.File(path, "w") as f:
f.create_dataset("height", data=data_a)
f.create_dataset("phase", data=data_b)
fields = self.mod.load(path)
assert len(fields) == 2
shapes = {f.data.shape for f in fields}
assert shapes == {(16, 16)}
def test_ignores_non_2d_datasets(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "mixed.h5"
with h5py.File(path, "w") as f:
f.create_dataset("topo", data=np.zeros((16, 16)))
f.create_dataset("vector", data=np.zeros((10,))) # 1-D, ignored
f.create_dataset("volume", data=np.zeros((4, 4, 4))) # 3-D, ignored
fields = self.mod.load(path)
assert len(fields) == 1
def test_empty_file_raises(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "empty.h5"
with h5py.File(path, "w") as f:
f.create_dataset("vec", data=np.zeros((10,)))
with pytest.raises(ValueError, match="No 2-D"):
self.mod.load(path)
def test_channel_names_top_level(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "named.h5"
with h5py.File(path, "w") as f:
f.create_dataset("height", data=np.zeros((8, 8)))
f.create_dataset("phase", data=np.zeros((8, 8)))
names = self.mod.channel_names(path)
assert set(names) == {"height", "phase"}
def test_channel_names_strips_leaf(self):
# "/image" leaf is stripped; display name is second-to-last component.
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "nested.h5"
with h5py.File(path, "w") as f:
f.create_dataset("scan/adhesion:retrace/image", data=np.zeros((8, 8)))
f.create_dataset("scan/phase:retrace/image", data=np.zeros((8, 8)))
names = self.mod.channel_names(path)
assert set(names) == {"adhesion:retrace", "phase:retrace"}
def test_channel_names_thumbnails_always_filtered(self):
# All thumbnail datasets are hidden, including global ones.
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "thumbs.h5"
with h5py.File(path, "w") as f:
f.create_dataset("data/adhesion/image", data=np.zeros((8, 8)))
f.create_dataset("info/channels/adhesion/thumbnail", data=np.zeros((8, 8)))
f.create_dataset("info/global/thumbnail", data=np.zeros((8, 8)))
names = self.mod.channel_names(path)
assert names == ["adhesion"]
def test_channel_names_sorted_alphabetically(self):
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "order.h5"
with h5py.File(path, "w") as f:
f.create_dataset("data/zzz/image", data=np.zeros((8, 8)))
f.create_dataset("data/aaa/image", data=np.zeros((8, 8)))
names = self.mod.channel_names(path)
assert names == ["aaa", "zzz"]
def test_channel_names_deduplication(self):
# Two kept datasets with the same second-to-last name get disambiguated.
import h5py
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "dup.h5"
with h5py.File(path, "w") as f:
f.create_dataset("dataset/adhesion/imageA", data=np.zeros((8, 8)))
f.create_dataset("datasetinfo/adhesion/imageB", data=np.zeros((8, 8)))
names = self.mod.channel_names(path)
assert set(names) == {"adhesion/imageA", "adhesion/imageB"}
def test_he5_extension_registered(self):
from backend.importers import get_importer
assert get_importer(".he5") is self.mod