Files
tono/backend/nodes/io.py

525 lines
18 KiB
Python

"""
I/O nodes: load and save images and SPM data.
"""
from __future__ import annotations
import os
import numpy as np
from pathlib import Path
from backend.node_registry import register_node
from backend.data_types import DataField, COLORMAPS, encode_preview, image_to_uint8
from backend.runtime_paths import demo_dir, input_dir, output_dir
# Resolved at server startup so nodes know where to look
DEMO_DIR = demo_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
# ---------------------------------------------------------------------------
# Channel listing helper (used by the /channels endpoint)
# ---------------------------------------------------------------------------
def _resolve_path(filepath: str) -> Path:
path = Path(filepath)
if path.is_absolute():
return path
# Try input dir first, then demo dir
candidate = INPUT_DIR / filepath
if candidate.exists():
return candidate
candidate = DEMO_DIR / filepath
if candidate.exists():
return candidate
# Fall back to input dir (will trigger FileNotFoundError later)
return INPUT_DIR / filepath
def list_channels(filepath: str) -> list[dict]:
"""Return available channel info for a file.
Returns a list of {"name": str, "type": "DATA_FIELD"} dicts.
For SPM formats this inspects the file header.
For images / arrays, returns a single unnamed channel.
"""
path = _resolve_path(filepath)
if not path.exists():
return [{"name": "field", "type": "DATA_FIELD"}]
ext = path.suffix.lower()
if ext == ".gwy":
try:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".sxm":
try:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
if sxm.signals:
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".ibw":
try:
from igor.binarywave import load as load_ibw
wave = load_ibw(str(path))
raw = wave["wave"]["wData"]
labels = wave["wave"].get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
# Multi-channel without labels — use numeric names
if raw.ndim >= 3 and raw.shape[2] > 1:
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
# Image or array — single channel
return [{"name": "field", "type": "DATA_FIELD"}]
# ---------------------------------------------------------------------------
# LoadFile (unified loader — replaces LoadImage + LoadSPM)
# ---------------------------------------------------------------------------
@register_node(display_name="Load File")
class LoadFile:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
"colormap": (list(COLORMAPS),),
}
}
# Default outputs — overridden dynamically by the frontend for multi-channel files
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = (
"Load any supported file. "
"SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; "
"each channel gets its own output. "
"Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields."
)
# Set by execution engine for warning broadcast
_broadcast_warning_fn = None
_current_node_id = None
def load(self, filename: str, colormap: str = "viridis"):
if not filename or not filename.strip():
raise ValueError("No file selected — use Browse to pick a file.")
path = _resolve_path(filename)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
if path.is_dir():
raise IsADirectoryError(f"Expected a file, got a directory: {path}")
ext = path.suffix.lower()
if ext in _SPM_EXTENSIONS:
fields = self._load_spm_all(path, ext)
for f in fields:
f.colormap = colormap
return tuple(fields)
# Image or array — uncalibrated, single output
field = self._load_image_or_array(path, ext)
field.colormap = colormap
self._send_warning("Uncalibrated data — no physical dimensions.")
return (field,)
def _send_warning(self, message: str):
fn = LoadFile._broadcast_warning_fn
nid = LoadFile._current_node_id
if fn and nid:
fn(nid, message)
# -- SPM: load all channels ---------------------------------------------
def _load_spm_all(self, path: Path, ext: str) -> list[DataField]:
if ext == ".gwy":
return self._load_gwy_all(path)
elif ext == ".sxm":
return self._load_sxm_all(path)
elif ext == ".ibw":
return self._load_ibw_all(path)
else:
raise ValueError(f"Unsupported SPM format: {ext}")
# -- GWY ----------------------------------------------------------------
def _load_gwy_all(self, path: Path) -> list[DataField]:
try:
import gwyfile
except ImportError:
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if not channels:
raise ValueError(f"No data channels found in {path.name}")
fields = []
for ch in channels.values():
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
fields.append(DataField(
data=data,
xreal=float(ch.xreal),
yreal=float(ch.yreal),
xoff=float(getattr(ch, "xoff", 0.0)),
yoff=float(getattr(ch, "yoff", 0.0)),
si_unit_xy="m",
si_unit_z="m",
))
return fields
# -- SXM ----------------------------------------------------------------
def _load_sxm_all(self, path: Path) -> list[DataField]:
try:
import nanonispy as nap
except ImportError:
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
sxm = nap.read.Scan(str(path))
signals = sxm.signals
if not signals:
raise ValueError(f"No signals found in {path.name}")
header = sxm.header
scan_range = header.get("scan_range", [1e-6, 1e-6])
fields = []
for sig in signals.values():
data = sig.get("forward", list(sig.values())[0])
data = np.asarray(data, dtype=np.float64)
if data.ndim != 2:
data = data.reshape(data.shape[-2], data.shape[-1])
fields.append(DataField(
data=data,
xreal=float(scan_range[0]),
yreal=float(scan_range[1]),
si_unit_xy="m",
si_unit_z="m",
))
return fields
# -- IBW ----------------------------------------------------------------
def _load_ibw_all(self, path: Path) -> list[DataField]:
try:
from igor.binarywave import load as load_ibw
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
wave = load_ibw(str(path))
wdata = wave["wave"]
header = wdata["wave_header"]
raw = wdata["wData"]
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
# Physical scaling
sfA = header.get("sfA", None)
def _decode_unit(raw_unit):
if raw_unit is None:
return "m"
if isinstance(raw_unit, bytes):
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
if isinstance(raw_unit, np.ndarray):
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
return str(raw_unit).strip() or "m"
dim_units_raw = header.get("dimUnits", None)
data_units_raw = header.get("dataUnits", None)
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
si_unit_xy = _decode_unit(dim_units_raw[0])
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
si_unit_xy = _decode_unit(dim_units_raw[0])
else:
si_unit_xy = _decode_unit(dim_units_raw)
si_unit_z = _decode_unit(data_units_raw)
fields = []
for ch_idx in range(n_channels):
if raw.ndim >= 3:
ch_data = raw[:, :, ch_idx]
elif raw.ndim == 1:
ch_data = raw.reshape(-1, 1)
else:
ch_data = raw
# Transpose from (xres, yres) Igor order to (yres, xres) DataField order,
# then flip vertically to match gwyddion
data = np.flipud(ch_data.T).astype(np.float64)
yres, xres = data.shape
if sfA is not None and len(sfA) >= 2:
xreal = abs(float(sfA[0]) * xres) or 1e-6
yreal = abs(float(sfA[1]) * yres) or 1e-6
else:
hsA = header.get("hsA", 0.0)
xreal = abs(float(hsA) * xres) or 1e-6
yreal = xreal * (yres / xres) if xres else 1e-6
fields.append(DataField(
data=data, xreal=xreal, yreal=yreal,
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
))
return fields
# -- Image / array (uncalibrated) --------------------------------------
def _load_image_or_array(self, path: Path, ext: str) -> DataField:
if ext == ".npy":
arr = np.load(str(path)).astype(np.float64)
elif ext == ".npz":
npz = np.load(str(path))
key = list(npz.files)[0]
arr = npz[key].astype(np.float64)
else:
from PIL import Image
img = Image.open(str(path))
arr = np.array(img)
if arr.dtype != np.uint8:
arr = arr.astype(np.float64)
if arr.ndim == 3:
gray = np.mean(arr.astype(np.float64), axis=2)
else:
gray = arr.astype(np.float64)
return DataField(data=gray)
# ---------------------------------------------------------------------------
# LoadDemo
# ---------------------------------------------------------------------------
def _list_demo_files() -> list[str]:
"""Return sorted list of demo filenames available in the demo/ directory."""
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
)
@register_node(display_name="Load Demo File")
class LoadDemo:
@classmethod
def INPUT_TYPES(cls):
choices = _list_demo_files() or ["(no demo files found)"]
return {
"required": {
"name": (choices,),
"colormap": (list(COLORMAPS),),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "load"
CATEGORY = "io"
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
def load(self, name: str, colormap: str = "viridis"):
path = DEMO_DIR / name
if not path.exists():
raise FileNotFoundError(f"Demo file not found: {name}")
loader = LoadFile()
return loader.load(filename=str(path), colormap=colormap)
# ---------------------------------------------------------------------------
# Coordinate
# ---------------------------------------------------------------------------
@register_node(display_name="Coordinate")
class Coordinate:
"""Provide a fractional (x, y) point for use with Cross Section or other nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("COORD",)
RETURN_NAMES = ("point",)
FUNCTION = "process"
CATEGORY = "io"
DESCRIPTION = "Output a fractional (x, y) coordinate pair in [0, 1]."
def process(self, x: float, y: float) -> tuple:
return ((float(x), float(y)),)
# ---------------------------------------------------------------------------
# RangeSlider
# ---------------------------------------------------------------------------
@register_node(display_name="Float Slider")
class RangeSlider:
"""Interactive float control node with min/max bounds and a slider value."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"min_value": ("FLOAT", {"default": 0.0, "step": 0.01}),
"max_value": ("FLOAT", {"default": 1.0, "step": 0.01}),
"value": ("FLOAT", {
"default": 0.5,
"step": 0.01,
"slider": True,
"min_widget": "min_value",
"max_widget": "max_value",
}),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "io"
DESCRIPTION = (
"Interactive float slider. Set min and max bounds, then drag the slider to output a FLOAT value."
)
def process(self, min_value: float, max_value: float, value: float) -> tuple:
lo = min(float(min_value), float(max_value))
hi = max(float(min_value), float(max_value))
if hi == lo:
return (lo,)
return (float(np.clip(float(value), lo, hi)),)
# ---------------------------------------------------------------------------
# SaveImage
# ---------------------------------------------------------------------------
_MAX_SAVE_FIELDS = 8
@register_node(display_name="Save Layers")
class SaveImage:
@classmethod
def INPUT_TYPES(cls):
optional = {}
for i in range(_MAX_SAVE_FIELDS):
optional[f"field_{i}"] = ("DATA_FIELD",)
return {
"required": {
"filename": ("FILE_PICKER", {"default": ""}),
"format": (["TIFF", "NPZ"],),
},
"optional": optional,
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "io"
OUTPUT_NODE = True
MANUAL_TRIGGER = True
DESCRIPTION = (
"Save one or more DATA_FIELD layers to a single file. "
"Connect fields to the inputs — a new slot appears as each is filled. "
"TIFF writes float32 multi-page; NPZ writes float64 named arrays. "
"Click Save to write (does not auto-run)."
)
_broadcast_warning_fn = None
_current_node_id = None
def save(self, filename: str, format: str = "TIFF", **kwargs):
# Collect connected fields in order
fields = []
for i in range(_MAX_SAVE_FIELDS):
f = kwargs.get(f"field_{i}")
if f is not None:
fields.append(f)
if not fields:
raise ValueError("No fields connected — connect at least one DATA_FIELD input.")
if not filename or not filename.strip():
raise ValueError("No output path selected — use Browse to pick a location.")
path = Path(filename)
# Ensure parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)
# Force correct extension
ext = ".tiff" if format == "TIFF" else ".npz"
if path.suffix.lower() != ext:
path = path.with_suffix(ext)
if format == "TIFF":
self._save_tiff(path, fields)
else:
self._save_npz(path, fields)
self._send_warning(f"Saved {len(fields)} layer(s) to {path.name}")
return ()
def _save_tiff(self, path: Path, fields: list[DataField]):
from PIL import Image
images = []
for f in fields:
images.append(Image.fromarray(f.data.astype(np.float32)))
images[0].save(str(path), save_all=True, append_images=images[1:])
def _save_npz(self, path: Path, fields: list[DataField]):
arrays = {}
for i, f in enumerate(fields):
arrays[f"layer_{i}"] = f.data
np.savez(str(path), **arrays)
def _send_warning(self, message: str):
fn = SaveImage._broadcast_warning_fn
nid = SaveImage._current_node_id
if fn and nid:
fn(nid, message)
return ()