1639 lines
53 KiB
Python
1639 lines
53 KiB
Python
"""
|
|
Tests for all argonode backend nodes (excluding FFT2D which has its own test file).
|
|
|
|
Run from project root:
|
|
.venv/bin/python -m tests.test_nodes
|
|
"""
|
|
import json
|
|
import sys
|
|
import os
|
|
import tempfile
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, ".")
|
|
from backend.data_types import DataField, MeasureTable, RecordTable, datafield_to_uint8
|
|
|
|
|
|
def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
|
|
"""Create a DataField, optionally from given data or a random field."""
|
|
if data is None:
|
|
data = np.random.default_rng(42).standard_normal(shape)
|
|
return DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m")
|
|
|
|
|
|
# =========================================================================
|
|
# Filters
|
|
# =========================================================================
|
|
|
|
def test_gaussian_filter():
|
|
print("=== Test: GaussianFilter ===")
|
|
from backend.nodes.filters import GaussianFilter
|
|
node = GaussianFilter()
|
|
field = make_field()
|
|
|
|
result, = node.process(field, sigma=2.0)
|
|
assert result.data.shape == field.data.shape
|
|
assert result.xreal == field.xreal
|
|
assert result.si_unit_z == field.si_unit_z
|
|
# Gaussian blur should reduce variance
|
|
assert result.data.std() < field.data.std()
|
|
# With very small sigma, output should be nearly unchanged
|
|
result_tiny, = node.process(field, sigma=0.01)
|
|
assert np.allclose(result_tiny.data, field.data, atol=1e-6)
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_median_filter():
|
|
print("=== Test: MedianFilter ===")
|
|
from backend.nodes.filters import MedianFilter
|
|
node = MedianFilter()
|
|
|
|
# Median filter should remove salt-and-pepper noise
|
|
data = np.zeros((64, 64))
|
|
rng = np.random.default_rng(7)
|
|
noise_idx = rng.choice(64 * 64, size=100, replace=False)
|
|
data.ravel()[noise_idx] = 1.0
|
|
field = make_field(data=data)
|
|
|
|
result, = node.process(field, size=3)
|
|
assert result.data.shape == field.data.shape
|
|
# Should remove most impulse noise
|
|
assert result.data.sum() < field.data.sum()
|
|
# Size=1 should be identity
|
|
result_1, = node.process(field, size=1)
|
|
assert np.array_equal(result_1.data, field.data)
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_crop_resize_field():
|
|
print("=== Test: CropResizeField ===")
|
|
from backend.nodes.modify import CropResizeField
|
|
node = CropResizeField()
|
|
|
|
data = np.arange(32, dtype=np.float64).reshape(4, 8)
|
|
field = DataField(
|
|
data=data,
|
|
xreal=8.0,
|
|
yreal=4.0,
|
|
xoff=10.0,
|
|
yoff=20.0,
|
|
si_unit_xy="nm",
|
|
si_unit_z="nm",
|
|
)
|
|
|
|
overlays = []
|
|
CropResizeField._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
|
|
CropResizeField._current_node_id = "test"
|
|
|
|
cropped, = node.process(
|
|
field,
|
|
x1=0.25,
|
|
y1=0.25,
|
|
x2=0.75,
|
|
y2=1.0,
|
|
target_width=0,
|
|
target_height=0,
|
|
interpolation="bilinear",
|
|
)
|
|
assert cropped.data.shape == (3, 4)
|
|
assert np.array_equal(cropped.data, data[1:4, 2:6])
|
|
assert cropped.xreal == 4.0
|
|
assert cropped.yreal == 3.0
|
|
assert cropped.xoff == 12.0
|
|
assert cropped.yoff == 21.0
|
|
assert cropped.si_unit_xy == field.si_unit_xy
|
|
assert cropped.si_unit_z == field.si_unit_z
|
|
assert len(overlays) == 1
|
|
assert overlays[0]["kind"] == "crop_box"
|
|
assert overlays[0]["image"].startswith("data:image/png;base64,")
|
|
assert overlays[0]["a_locked"] is False
|
|
assert overlays[0]["b_locked"] is False
|
|
|
|
resized, = node.process(
|
|
field,
|
|
x1=0.0,
|
|
y1=0.0,
|
|
x2=1.0,
|
|
y2=1.0,
|
|
target_width=8,
|
|
target_height=0,
|
|
interpolation="bilinear",
|
|
corner_a=(0.25, 0.25),
|
|
corner_b=(0.75, 1.0),
|
|
)
|
|
assert resized.data.shape == (6, 8)
|
|
assert resized.xreal == cropped.xreal
|
|
assert resized.yreal == cropped.yreal
|
|
assert resized.xoff == cropped.xoff
|
|
assert resized.yoff == cropped.yoff
|
|
assert resized.domain == field.domain
|
|
assert overlays[-1]["a_locked"] is True
|
|
assert overlays[-1]["b_locked"] is True
|
|
|
|
reversed_crop, = node.process(
|
|
field,
|
|
x1=0.75,
|
|
y1=1.0,
|
|
x2=0.25,
|
|
y2=0.25,
|
|
target_width=0,
|
|
target_height=0,
|
|
interpolation="nearest",
|
|
)
|
|
assert np.array_equal(reversed_crop.data, cropped.data)
|
|
|
|
try:
|
|
node.process(
|
|
field,
|
|
x1=0.9,
|
|
y1=0.0,
|
|
x2=0.9,
|
|
y2=1.0,
|
|
target_width=0,
|
|
target_height=0,
|
|
interpolation="nearest",
|
|
)
|
|
raise AssertionError("Expected invalid crop bounds to raise ValueError")
|
|
except ValueError:
|
|
pass
|
|
|
|
CropResizeField._broadcast_overlay_fn = None
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_rotate_field():
|
|
print("=== Test: RotateField ===")
|
|
from backend.nodes.modify import RotateField
|
|
node = RotateField()
|
|
|
|
data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
|
|
field = DataField(
|
|
data=data,
|
|
xreal=6.0,
|
|
yreal=4.0,
|
|
xoff=10.0,
|
|
yoff=20.0,
|
|
si_unit_xy="nm",
|
|
si_unit_z="nm",
|
|
)
|
|
|
|
rotated_90, = node.process(
|
|
field,
|
|
angle=90.0,
|
|
interpolation="nearest",
|
|
expand_canvas=True,
|
|
)
|
|
assert np.array_equal(rotated_90.data, np.rot90(data))
|
|
assert rotated_90.data.shape == (3, 2)
|
|
assert rotated_90.xreal == 4.0
|
|
assert rotated_90.yreal == 6.0
|
|
assert rotated_90.xoff == 11.0
|
|
assert rotated_90.yoff == 19.0
|
|
assert rotated_90.si_unit_xy == field.si_unit_xy
|
|
assert rotated_90.si_unit_z == field.si_unit_z
|
|
|
|
rotated_180, = node.process(
|
|
field,
|
|
angle=180.0,
|
|
interpolation="nearest",
|
|
expand_canvas=False,
|
|
)
|
|
assert np.array_equal(rotated_180.data, np.rot90(data, 2))
|
|
assert rotated_180.data.shape == data.shape
|
|
assert rotated_180.xreal == field.xreal
|
|
assert rotated_180.yreal == field.yreal
|
|
assert rotated_180.xoff == field.xoff
|
|
assert rotated_180.yoff == field.yoff
|
|
|
|
rotated_45, = node.process(
|
|
field,
|
|
angle=45.0,
|
|
interpolation="bilinear",
|
|
expand_canvas=True,
|
|
)
|
|
expected_xreal = abs(field.xreal * np.cos(np.deg2rad(45.0))) + abs(field.yreal * np.sin(np.deg2rad(45.0)))
|
|
expected_yreal = abs(field.xreal * np.sin(np.deg2rad(45.0))) + abs(field.yreal * np.cos(np.deg2rad(45.0)))
|
|
assert rotated_45.data.shape[0] > field.data.shape[0]
|
|
assert rotated_45.data.shape[1] > field.data.shape[1]
|
|
assert np.isclose(rotated_45.xreal, expected_xreal)
|
|
assert np.isclose(rotated_45.yreal, expected_yreal)
|
|
assert np.isclose(rotated_45.xoff + rotated_45.xreal / 2.0, field.xoff + field.xreal / 2.0)
|
|
assert np.isclose(rotated_45.yoff + rotated_45.yreal / 2.0, field.yoff + field.yreal / 2.0)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_colormap_adjust():
|
|
print("=== Test: ColormapAdjust ===")
|
|
from backend.nodes.modify import ColormapAdjust
|
|
|
|
node = ColormapAdjust()
|
|
field = DataField(
|
|
data=np.array([[0.0, 0.25, 0.5, 0.75, 1.0]], dtype=np.float64),
|
|
xreal=5.0,
|
|
yreal=1.0,
|
|
colormap="gray",
|
|
)
|
|
|
|
adjusted, = node.process(field, offset=0.25, scale=0.5)
|
|
assert np.array_equal(adjusted.data, field.data)
|
|
assert adjusted.display_offset == 0.25
|
|
assert adjusted.display_scale == 0.5
|
|
assert adjusted.colormap == field.colormap
|
|
|
|
rgb = datafield_to_uint8(adjusted, "gray")
|
|
intensities = rgb[0, :, 0]
|
|
assert intensities[0] == 0
|
|
assert intensities[1] == 0
|
|
assert 110 <= intensities[2] <= 145
|
|
assert intensities[3] == 255
|
|
assert intensities[4] == 255
|
|
|
|
auto_like, = node.process(field, offset=0.0, scale=1.0)
|
|
auto_rgb = datafield_to_uint8(auto_like, "gray")
|
|
auto_intensities = auto_rgb[0, :, 0]
|
|
assert auto_intensities[0] == 0
|
|
assert auto_intensities[-1] == 255
|
|
|
|
try:
|
|
node.process(field, offset=0.0, scale=0.0)
|
|
raise AssertionError("Expected non-positive scale to raise ValueError")
|
|
except ValueError:
|
|
pass
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_edge_detect():
|
|
print("=== Test: EdgeDetect ===")
|
|
from backend.nodes.filters import EdgeDetect
|
|
node = EdgeDetect()
|
|
|
|
# Create an image with a sharp vertical edge
|
|
data = np.zeros((64, 64))
|
|
data[:, 32:] = 1.0
|
|
field = make_field(data=data)
|
|
|
|
for method in ["sobel", "prewitt", "laplacian", "log"]:
|
|
result, = node.process(field, method=method, sigma=1.0)
|
|
assert result.data.shape == field.data.shape
|
|
# Edge response should be strongest near column 32
|
|
col_energy = np.abs(result.data).sum(axis=0)
|
|
peak_col = np.argmax(col_energy)
|
|
assert abs(peak_col - 32) <= 2, f"{method}: peak at col {peak_col}, expected ~32"
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_fft_filter_1d():
|
|
print("=== Test: FFTFilter1D ===")
|
|
from backend.nodes.filters import FFTFilter1D
|
|
node = FFTFilter1D()
|
|
|
|
# Signal: low-frequency sine + high-frequency sine
|
|
n = 256
|
|
t = np.arange(n, dtype=np.float64) / n
|
|
low = np.sin(2 * np.pi * 3 * t) # 3 cycles — low freq
|
|
high = np.sin(2 * np.pi * 80 * t) # 80 cycles — high freq
|
|
line = low + high
|
|
|
|
# Lowpass should keep low, suppress high
|
|
filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
|
assert len(filtered_lp) == n
|
|
corr_low = np.corrcoef(filtered_lp, low)[0, 1]
|
|
corr_high = np.corrcoef(filtered_lp, high)[0, 1]
|
|
assert corr_low > 0.95, f"Lowpass: correlation with low={corr_low}"
|
|
assert abs(corr_high) < 0.3, f"Lowpass: correlation with high={corr_high}"
|
|
|
|
# Highpass should keep high, suppress low
|
|
filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
|
corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1]
|
|
corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1]
|
|
assert abs(corr_low_hp) < 0.3, f"Highpass: correlation with low={corr_low_hp}"
|
|
assert corr_high_hp > 0.95, f"Highpass: correlation with high={corr_high_hp}"
|
|
|
|
# Bandpass centred on the high frequency
|
|
filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4)
|
|
corr_low_bp = np.corrcoef(filtered_bp, low)[0, 1]
|
|
corr_high_bp = np.corrcoef(filtered_bp, high)[0, 1]
|
|
assert abs(corr_low_bp) < 0.3, f"Bandpass: correlation with low={corr_low_bp}"
|
|
assert corr_high_bp > 0.9, f"Bandpass: correlation with high={corr_high_bp}"
|
|
|
|
# Notch (band-reject) centred on the high frequency — should remove it
|
|
filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4)
|
|
corr_low_notch = np.corrcoef(filtered_notch, low)[0, 1]
|
|
corr_high_notch = np.corrcoef(filtered_notch, high)[0, 1]
|
|
assert corr_low_notch > 0.95, f"Notch: correlation with low={corr_low_notch}"
|
|
assert abs(corr_high_notch) < 0.3, f"Notch: correlation with high={corr_high_notch}"
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_fft_filter_2d():
|
|
print("=== Test: FFTFilter2D ===")
|
|
from backend.nodes.filters import FFTFilter2D
|
|
node = FFTFilter2D()
|
|
|
|
N = 128
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
# Low-frequency 2D pattern + high-frequency pattern
|
|
low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y)
|
|
high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y)
|
|
data = low_2d + high_2d
|
|
field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6)
|
|
|
|
# Lowpass — should preserve low, remove high
|
|
result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
|
assert result_lp.data.shape == (N, N)
|
|
assert result_lp.xreal == field.xreal
|
|
assert result_lp.si_unit_z == field.si_unit_z
|
|
corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1]
|
|
corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1]
|
|
assert corr_low > 0.9, f"2D lowpass: correlation with low={corr_low}"
|
|
assert abs(corr_high) < 0.3, f"2D lowpass: correlation with high={corr_high}"
|
|
|
|
# Highpass — should preserve high, remove low
|
|
result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
|
corr_low_hp = np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]
|
|
corr_high_hp = np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1]
|
|
assert abs(corr_low_hp) < 0.3, f"2D highpass: correlation with low={corr_low_hp}"
|
|
assert corr_high_hp > 0.9, f"2D highpass: correlation with high={corr_high_hp}"
|
|
|
|
# Constant field should be unchanged by lowpass (DC preservation)
|
|
const = make_field(data=np.ones((32, 32)) * 7.0)
|
|
result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2)
|
|
assert np.allclose(result_const.data, 7.0, atol=1e-10), "Lowpass should preserve constant field"
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Level
|
|
# =========================================================================
|
|
|
|
def test_plane_level():
|
|
print("=== Test: PlaneLevelField ===")
|
|
from backend.nodes.level import PlaneLevelField
|
|
node = PlaneLevelField()
|
|
|
|
# Create a tilted plane + small signal
|
|
N = 64
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
signal = np.sin(2 * np.pi * 5 * x)
|
|
data = 100 * x + 50 * y + signal
|
|
field = make_field(data=data)
|
|
|
|
result, = node.process(field)
|
|
assert result.data.shape == field.data.shape
|
|
# After plane leveling, mean should be near zero
|
|
assert abs(result.data.mean()) < 1e-10
|
|
# The signal should remain (correlation with original sine)
|
|
corr = np.corrcoef(result.data.ravel(), signal.ravel())[0, 1]
|
|
assert corr > 0.98, f"Signal correlation after leveling: {corr}"
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_poly_level():
|
|
print("=== Test: PolyLevelField ===")
|
|
from backend.nodes.level import PolyLevelField
|
|
node = PolyLevelField()
|
|
|
|
N = 64
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
# Quadratic background + signal
|
|
background = 50 * x**2 + 30 * y**2 + 10 * x * y
|
|
signal = np.sin(2 * np.pi * 8 * x)
|
|
data = background + signal
|
|
field = make_field(data=data)
|
|
|
|
leveled, bg = node.process(field, degree_x=2, degree_y=2)
|
|
assert leveled.data.shape == field.data.shape
|
|
assert bg.data.shape == field.data.shape
|
|
# leveled + bg should reconstruct original
|
|
assert np.allclose(leveled.data + bg.data, field.data, atol=1e-10)
|
|
# Signal should be preserved after leveling
|
|
corr = np.corrcoef(leveled.data.ravel(), signal.ravel())[0, 1]
|
|
assert corr > 0.95, f"Signal correlation after poly leveling: {corr}"
|
|
# Degree 0 should just subtract the mean
|
|
leveled_0, bg_0 = node.process(field, degree_x=0, degree_y=0)
|
|
assert abs(leveled_0.data.mean()) < 1e-10
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_fix_zero():
|
|
print("=== Test: FixZero ===")
|
|
from backend.nodes.level import FixZero
|
|
node = FixZero()
|
|
field = make_field(data=np.array([[10, 20], [30, 40]], dtype=np.float64))
|
|
|
|
result_min, = node.process(field, method="min")
|
|
assert result_min.data.min() == 0.0
|
|
assert result_min.data.max() == 30.0
|
|
|
|
result_mean, = node.process(field, method="mean")
|
|
assert abs(result_mean.data.mean()) < 1e-10
|
|
|
|
result_median, = node.process(field, method="median")
|
|
assert abs(np.median(result_median.data)) < 1e-10
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Analysis (non-FFT)
|
|
# =========================================================================
|
|
|
|
def test_statistics():
|
|
print("=== Test: StatisticsNode ===")
|
|
from backend.nodes.analysis import StatisticsNode
|
|
node = StatisticsNode()
|
|
|
|
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
|
|
field = make_field(data=data)
|
|
|
|
table, = node.process(field)
|
|
stats = {row["quantity"]: row["value"] for row in table}
|
|
|
|
assert stats["min"] == 1.0
|
|
assert stats["max"] == 4.0
|
|
assert stats["mean"] == 2.5
|
|
assert stats["median"] == 2.5
|
|
assert stats["range"] == 3.0
|
|
# RMS = sqrt(mean((x - mean)^2))
|
|
expected_rms = np.sqrt(np.mean((data - 2.5) ** 2))
|
|
assert abs(stats["RMS"] - expected_rms) < 1e-10
|
|
|
|
# Constant data should have RMS=0, skewness=0, kurtosis=0
|
|
const_field = make_field(data=np.ones((4, 4)) * 5.0)
|
|
table_const, = node.process(const_field)
|
|
const_stats = {row["quantity"]: row["value"] for row in table_const}
|
|
assert const_stats["RMS"] == 0.0
|
|
assert const_stats["skewness"] == 0.0
|
|
assert const_stats["kurtosis"] == 0.0
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_height_histogram():
|
|
print("=== Test: HeightHistogram ===")
|
|
from backend.nodes.analysis import HeightHistogram
|
|
node = HeightHistogram()
|
|
|
|
# Uniform data should give a roughly flat histogram
|
|
data = np.linspace(0, 1, 1000).reshape(25, 40)
|
|
field = make_field(data=data)
|
|
|
|
overlays = []
|
|
HeightHistogram._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
|
|
HeightHistogram._current_node_id = "test"
|
|
|
|
table, = node.process(
|
|
field,
|
|
n_bins=10,
|
|
y_scale="linear",
|
|
x1=0.2,
|
|
y1=0.5,
|
|
x2=0.8,
|
|
y2=0.5,
|
|
)
|
|
measurements = {row["quantity"]: row for row in table}
|
|
assert "A position" in measurements
|
|
assert "A count" in measurements
|
|
assert "B position" in measurements
|
|
assert "B count" in measurements
|
|
assert "delta X" in measurements
|
|
assert "delta Y" in measurements
|
|
assert measurements["A count"]["unit"] == "count"
|
|
assert measurements["B count"]["unit"] == "count"
|
|
assert measurements["B position"]["value"] > measurements["A position"]["value"]
|
|
assert len(overlays) == 1
|
|
assert overlays[0]["kind"] == "line_plot"
|
|
assert overlays[0]["section_title"] == "Histogram"
|
|
assert len(overlays[0]["line"]) == 10
|
|
assert len(overlays[0]["x_axis"]) == 10
|
|
assert np.isclose(overlays[0]["x1"], 0.2)
|
|
assert np.isclose(overlays[0]["x2"], 0.8)
|
|
assert np.isclose(
|
|
measurements["delta Y"]["value"],
|
|
measurements["B count"]["value"] - measurements["A count"]["value"],
|
|
)
|
|
|
|
HeightHistogram._broadcast_overlay_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_cross_section():
|
|
print("=== Test: CrossSection ===")
|
|
from backend.nodes.analysis import CrossSection
|
|
node = CrossSection()
|
|
|
|
# Create a field with a known horizontal gradient
|
|
N = 100
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
data = x * 10.0 # value = 10 * x_fraction
|
|
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
|
|
|
# Horizontal cross section at y=0.5
|
|
(profile,) = node.process(
|
|
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
|
|
extend="none", n_samples=100,
|
|
)
|
|
assert len(profile) == 100
|
|
# Profile should be a linear ramp from ~0 to ~10
|
|
assert profile[0] < 0.5, f"Start of profile: {profile[0]}"
|
|
assert profile[-1] > 9.5, f"End of profile: {profile[-1]}"
|
|
|
|
# n_samples=0 should auto-calculate
|
|
(profile_auto,) = node.process(
|
|
field, x1=0.0, y1=0.5, x2=1.0, y2=0.5,
|
|
extend="none", n_samples=0,
|
|
)
|
|
assert len(profile_auto) >= 2
|
|
|
|
# Test extend to edges — a short segment should be extended
|
|
(profile_ext,) = node.process(
|
|
field, x1=0.3, y1=0.5, x2=0.7, y2=0.5,
|
|
extend="to_edges", n_samples=100,
|
|
)
|
|
# Extended profile should start near 0 and end near 10
|
|
assert profile_ext[0] < 0.5
|
|
assert profile_ext[-1] > 9.5
|
|
|
|
# Diagonal cross section
|
|
(profile_diag,) = node.process(
|
|
field, x1=0.0, y1=0.0, x2=1.0, y2=1.0,
|
|
extend="none", n_samples=50,
|
|
)
|
|
assert len(profile_diag) == 50
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Grains
|
|
# =========================================================================
|
|
|
|
def test_threshold_mask():
|
|
print("=== Test: ThresholdMask ===")
|
|
from backend.nodes.mask import ThresholdMask
|
|
node = ThresholdMask()
|
|
|
|
# Clear bimodal data: left half = 0, right half = 1
|
|
data = np.zeros((64, 64))
|
|
data[:, 32:] = 1.0
|
|
field = make_field(data=data)
|
|
|
|
# Capture overlay preview
|
|
previews = []
|
|
ThresholdMask._broadcast_fn = lambda nid, uri: previews.append(uri)
|
|
ThresholdMask._current_node_id = "test"
|
|
|
|
# Absolute threshold at 0.5
|
|
mask, = node.process(field, method="absolute", threshold=0.5, direction="above")
|
|
assert mask.dtype == np.uint8
|
|
assert mask.shape == (64, 64)
|
|
assert np.all(mask[:, :32] == 0)
|
|
assert np.all(mask[:, 32:] == 255)
|
|
|
|
# Verify overlay preview was broadcast
|
|
assert len(previews) == 1
|
|
assert previews[0].startswith("data:image/png;base64,")
|
|
|
|
# Direction "below"
|
|
mask_below, = node.process(field, method="absolute", threshold=0.5, direction="below")
|
|
assert np.all(mask_below[:, :32] == 255)
|
|
assert np.all(mask_below[:, 32:] == 0)
|
|
|
|
# Relative threshold at 0.5 (midpoint of range)
|
|
mask_rel, = node.process(field, method="relative", threshold=0.5, direction="above")
|
|
assert np.all(mask_rel[:, 32:] == 255)
|
|
|
|
# Otsu should find the bimodal threshold
|
|
mask_otsu, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
|
assert mask_otsu[:, 32:].sum() > mask_otsu[:, :32].sum()
|
|
|
|
ThresholdMask._broadcast_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_mask_morphology():
|
|
print("=== Test: MaskMorphology ===")
|
|
from backend.nodes.mask import MaskMorphology
|
|
node = MaskMorphology()
|
|
|
|
# Small square blob in the centre
|
|
mask = np.zeros((64, 64), dtype=np.uint8)
|
|
mask[28:36, 28:36] = 255 # 8x8 block
|
|
orig_count = np.count_nonzero(mask)
|
|
|
|
# Dilate should grow the region
|
|
dilated, = node.process(mask, operation="dilate", radius=1, shape="square")
|
|
assert dilated.dtype == np.uint8
|
|
assert np.count_nonzero(dilated) > orig_count
|
|
|
|
# Erode should shrink it
|
|
eroded, = node.process(mask, operation="erode", radius=1, shape="square")
|
|
assert np.count_nonzero(eroded) < orig_count
|
|
|
|
# Open on a clean block should give back roughly the same block
|
|
opened, = node.process(mask, operation="open", radius=1, shape="square")
|
|
assert np.count_nonzero(opened) <= orig_count
|
|
|
|
# Close on a mask with a 1-pixel hole should fill the hole
|
|
mask_hole = mask.copy()
|
|
mask_hole[32, 32] = 0 # poke a hole
|
|
assert np.count_nonzero(mask_hole) == orig_count - 1
|
|
closed, = node.process(mask_hole, operation="close", radius=1, shape="square")
|
|
assert closed[32, 32] == 255, "Close should fill the 1-pixel hole"
|
|
|
|
# Disk structuring element should also work
|
|
dilated_disk, = node.process(mask, operation="dilate", radius=2, shape="disk")
|
|
assert np.count_nonzero(dilated_disk) > orig_count
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_mask_invert():
|
|
print("=== Test: MaskInvert ===")
|
|
from backend.nodes.mask import MaskInvert
|
|
node = MaskInvert()
|
|
|
|
mask = np.zeros((64, 64), dtype=np.uint8)
|
|
mask[10:20, 10:20] = 255
|
|
|
|
inverted, = node.process(mask)
|
|
assert inverted.dtype == np.uint8
|
|
assert np.all(inverted[10:20, 10:20] == 0)
|
|
assert np.all(inverted[0:10, 0:10] == 255)
|
|
# Double-invert should return to original
|
|
double, = node.process(inverted)
|
|
assert np.array_equal(double, mask)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_mask_combine():
|
|
print("=== Test: MaskCombine ===")
|
|
from backend.nodes.mask import MaskCombine
|
|
node = MaskCombine()
|
|
|
|
# Two overlapping squares
|
|
a = np.zeros((64, 64), dtype=np.uint8)
|
|
a[10:30, 10:30] = 255 # 20x20
|
|
b = np.zeros((64, 64), dtype=np.uint8)
|
|
b[20:40, 20:40] = 255 # 20x20, overlaps 10x10
|
|
|
|
# AND — only the overlap
|
|
result_and, = node.process(a, b, operation="and")
|
|
assert np.all(result_and[20:30, 20:30] == 255)
|
|
assert result_and[15, 15] == 0 # a-only region
|
|
assert result_and[35, 35] == 0 # b-only region
|
|
|
|
# OR — union
|
|
result_or, = node.process(a, b, operation="or")
|
|
assert result_or[15, 15] == 255
|
|
assert result_or[35, 35] == 255
|
|
assert result_or[25, 25] == 255
|
|
assert result_or[5, 5] == 0
|
|
|
|
# XOR — symmetric difference
|
|
result_xor, = node.process(a, b, operation="xor")
|
|
assert result_xor[15, 15] == 255 # a-only
|
|
assert result_xor[35, 35] == 255 # b-only
|
|
assert result_xor[25, 25] == 0 # overlap excluded
|
|
|
|
# Subtract — a minus b
|
|
result_sub, = node.process(a, b, operation="subtract")
|
|
assert result_sub[15, 15] == 255 # a-only kept
|
|
assert result_sub[25, 25] == 0 # overlap removed
|
|
assert result_sub[35, 35] == 0 # b-only not included
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_draw_mask():
|
|
print("=== Test: DrawMask ===")
|
|
from backend.nodes.mask import DrawMask
|
|
node = DrawMask()
|
|
|
|
field = make_field(data=np.zeros((32, 32), dtype=np.float64))
|
|
overlays = []
|
|
DrawMask._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
|
|
DrawMask._current_node_id = "test"
|
|
|
|
mask_paths = [
|
|
{
|
|
"size": 5,
|
|
"points": [
|
|
{"x": 0.2, "y": 0.5},
|
|
{"x": 0.8, "y": 0.5},
|
|
],
|
|
}
|
|
]
|
|
|
|
mask, = node.process(field, pen_size=2, invert=False, mask_paths=json.dumps(mask_paths))
|
|
assert mask.dtype == np.uint8
|
|
assert mask.shape == (32, 32)
|
|
assert mask[16, 16] == 255
|
|
assert mask[14, 16] == 255
|
|
assert mask[0, 0] == 0
|
|
|
|
assert len(overlays) == 1
|
|
assert overlays[0]["kind"] == "mask_paint"
|
|
assert overlays[0]["section_title"] == "Mask"
|
|
assert overlays[0]["image"].startswith("data:image/png;base64,")
|
|
assert overlays[0]["image_width"] == field.xres
|
|
assert overlays[0]["image_height"] == field.yres
|
|
assert overlays[0]["invert"] is False
|
|
|
|
inverted, = node.process(field, pen_size=2, invert=True, mask_paths=json.dumps(mask_paths))
|
|
assert inverted[16, 16] == 0
|
|
assert inverted[0, 0] == 255
|
|
assert overlays[-1]["invert"] is True
|
|
|
|
cleared, = node.process(field, pen_size=12, invert=False, mask_paths="[]")
|
|
assert np.count_nonzero(cleared) == 0
|
|
|
|
DrawMask._broadcast_overlay_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_particle_analysis():
|
|
print("=== Test: ParticleAnalysis ===")
|
|
from backend.nodes.particless import ParticleAnalysis
|
|
node = ParticleAnalysis()
|
|
|
|
# Create a field with two distinct particles
|
|
N = 64
|
|
data = np.zeros((N, N))
|
|
# Particle 1: 10x10 block at top-left with height 5
|
|
data[5:15, 5:15] = 5.0
|
|
# Particle 2: 8x8 block at bottom-right with height 3
|
|
data[45:53, 45:53] = 3.0
|
|
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
|
|
|
# Create matching mask
|
|
mask = np.zeros((N, N), dtype=np.uint8)
|
|
mask[5:15, 5:15] = 255
|
|
mask[45:53, 45:53] = 255
|
|
|
|
table, = node.process(field, mask=mask, min_size=10)
|
|
assert len(table) == 2, f"Expected 2 particles, got {len(table)}"
|
|
|
|
# Sort by area descending
|
|
table.sort(key=lambda r: r["area_px"], reverse=True)
|
|
assert table[0]["area_px"] == 100 # 10x10
|
|
assert table[1]["area_px"] == 64 # 8x8
|
|
assert abs(table[0]["mean_height"] - 5.0) < 1e-10
|
|
assert abs(table[1]["mean_height"] - 3.0) < 1e-10
|
|
|
|
# min_size filtering: only keep particles >= 80 px
|
|
table_filtered, = node.process(field, mask=mask, min_size=80)
|
|
assert len(table_filtered) == 1
|
|
assert table_filtered[0]["area_px"] == 100
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# I/O
|
|
# =========================================================================
|
|
|
|
def test_load_file():
|
|
print("=== Test: LoadFile ===")
|
|
from backend.nodes.io import LoadFile
|
|
from PIL import Image
|
|
node = LoadFile()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Test loading a grayscale PNG → single DataField output
|
|
arr = np.random.default_rng(1).integers(0, 256, (48, 64), dtype=np.uint8)
|
|
img = Image.fromarray(arr, mode="L")
|
|
path = os.path.join(tmpdir, "test_gray.png")
|
|
img.save(path)
|
|
|
|
result = node.load(filename=path)
|
|
assert len(result) == 1
|
|
field = result[0]
|
|
assert field.data.shape == (48, 64)
|
|
assert field.data.dtype == np.float64
|
|
|
|
# Test loading an RGB PNG (should average to grayscale)
|
|
arr_rgb = np.random.default_rng(2).integers(0, 256, (32, 32, 3), dtype=np.uint8)
|
|
img_rgb = Image.fromarray(arr_rgb, mode="RGB")
|
|
path_rgb = os.path.join(tmpdir, "test_rgb.png")
|
|
img_rgb.save(path_rgb)
|
|
|
|
result_rgb = node.load(filename=path_rgb)
|
|
assert len(result_rgb) == 1
|
|
assert result_rgb[0].data.shape == (32, 32)
|
|
|
|
# Test loading a .npy file
|
|
data_npy = np.random.default_rng(3).standard_normal((50, 60))
|
|
path_npy = os.path.join(tmpdir, "test.npy")
|
|
np.save(path_npy, data_npy)
|
|
|
|
result_npy = node.load(filename=path_npy)
|
|
assert np.allclose(result_npy[0].data, data_npy)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_save_image():
|
|
print("=== Test: SaveImage (Save Layers) ===")
|
|
from backend.nodes.io import SaveImage
|
|
node = SaveImage()
|
|
|
|
field_a = make_field(data=np.random.default_rng(4).random((32, 32)))
|
|
field_b = make_field(data=np.random.default_rng(5).random((32, 32)))
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Save single layer as TIFF
|
|
tiff_path = os.path.join(tmpdir, "out.tiff")
|
|
node.save(filename=tiff_path, format="TIFF", field_0=field_a)
|
|
assert os.path.exists(tiff_path), "TIFF file not created"
|
|
from PIL import Image
|
|
im = Image.open(tiff_path)
|
|
assert im.n_frames == 1
|
|
arr_back = np.array(im)
|
|
assert arr_back.shape == (32, 32)
|
|
|
|
# Save multi-layer as TIFF
|
|
tiff_path2 = os.path.join(tmpdir, "multi.tiff")
|
|
node.save(filename=tiff_path2, format="TIFF", field_0=field_a, field_1=field_b)
|
|
im2 = Image.open(tiff_path2)
|
|
assert im2.n_frames == 2
|
|
|
|
# Save as NPZ
|
|
npz_path = os.path.join(tmpdir, "out.npz")
|
|
node.save(filename=npz_path, format="NPZ", field_0=field_a, field_1=field_b)
|
|
assert os.path.exists(npz_path)
|
|
npz = np.load(npz_path)
|
|
assert len(npz.files) == 2
|
|
assert np.allclose(npz["layer_0"], field_a.data)
|
|
assert np.allclose(npz["layer_1"], field_b.data)
|
|
|
|
# Extension is forced to match format
|
|
wrong_ext = os.path.join(tmpdir, "output.png")
|
|
node.save(filename=wrong_ext, format="TIFF", field_0=field_a)
|
|
assert os.path.exists(os.path.join(tmpdir, "output.tiff"))
|
|
|
|
# No fields connected → error
|
|
try:
|
|
node.save(filename=os.path.join(tmpdir, "empty.tiff"), format="TIFF")
|
|
assert False, "Should have raised ValueError"
|
|
except ValueError:
|
|
pass
|
|
|
|
# No filename → error
|
|
try:
|
|
node.save(filename="", format="TIFF", field_0=field_a)
|
|
assert False, "Should have raised ValueError"
|
|
except ValueError:
|
|
pass
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Display (limited testing — these are output nodes with WS callbacks)
|
|
# =========================================================================
|
|
|
|
def test_preview_image():
|
|
print("=== Test: PreviewImage ===")
|
|
from backend.nodes.display import PreviewImage
|
|
node = PreviewImage()
|
|
|
|
# Set up a capture for the broadcast
|
|
captured = []
|
|
PreviewImage._broadcast_fn = lambda node_id, data_uri: captured.append(data_uri)
|
|
PreviewImage._current_node_id = "test"
|
|
|
|
# Preview with a DataField
|
|
field = make_field()
|
|
node.preview(colormap="viridis", field=field)
|
|
assert len(captured) == 1
|
|
assert captured[0].startswith("data:image/png;base64,")
|
|
|
|
# Preview with an IMAGE array
|
|
captured.clear()
|
|
arr = np.random.default_rng(5).integers(0, 256, (32, 32), dtype=np.uint8)
|
|
node.preview(colormap="gray", image=arr)
|
|
assert len(captured) == 1
|
|
|
|
# Clean up
|
|
PreviewImage._broadcast_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_print_table():
|
|
print("=== Test: PrintTable ===")
|
|
from backend.nodes.display import PrintTable
|
|
node = PrintTable()
|
|
|
|
captured = []
|
|
PrintTable._broadcast_table_fn = lambda node_id, rows: captured.append(rows)
|
|
PrintTable._current_node_id = "test"
|
|
|
|
table = [{"quantity": "test", "value": 42.0, "unit": "m"}]
|
|
node.print_table(table=table)
|
|
assert len(captured) == 1
|
|
assert captured[0] == table
|
|
|
|
PrintTable._broadcast_table_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_value_display():
|
|
print("=== Test: ValueDisplay ===")
|
|
from backend.nodes.display import ValueDisplay
|
|
|
|
node = ValueDisplay()
|
|
captured = []
|
|
ValueDisplay._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
|
|
ValueDisplay._current_node_id = "test"
|
|
|
|
result = node.display_value(3.25)
|
|
assert result == (3.25,)
|
|
assert captured == [("test", {"value": 3.25})]
|
|
|
|
measurements = MeasureTable([
|
|
{"quantity": "delta X", "value": 1.7e-7, "unit": "m"},
|
|
{"quantity": "delta Y", "value": 463, "unit": "count"},
|
|
])
|
|
result = node.display_value(measurements, measurement="delta X")
|
|
assert result == (1.7e-7,)
|
|
assert captured[-1] == ("test", {"value": 1.7e-7, "unit": "m"})
|
|
|
|
ValueDisplay._broadcast_value_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# I/O — IBW multi-channel loading
|
|
# =========================================================================
|
|
|
|
def test_load_file_ibw():
|
|
print("=== Test: LoadFile IBW multi-channel ===")
|
|
from backend.nodes.io import LoadFile
|
|
|
|
node = LoadFile()
|
|
ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw")
|
|
ibw_path = os.path.abspath(ibw_path)
|
|
if not os.path.exists(ibw_path):
|
|
print(" SKIP (demo IBW file not found)\n")
|
|
return
|
|
|
|
result = node.load(filename=ibw_path)
|
|
|
|
# BR_New20012.ibw has 4 channels
|
|
assert len(result) == 4, f"Expected 4 channels, got {len(result)}"
|
|
|
|
for i, field in enumerate(result):
|
|
assert isinstance(field, DataField), f"Channel {i} is not a DataField"
|
|
assert field.data.shape == (512, 1024), f"Channel {i} shape: {field.data.shape}"
|
|
assert field.data.dtype == np.float64
|
|
# Physical dimensions should be populated (not default 1e-6)
|
|
assert field.xreal > 1e-8, f"Channel {i} xreal too small: {field.xreal}"
|
|
assert field.yreal > 1e-8, f"Channel {i} yreal too small: {field.yreal}"
|
|
assert field.si_unit_xy == "m"
|
|
assert field.si_unit_z == "m"
|
|
|
|
# All channels should share the same physical dimensions
|
|
assert result[0].xreal == result[1].xreal
|
|
assert result[0].yreal == result[1].yreal
|
|
|
|
# Different channels should have different data
|
|
assert not np.array_equal(result[0].data, result[1].data)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_load_file_npz():
|
|
print("=== Test: LoadFile .npz ===")
|
|
from backend.nodes.io import LoadFile
|
|
|
|
node = LoadFile()
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
data = np.random.default_rng(99).standard_normal((30, 40))
|
|
path = os.path.join(tmpdir, "test.npz")
|
|
np.savez(path, my_array=data)
|
|
|
|
result = node.load(filename=path)
|
|
assert len(result) == 1
|
|
assert np.allclose(result[0].data, data)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_load_file_not_found():
|
|
print("=== Test: LoadFile not found ===")
|
|
from backend.nodes.io import LoadFile
|
|
|
|
node = LoadFile()
|
|
try:
|
|
node.load(filename="/nonexistent/path/file.png")
|
|
assert False, "Should have raised FileNotFoundError"
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_load_file_unsupported():
|
|
print("=== Test: LoadFile unsupported format ===")
|
|
from backend.nodes.io import LoadFile
|
|
|
|
node = LoadFile()
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
path = os.path.join(tmpdir, "test.xyz")
|
|
with open(path, "w") as f:
|
|
f.write("hello")
|
|
try:
|
|
node.load(filename=path)
|
|
assert False, "Should have raised an error for .xyz"
|
|
except Exception:
|
|
pass
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_load_file_warning():
|
|
print("=== Test: LoadFile warning for uncalibrated data ===")
|
|
from backend.nodes.io import LoadFile
|
|
from PIL import Image
|
|
|
|
node = LoadFile()
|
|
warnings = []
|
|
LoadFile._broadcast_warning_fn = lambda nid, msg: warnings.append(msg)
|
|
LoadFile._current_node_id = "test"
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8)
|
|
img = Image.fromarray(arr)
|
|
path = os.path.join(tmpdir, "test.png")
|
|
img.save(path)
|
|
|
|
result = node.load(filename=path)
|
|
assert len(result) == 1
|
|
assert len(warnings) == 1
|
|
assert "Uncalibrated" in warnings[0]
|
|
|
|
LoadFile._broadcast_warning_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# I/O — list_channels helper
|
|
# =========================================================================
|
|
|
|
def test_list_channels():
|
|
print("=== Test: list_channels ===")
|
|
from backend.nodes.io import list_channels
|
|
|
|
# Non-existent file → default
|
|
ch = list_channels("/nonexistent/file.ibw")
|
|
assert len(ch) == 1
|
|
assert ch[0]["name"] == "field"
|
|
|
|
# IBW with channels
|
|
ibw_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw"))
|
|
if os.path.exists(ibw_path):
|
|
ch = list_channels(ibw_path)
|
|
assert len(ch) == 4
|
|
names = [c["name"] for c in ch]
|
|
assert "HeightRetrace" in names
|
|
assert "AmplitudeRetrace" in names
|
|
assert all(c["type"] == "DATA_FIELD" for c in ch)
|
|
|
|
# Plain image → single default channel
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
from PIL import Image
|
|
img = Image.fromarray(np.zeros((8, 8), dtype=np.uint8))
|
|
path = os.path.join(tmpdir, "test.png")
|
|
img.save(path)
|
|
|
|
ch = list_channels(path)
|
|
assert len(ch) == 1
|
|
assert ch[0]["name"] == "field"
|
|
|
|
# .npy → single default channel
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
path = os.path.join(tmpdir, "test.npy")
|
|
np.save(path, np.zeros((4, 4)))
|
|
|
|
ch = list_channels(path)
|
|
assert len(ch) == 1
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# I/O — LoadDemo
|
|
# =========================================================================
|
|
|
|
def test_load_demo():
|
|
print("=== Test: LoadDemo ===")
|
|
from backend.nodes.io import LoadDemo
|
|
|
|
node = LoadDemo()
|
|
|
|
# Should be able to load a demo file by name
|
|
result = node.load(name="nanoparticles.npy")
|
|
assert len(result) >= 1
|
|
assert isinstance(result[0], DataField)
|
|
assert result[0].data.ndim == 2
|
|
|
|
# IBW demo should return multiple channels
|
|
result_ibw = node.load(name="whiskers.ibw")
|
|
assert len(result_ibw) == 4
|
|
for field in result_ibw:
|
|
assert isinstance(field, DataField)
|
|
|
|
# Non-existent demo should raise
|
|
try:
|
|
node.load(name="nonexistent_file.png")
|
|
assert False, "Should have raised FileNotFoundError"
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# I/O — Coordinate
|
|
# =========================================================================
|
|
|
|
def test_coordinate():
|
|
print("=== Test: Coordinate ===")
|
|
from backend.nodes.io import Coordinate
|
|
|
|
node = Coordinate()
|
|
|
|
result = node.process(x=0.3, y=0.7)
|
|
assert len(result) == 1
|
|
assert result[0] == (0.3, 0.7)
|
|
|
|
# Edge values
|
|
result_zero = node.process(x=0.0, y=0.0)
|
|
assert result_zero[0] == (0.0, 0.0)
|
|
|
|
result_one = node.process(x=1.0, y=1.0)
|
|
assert result_one[0] == (1.0, 1.0)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
def test_range_slider():
|
|
print("=== Test: RangeSlider ===")
|
|
from backend.nodes.io import RangeSlider
|
|
|
|
node = RangeSlider()
|
|
|
|
result = node.process(min_value=0.0, max_value=10.0, value=3.25)
|
|
assert result == (3.25,)
|
|
|
|
# Clamp above max
|
|
result_high = node.process(min_value=0.0, max_value=10.0, value=12.0)
|
|
assert result_high == (10.0,)
|
|
|
|
# Reversed bounds should still work
|
|
result_reversed = node.process(min_value=5.0, max_value=-1.0, value=4.0)
|
|
assert result_reversed == (4.0,)
|
|
|
|
# Equal bounds collapse to a fixed value
|
|
result_fixed = node.process(min_value=2.5, max_value=2.5, value=99.0)
|
|
assert result_fixed == (2.5,)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Analysis — LineCursors
|
|
# =========================================================================
|
|
|
|
def test_line_cursors():
|
|
print("=== Test: LineCursors ===")
|
|
from backend.nodes.analysis import LineCursors
|
|
|
|
node = LineCursors()
|
|
|
|
# Create a simple linear ramp
|
|
line = np.linspace(0, 10, 100).astype(np.float64)
|
|
|
|
# Capture overlay
|
|
overlays = []
|
|
LineCursors._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
|
|
LineCursors._current_node_id = "test"
|
|
|
|
table, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5)
|
|
|
|
# Should produce a 6-row table
|
|
assert len(table) == 6
|
|
quantities = {row["quantity"] for row in table}
|
|
assert "A position" in quantities
|
|
assert "B position" in quantities
|
|
assert "delta X" in quantities
|
|
assert "delta Y" in quantities
|
|
|
|
# B should be at a later position than A
|
|
a_pos = next(r["value"] for r in table if r["quantity"] == "A position")
|
|
b_pos = next(r["value"] for r in table if r["quantity"] == "B position")
|
|
assert b_pos > a_pos
|
|
|
|
# Delta Y should reflect the height difference along the ramp
|
|
dy = next(r["value"] for r in table if r["quantity"] == "delta Y")
|
|
assert dy > 0 # ramp goes upward
|
|
|
|
# Overlay should have been broadcast
|
|
assert len(overlays) == 1
|
|
assert overlays[0]["kind"] == "line_plot"
|
|
assert len(overlays[0]["line"]) == len(line)
|
|
assert len(overlays[0]["x_axis"]) == len(line)
|
|
assert 0.0 <= overlays[0]["x1"] <= 1.0
|
|
assert 0.0 <= overlays[0]["x2"] <= 1.0
|
|
|
|
# With x_axis provided
|
|
x_axis = np.linspace(0, 1, 100).astype(np.float64)
|
|
table2, = node.process(line, x1=0.25, y1=0.5, x2=0.75, y2=0.5, x_axis=x_axis)
|
|
assert len(table2) == 6
|
|
|
|
LineCursors._broadcast_overlay_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Analysis — FFT2D
|
|
# =========================================================================
|
|
|
|
def test_fft2d():
|
|
print("=== Test: FFT2D ===")
|
|
from backend.nodes.analysis import FFT2D
|
|
|
|
node = FFT2D()
|
|
|
|
# Pure single-frequency signal: peak should appear at the right location
|
|
N = 64
|
|
y, x = np.mgrid[0:N, 0:N] / N
|
|
freq = 5
|
|
data = np.sin(2 * np.pi * freq * x)
|
|
field = make_field(data=data, xreal=1e-6, yreal=1e-6)
|
|
|
|
# log_magnitude
|
|
spectrum, spec_mag, spec_phase, spec_psdf = node.process(field, windowing="none", level="none")
|
|
assert spectrum.data.shape == (N, N)
|
|
assert spectrum.domain == "frequency"
|
|
assert spectrum.si_unit_xy == "1/m"
|
|
# Peak should be symmetric about centre
|
|
centre = N // 2
|
|
row = spectrum.data[centre, :]
|
|
peak_idx = np.argmax(row[centre + 1:]) + centre + 1
|
|
assert abs(peak_idx - (centre + freq)) <= 1, f"Peak at {peak_idx}, expected ~{centre + freq}"
|
|
|
|
# magnitude output
|
|
_, spec_mag, _, _ = node.process(field, windowing="hann", level="mean")
|
|
assert spec_mag.data.shape == (N, N)
|
|
assert np.all(spec_mag.data >= 0)
|
|
|
|
# phase output
|
|
_, _, spec_phase, _ = node.process(field, windowing="none", level="none")
|
|
assert spec_phase.data.shape == (N, N)
|
|
assert spec_phase.data.min() >= -np.pi - 0.01
|
|
assert spec_phase.data.max() <= np.pi + 0.01
|
|
|
|
# psdf output — units should reflect PSDF calibration
|
|
_, _, _, spec_psdf = node.process(field, windowing="hamming", level="plane")
|
|
assert spec_psdf.data.shape == (N, N)
|
|
assert np.all(spec_psdf.data >= 0)
|
|
assert "^2" in spec_psdf.si_unit_z
|
|
|
|
# Constant field should have all energy at DC
|
|
const_field = make_field(data=np.ones((32, 32)) * 3.0)
|
|
_, spec_const, _, _ = node.process(const_field, windowing="none", level="none")
|
|
centre32 = 16
|
|
dc_val = spec_const.data[centre32, centre32]
|
|
assert dc_val == spec_const.data.max(), "DC should be the maximum for constant field"
|
|
|
|
# Blackman windowing should also work without error
|
|
spec_bk, _, _, _ = node.process(field, windowing="blackman", level="none")
|
|
assert spec_bk.data.shape == (N, N)
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Analysis — LineMath
|
|
# =========================================================================
|
|
|
|
def test_line_math():
|
|
print("=== Test: LineMath ===")
|
|
from backend.nodes.analysis import LineMath
|
|
|
|
node = LineMath()
|
|
line = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
|
|
|
# Basic stats
|
|
table, = node.process(line, operation="min")
|
|
assert table[0]["value"] == 1.0
|
|
|
|
table, = node.process(line, operation="max")
|
|
assert table[0]["value"] == 5.0
|
|
|
|
table, = node.process(line, operation="mean")
|
|
assert table[0]["value"] == 3.0
|
|
|
|
table, = node.process(line, operation="median")
|
|
assert table[0]["value"] == 3.0
|
|
|
|
table, = node.process(line, operation="sum")
|
|
assert table[0]["value"] == 15.0
|
|
|
|
table, = node.process(line, operation="range")
|
|
assert table[0]["value"] == 4.0
|
|
|
|
table, = node.process(line, operation="length")
|
|
assert table[0]["value"] == 5.0
|
|
|
|
# RMS of [1,2,3,4,5]
|
|
table, = node.process(line, operation="rms")
|
|
expected_rms = np.sqrt(np.mean(line ** 2))
|
|
assert abs(table[0]["value"] - expected_rms) < 1e-10
|
|
|
|
# Roughness parameters
|
|
table, = node.process(line, operation="Ra")
|
|
d = line - line.mean()
|
|
expected_ra = float(np.mean(np.abs(d)))
|
|
assert abs(table[0]["value"] - expected_ra) < 1e-10
|
|
|
|
table, = node.process(line, operation="Rq")
|
|
expected_rq = float(np.sqrt(np.mean(d ** 2)))
|
|
assert abs(table[0]["value"] - expected_rq) < 1e-10
|
|
|
|
# Rp = max of (z - mean)
|
|
table, = node.process(line, operation="Rp")
|
|
assert abs(table[0]["value"] - d.max()) < 1e-10
|
|
|
|
# Rv = -(min of (z - mean))
|
|
table, = node.process(line, operation="Rv")
|
|
assert abs(table[0]["value"] - (-d.min())) < 1e-10
|
|
|
|
# Rt = Rp + Rv = range of (z - mean)
|
|
table, = node.process(line, operation="Rt")
|
|
assert abs(table[0]["value"] - (d.max() - d.min())) < 1e-10
|
|
|
|
# Constant line: roughness parameters should all be zero
|
|
const_line = np.ones(10) * 7.0
|
|
table, = node.process(const_line, operation="Ra")
|
|
assert table[0]["value"] == 0.0
|
|
table, = node.process(const_line, operation="Rq")
|
|
assert table[0]["value"] == 0.0
|
|
table, = node.process(const_line, operation="Rsk")
|
|
assert table[0]["value"] == 0.0
|
|
table, = node.process(const_line, operation="Rku")
|
|
assert table[0]["value"] == 0.0
|
|
|
|
# Slope-based: Dq and Da
|
|
table, = node.process(line, operation="Dq")
|
|
dz = np.diff(line)
|
|
expected_dq = float(np.sqrt(np.mean(dz * dz)))
|
|
assert abs(table[0]["value"] - expected_dq) < 1e-10
|
|
|
|
table, = node.process(line, operation="Da")
|
|
expected_da = float(np.mean(np.abs(dz)))
|
|
assert abs(table[0]["value"] - expected_da) < 1e-10
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Analysis — TableMath
|
|
# =========================================================================
|
|
|
|
def test_table_math():
|
|
print("=== Test: TableMath ===")
|
|
from backend.nodes.analysis import TableMath
|
|
|
|
node = TableMath()
|
|
captured = []
|
|
TableMath._broadcast_value_fn = lambda node_id, value: captured.append((node_id, value))
|
|
TableMath._current_node_id = "test"
|
|
table = RecordTable([
|
|
{"label": "a", "value": 1.0, "other": 10},
|
|
{"label": "b", "value": 5.0, "other": 20},
|
|
{"label": "c", "value": "3.0", "other": 30},
|
|
{"label": "d", "value": "bad", "other": 40},
|
|
])
|
|
|
|
result, = node.process(table, column="value", operation="max")
|
|
assert result == 5.0
|
|
assert captured[-1] == ("test", 5.0)
|
|
|
|
result, = node.process(table, column="value", operation="min")
|
|
assert result == 1.0
|
|
|
|
result, = node.process(table, column="value", operation="avg")
|
|
assert np.isclose(result, 3.0)
|
|
|
|
result, = node.process(table, column="value", operation="median")
|
|
assert np.isclose(result, 3.0)
|
|
|
|
result, = node.process(table, column="other", operation="sum")
|
|
assert result == 100.0
|
|
|
|
result, = node.process(table, column="other", operation="count")
|
|
assert result == 4.0
|
|
|
|
# Blank column name should fall back to the common "value" column.
|
|
result, = node.process(table, column="", operation="range")
|
|
assert result == 4.0
|
|
|
|
try:
|
|
node.process(table, column="missing", operation="max")
|
|
raise AssertionError("Expected missing numeric column to raise ValueError")
|
|
except ValueError:
|
|
pass
|
|
|
|
try:
|
|
node.process([{"label": "only text"}], column="label", operation="max")
|
|
raise AssertionError("Expected non-numeric column to raise ValueError")
|
|
except ValueError:
|
|
pass
|
|
|
|
try:
|
|
node.process(
|
|
MeasureTable([{"quantity": "A position", "value": 1.0, "unit": "m"}]),
|
|
column="value",
|
|
operation="max",
|
|
)
|
|
raise AssertionError("Expected measurement table input to raise ValueError")
|
|
except ValueError:
|
|
pass
|
|
|
|
TableMath._broadcast_value_fn = None
|
|
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Analysis — Stats
|
|
# =========================================================================
|
|
|
|
def test_stats():
|
|
print("=== Test: Stats ===")
|
|
from backend.nodes.analysis import Stats
|
|
|
|
node = Stats()
|
|
captured = []
|
|
Stats._broadcast_value_fn = lambda node_id, payload: captured.append((node_id, payload))
|
|
Stats._current_node_id = "test"
|
|
|
|
line = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float64)
|
|
result, = node.process(line, operation="mean", column="value")
|
|
assert np.isclose(result, 2.5)
|
|
assert captured[-1] == ("test", {"value": result})
|
|
|
|
table = RecordTable([
|
|
{"name": "a", "value": 3.0, "unit": "m", "other": 10.0},
|
|
{"name": "b", "value": 7.0, "unit": "m", "other": 20.0},
|
|
])
|
|
result, = node.process(table, operation="max", column="value")
|
|
assert result == 7.0
|
|
assert captured[-1] == ("test", {"value": 7.0, "unit": "m"})
|
|
|
|
field = make_field(data=np.array([[1.0, 5.0], [2.0, 4.0]], dtype=np.float64))
|
|
result, = node.process(field, operation="range", column="value")
|
|
assert result == 4.0
|
|
assert captured[-1] == ("test", {"value": 4.0, "unit": "m"})
|
|
|
|
image = np.array([[0, 10], [20, 30]], dtype=np.uint8)
|
|
result, = node.process(image, operation="avg", column="value")
|
|
assert np.isclose(result, 15.0)
|
|
assert captured[-1] == ("test", {"value": 15.0})
|
|
|
|
try:
|
|
node.process(table, operation="Rq", column="value")
|
|
raise AssertionError("Expected invalid TABLE operation to raise ValueError")
|
|
except ValueError:
|
|
pass
|
|
|
|
try:
|
|
node.process(
|
|
MeasureTable([{"quantity": "min", "value": 1.0, "unit": "m"}]),
|
|
operation="max",
|
|
column="value",
|
|
)
|
|
raise AssertionError("Expected measurement table input to raise ValueError")
|
|
except ValueError:
|
|
pass
|
|
|
|
Stats._broadcast_value_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Display — View3D
|
|
# =========================================================================
|
|
|
|
def test_view3d():
|
|
print("=== Test: View3D ===")
|
|
from backend.nodes.display import View3D
|
|
|
|
node = View3D()
|
|
field = make_field()
|
|
|
|
captured = []
|
|
View3D._broadcast_mesh_fn = lambda nid, mesh: captured.append(mesh)
|
|
View3D._current_node_id = "test"
|
|
|
|
result = node.render(field, colormap="viridis", z_scale=2.0, resolution=64)
|
|
assert result == ()
|
|
assert len(captured) == 1
|
|
|
|
mesh = captured[0]
|
|
assert "width" in mesh
|
|
assert "height" in mesh
|
|
assert "z_data" in mesh
|
|
assert "colors" in mesh
|
|
assert mesh["z_scale"] == 2.0
|
|
assert mesh["width"] <= 64
|
|
assert mesh["height"] <= 64
|
|
# z_min < z_max for non-constant data
|
|
assert mesh["z_min"] < mesh["z_max"]
|
|
|
|
# Verify base64 data can be decoded
|
|
import base64
|
|
z_bytes = base64.b64decode(mesh["z_data"])
|
|
assert len(z_bytes) == mesh["width"] * mesh["height"] * 4 # float32
|
|
|
|
colors_bytes = base64.b64decode(mesh["colors"])
|
|
assert len(colors_bytes) == mesh["width"] * mesh["height"] * 3 # uint8 RGB
|
|
|
|
# High-res input should be downsampled
|
|
big_field = make_field(shape=(256, 256))
|
|
captured.clear()
|
|
node.render(big_field, colormap="hot", z_scale=1.0, resolution=64)
|
|
assert captured[0]["width"] <= 64
|
|
assert captured[0]["height"] <= 64
|
|
|
|
View3D._broadcast_mesh_fn = None
|
|
print(" PASS\n")
|
|
|
|
|
|
# =========================================================================
|
|
# Run all tests
|
|
# =========================================================================
|
|
|
|
if __name__ == "__main__":
|
|
# Filters
|
|
test_gaussian_filter()
|
|
test_median_filter()
|
|
test_crop_resize_field()
|
|
test_rotate_field()
|
|
test_colormap_adjust()
|
|
test_edge_detect()
|
|
test_fft_filter_1d()
|
|
test_fft_filter_2d()
|
|
|
|
# Level
|
|
test_plane_level()
|
|
test_poly_level()
|
|
test_fix_zero()
|
|
|
|
# Analysis
|
|
test_statistics()
|
|
test_height_histogram()
|
|
test_cross_section()
|
|
test_line_cursors()
|
|
test_fft2d()
|
|
test_line_math()
|
|
test_table_math()
|
|
test_stats()
|
|
|
|
# Mask
|
|
test_threshold_mask()
|
|
test_mask_morphology()
|
|
test_mask_invert()
|
|
test_mask_combine()
|
|
test_draw_mask()
|
|
|
|
# Grains
|
|
test_particle_analysis()
|
|
|
|
# I/O
|
|
test_load_file()
|
|
test_load_file_ibw()
|
|
test_load_file_npz()
|
|
test_load_file_not_found()
|
|
test_load_file_unsupported()
|
|
test_load_file_warning()
|
|
test_list_channels()
|
|
test_load_demo()
|
|
test_coordinate()
|
|
test_range_slider()
|
|
test_save_image()
|
|
|
|
# Display
|
|
test_preview_image()
|
|
test_print_table()
|
|
test_value_display()
|
|
test_view3d()
|
|
|
|
print("All tests passed!")
|