multichannel support + colormap inherit
This commit is contained in:
@@ -523,41 +523,42 @@ def test_particle_analysis():
|
||||
# I/O
|
||||
# =========================================================================
|
||||
|
||||
def test_load_image():
|
||||
print("=== Test: LoadImage ===")
|
||||
from backend.nodes.io import LoadImage
|
||||
def test_load_file():
|
||||
print("=== Test: LoadFile ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
from PIL import Image
|
||||
node = LoadImage()
|
||||
node = LoadFile()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Test loading a grayscale PNG
|
||||
# Test loading a grayscale PNG → single DataField output
|
||||
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
|
||||
img = Image.fromarray(arr, mode="L")
|
||||
path = os.path.join(tmpdir, "test_gray.png")
|
||||
img.save(path)
|
||||
|
||||
image, field = node.load(filename=path)
|
||||
assert image.shape == (48, 64)
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
field = result[0]
|
||||
assert field.data.shape == (48, 64)
|
||||
assert field.data.dtype == np.float64
|
||||
|
||||
# Test loading an RGB PNG (should average to grayscale for field)
|
||||
# Test loading an RGB PNG (should average to grayscale)
|
||||
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
|
||||
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
|
||||
path_rgb = os.path.join(tmpdir, "test_rgb.png")
|
||||
img_rgb.save(path_rgb)
|
||||
|
||||
image_rgb, field_rgb = node.load(filename=path_rgb)
|
||||
assert image_rgb.shape == (32, 32, 3)
|
||||
assert field_rgb.data.shape == (32, 32)
|
||||
result_rgb = node.load(filename=path_rgb)
|
||||
assert len(result_rgb) == 1
|
||||
assert result_rgb[0].data.shape == (32, 32)
|
||||
|
||||
# Test loading a .npy file
|
||||
data_npy = np.random.default_rng(3).standard_normal((50, 60))
|
||||
path_npy = os.path.join(tmpdir, "test.npy")
|
||||
np.save(path_npy, data_npy)
|
||||
|
||||
image_npy, field_npy = node.load(filename=path_npy)
|
||||
assert np.allclose(field_npy.data, data_npy)
|
||||
result_npy = node.load(filename=path_npy)
|
||||
assert np.allclose(result_npy[0].data, data_npy)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
@@ -641,6 +642,464 @@ def test_print_table():
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — IBW multi-channel loading
|
||||
# =========================================================================
|
||||
|
||||
def test_load_file_ibw():
|
||||
print("=== Test: LoadFile IBW multi-channel ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
|
||||
node = LoadFile()
|
||||
ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")
|
||||
ibw_path = os.path.abspath(ibw_path)
|
||||
if not os.path.exists(ibw_path):
|
||||
print(" SKIP (demo IBW file not found)\n")
|
||||
return
|
||||
|
||||
result = node.load(filename=ibw_path)
|
||||
|
||||
# BR_New20012.ibw has 4 channels
|
||||
assert len(result) == 4, f"Expected 4 channels, got {len(result)}"
|
||||
|
||||
for i, field in enumerate(result):
|
||||
assert isinstance(field, DataField), f"Channel {i} is not a DataField"
|
||||
assert field.data.shape == (512, 1024), f"Channel {i} shape: {field.data.shape}"
|
||||
assert field.data.dtype == np.float64
|
||||
# Physical dimensions should be populated (not default 1e-6)
|
||||
assert field.xreal > 1e-8, f"Channel {i} xreal too small: {field.xreal}"
|
||||
assert field.yreal > 1e-8, f"Channel {i} yreal too small: {field.yreal}"
|
||||
assert field.si_unit_xy == "m"
|
||||
assert field.si_unit_z == "m"
|
||||
|
||||
# All channels should share the same physical dimensions
|
||||
assert result[0].xreal == result[1].xreal
|
||||
assert result[0].yreal == result[1].yreal
|
||||
|
||||
# Different channels should have different data
|
||||
assert not np.array_equal(result[0].data, result[1].data)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_load_file_npz():
|
||||
print("=== Test: LoadFile .npz ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
|
||||
node = LoadFile()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
data = np.random.default_rng(99).standard_normal((30, 40))
|
||||
path = os.path.join(tmpdir, "test.npz")
|
||||
np.savez(path, my_array=data)
|
||||
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
assert np.allclose(result[0].data, data)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_load_file_not_found():
|
||||
print("=== Test: LoadFile not found ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
|
||||
node = LoadFile()
|
||||
try:
|
||||
node.load(filename="/nonexistent/path/file.png")
|
||||
assert False, "Should have raised FileNotFoundError"
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_load_file_unsupported():
|
||||
print("=== Test: LoadFile unsupported format ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
|
||||
node = LoadFile()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "test.xyz")
|
||||
with open(path, "w") as f:
|
||||
f.write("hello")
|
||||
try:
|
||||
node.load(filename=path)
|
||||
assert False, "Should have raised an error for .xyz"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
def test_load_file_warning():
|
||||
print("=== Test: LoadFile warning for uncalibrated data ===")
|
||||
from backend.nodes.io import LoadFile
|
||||
from PIL import Image
|
||||
|
||||
node = LoadFile()
|
||||
warnings = []
|
||||
LoadFile._broadcast_warning_fn = lambda nid, msg: warnings.append(msg)
|
||||
LoadFile._current_node_id = "test"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8)
|
||||
img = Image.fromarray(arr)
|
||||
path = os.path.join(tmpdir, "test.png")
|
||||
img.save(path)
|
||||
|
||||
result = node.load(filename=path)
|
||||
assert len(result) == 1
|
||||
assert len(warnings) == 1
|
||||
assert "Uncalibrated" in warnings[0]
|
||||
|
||||
LoadFile._broadcast_warning_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — list_channels helper
|
||||
# =========================================================================
|
||||
|
||||
def test_list_channels():
|
||||
print("=== Test: list_channels ===")
|
||||
from backend.nodes.io import list_channels
|
||||
|
||||
# Non-existent file → default
|
||||
ch = list_channels("/nonexistent/file.ibw")
|
||||
assert len(ch) == 1
|
||||
assert ch[0]["name"] == "field"
|
||||
|
||||
# IBW with channels
|
||||
ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw"))
|
||||
if os.path.exists(ibw_path):
|
||||
ch = list_channels(ibw_path)
|
||||
assert len(ch) == 4
|
||||
names = [c["name"] for c in ch]
|
||||
assert "HeightRetrace" in names
|
||||
assert "AmplitudeRetrace" in names
|
||||
assert all(c["type"] == "DATA_FIELD" for c in ch)
|
||||
|
||||
# Plain image → single default channel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
from PIL import Image
|
||||
img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8))
|
||||
path = os.path.join(tmpdir, "test.png")
|
||||
img.save(path)
|
||||
|
||||
ch = list_channels(path)
|
||||
assert len(ch) == 1
|
||||
assert ch[0]["name"] == "field"
|
||||
|
||||
# .npy → single default channel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "test.npy")
|
||||
np.save(path, np.zeros((4, 4)))
|
||||
|
||||
ch = list_channels(path)
|
||||
assert len(ch) == 1
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — LoadDemo
|
||||
# =========================================================================
|
||||
|
||||
def test_load_demo():
|
||||
print("=== Test: LoadDemo ===")
|
||||
from backend.nodes.io import LoadDemo
|
||||
|
||||
node = LoadDemo()
|
||||
|
||||
# Should be able to load a demo file by name
|
||||
result = node.load(name="nanoparticles.npy")
|
||||
assert len(result) >= 1
|
||||
assert isinstance(result[0], DataField)
|
||||
assert result[0].data.ndim == 2
|
||||
|
||||
# IBW demo should return multiple channels
|
||||
result_ibw = node.load(name="whiskers.ibw")
|
||||
assert len(result_ibw) == 4
|
||||
for field in result_ibw:
|
||||
assert isinstance(field, DataField)
|
||||
|
||||
# Non-existent demo should raise
|
||||
try:
|
||||
node.load(name="nonexistent_file.png")
|
||||
assert False, "Should have raised FileNotFoundError"
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# I/O — Coordinate
|
||||
# =========================================================================
|
||||
|
||||
def test_coordinate():
|
||||
print("=== Test: Coordinate ===")
|
||||
from backend.nodes.io import Coordinate
|
||||
|
||||
node = Coordinate()
|
||||
|
||||
result = node.process(x=0.3, y=0.7)
|
||||
assert len(result) == 1
|
||||
assert result[0] == (0.3, 0.7)
|
||||
|
||||
# Edge values
|
||||
result_zero = node.process(x=0.0, y=0.0)
|
||||
assert result_zero[0] == (0.0, 0.0)
|
||||
|
||||
result_one = node.process(x=1.0, y=1.0)
|
||||
assert result_one[0] == (1.0, 1.0)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis — LineCursors
|
||||
# =========================================================================
|
||||
|
||||
def test_line_cursors():
|
||||
print("=== Test: LineCursors ===")
|
||||
from backend.nodes.analysis import LineCursors
|
||||
|
||||
node = LineCursors()
|
||||
|
||||
# Create a simple linear ramp
|
||||
line = np.linspace(0, 10, 100).astype(np.float64)
|
||||
|
||||
# Capture overlay
|
||||
overlays = []
|
||||
LineCursors._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
|
||||
LineCursors._current_node_id = "test"
|
||||
|
||||
table, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
|
||||
|
||||
# Should produce a 6-row table
|
||||
assert len(table) == 6
|
||||
quantities = {row["quantity"] for row in table}
|
||||
assert "A position" in quantities
|
||||
assert "B position" in quantities
|
||||
assert "delta X" in quantities
|
||||
assert "delta Y" in quantities
|
||||
|
||||
# B should be at a later position than A
|
||||
a_pos = next(r["value"] for r in table if r["quantity"] == "A position")
|
||||
b_pos = next(r["value"] for r in table if r["quantity"] == "B position")
|
||||
assert b_pos > a_pos
|
||||
|
||||
# Delta Y should reflect the height difference along the ramp
|
||||
dy = next(r["value"] for r in table if r["quantity"] == "delta Y")
|
||||
assert dy > 0 # ramp goes upward
|
||||
|
||||
# Overlay should have been broadcast
|
||||
assert len(overlays) == 1
|
||||
assert "image" in overlays[0]
|
||||
assert overlays[0]["image"].startswith("data:image/png;base64,")
|
||||
|
||||
# With x_axis provided
|
||||
x_axis = np.linspace(0, 1, 100).astype(np.float64)
|
||||
table2, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5, x_axis=x_axis)
|
||||
assert len(table2) == 6
|
||||
|
||||
LineCursors._broadcast_overlay_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis — FFT2D
|
||||
# =========================================================================
|
||||
|
||||
def test_fft2d():
|
||||
print("=== Test: FFT2D ===")
|
||||
from backend.nodes.analysis import FFT2D
|
||||
|
||||
node = FFT2D()
|
||||
|
||||
# Pure single-frequency signal: peak should appear at the right location
|
||||
N = 64
|
||||
y, x = np.mgrid[0:N, 0:N] / N
|
||||
freq = 5
|
||||
data = np.sin(2 * np.pi * freq * x)
|
||||
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
||||
|
||||
# log_magnitude
|
||||
spectrum, = node.process(field, windowing="none", level="none", output="log_magnitude")
|
||||
assert spectrum.data.shape == (N, N)
|
||||
assert spectrum.domain == "frequency"
|
||||
assert spectrum.si_unit_xy == "1/m"
|
||||
# Peak should be symmetric about centre
|
||||
centre = N // 2
|
||||
row = spectrum.data[centre, :]
|
||||
peak_idx = np.argmax(row[centre + 1:]) + centre + 1
|
||||
assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}"
|
||||
|
||||
# magnitude output
|
||||
spec_mag, = node.process(field, windowing="hann", level="mean", output="magnitude")
|
||||
assert spec_mag.data.shape == (N, N)
|
||||
assert np.all(spec_mag.data >= 0)
|
||||
|
||||
# phase output
|
||||
spec_phase, = node.process(field, windowing="none", level="none", output="phase")
|
||||
assert spec_phase.data.shape == (N, N)
|
||||
assert spec_phase.data.min() >= -np.pi - 0.01
|
||||
assert spec_phase.data.max() <= np.pi + 0.01
|
||||
|
||||
# psdf output — units should reflect PSDF calibration
|
||||
spec_psdf, = node.process(field, windowing="hamming", level="plane", output="psdf")
|
||||
assert spec_psdf.data.shape == (N, N)
|
||||
assert np.all(spec_psdf.data >= 0)
|
||||
assert "^2" in spec_psdf.si_unit_z
|
||||
|
||||
# Constant field should have all energy at DC
|
||||
const_field = make_field(data=np.ones((32, 32)) * 3.0)
|
||||
spec_const, = node.process(const_field, windowing="none", level="none", output="magnitude")
|
||||
centre32 = 16
|
||||
dc_val = spec_const.data[centre32, centre32]
|
||||
assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field"
|
||||
|
||||
# Blackman windowing should also work without error
|
||||
spec_bk, = node.process(field, windowing="blackman", level="none", output="log_magnitude")
|
||||
assert spec_bk.data.shape == (N, N)
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Analysis — LineMath
|
||||
# =========================================================================
|
||||
|
||||
def test_line_math():
|
||||
print("=== Test: LineMath ===")
|
||||
from backend.nodes.analysis import LineMath
|
||||
|
||||
node = LineMath()
|
||||
line = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
# Basic stats
|
||||
table, = node.process(line, operation="min")
|
||||
assert table[0]["value"] == 1.0
|
||||
|
||||
table, = node.process(line, operation="max")
|
||||
assert table[0]["value"] == 5.0
|
||||
|
||||
table, = node.process(line, operation="mean")
|
||||
assert table[0]["value"] == 3.0
|
||||
|
||||
table, = node.process(line, operation="median")
|
||||
assert table[0]["value"] == 3.0
|
||||
|
||||
table, = node.process(line, operation="sum")
|
||||
assert table[0]["value"] == 15.0
|
||||
|
||||
table, = node.process(line, operation="range")
|
||||
assert table[0]["value"] == 4.0
|
||||
|
||||
table, = node.process(line, operation="length")
|
||||
assert table[0]["value"] == 5.0
|
||||
|
||||
# RMS of [1,2,3,4,5]
|
||||
table, = node.process(line, operation="rms")
|
||||
expected_rms = np.sqrt(np.mean(line ** 2))
|
||||
assert abs(table[0]["value"] - expected_rms) < 1e-10
|
||||
|
||||
# Roughness parameters
|
||||
table, = node.process(line, operation="Ra")
|
||||
d = line - line.mean()
|
||||
expected_ra = float(np.mean(np.abs(d)))
|
||||
assert abs(table[0]["value"] - expected_ra) < 1e-10
|
||||
|
||||
table, = node.process(line, operation="Rq")
|
||||
expected_rq = float(np.sqrt(np.mean(d ** 2)))
|
||||
assert abs(table[0]["value"] - expected_rq) < 1e-10
|
||||
|
||||
# Rp = max of (z - mean)
|
||||
table, = node.process(line, operation="Rp")
|
||||
assert abs(table[0]["value"] - d.max()) < 1e-10
|
||||
|
||||
# Rv = -(min of (z - mean))
|
||||
table, = node.process(line, operation="Rv")
|
||||
assert abs(table[0]["value"] - (-d.min())) < 1e-10
|
||||
|
||||
# Rt = Rp + Rv = range of (z - mean)
|
||||
table, = node.process(line, operation="Rt")
|
||||
assert abs(table[0]["value"] - (d.max() - d.min())) < 1e-10
|
||||
|
||||
# Constant line: roughness parameters should all be zero
|
||||
const_line = np.ones(10) * 7.0
|
||||
table, = node.process(const_line, operation="Ra")
|
||||
assert table[0]["value"] == 0.0
|
||||
table, = node.process(const_line, operation="Rq")
|
||||
assert table[0]["value"] == 0.0
|
||||
table, = node.process(const_line, operation="Rsk")
|
||||
assert table[0]["value"] == 0.0
|
||||
table, = node.process(const_line, operation="Rku")
|
||||
assert table[0]["value"] == 0.0
|
||||
|
||||
# Slope-based: Dq and Da
|
||||
table, = node.process(line, operation="Dq")
|
||||
dz = np.diff(line)
|
||||
expected_dq = float(np.sqrt(np.mean(dz * dz)))
|
||||
assert abs(table[0]["value"] - expected_dq) < 1e-10
|
||||
|
||||
table, = node.process(line, operation="Da")
|
||||
expected_da = float(np.mean(np.abs(dz)))
|
||||
assert abs(table[0]["value"] - expected_da) < 1e-10
|
||||
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Display — View3D
|
||||
# =========================================================================
|
||||
|
||||
def test_view3d():
|
||||
print("=== Test: View3D ===")
|
||||
from backend.nodes.display import View3D
|
||||
|
||||
node = View3D()
|
||||
field = make_field()
|
||||
|
||||
captured = []
|
||||
View3D._broadcast_mesh_fn = lambda nid, mesh: captured.append(mesh)
|
||||
View3D._current_node_id = "test"
|
||||
|
||||
result = node.render(field, colormap="viridis", z_scale=2.0, resolution=64)
|
||||
assert result == ()
|
||||
assert len(captured) == 1
|
||||
|
||||
mesh = captured[0]
|
||||
assert "width" in mesh
|
||||
assert "height" in mesh
|
||||
assert "z_data" in mesh
|
||||
assert "colors" in mesh
|
||||
assert mesh["z_scale"] == 2.0
|
||||
assert mesh["width"] <= 64
|
||||
assert mesh["height"] <= 64
|
||||
# z_min < z_max for non-constant data
|
||||
assert mesh["z_min"] < mesh["z_max"]
|
||||
|
||||
# Verify base64 data can be decoded
|
||||
import base64
|
||||
z_bytes = base64.b64decode(mesh["z_data"])
|
||||
assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 # float32
|
||||
|
||||
colors_bytes = base64.b64decode(mesh["colors"])
|
||||
assert len(colors_bytes) == mesh["width"] * mesh["height"] * 3 # uint8 RGB
|
||||
|
||||
# High-res input should be downsampled
|
||||
big_field = make_field(shape=(256, 256))
|
||||
captured.clear()
|
||||
node.render(big_field, colormap="hot", z_scale=1.0, resolution=64)
|
||||
assert captured[0]["width"] <= 64
|
||||
assert captured[0]["height"] <= 64
|
||||
|
||||
View3D._broadcast_mesh_fn = None
|
||||
print(" PASS\n")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Run all tests
|
||||
# =========================================================================
|
||||
@@ -662,6 +1121,9 @@ if __name__ == "__main__":
|
||||
test_statistics()
|
||||
test_height_histogram()
|
||||
test_cross_section()
|
||||
test_line_cursors()
|
||||
test_fft2d()
|
||||
test_line_math()
|
||||
|
||||
# Mask
|
||||
test_threshold_mask()
|
||||
@@ -673,11 +1135,20 @@ if __name__ == "__main__":
|
||||
test_particle_analysis()
|
||||
|
||||
# I/O
|
||||
test_load_image()
|
||||
test_load_file()
|
||||
test_load_file_ibw()
|
||||
test_load_file_npz()
|
||||
test_load_file_not_found()
|
||||
test_load_file_unsupported()
|
||||
test_load_file_warning()
|
||||
test_list_channels()
|
||||
test_load_demo()
|
||||
test_coordinate()
|
||||
test_save_image()
|
||||
|
||||
# Display
|
||||
test_preview_image()
|
||||
test_print_table()
|
||||
test_view3d()
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
Reference in New Issue
Block a user