multichannel support + colormap inherit
This commit is contained in:
@@ -16,6 +16,9 @@ from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
|
||||
|
||||
COLORMAPS = ("viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain",
|
||||
"cividis", "magma", "copper", "afmhot")
|
||||
|
||||
@dataclass
|
||||
class DataField:
|
||||
data: np.ndarray # shape (yres, xres), dtype float64
|
||||
@@ -28,6 +31,7 @@ class DataField:
|
||||
si_unit_xy: str = "m"
|
||||
si_unit_z: str = "m"
|
||||
domain: str = "spatial" # "spatial" or "frequency"
|
||||
colormap: str = "viridis"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.data = np.asarray(self.data, dtype=np.float64)
|
||||
@@ -48,6 +52,7 @@ class DataField:
|
||||
si_unit_xy=self.si_unit_xy,
|
||||
si_unit_z=self.si_unit_z,
|
||||
domain=self.domain,
|
||||
colormap=self.colormap,
|
||||
)
|
||||
|
||||
def replace(self, **kwargs) -> "DataField":
|
||||
@@ -63,6 +68,7 @@ class DataField:
|
||||
"si_unit_xy": self.si_unit_xy,
|
||||
"si_unit_z": self.si_unit_z,
|
||||
"domain": self.domain,
|
||||
"colormap": self.colormap,
|
||||
}
|
||||
base.update(kwargs)
|
||||
return DataField(**base)
|
||||
|
||||
@@ -50,6 +50,7 @@ class ExecutionEngine:
|
||||
on_table: Callable[[str, list], None] | None = None,
|
||||
on_mesh: Callable[[str, dict], None] | None = None,
|
||||
on_overlay: Callable[[str, str], None] | None = None,
|
||||
on_warning: Callable[[str, str], None] | None = None,
|
||||
) -> dict[str, tuple]:
|
||||
"""
|
||||
Execute the workflow described by `prompt`.
|
||||
@@ -62,6 +63,7 @@ class ExecutionEngine:
|
||||
on_preview : called with (node_id, data_uri) when a display node runs
|
||||
on_table : called with (node_id, table_list) when PrintTable runs
|
||||
on_overlay : called with (node_id, data_uri) for interactive overlays
|
||||
on_warning : called with (node_id, message) for node warnings
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -71,7 +73,7 @@ class ExecutionEngine:
|
||||
node_outputs: dict[str, tuple] = {}
|
||||
|
||||
# Inject display callbacks before execution
|
||||
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay)
|
||||
self._inject_display_callbacks(on_preview, on_table, on_mesh, on_overlay, on_warning)
|
||||
|
||||
for node_id in order:
|
||||
node_def = prompt[node_id]
|
||||
@@ -174,12 +176,13 @@ class ExecutionEngine:
|
||||
on_table: Callable | None,
|
||||
on_mesh: Callable | None = None,
|
||||
on_overlay: Callable | None = None,
|
||||
on_warning: Callable | None = None,
|
||||
) -> None:
|
||||
"""Wire up broadcast callbacks on display node classes."""
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection, LineCursors
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||
from backend.nodes.io import SaveImage
|
||||
from backend.nodes.io import SaveImage, LoadFile
|
||||
|
||||
PreviewImage._broadcast_fn = on_preview
|
||||
ThresholdMask._broadcast_fn = on_preview
|
||||
@@ -190,6 +193,7 @@ class ExecutionEngine:
|
||||
PrintTable._broadcast_table_fn = on_table
|
||||
CrossSection._broadcast_overlay_fn = on_overlay
|
||||
LineCursors._broadcast_overlay_fn = on_overlay
|
||||
LoadFile._broadcast_warning_fn = on_warning
|
||||
SaveImage._broadcast_preview = (
|
||||
(lambda data_uri: on_preview("save", data_uri)) if on_preview else None
|
||||
)
|
||||
@@ -199,8 +203,9 @@ class ExecutionEngine:
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D
|
||||
from backend.nodes.analysis import CrossSection, LineCursors
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||
from backend.nodes.io import LoadFile
|
||||
if cls in (PreviewImage, PrintTable, View3D, CrossSection, LineCursors,
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine):
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, LoadFile):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
def _auto_preview(
|
||||
@@ -232,7 +237,7 @@ class ExecutionEngine:
|
||||
value = result[slot]
|
||||
|
||||
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
|
||||
arr = datafield_to_uint8(value, "viridis")
|
||||
arr = datafield_to_uint8(value, value.colormap)
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return # one preview per node is enough
|
||||
|
||||
|
||||
@@ -326,6 +326,7 @@ class FFT2D:
|
||||
si_unit_xy="1/m",
|
||||
si_unit_z=z_unit,
|
||||
domain="frequency",
|
||||
colormap=field.colormap,
|
||||
)
|
||||
return (out_field,)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ before execution begins.
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, datafield_to_uint8, image_to_uint8, encode_preview
|
||||
from backend.data_types import DataField, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview
|
||||
|
||||
|
||||
@register_node(display_name="Preview")
|
||||
@@ -18,7 +18,7 @@ class PreviewImage:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"colormap": (["gray", "hot", "jet", "viridis", "plasma", "inferno"],),
|
||||
"colormap": (["auto"] + list(COLORMAPS),),
|
||||
},
|
||||
"optional": {
|
||||
"image": ("IMAGE",),
|
||||
@@ -36,6 +36,10 @@ class PreviewImage:
|
||||
_current_node_id: str = ""
|
||||
|
||||
def preview(self, colormap: str, image: np.ndarray | None = None, field=None) -> tuple:
|
||||
# Resolve "auto" — use field's colormap if available, else fall back to gray
|
||||
if colormap == "auto":
|
||||
colormap = field.colormap if field is not None else "gray"
|
||||
|
||||
# Prefer field if both are connected; accept whichever is provided
|
||||
if field is not None:
|
||||
arr_u8 = datafield_to_uint8(field, colormap)
|
||||
@@ -73,7 +77,7 @@ class View3D:
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"colormap": (["viridis", "gray", "hot", "jet", "plasma", "inferno", "terrain"],),
|
||||
"colormap": (["auto"] + list(COLORMAPS),),
|
||||
"z_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.1}),
|
||||
"resolution": ("INT", {"default": 128, "min": 32, "max": 512, "step": 16}),
|
||||
}
|
||||
@@ -114,7 +118,8 @@ class View3D:
|
||||
else:
|
||||
z_norm = np.zeros_like(z)
|
||||
|
||||
cmap = cm.get_cmap(colormap)
|
||||
cmap_name = field.colormap if colormap == "auto" else colormap
|
||||
cmap = cm.get_cmap(cmap_name)
|
||||
rgba = cmap(z_norm) # (ny, nx, 4) float [0,1]
|
||||
colors_u8 = (rgba[:, :, :3] * 255).astype(np.uint8)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, encode_preview, image_to_uint8
|
||||
from backend.data_types import DataField, COLORMAPS, encode_preview, image_to_uint8
|
||||
from backend.runtime_paths import demo_dir, input_dir, output_dir
|
||||
|
||||
# Resolved at server startup so nodes know where to look
|
||||
@@ -19,112 +19,293 @@ OUTPUT_DIR = output_dir()
|
||||
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
|
||||
".gwy", ".sxm", ".ibw"}
|
||||
|
||||
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
|
||||
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
|
||||
_ARRAY_EXTENSIONS = {".npy", ".npz"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadImage
|
||||
# Channel listing helper (used by the /channels endpoint)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Load Image")
|
||||
class LoadImage:
|
||||
def _resolve_path(filepath: str) -> Path:
|
||||
path = Path(filepath)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
# Try input dir first, then demo dir
|
||||
candidate = INPUT_DIR / filepath
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
candidate = DEMO_DIR / filepath
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
# Fall back to input dir (will trigger FileNotFoundError later)
|
||||
return INPUT_DIR / filepath
|
||||
|
||||
|
||||
def list_channels(filepath: str) -> list[dict]:
|
||||
"""Return available channel info for a file.
|
||||
|
||||
Returns a list of {"name": str, "type": "DATA_FIELD"} dicts.
|
||||
For SPM formats this inspects the file header.
|
||||
For images / arrays, returns a single unnamed channel.
|
||||
"""
|
||||
path = _resolve_path(filepath)
|
||||
if not path.exists():
|
||||
return [{"name": "field", "type": "DATA_FIELD"}]
|
||||
|
||||
ext = path.suffix.lower()
|
||||
|
||||
if ext == ".gwy":
|
||||
try:
|
||||
import gwyfile
|
||||
obj = gwyfile.load(str(path))
|
||||
channels = gwyfile.util.get_datafields(obj)
|
||||
if channels:
|
||||
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
|
||||
except Exception:
|
||||
pass
|
||||
return [{"name": "field", "type": "DATA_FIELD"}]
|
||||
|
||||
if ext == ".sxm":
|
||||
try:
|
||||
import nanonispy as nap
|
||||
sxm = nap.read.Scan(str(path))
|
||||
if sxm.signals:
|
||||
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
|
||||
except Exception:
|
||||
pass
|
||||
return [{"name": "field", "type": "DATA_FIELD"}]
|
||||
|
||||
if ext == ".ibw":
|
||||
try:
|
||||
from igor.binarywave import load as load_ibw
|
||||
wave = load_ibw(str(path))
|
||||
raw = wave["wave"]["wData"]
|
||||
labels = wave["wave"].get("labels", None)
|
||||
if raw.ndim >= 3 and labels:
|
||||
dim_idx = min(2, len(labels) - 1)
|
||||
if dim_idx >= 0 and labels[dim_idx]:
|
||||
decoded = []
|
||||
for lbl in labels[dim_idx]:
|
||||
if lbl:
|
||||
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
|
||||
if name:
|
||||
decoded.append(name)
|
||||
if decoded:
|
||||
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
|
||||
# Multi-channel without labels — use numeric names
|
||||
if raw.ndim >= 3 and raw.shape[2] > 1:
|
||||
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
|
||||
except Exception:
|
||||
pass
|
||||
return [{"name": "field", "type": "DATA_FIELD"}]
|
||||
|
||||
# Image or array — single channel
|
||||
return [{"name": "field", "type": "DATA_FIELD"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadFile (unified loader — replaces LoadImage + LoadSPM)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Load File")
|
||||
class LoadFile:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("FILE_PICKER", {"default": ""}),
|
||||
"colormap": (list(COLORMAPS),),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
|
||||
RETURN_NAMES = ("image", "field")
|
||||
# Default outputs — overridden dynamically by the frontend for multi-channel files
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("field",)
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load a PNG, TIFF, JPG image or .npy/.npz array from the input folder. Outputs both IMAGE and DATA_FIELD."
|
||||
|
||||
def load(self, filename: str):
|
||||
# Accept absolute paths or filenames relative to input/
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = INPUT_DIR / filename
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
if ext in (".npy",):
|
||||
arr = np.load(str(path)).astype(np.float64)
|
||||
elif ext in (".npz",):
|
||||
npz = np.load(str(path))
|
||||
key = list(npz.files)[0]
|
||||
arr = npz[key].astype(np.float64)
|
||||
else:
|
||||
from PIL import Image
|
||||
img = Image.open(str(path))
|
||||
arr = np.array(img)
|
||||
if arr.dtype != np.uint8:
|
||||
arr = arr.astype(np.float64)
|
||||
|
||||
# Convert to float64 grayscale for the DATA_FIELD output
|
||||
if arr.ndim == 3:
|
||||
gray = np.mean(arr.astype(np.float64), axis=2)
|
||||
else:
|
||||
gray = arr.astype(np.float64)
|
||||
|
||||
field = DataField(data=gray)
|
||||
return (arr, field)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadDemo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _list_demo_files() -> list[str]:
|
||||
"""Return sorted list of demo filenames available in the demo/ directory."""
|
||||
if not DEMO_DIR.exists():
|
||||
return []
|
||||
return sorted(
|
||||
f.name for f in DEMO_DIR.iterdir()
|
||||
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
|
||||
DESCRIPTION = (
|
||||
"Load any supported file. "
|
||||
"SPM formats (.gwy, .sxm, .ibw) provide calibrated dimensions; "
|
||||
"each channel gets its own output. "
|
||||
"Images (.png, .tiff, .jpg) and arrays (.npy, .npz) are loaded as uncalibrated fields."
|
||||
)
|
||||
|
||||
# Set by execution engine for warning broadcast
|
||||
_broadcast_warning_fn = None
|
||||
_current_node_id = None
|
||||
|
||||
@register_node(display_name="Load Demo Image")
|
||||
class LoadDemo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
choices = _list_demo_files() or ["(no demo images found)"]
|
||||
return {
|
||||
"required": {
|
||||
"name": (choices,),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "DATA_FIELD")
|
||||
RETURN_NAMES = ("image", "field")
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load a bundled demo image so you can try the app without providing your own data."
|
||||
|
||||
def load(self, name: str):
|
||||
path = DEMO_DIR / name
|
||||
def load(self, filename: str, colormap: str = "viridis"):
|
||||
if not filename or not filename.strip():
|
||||
raise ValueError("No file selected — use Browse to pick a file.")
|
||||
path = _resolve_path(filename)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Demo image not found: {name}")
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
if path.is_dir():
|
||||
raise IsADirectoryError(f"Expected a file, got a directory: {path}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
|
||||
# SPM formats → delegate to LoadSPM-style loading, return as IMAGE + DATA_FIELD
|
||||
if ext == ".gwy":
|
||||
field = LoadSPM()._load_gwy(path, "Z")
|
||||
arr = field.data
|
||||
return (arr, field)
|
||||
elif ext == ".sxm":
|
||||
field = LoadSPM()._load_sxm(path, "Z")
|
||||
arr = field.data
|
||||
return (arr, field)
|
||||
elif ext == ".ibw":
|
||||
field = LoadSPM()._load_ibw(path)
|
||||
arr = field.data
|
||||
return (arr, field)
|
||||
if ext in _SPM_EXTENSIONS:
|
||||
fields = self._load_spm_all(path, ext)
|
||||
for f in fields:
|
||||
f.colormap = colormap
|
||||
return tuple(fields)
|
||||
|
||||
# npy / npz
|
||||
# Image or array — uncalibrated, single output
|
||||
field = self._load_image_or_array(path, ext)
|
||||
field.colormap = colormap
|
||||
self._send_warning("Uncalibrated data — no physical dimensions.")
|
||||
return (field,)
|
||||
|
||||
def _send_warning(self, message: str):
|
||||
fn = LoadFile._broadcast_warning_fn
|
||||
nid = LoadFile._current_node_id
|
||||
if fn and nid:
|
||||
fn(nid, message)
|
||||
|
||||
# -- SPM: load all channels ---------------------------------------------
|
||||
|
||||
def _load_spm_all(self, path: Path, ext: str) -> list[DataField]:
|
||||
if ext == ".gwy":
|
||||
return self._load_gwy_all(path)
|
||||
elif ext == ".sxm":
|
||||
return self._load_sxm_all(path)
|
||||
elif ext == ".ibw":
|
||||
return self._load_ibw_all(path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported SPM format: {ext}")
|
||||
|
||||
# -- GWY ----------------------------------------------------------------
|
||||
|
||||
def _load_gwy_all(self, path: Path) -> list[DataField]:
|
||||
try:
|
||||
import gwyfile
|
||||
except ImportError:
|
||||
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
|
||||
|
||||
obj = gwyfile.load(str(path))
|
||||
channels = gwyfile.util.get_datafields(obj)
|
||||
if not channels:
|
||||
raise ValueError(f"No data channels found in {path.name}")
|
||||
|
||||
fields = []
|
||||
for ch in channels.values():
|
||||
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
|
||||
fields.append(DataField(
|
||||
data=data,
|
||||
xreal=float(ch.xreal),
|
||||
yreal=float(ch.yreal),
|
||||
xoff=float(getattr(ch, "xoff", 0.0)),
|
||||
yoff=float(getattr(ch, "yoff", 0.0)),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
))
|
||||
return fields
|
||||
|
||||
# -- SXM ----------------------------------------------------------------
|
||||
|
||||
def _load_sxm_all(self, path: Path) -> list[DataField]:
|
||||
try:
|
||||
import nanonispy as nap
|
||||
except ImportError:
|
||||
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
|
||||
|
||||
sxm = nap.read.Scan(str(path))
|
||||
signals = sxm.signals
|
||||
if not signals:
|
||||
raise ValueError(f"No signals found in {path.name}")
|
||||
|
||||
header = sxm.header
|
||||
scan_range = header.get("scan_range", [1e-6, 1e-6])
|
||||
|
||||
fields = []
|
||||
for sig in signals.values():
|
||||
data = sig.get("forward", list(sig.values())[0])
|
||||
data = np.asarray(data, dtype=np.float64)
|
||||
if data.ndim != 2:
|
||||
data = data.reshape(data.shape[-2], data.shape[-1])
|
||||
fields.append(DataField(
|
||||
data=data,
|
||||
xreal=float(scan_range[0]),
|
||||
yreal=float(scan_range[1]),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
))
|
||||
return fields
|
||||
|
||||
# -- IBW ----------------------------------------------------------------
|
||||
|
||||
def _load_ibw_all(self, path: Path) -> list[DataField]:
|
||||
try:
|
||||
from igor.binarywave import load as load_ibw
|
||||
except ImportError:
|
||||
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
|
||||
|
||||
wave = load_ibw(str(path))
|
||||
wdata = wave["wave"]
|
||||
header = wdata["wave_header"]
|
||||
raw = wdata["wData"]
|
||||
|
||||
n_channels = raw.shape[2] if raw.ndim >= 3 else 1
|
||||
|
||||
# Physical scaling
|
||||
sfA = header.get("sfA", None)
|
||||
|
||||
def _decode_unit(raw_unit):
|
||||
if raw_unit is None:
|
||||
return "m"
|
||||
if isinstance(raw_unit, bytes):
|
||||
return raw_unit.split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
|
||||
if isinstance(raw_unit, np.ndarray):
|
||||
return bytes(raw_unit).split(b"\x00", 1)[0].decode("ascii", errors="replace").strip() or "m"
|
||||
return str(raw_unit).strip() or "m"
|
||||
|
||||
dim_units_raw = header.get("dimUnits", None)
|
||||
data_units_raw = header.get("dataUnits", None)
|
||||
|
||||
if isinstance(dim_units_raw, np.ndarray) and dim_units_raw.ndim == 2:
|
||||
si_unit_xy = _decode_unit(dim_units_raw[0])
|
||||
elif isinstance(dim_units_raw, (list, np.ndarray)) and len(dim_units_raw) > 0:
|
||||
si_unit_xy = _decode_unit(dim_units_raw[0])
|
||||
else:
|
||||
si_unit_xy = _decode_unit(dim_units_raw)
|
||||
|
||||
si_unit_z = _decode_unit(data_units_raw)
|
||||
|
||||
fields = []
|
||||
for ch_idx in range(n_channels):
|
||||
if raw.ndim >= 3:
|
||||
ch_data = raw[:, :, ch_idx]
|
||||
elif raw.ndim == 1:
|
||||
ch_data = raw.reshape(-1, 1)
|
||||
else:
|
||||
ch_data = raw
|
||||
|
||||
# Transpose from (xres, yres) Igor order to (yres, xres) DataField order,
|
||||
# then flip vertically to match gwyddion
|
||||
data = np.flipud(ch_data.T).astype(np.float64)
|
||||
yres, xres = data.shape
|
||||
|
||||
if sfA is not None and len(sfA) >= 2:
|
||||
xreal = abs(float(sfA[0]) * xres) or 1e-6
|
||||
yreal = abs(float(sfA[1]) * yres) or 1e-6
|
||||
else:
|
||||
hsA = header.get("hsA", 0.0)
|
||||
xreal = abs(float(hsA) * xres) or 1e-6
|
||||
yreal = xreal * (yres / xres) if xres else 1e-6
|
||||
|
||||
fields.append(DataField(
|
||||
data=data, xreal=xreal, yreal=yreal,
|
||||
si_unit_xy=si_unit_xy, si_unit_z=si_unit_z,
|
||||
))
|
||||
|
||||
return fields
|
||||
|
||||
# -- Image / array (uncalibrated) --------------------------------------
|
||||
|
||||
def _load_image_or_array(self, path: Path, ext: str) -> DataField:
|
||||
if ext == ".npy":
|
||||
arr = np.load(str(path)).astype(np.float64)
|
||||
elif ext == ".npz":
|
||||
@@ -143,22 +324,32 @@ class LoadDemo:
|
||||
else:
|
||||
gray = arr.astype(np.float64)
|
||||
|
||||
field = DataField(data=gray)
|
||||
return (arr, field)
|
||||
return DataField(data=gray)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoadSPM
|
||||
# LoadDemo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Load SPM File")
|
||||
class LoadSPM:
|
||||
def _list_demo_files() -> list[str]:
|
||||
"""Return sorted list of demo filenames available in the demo/ directory."""
|
||||
if not DEMO_DIR.exists():
|
||||
return []
|
||||
return sorted(
|
||||
f.name for f in DEMO_DIR.iterdir()
|
||||
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
@register_node(display_name="Load Demo File")
|
||||
class LoadDemo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
choices = _list_demo_files() or ["(no demo files found)"]
|
||||
return {
|
||||
"required": {
|
||||
"filename": ("FILE_PICKER", {"default": ""}),
|
||||
"channel": ("STRING", {"default": "Z"}),
|
||||
"name": (choices,),
|
||||
"colormap": (list(COLORMAPS),),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,111 +357,15 @@ class LoadSPM:
|
||||
RETURN_NAMES = ("field",)
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "io"
|
||||
DESCRIPTION = "Load SPM/AFM data from .gwy, .sxm, or .ibw files into a calibrated DataField."
|
||||
DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data."
|
||||
|
||||
def load(self, filename: str, channel: str = "Z"):
|
||||
path = Path(filename)
|
||||
if not path.is_absolute():
|
||||
path = INPUT_DIR / filename
|
||||
def load(self, name: str, colormap: str = "viridis"):
|
||||
path = DEMO_DIR / name
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
raise FileNotFoundError(f"Demo file not found: {name}")
|
||||
|
||||
ext = path.suffix.lower()
|
||||
|
||||
if ext == ".gwy":
|
||||
return (self._load_gwy(path, channel),)
|
||||
elif ext == ".sxm":
|
||||
return (self._load_sxm(path, channel),)
|
||||
elif ext in (".ibw",):
|
||||
return (self._load_ibw(path),)
|
||||
elif ext in (".npy",):
|
||||
data = np.load(str(path)).astype(np.float64)
|
||||
return (DataField(data=data),)
|
||||
elif ext in (".npz",):
|
||||
npz = np.load(str(path))
|
||||
key = list(npz.files)[0]
|
||||
return (DataField(data=npz[key].astype(np.float64)),)
|
||||
else:
|
||||
raise ValueError(f"Unsupported SPM format: {ext}. Supported: .gwy, .sxm, .ibw, .npy, .npz")
|
||||
|
||||
def _load_gwy(self, path: Path, channel: str) -> DataField:
|
||||
try:
|
||||
import gwyfile
|
||||
except ImportError:
|
||||
raise ImportError("Install 'gwyfile' package to load .gwy files: pip install gwyfile")
|
||||
|
||||
obj = gwyfile.load(str(path))
|
||||
channels = gwyfile.util.get_datafields(obj)
|
||||
if not channels:
|
||||
raise ValueError(f"No data channels found in {path.name}")
|
||||
|
||||
# Try requested channel name, fall back to first available
|
||||
ch = None
|
||||
for key, df in channels.items():
|
||||
if channel.lower() in key.lower():
|
||||
ch = df
|
||||
break
|
||||
if ch is None:
|
||||
ch = next(iter(channels.values()))
|
||||
|
||||
data = np.array(ch.data, dtype=np.float64).reshape(ch.yres, ch.xres)
|
||||
return DataField(
|
||||
data=data,
|
||||
xreal=float(ch.xreal),
|
||||
yreal=float(ch.yreal),
|
||||
xoff=float(getattr(ch, "xoff", 0.0)),
|
||||
yoff=float(getattr(ch, "yoff", 0.0)),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
)
|
||||
|
||||
def _load_sxm(self, path: Path, channel: str) -> DataField:
|
||||
try:
|
||||
import nanonispy as nap
|
||||
except ImportError:
|
||||
raise ImportError("Install 'nanonispy' package to load .sxm files: pip install nanonispy")
|
||||
|
||||
sxm = nap.read.Scan(str(path))
|
||||
signals = sxm.signals
|
||||
|
||||
# Pick channel
|
||||
ch_key = None
|
||||
for key in signals:
|
||||
if channel.upper() in key.upper():
|
||||
ch_key = key
|
||||
break
|
||||
if ch_key is None:
|
||||
ch_key = next(iter(signals))
|
||||
|
||||
data = signals[ch_key].get("forward", list(signals[ch_key].values())[0])
|
||||
data = np.asarray(data, dtype=np.float64)
|
||||
if data.ndim != 2:
|
||||
data = data.reshape(data.shape[-2], data.shape[-1])
|
||||
|
||||
header = sxm.header
|
||||
scan_range = header.get("scan_range", [1e-6, 1e-6])
|
||||
return DataField(
|
||||
data=data,
|
||||
xreal=float(scan_range[0]),
|
||||
yreal=float(scan_range[1]),
|
||||
si_unit_xy="m",
|
||||
si_unit_z="m",
|
||||
)
|
||||
|
||||
def _load_ibw(self, path: Path) -> DataField:
|
||||
try:
|
||||
import igor.igorpy as igorpy
|
||||
wave = igorpy.load(str(path))
|
||||
data = wave.wave["wData"].squeeze().astype(np.float64)
|
||||
except ImportError:
|
||||
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
|
||||
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(1, -1)
|
||||
elif data.ndim != 2:
|
||||
data = data[:, :, 0] if data.ndim == 3 else data.reshape(data.shape[0], -1)
|
||||
|
||||
return DataField(data=data, si_unit_xy="m", si_unit_z="m")
|
||||
loader = LoadFile()
|
||||
return loader.load(filename=str(path), colormap=colormap)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -101,6 +101,9 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
def on_overlay(node_id: str, overlay_data) -> None:
|
||||
broadcast({"type": "overlay", "data": {"node_id": node_id, "overlay": overlay_data}})
|
||||
|
||||
def on_warning(node_id: str, message: str) -> None:
|
||||
broadcast({"type": "node_warning", "data": {"node_id": node_id, "message": message}})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ------------------------------------------------------------------
|
||||
@@ -193,6 +196,18 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
},
|
||||
)
|
||||
|
||||
async def get_channels(request: web.Request) -> web.Response:
|
||||
"""Return available channels for a given file path."""
|
||||
from backend.nodes.io import list_channels
|
||||
filepath = request.query.get("file", "")
|
||||
if not filepath:
|
||||
return web.Response(
|
||||
text=_dumps([{"name": "field", "type": "DATA_FIELD"}]),
|
||||
content_type="application/json",
|
||||
)
|
||||
channels = await loop.run_in_executor(None, list_channels, filepath)
|
||||
return web.Response(text=_dumps(channels), content_type="application/json")
|
||||
|
||||
async def submit_prompt(request: web.Request) -> web.Response:
|
||||
body = await request.json()
|
||||
prompt = body.get("prompt")
|
||||
@@ -218,6 +233,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
on_table=on_table,
|
||||
on_mesh=on_mesh,
|
||||
on_overlay=on_overlay,
|
||||
on_warning=on_warning,
|
||||
),
|
||||
)
|
||||
broadcast({"type": "execution_complete", "data": {"prompt_id": prompt_id}})
|
||||
@@ -262,6 +278,7 @@ def create_app(loop: asyncio.AbstractEventLoop) -> web.Application:
|
||||
app.router.add_get("/browse", browse_dir)
|
||||
app.router.add_post("/upload", upload_file)
|
||||
app.router.add_post("/download", download_file)
|
||||
app.router.add_get("/channels", get_channels)
|
||||
app.router.add_post("/prompt", submit_prompt)
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
|
||||
|
||||
@@ -436,6 +436,9 @@ function Flow() {
|
||||
case 'overlay':
|
||||
updateNodeData(msg.data.node_id, { overlay: msg.data.overlay });
|
||||
break;
|
||||
case 'node_warning':
|
||||
updateNodeData(msg.data.node_id, { warning: msg.data.message });
|
||||
break;
|
||||
}
|
||||
});
|
||||
api.initWS();
|
||||
@@ -500,9 +503,36 @@ function Flow() {
|
||||
data: {
|
||||
...n.data,
|
||||
widgetValues: { ...n.data.widgetValues, [name]: value },
|
||||
// Clear warning when user changes a value
|
||||
warning: null,
|
||||
},
|
||||
};
|
||||
}));
|
||||
|
||||
// If this is a filename/name change on a LoadFile/LoadDemo node, fetch channels
|
||||
if ((name === 'filename' || name === 'name') && value) {
|
||||
const node = reactFlow.getNode(nodeId);
|
||||
if (node && (node.data.className === 'LoadFile' || node.data.className === 'LoadDemo')) {
|
||||
api.getChannels(value).then((channels) => {
|
||||
setNodes((prev) => prev.map((n) => {
|
||||
if (n.id !== nodeId) return n;
|
||||
return {
|
||||
...n,
|
||||
data: {
|
||||
...n.data,
|
||||
definition: {
|
||||
...n.data.definition,
|
||||
output: channels.map((c) => c.type),
|
||||
output_name: channels.map((c) => c.name),
|
||||
},
|
||||
},
|
||||
};
|
||||
}));
|
||||
reactFlow.updateNodeInternals(nodeId);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
scheduleAutoRun();
|
||||
}, [setNodes]); // scheduleAutoRun is stable (no deps)
|
||||
|
||||
@@ -568,6 +598,27 @@ function Flow() {
|
||||
|
||||
setNodes((ns) => [...ns, newNode]);
|
||||
|
||||
// For LoadFile/LoadDemo, auto-fetch channels for the default value
|
||||
if (className === 'LoadDemo' && widgetValues.name) {
|
||||
api.getChannels(widgetValues.name).then((channels) => {
|
||||
setNodes((prev) => prev.map((n) => {
|
||||
if (n.id !== newNodeId) return n;
|
||||
return {
|
||||
...n,
|
||||
data: {
|
||||
...n.data,
|
||||
definition: {
|
||||
...n.data.definition,
|
||||
output: channels.map((c) => c.type),
|
||||
output_name: channels.map((c) => c.name),
|
||||
},
|
||||
},
|
||||
};
|
||||
}));
|
||||
reactFlow.updateNodeInternals(newNodeId);
|
||||
});
|
||||
}
|
||||
|
||||
// Auto-connect if this was triggered by dropping a connection on blank space
|
||||
if (contextMenu.pendingHandleId) {
|
||||
const filterType = contextMenu.filterType;
|
||||
|
||||
@@ -211,6 +211,11 @@ function CustomNode({ id, data }) {
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Warning notification */}
|
||||
{data.warning && (
|
||||
<div className="node-warning">{data.warning}</div>
|
||||
)}
|
||||
|
||||
{/* Widget rows */}
|
||||
{widgets.map((w) => (
|
||||
<div className="widget-row" key={w.name}>
|
||||
|
||||
@@ -34,6 +34,12 @@ export async function uploadFile(file) {
|
||||
return r.json();
|
||||
}
|
||||
|
||||
export async function getChannels(filepath) {
|
||||
const r = await fetch(`/channels?file=${encodeURIComponent(filepath)}`);
|
||||
if (!r.ok) return [{ name: 'field', type: 'DATA_FIELD' }];
|
||||
return r.json();
|
||||
}
|
||||
|
||||
export async function runPrompt(prompt) {
|
||||
const r = await fetch('/prompt', {
|
||||
method: 'POST',
|
||||
|
||||
@@ -82,10 +82,7 @@ html, body, #root {
|
||||
padding: 4px 10px;
|
||||
border-radius: 4px;
|
||||
font-size: 11px;
|
||||
max-width: 400px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
max-width: 60%;
|
||||
flex-shrink: 1;
|
||||
}
|
||||
.status-bar.info { color: #90caf9; }
|
||||
@@ -155,6 +152,15 @@ html, body, #root {
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
.node-warning {
|
||||
padding: 3px 10px;
|
||||
font-size: 10px;
|
||||
color: #fbbf24;
|
||||
background: rgba(251, 191, 36, 0.1);
|
||||
border-top: 1px solid rgba(251, 191, 36, 0.2);
|
||||
border-bottom: 1px solid rgba(251, 191, 36, 0.2);
|
||||
}
|
||||
|
||||
/* ── I/O rows ──────────────────────────────────────────────────────── */
|
||||
.io-row {
|
||||
display: flex;
|
||||
|
||||
@@ -9,6 +9,7 @@ export default defineConfig({
|
||||
'/nodes': 'http://127.0.0.1:8188',
|
||||
'/files': 'http://127.0.0.1:8188',
|
||||
'/browse': 'http://127.0.0.1:8188',
|
||||
'/channels': 'http://127.0.0.1:8188',
|
||||
'/upload': 'http://127.0.0.1:8188',
|
||||
'/download': 'http://127.0.0.1:8188',
|
||||
'/prompt': 'http://127.0.0.1:8188',
|
||||
|
||||
@@ -523,41 +523,42 @@ def test_particle_analysis():
|
||||
# I/O
|
||||
# =========================================================================
|
||||
|
||||
def test_load_image():
|
||||
print("=== Test: LoadImage ===")
|
||||
from backend.nodes.io import LoadImage
|
||||
def test_load_file():
|
||||
print("=== Test: LoadFile ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
from PIL import Image
|
||||
node = LoadImage()
|
||||
node = LoadFile()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Test loading a grayscale PNG
|
||||
# Test loading a grayscale PNG → single DataField output
|
||||
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
|
||||
img = Image.fromarray(arr, mode="L")
|
||||
path = os.path.join(tmpdir, "test_gray.png")
|
||||
img.save(path)
|
||||
|
||||
image, field = node.load(filename=path)
|
||||
assert image.shape == (48, 64)
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
field = result[0]
|
||||
assert field.data.shape == (48, 64)
|
||||
assert field.data.dtype == np.float64
|
||||
|
||||
# Test loading an RGB PNG (should average to grayscale for field)
|
||||
# Test loading an RGB PNG (should average to grayscale)
|
||||
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
|
||||
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
|
||||
path_rgb = os.path.join(tmpdir, "test_rgb.png")
|
||||
img_rgb.save(path_rgb)
|
||||
|
||||
image_rgb, field_rgb = node.load(filename=path_rgb)
|
||||
assert image_rgb.shape == (32, 32, 3)
|
||||
assert field_rgb.data.shape == (32, 32)
|
||||
result_rgb = node.load(filename=path_rgb)
|
||||
assert len(result_rgb) == 1
|
||||
assert result_rgb[0].data.shape == (32, 32)
|
||||
|
||||
# Test loading a .npy file
|
||||
data_npy = np.random.default_rng(3).standard_normal((50, 60))
|
||||
path_npy = os.path.join(tmpdir, "test.npy")
|
||||
np.save(path_npy, data_npy)
|
||||
|
||||
image_npy, field_npy = node.load(filename=path_npy)
|
||||
assert np.allclose(field_npy.data, data_npy)
|
||||
result_npy = node.load(filename=path_npy)
|
||||
assert np.allclose(result_npy[0].data, data_npy)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
@@ -641,6 +642,464 @@ def test_print_table():
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — IBW multi-channel loading
|
||||
# =========================================================================
|
||||
|
||||
def test_load_file_ibw():
|
||||
print("=== Test: LoadFile IBW multi-channel ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
|
||||
node = LoadFile()
|
||||
ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")
|
||||
ibw_path = os.path.abspath(ibw_path)
|
||||
if not os.path.exists(ibw_path):
|
||||
print(" SKIP (demo IBW file not found)\n")
|
||||
return
|
||||
|
||||
result = node.load(filename=ibw_path)
|
||||
|
||||
# BR_New20012.ibw has 4 channels
|
||||
assert len(result) == 4, f"Expected 4 channels, got {len(result)}"
|
||||
|
||||
for i, field in enumerate(result):
|
||||
assert isinstance(field, DataField), f"Channel {i} is not a DataField"
|
||||
assert field.data.shape == (512, 1024), f"Channel {i} shape: {field.data.shape}"
|
||||
assert field.data.dtype == np.float64
|
||||
# Physical dimensions should be populated (not default 1e-6)
|
||||
assert field.xreal > 1e-8, f"Channel {i} xreal too small: {field.xreal}"
|
||||
assert field.yreal > 1e-8, f"Channel {i} yreal too small: {field.yreal}"
|
||||
assert field.si_unit_xy == "m"
|
||||
assert field.si_unit_z == "m"
|
||||
|
||||
# All channels should share the same physical dimensions
|
||||
assert result[0].xreal == result[1].xreal
|
||||
assert result[0].yreal == result[1].yreal
|
||||
|
||||
# Different channels should have different data
|
||||
assert not np.array_equal(result[0].data, result[1].data)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_load_file_npz():
|
||||
print("=== Test: LoadFile .npz ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
|
||||
node = LoadFile()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
data = np.random.default_rng(99).standard_normal((30, 40))
|
||||
path = os.path.join(tmpdir, "test.npz")
|
||||
np.savez(path, my_array=data)
|
||||
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
assert np.allclose(result[0].data, data)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_load_file_not_found():
|
||||
print("=== Test: LoadFile not found ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
|
||||
node = LoadFile()
|
||||
try:
|
||||
node.load(filename="/nonexistent/path/file.png")
|
||||
assert False, "Should have raised FileNotFoundError"
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_load_file_unsupported():
|
||||
print("=== Test: LoadFile unsupported format ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
|
||||
node = LoadFile()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "test.xyz")
|
||||
with open(path, "w") as f:
|
||||
f.write("hello")
|
||||
try:
|
||||
node.load(filename=path)
|
||||
assert False, "Should have raised an error for .xyz"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_load_file_warning():
|
||||
print("=== Test: LoadFile warning for uncalibrated data ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
from PIL import Image
|
||||
|
||||
node = LoadFile()
|
||||
warnings = []
|
||||
LoadFile._broadcast_warning_fn = lambda nid, msg: warnings.append(msg)
|
||||
LoadFile._current_node_id = "test"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8)
|
||||
img = Image.fromarray(arr)
|
||||
path = os.path.join(tmpdir, "test.png")
|
||||
img.save(path)
|
||||
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
assert len(warnings) == 1
|
||||
assert "Uncalibrated" in warnings[0]
|
||||
|
||||
LoadFile._broadcast_warning_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — list_channels helper
|
||||
# =========================================================================
|
||||
|
||||
def test_list_channels():
|
||||
print("=== Test: list_channels ===")
|
||||
from backend.nodes.io import list_channels
|
||||
|
||||
# Non-existent file → default
|
||||
ch = list_channels("/nonexistent/file.ibw")
|
||||
assert len(ch) == 1
|
||||
assert ch[0]["name"] == "field"
|
||||
|
||||
# IBW with channels
|
||||
ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw"))
|
||||
if os.path.exists(ibw_path):
|
||||
ch = list_channels(ibw_path)
|
||||
assert len(ch) == 4
|
||||
names = [c["name"] for c in ch]
|
||||
assert "HeightRetrace" in names
|
||||
assert "AmplitudeRetrace" in names
|
||||
assert all(c["type"] == "DATA_FIELD" for c in ch)
|
||||
|
||||
# Plain image → single default channel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
from PIL import Image
|
||||
img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8))
|
||||
path = os.path.join(tmpdir, "test.png")
|
||||
img.save(path)
|
||||
|
||||
ch = list_channels(path)
|
||||
assert len(ch) == 1
|
||||
assert ch[0]["name"] == "field"
|
||||
|
||||
# .npy → single default channel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "test.npy")
|
||||
np.save(path, np.zeros((4, 4)))
|
||||
|
||||
ch = list_channels(path)
|
||||
assert len(ch) == 1
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — LoadDemo
|
||||
# =========================================================================
|
||||
|
||||
def test_load_demo():
|
||||
print("=== Test: LoadDemo ===")
|
||||
from backend.nodes.io import LoadDemo
|
||||
|
||||
node = LoadDemo()
|
||||
|
||||
# Should be able to load a demo file by name
|
||||
result = node.load(name="nanoparticles.npy")
|
||||
assert len(result) >= 1
|
||||
assert isinstance(result[0], DataField)
|
||||
assert result[0].data.ndim == 2
|
||||
|
||||
# IBW demo should return multiple channels
|
||||
result_ibw = node.load(name="whiskers.ibw")
|
||||
assert len(result_ibw) == 4
|
||||
for field in result_ibw:
|
||||
assert isinstance(field, DataField)
|
||||
|
||||
# Non-existent demo should raise
|
||||
try:
|
||||
node.load(name="nonexistent_file.png")
|
||||
assert False, "Should have raised FileNotFoundError"
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — Coordinate
|
||||
# =========================================================================
|
||||
|
||||
def test_coordinate():
|
||||
print("=== Test: Coordinate ===")
|
||||
from backend.nodes.io import Coordinate
|
||||
|
||||
node = Coordinate()
|
||||
|
||||
result = node.process(x=0.3, y=0.7)
|
||||
assert len(result) == 1
|
||||
assert result[0] == (0.3, 0.7)
|
||||
|
||||
# Edge values
|
||||
result_zero = node.process(x=0.0, y=0.0)
|
||||
assert result_zero[0] == (0.0, 0.0)
|
||||
|
||||
result_one = node.process(x=1.0, y=1.0)
|
||||
assert result_one[0] == (1.0, 1.0)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis — LineCursors
|
||||
# =========================================================================
|
||||
|
||||
def test_line_cursors():
|
||||
print("=== Test: LineCursors ===")
|
||||
from backend.nodes.analysis import LineCursors
|
||||
|
||||
node = LineCursors()
|
||||
|
||||
# Create a simple linear ramp
|
||||
line = np.linspace(0, 10, 100).astype(np.float64)
|
||||
|
||||
# Capture overlay
|
||||
overlays = []
|
||||
LineCursors._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
|
||||
LineCursors._current_node_id = "test"
|
||||
|
||||
table, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
|
||||
|
||||
# Should produce a 6-row table
|
||||
assert len(table) == 6
|
||||
quantities = {row["quantity"] for row in table}
|
||||
assert "A position" in quantities
|
||||
assert "B position" in quantities
|
||||
assert "delta X" in quantities
|
||||
assert "delta Y" in quantities
|
||||
|
||||
# B should be at a later position than A
|
||||
a_pos = next(r["value"] for r in table if r["quantity"] == "A position")
|
||||
b_pos = next(r["value"] for r in table if r["quantity"] == "B position")
|
||||
assert b_pos > a_pos
|
||||
|
||||
# Delta Y should reflect the height difference along the ramp
|
||||
dy = next(r["value"] for r in table if r["quantity"] == "delta Y")
|
||||
assert dy > 0 # ramp goes upward
|
||||
|
||||
# Overlay should have been broadcast
|
||||
assert len(overlays) == 1
|
||||
assert "image" in overlays[0]
|
||||
assert overlays[0]["image"].startswith("data:image/png;base64,")
|
||||
|
||||
# With x_axis provided
|
||||
x_axis = np.linspace(0, 1, 100).astype(np.float64)
|
||||
table2, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5, x_axis=x_axis)
|
||||
assert len(table2) == 6
|
||||
|
||||
LineCursors._broadcast_overlay_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis — FFT2D
|
||||
# =========================================================================
|
||||
|
||||
def test_fft2d():
|
||||
print("=== Test: FFT2D ===")
|
||||
from backend.nodes.analysis import FFT2D
|
||||
|
||||
node = FFT2D()
|
||||
|
||||
# Pure single-frequency signal: peak should appear at the right location
|
||||
N = 64
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
freq = 5
|
||||
data = np.sin(2 * np.pi * freq * x)
|
||||
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
||||
|
||||
# log_magnitude
|
||||
spectrum, = node.process(field, windowing="none", level="none", output="log_magnitude")
|
||||
assert spectrum.data.shape == (N, N)
|
||||
assert spectrum.domain == "frequency"
|
||||
assert spectrum.si_unit_xy == "1/m"
|
||||
# Peak should be symmetric about centre
|
||||
centre = N // 2
|
||||
row = spectrum.data[centre, :]
|
||||
peak_idx = np.argmax(row[centre + 1:]) + centre + 1
|
||||
assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}"
|
||||
|
||||
# magnitude output
|
||||
spec_mag, = node.process(field, windowing="hann", level="mean", output="magnitude")
|
||||
assert spec_mag.data.shape == (N, N)
|
||||
assert np.all(spec_mag.data >= 0)
|
||||
|
||||
# phase output
|
||||
spec_phase, = node.process(field, windowing="none", level="none", output="phase")
|
||||
assert spec_phase.data.shape == (N, N)
|
||||
assert spec_phase.data.min() >= -np.pi - 0.01
|
||||
assert spec_phase.data.max() <= np.pi + 0.01
|
||||
|
||||
# psdf output — units should reflect PSDF calibration
|
||||
spec_psdf, = node.process(field, windowing="hamming", level="plane", output="psdf")
|
||||
assert spec_psdf.data.shape == (N, N)
|
||||
assert np.all(spec_psdf.data >= 0)
|
||||
assert "^2" in spec_psdf.si_unit_z
|
||||
|
||||
# Constant field should have all energy at DC
|
||||
const_field = make_field(data=np.ones((32, 32)) * 3.0)
|
||||
spec_const, = node.process(const_field, windowing="none", level="none", output="magnitude")
|
||||
centre32 = 16
|
||||
dc_val = spec_const.data[centre32, centre32]
|
||||
assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field"
|
||||
|
||||
# Blackman windowing should also work without error
|
||||
spec_bk, = node.process(field, windowing="blackman", level="none", output="log_magnitude")
|
||||
assert spec_bk.data.shape == (N, N)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis — LineMath
|
||||
# =========================================================================
|
||||
|
||||
def test_line_math():
|
||||
print("=== Test: LineMath ===")
|
||||
from backend.nodes.analysis import LineMath
|
||||
|
||||
node = LineMath()
|
||||
line = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
# Basic stats
|
||||
table, = node.process(line, operation="min")
|
||||
assert table[0]["value"] == 1.0
|
||||
|
||||
table, = node.process(line, operation="max")
|
||||
assert table[0]["value"] == 5.0
|
||||
|
||||
table, = node.process(line, operation="mean")
|
||||
assert table[0]["value"] == 3.0
|
||||
|
||||
table, = node.process(line, operation="median")
|
||||
assert table[0]["value"] == 3.0
|
||||
|
||||
table, = node.process(line, operation="sum")
|
||||
assert table[0]["value"] == 15.0
|
||||
|
||||
table, = node.process(line, operation="range")
|
||||
assert table[0]["value"] == 4.0
|
||||
|
||||
table, = node.process(line, operation="length")
|
||||
assert table[0]["value"] == 5.0
|
||||
|
||||
# RMS of [1,2,3,4,5]
|
||||
table, = node.process(line, operation="rms")
|
||||
expected_rms = np.sqrt(np.mean(line ** 2))
|
||||
assert abs(table[0]["value"] - expected_rms) < 1e-10
|
||||
|
||||
# Roughness parameters
|
||||
table, = node.process(line, operation="Ra")
|
||||
d = line - line.mean()
|
||||
expected_ra = float(np.mean(np.abs(d)))
|
||||
assert abs(table[0]["value"] - expected_ra) < 1e-10
|
||||
|
||||
table, = node.process(line, operation="Rq")
|
||||
expected_rq = float(np.sqrt(np.mean(d ** 2)))
|
||||
assert abs(table[0]["value"] - expected_rq) < 1e-10
|
||||
|
||||
# Rp = max of (z - mean)
|
||||
table, = node.process(line, operation="Rp")
|
||||
assert abs(table[0]["value"] - d.max()) < 1e-10
|
||||
|
||||
# Rv = -(min of (z - mean))
|
||||
table, = node.process(line, operation="Rv")
|
||||
assert abs(table[0]["value"] - (-d.min())) < 1e-10
|
||||
|
||||
# Rt = Rp + Rv = range of (z - mean)
|
||||
table, = node.process(line, operation="Rt")
|
||||
assert abs(table[0]["value"] - (d.max() - d.min())) < 1e-10
|
||||
|
||||
# Constant line: roughness parameters should all be zero
|
||||
const_line = np.ones(10) * 7.0
|
||||
table, = node.process(const_line, operation="Ra")
|
||||
assert table[0]["value"] == 0.0
|
||||
table, = node.process(const_line, operation="Rq")
|
||||
assert table[0]["value"] == 0.0
|
||||
table, = node.process(const_line, operation="Rsk")
|
||||
assert table[0]["value"] == 0.0
|
||||
table, = node.process(const_line, operation="Rku")
|
||||
assert table[0]["value"] == 0.0
|
||||
|
||||
# Slope-based: Dq and Da
|
||||
table, = node.process(line, operation="Dq")
|
||||
dz = np.diff(line)
|
||||
expected_dq = float(np.sqrt(np.mean(dz * dz)))
|
||||
assert abs(table[0]["value"] - expected_dq) < 1e-10
|
||||
|
||||
table, = node.process(line, operation="Da")
|
||||
expected_da = float(np.mean(np.abs(dz)))
|
||||
assert abs(table[0]["value"] - expected_da) < 1e-10
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Display — View3D
|
||||
# =========================================================================
|
||||
|
||||
def test_view3d():
|
||||
print("=== Test: View3D ===")
|
||||
from backend.nodes.display import View3D
|
||||
|
||||
node = View3D()
|
||||
field = make_field()
|
||||
|
||||
captured = []
|
||||
View3D._broadcast_mesh_fn = lambda nid, mesh: captured.append(mesh)
|
||||
View3D._current_node_id = "test"
|
||||
|
||||
result = node.render(field, colormap="viridis", z_scale=2.0, resolution=64)
|
||||
assert result == ()
|
||||
assert len(captured) == 1
|
||||
|
||||
mesh = captured[0]
|
||||
assert "width" in mesh
|
||||
assert "height" in mesh
|
||||
assert "z_data" in mesh
|
||||
assert "colors" in mesh
|
||||
assert mesh["z_scale"] == 2.0
|
||||
assert mesh["width"] <= 64
|
||||
assert mesh["height"] <= 64
|
||||
# z_min < z_max for non-constant data
|
||||
assert mesh["z_min"] < mesh["z_max"]
|
||||
|
||||
# Verify base64 data can be decoded
|
||||
import base64
|
||||
z_bytes = base64.b64decode(mesh["z_data"])
|
||||
assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 # float32
|
||||
|
||||
colors_bytes = base64.b64decode(mesh["colors"])
|
||||
assert len(colors_bytes) == mesh["width"] * mesh["height"] * 3 # uint8 RGB
|
||||
|
||||
# High-res input should be downsampled
|
||||
big_field = make_field(shape=(256, 256))
|
||||
captured.clear()
|
||||
node.render(big_field, colormap="hot", z_scale=1.0, resolution=64)
|
||||
assert captured[0]["width"] <= 64
|
||||
assert captured[0]["height"] <= 64
|
||||
|
||||
View3D._broadcast_mesh_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Run all tests
|
||||
# =========================================================================
|
||||
@@ -662,6 +1121,9 @@ if __name__ == "__main__":
|
||||
test_statistics()
|
||||
test_height_histogram()
|
||||
test_cross_section()
|
||||
test_line_cursors()
|
||||
test_fft2d()
|
||||
test_line_math()
|
||||
|
||||
# Mask
|
||||
test_threshold_mask()
|
||||
@@ -673,11 +1135,20 @@ if __name__ == "__main__":
|
||||
test_particle_analysis()
|
||||
|
||||
# I/O
|
||||
test_load_image()
|
||||
test_load_file()
|
||||
test_load_file_ibw()
|
||||
test_load_file_npz()
|
||||
test_load_file_not_found()
|
||||
test_load_file_unsupported()
|
||||
test_load_file_warning()
|
||||
test_list_channels()
|
||||
test_load_demo()
|
||||
test_coordinate()
|
||||
test_save_image()
|
||||
|
||||
# Display
|
||||
test_preview_image()
|
||||
test_print_table()
|
||||
test_view3d()
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
Reference in New Issue
Block a user