multichannel support + colormap inherit

This commit is contained in:
2026-03-24 21:01:58 -07:00
parent 53e2fc7746
commit a60b0c15ca
12 changed files with 889 additions and 220 deletions

View File

@@ -523,41 +523,42 @@ def test_particle_analysis():
# I/O
# =========================================================================
def test_load_image():
print("=== Test: LoadImage ===")
from backend.nodes.io import LoadImage
def test_load_file():
print("=== Test: LoadFile ===")
from backend.nodes.io import LoadFile
from PIL import Image
node = LoadImage()
node = LoadFile()
with tempfile.TemporaryDirectory() as tmpdir:
# Test loading a grayscale PNG
# Test loading a grayscale PNG → single DataField output
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
img = Image.fromarray(arr, mode="L")
path = os.path.join(tmpdir, "test_gray.png")
img.save(path)
image, field = node.load(filename=path)
assert image.shape == (48, 64)
result = node.load(filename=path)
assert len(result) == 1
field = result[0]
assert field.data.shape == (48, 64)
assert field.data.dtype == np.float64
# Test loading an RGB PNG (should average to grayscale for field)
# Test loading an RGB PNG (should average to grayscale)
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
path_rgb = os.path.join(tmpdir, "test_rgb.png")
img_rgb.save(path_rgb)
image_rgb, field_rgb = node.load(filename=path_rgb)
assert image_rgb.shape == (32, 32, 3)
assert field_rgb.data.shape == (32, 32)
result_rgb = node.load(filename=path_rgb)
assert len(result_rgb) == 1
assert result_rgb[0].data.shape == (32, 32)
# Test loading a .npy file
data_npy = np.random.default_rng(3).standard_normal((50, 60))
path_npy = os.path.join(tmpdir, "test.npy")
np.save(path_npy, data_npy)
image_npy, field_npy = node.load(filename=path_npy)
assert np.allclose(field_npy.data, data_npy)
result_npy = node.load(filename=path_npy)
assert np.allclose(result_npy[0].data, data_npy)
print(" PASS\n")
@@ -641,6 +642,464 @@ def test_print_table():
print(" PASS\n")
# =========================================================================
# I/O — IBW multi-channel loading
# =========================================================================
def test_load_file_ibw():
print("=== Test: LoadFile IBW multi-channel ===")
from backend.nodes.io import LoadFile
node = LoadFile()
ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")
ibw_path = os.path.abspath(ibw_path)
if not os.path.exists(ibw_path):
print(" SKIP (demo IBW file not found)\n")
return
result = node.load(filename=ibw_path)
# BR_New20012.ibw has 4 channels
assert len(result) == 4, f"Expected 4 channels, got {len(result)}"
for i, field in enumerate(result):
assert isinstance(field, DataField), f"Channel {i} is not a DataField"
assert field.data.shape == (512, 1024), f"Channel {i} shape: {field.data.shape}"
assert field.data.dtype == np.float64
# Physical dimensions should be populated (not default 1e-6)
assert field.xreal > 1e-8, f"Channel {i} xreal too small: {field.xreal}"
assert field.yreal > 1e-8, f"Channel {i} yreal too small: {field.yreal}"
assert field.si_unit_xy == "m"
assert field.si_unit_z == "m"
# All channels should share the same physical dimensions
assert result[0].xreal == result[1].xreal
assert result[0].yreal == result[1].yreal
# Different channels should have different data
assert not np.array_equal(result[0].data, result[1].data)
print(" PASS\n")
def test_load_file_npz():
print("=== Test: LoadFile .npz ===")
from backend.nodes.io import LoadFile
node = LoadFile()
with tempfile.TemporaryDirectory() as tmpdir:
data = np.random.default_rng(99).standard_normal((30, 40))
path = os.path.join(tmpdir, "test.npz")
np.savez(path, my_array=data)
result = node.load(filename=path)
assert len(result) == 1
assert np.allclose(result[0].data, data)
print(" PASS\n")
def test_load_file_not_found():
print("=== Test: LoadFile not found ===")
from backend.nodes.io import LoadFile
node = LoadFile()
try:
node.load(filename="/nonexistent/path/file.png")
assert False, "Should have raised FileNotFoundError"
except FileNotFoundError:
pass
print(" PASS\n")
def test_load_file_unsupported():
print("=== Test: LoadFile unsupported format ===")
from backend.nodes.io import LoadFile
node = LoadFile()
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "test.xyz")
with open(path, "w") as f:
f.write("hello")
try:
node.load(filename=path)
assert False, "Should have raised an error for .xyz"
except Exception:
pass
print(" PASS\n")
def test_load_file_warning():
print("=== Test: LoadFile warning for uncalibrated data ===")
from backend.nodes.io import LoadFile
from PIL import Image
node = LoadFile()
warnings = []
LoadFile._broadcast_warning_fn = lambda nid, msg: warnings.append(msg)
LoadFile._current_node_id = "test"
with tempfile.TemporaryDirectory() as tmpdir:
arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8)
img = Image.fromarray(arr)
path = os.path.join(tmpdir, "test.png")
img.save(path)
result = node.load(filename=path)
assert len(result) == 1
assert len(warnings) == 1
assert "Uncalibrated" in warnings[0]
LoadFile._broadcast_warning_fn = None
print(" PASS\n")
# =========================================================================
# I/O — list_channels helper
# =========================================================================
def test_list_channels():
print("=== Test: list_channels ===")
from backend.nodes.io import list_channels
# Non-existent file → default
ch = list_channels("/nonexistent/file.ibw")
assert len(ch) == 1
assert ch[0]["name"] == "field"
# IBW with channels
ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw"))
if os.path.exists(ibw_path):
ch = list_channels(ibw_path)
assert len(ch) == 4
names = [c["name"] for c in ch]
assert "HeightRetrace" in names
assert "AmplitudeRetrace" in names
assert all(c["type"] == "DATA_FIELD" for c in ch)
# Plain image → single default channel
with tempfile.TemporaryDirectory() as tmpdir:
from PIL import Image
img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8))
path = os.path.join(tmpdir, "test.png")
img.save(path)
ch = list_channels(path)
assert len(ch) == 1
assert ch[0]["name"] == "field"
# .npy → single default channel
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "test.npy")
np.save(path, np.zeros((4, 4)))
ch = list_channels(path)
assert len(ch) == 1
print(" PASS\n")
# =========================================================================
# I/O — LoadDemo
# =========================================================================
def test_load_demo():
print("=== Test: LoadDemo ===")
from backend.nodes.io import LoadDemo
node = LoadDemo()
# Should be able to load a demo file by name
result = node.load(name="nanoparticles.npy")
assert len(result) >= 1
assert isinstance(result[0], DataField)
assert result[0].data.ndim == 2
# IBW demo should return multiple channels
result_ibw = node.load(name="whiskers.ibw")
assert len(result_ibw) == 4
for field in result_ibw:
assert isinstance(field, DataField)
# Non-existent demo should raise
try:
node.load(name="nonexistent_file.png")
assert False, "Should have raised FileNotFoundError"
except FileNotFoundError:
pass
print(" PASS\n")
# =========================================================================
# I/O — Coordinate
# =========================================================================
def test_coordinate():
print("=== Test: Coordinate ===")
from backend.nodes.io import Coordinate
node = Coordinate()
result = node.process(x=0.3, y=0.7)
assert len(result) == 1
assert result[0] == (0.3, 0.7)
# Edge values
result_zero = node.process(x=0.0, y=0.0)
assert result_zero[0] == (0.0, 0.0)
result_one = node.process(x=1.0, y=1.0)
assert result_one[0] == (1.0, 1.0)
print(" PASS\n")
# =========================================================================
# Analysis — LineCursors
# =========================================================================
def test_line_cursors():
print("=== Test: LineCursors ===")
from backend.nodes.analysis import LineCursors
node = LineCursors()
# Create a simple linear ramp
line = np.linspace(0, 10, 100).astype(np.float64)
# Capture overlay
overlays = []
LineCursors._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
LineCursors._current_node_id = "test"
table, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
# Should produce a 6-row table
assert len(table) == 6
quantities = {row["quantity"] for row in table}
assert "A position" in quantities
assert "B position" in quantities
assert "delta X" in quantities
assert "delta Y" in quantities
# B should be at a later position than A
a_pos = next(r["value"] for r in table if r["quantity"] == "A position")
b_pos = next(r["value"] for r in table if r["quantity"] == "B position")
assert b_pos > a_pos
# Delta Y should reflect the height difference along the ramp
dy = next(r["value"] for r in table if r["quantity"] == "delta Y")
assert dy > 0 # ramp goes upward
# Overlay should have been broadcast
assert len(overlays) == 1
assert "image" in overlays[0]
assert overlays[0]["image"].startswith("data:image/png;base64,")
# With x_axis provided
x_axis = np.linspace(0, 1, 100).astype(np.float64)
table2, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5, x_axis=x_axis)
assert len(table2) == 6
LineCursors._broadcast_overlay_fn = None
print(" PASS\n")
# =========================================================================
# Analysis — FFT2D
# =========================================================================
def test_fft2d():
print("=== Test: FFT2D ===")
from backend.nodes.analysis import FFT2D
node = FFT2D()
# Pure single-frequency signal: peak should appear at the right location
N = 64
y, x = np.mgrid[0:N, 0:N] / N
freq = 5
data = np.sin(2 * np.pi * freq * x)
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
# log_magnitude
spectrum, = node.process(field, windowing="none", level="none", output="log_magnitude")
assert spectrum.data.shape == (N, N)
assert spectrum.domain == "frequency"
assert spectrum.si_unit_xy == "1/m"
# Peak should be symmetric about centre
centre = N // 2
row = spectrum.data[centre, :]
peak_idx = np.argmax(row[centre + 1:]) + centre + 1
assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}"
# magnitude output
spec_mag, = node.process(field, windowing="hann", level="mean", output="magnitude")
assert spec_mag.data.shape == (N, N)
assert np.all(spec_mag.data >= 0)
# phase output
spec_phase, = node.process(field, windowing="none", level="none", output="phase")
assert spec_phase.data.shape == (N, N)
assert spec_phase.data.min() >= -np.pi - 0.01
assert spec_phase.data.max() <= np.pi + 0.01
# psdf output — units should reflect PSDF calibration
spec_psdf, = node.process(field, windowing="hamming", level="plane", output="psdf")
assert spec_psdf.data.shape == (N, N)
assert np.all(spec_psdf.data >= 0)
assert "^2" in spec_psdf.si_unit_z
# Constant field should have all energy at DC
const_field = make_field(data=np.ones((32, 32)) * 3.0)
spec_const, = node.process(const_field, windowing="none", level="none", output="magnitude")
centre32 = 16
dc_val = spec_const.data[centre32, centre32]
assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field"
# Blackman windowing should also work without error
spec_bk, = node.process(field, windowing="blackman", level="none", output="log_magnitude")
assert spec_bk.data.shape == (N, N)
print(" PASS\n")
# =========================================================================
# Analysis — LineMath
# =========================================================================
def test_line_math():
print("=== Test: LineMath ===")
from backend.nodes.analysis import LineMath
node = LineMath()
line = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
# Basic stats
table, = node.process(line, operation="min")
assert table[0]["value"] == 1.0
table, = node.process(line, operation="max")
assert table[0]["value"] == 5.0
table, = node.process(line, operation="mean")
assert table[0]["value"] == 3.0
table, = node.process(line, operation="median")
assert table[0]["value"] == 3.0
table, = node.process(line, operation="sum")
assert table[0]["value"] == 15.0
table, = node.process(line, operation="range")
assert table[0]["value"] == 4.0
table, = node.process(line, operation="length")
assert table[0]["value"] == 5.0
# RMS of [1,2,3,4,5]
table, = node.process(line, operation="rms")
expected_rms = np.sqrt(np.mean(line ** 2))
assert abs(table[0]["value"] - expected_rms) < 1e-10
# Roughness parameters
table, = node.process(line, operation="Ra")
d = line - line.mean()
expected_ra = float(np.mean(np.abs(d)))
assert abs(table[0]["value"] - expected_ra) < 1e-10
table, = node.process(line, operation="Rq")
expected_rq = float(np.sqrt(np.mean(d ** 2)))
assert abs(table[0]["value"] - expected_rq) < 1e-10
# Rp = max of (z - mean)
table, = node.process(line, operation="Rp")
assert abs(table[0]["value"] - d.max()) < 1e-10
# Rv = -(min of (z - mean))
table, = node.process(line, operation="Rv")
assert abs(table[0]["value"] - (-d.min())) < 1e-10
# Rt = Rp + Rv = range of (z - mean)
table, = node.process(line, operation="Rt")
assert abs(table[0]["value"] - (d.max() - d.min())) < 1e-10
# Constant line: roughness parameters should all be zero
const_line = np.ones(10) * 7.0
table, = node.process(const_line, operation="Ra")
assert table[0]["value"] == 0.0
table, = node.process(const_line, operation="Rq")
assert table[0]["value"] == 0.0
table, = node.process(const_line, operation="Rsk")
assert table[0]["value"] == 0.0
table, = node.process(const_line, operation="Rku")
assert table[0]["value"] == 0.0
# Slope-based: Dq and Da
table, = node.process(line, operation="Dq")
dz = np.diff(line)
expected_dq = float(np.sqrt(np.mean(dz * dz)))
assert abs(table[0]["value"] - expected_dq) < 1e-10
table, = node.process(line, operation="Da")
expected_da = float(np.mean(np.abs(dz)))
assert abs(table[0]["value"] - expected_da) < 1e-10
print(" PASS\n")
# =========================================================================
# Display — View3D
# =========================================================================
def test_view3d():
print("=== Test: View3D ===")
from backend.nodes.display import View3D
node = View3D()
field = make_field()
captured = []
View3D._broadcast_mesh_fn = lambda nid, mesh: captured.append(mesh)
View3D._current_node_id = "test"
result = node.render(field, colormap="viridis", z_scale=2.0, resolution=64)
assert result == ()
assert len(captured) == 1
mesh = captured[0]
assert "width" in mesh
assert "height" in mesh
assert "z_data" in mesh
assert "colors" in mesh
assert mesh["z_scale"] == 2.0
assert mesh["width"] <= 64
assert mesh["height"] <= 64
# z_min < z_max for non-constant data
assert mesh["z_min"] < mesh["z_max"]
# Verify base64 data can be decoded
import base64
z_bytes = base64.b64decode(mesh["z_data"])
assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 # float32
colors_bytes = base64.b64decode(mesh["colors"])
assert len(colors_bytes) == mesh["width"] * mesh["height"] * 3 # uint8 RGB
# High-res input should be downsampled
big_field = make_field(shape=(256, 256))
captured.clear()
node.render(big_field, colormap="hot", z_scale=1.0, resolution=64)
assert captured[0]["width"] <= 64
assert captured[0]["height"] <= 64
View3D._broadcast_mesh_fn = None
print(" PASS\n")
# =========================================================================
# Run all tests
# =========================================================================
@@ -662,6 +1121,9 @@ if __name__ == "__main__":
test_statistics()
test_height_histogram()
test_cross_section()
test_line_cursors()
test_fft2d()
test_line_math()
# Mask
test_threshold_mask()
@@ -673,11 +1135,20 @@ if __name__ == "__main__":
test_particle_analysis()
# I/O
test_load_image()
test_load_file()
test_load_file_ibw()
test_load_file_npz()
test_load_file_not_found()
test_load_file_unsupported()
test_load_file_warning()
test_list_channels()
test_load_demo()
test_coordinate()
test_save_image()
# Display
test_preview_image()
test_print_table()
test_view3d()
print("All tests passed!")