diff --git a/.coverage b/.coverage new file mode 100644 index 0000000..26c5629 Binary files /dev/null and b/.coverage differ diff --git a/backend/server.py b/backend/server.py index 193f6a7..095194f 100644 --- a/backend/server.py +++ b/backend/server.py @@ -74,6 +74,20 @@ class _SafeEncoder(json.JSONEncoder): return super().default(obj) +def _sanitize_non_finite(obj): + """Recursively replace non-finite floats so they survive JSON serialization.""" + if isinstance(obj, float): + if math.isnan(obj): + return "NaN" + if math.isinf(obj): + return "∞" if obj > 0 else "-∞" + elif isinstance(obj, dict): + return {k: _sanitize_non_finite(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_sanitize_non_finite(v) for v in obj] + return obj + + def _dumps(obj) -> str: return json.dumps(obj, cls=_SafeEncoder) @@ -190,7 +204,7 @@ def create_app( broadcast(session_id, {"type": "preview", "data": {"node_id": node_id, "image": data_uri}}) def on_table(session_id: str, node_id: str, rows: list) -> None: - broadcast(session_id, {"type": "table", "data": {"node_id": node_id, "rows": rows}}) + broadcast(session_id, {"type": "table", "data": {"node_id": node_id, "rows": _sanitize_non_finite(rows)}}) def on_mesh(session_id: str, node_id: str, mesh_data: dict) -> None: broadcast(session_id, {"type": "mesh3d", "data": {"node_id": node_id, "mesh": mesh_data}}) diff --git a/pyproject.toml b/pyproject.toml index e83b80c..5e45366 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest>=8,<9", + "pytest-cov>=7,<8", ] desktop = [ "pyinstaller>=6,<7", @@ -31,3 +32,11 @@ desktop = [ [tool.setuptools.packages.find] include = ["backend*"] + +[tool.coverage.run] +source = ["backend"] +omit = ["backend/nodes/__init__.py"] + +[tool.coverage.report] +show_missing = true +skip_covered = false diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 45491f6..26731f7 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -704,7 +704,7 @@ def test_curvature(): recovered_radii = sorted([rows["Curvature radius 1"]["value"], rows["Curvature radius 2"]["value"]]) expected_radii = sorted([rx, ry]) assert len(previews) == 1 - assert previews[0].startswith("data:image/png;base64,") + assert isinstance(previews[0], dict) and previews[0].get("kind") == "panels" assert len(tables) == 1 assert abs(rows["Center x position"]["value"] - x0) < xreal * 0.02 assert abs(rows["Center y position"]["value"] - y0) < yreal * 0.02 @@ -743,6 +743,127 @@ def test_curvature(): print(" PASS\n") +def test_curvature_flat_surface(): + """A perfectly flat surface has zero curvature — both radii must be float('inf').""" + print("=== Test: Curvature (flat surface → inf radii) ===") + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.curvature import Curvature + + node = Curvature() + data = np.zeros((64, 64), dtype=np.float64) + field = DataField(data=data, xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m") + + warnings = [] + tables = [] + with execution_callbacks( + preview=lambda nid, v: None, + table=lambda nid, rows: tables.append(rows), + warning=lambda nid, msg: warnings.append(msg), + ), active_node("test"): + _, table, _, _ = node.process(field, masking="ignore") + + rows = {row["quantity"]: row for row in table} + assert rows["Curvature radius 1"]["value"] == float("inf") + assert rows["Curvature radius 2"]["value"] == float("inf") + # No warnings expected for a valid (flat) surface + assert len(warnings) == 0 + print(" PASS\n") + + +def test_curvature_cylindrical(): + """A cylindrical surface is curved in one direction only — one radius finite, one inf.""" + print("=== Test: Curvature (cylindrical → one inf radius) ===") + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.curvature import Curvature + + node = Curvature() + N = 64 + xreal = yreal = 1e-6 + x = np.linspace(-xreal / 2, xreal / 2, N, dtype=np.float64) + xx = np.broadcast_to(x, (N, N)) + r_x = 0.8e-6 + # Curved parabolically in x, flat in y + data = xx**2 / (2.0 * r_x) + field = DataField(data=data, xreal=xreal, yreal=yreal, si_unit_xy="m", si_unit_z="m") + + tables = [] + with execution_callbacks( + preview=lambda nid, v: None, + table=lambda nid, rows: tables.append(rows), + ), active_node("test"): + _, table, _, _ = node.process(field, masking="ignore") + + rows = {row["quantity"]: row for row in table} + radii = sorted([rows["Curvature radius 1"]["value"], rows["Curvature radius 2"]["value"]]) + # One radius should be finite (≈ r_x), the other infinite + finite = [r for r in radii if np.isfinite(r)] + infinite = [r for r in radii if not np.isfinite(r)] + assert len(finite) == 1, f"Expected 1 finite radius, got {radii}" + assert len(infinite) == 1, f"Expected 1 inf radius, got {radii}" + assert abs(finite[0] - r_x) < r_x * 0.1, f"Finite radius {finite[0]} far from expected {r_x}" + print(" PASS\n") + + +def test_curvature_too_few_pixels(): + """Curvature with fewer than 6 valid pixels emits a warning and returns an empty table.""" + print("=== Test: Curvature (too few valid pixels) ===") + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.curvature import Curvature + + node = Curvature() + N = 16 + data = np.random.default_rng(0).standard_normal((N, N)) + field = DataField(data=data, xreal=1e-6, yreal=1e-6, si_unit_xy="m", si_unit_z="m") + + # Mask with only 4 'include' pixels — below the 6-pixel minimum + mask = np.zeros((N, N), dtype=np.uint8) + mask[N // 2, N // 2:N // 2 + 4] = 255 + + warnings = [] + tables = [] + with execution_callbacks( + preview=lambda nid, v: None, + table=lambda nid, rows: tables.append(rows), + warning=lambda nid, msg: warnings.append(msg), + ), active_node("test"): + _, table, profile1, profile2 = node.process(field, masking="include", mask=mask) + + assert len(warnings) == 1 + assert "six" in warnings[0].lower() or "6" in warnings[0] + assert len(list(table)) == 0 + # Empty profiles are returned + assert len(profile1.data) == 0 + assert len(profile2.data) == 0 + print(" PASS\n") + + +def test_curvature_inf_json_safe(): + """inf radii from curvature must not produce invalid JSON when sent over the wire.""" + print("=== Test: Curvature (inf radii → valid JSON via server sanitizer) ===") + import json + from backend.server import _sanitize_non_finite, _dumps + + # Simulate a table row as produced by the curvature node for a flat surface + rows = [ + {"quantity": "Curvature radius 1", "value": float("inf"), "unit": "m"}, + {"quantity": "Curvature radius 2", "value": float("-inf"), "unit": "m"}, + {"quantity": "Center value", "value": float("nan"), "unit": "m"}, + {"quantity": "Center x position", "value": 1.5e-7, "unit": "m"}, + ] + + sanitized = _sanitize_non_finite(rows) + assert sanitized[0]["value"] == "∞" + assert sanitized[1]["value"] == "-∞" + assert sanitized[2]["value"] == "NaN" + assert sanitized[3]["value"] == 1.5e-7 # finite float unchanged + + # Must not raise and must produce parseable JSON + payload = _dumps({"type": "table", "data": {"node_id": "n1", "rows": sanitized}}) + decoded = json.loads(payload) + assert decoded["data"]["rows"][0]["value"] == "∞" + print(" PASS\n") + + def test_line_correction(): print("=== Test: LineCorrection ===") from backend.node_registry import get_node_info