Files
tono/backend/nodes/curvature.py
2026-03-28 22:34:10 -07:00

385 lines
12 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from scipy.ndimage import map_coordinates
from backend.data_types import (
DataField,
LineData,
RecordTable,
_apply_markup_overlay,
encode_preview,
render_datafield_preview,
)
from backend.execution_context import emit_preview, emit_table, emit_warning
from backend.node_registry import register_node
from backend.nodes.surface_common import require_compatible_xy_z_units
_CURVATURE_COLOR = "#ff9800"
_CENTER_COLOR = "#8bd3ff"
@dataclass(frozen=True)
class _Intersection:
t: float
x: float
y: float
def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None:
if mask is None:
return None
mask_array = np.asarray(mask)
if mask_array.shape[:2] != shape:
raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.")
return mask_array > 127
def _canonicalize_half_pi(angle: float) -> float:
wrapped = (float(angle) + 0.5 * np.pi) % np.pi - 0.5 * np.pi
if wrapped <= -0.5 * np.pi + 1e-15:
wrapped += np.pi
return float(wrapped)
def _fit_quadratic_surface(data: np.ndarray, mask: np.ndarray | None, masking: str) -> np.ndarray | None:
yres, xres = data.shape
yy, xx = np.mgrid[0:yres, 0:xres]
x = 2.0 * xx.astype(np.float64) / max(xres - 1, 1) - 1.0
y = 2.0 * yy.astype(np.float64) / max(yres - 1, 1) - 1.0
valid = np.ones(data.shape, dtype=bool)
if mask is not None and masking != "ignore":
valid = mask if masking == "include" else ~mask
if np.count_nonzero(valid) < 6:
return None
design = np.column_stack([
np.ones(int(np.count_nonzero(valid)), dtype=np.float64),
x[valid],
x[valid] ** 2,
y[valid],
x[valid] * y[valid],
y[valid] ** 2,
])
coeffs, _, _, _ = np.linalg.lstsq(design, np.asarray(data, dtype=np.float64)[valid], rcond=None)
return np.asarray(coeffs, dtype=np.float64)
def _curvature_at_apex(coeffs: np.ndarray) -> tuple[int, float, float, float, float, float, float, float]:
a, bx, by, cxx, cxy, cyy = [float(value) for value in coeffs]
if abs(cxx) + abs(cxy) + abs(cyy) <= 1e-14 * (abs(bx) + abs(by)):
return 0, 0.0, 0.0, 0.0, float(0.5 * np.pi), 0.0, 0.0, a
cm = cxx - cyy
cp = cxx + cyy
phi = 0.5 * float(np.arctan2(cxy, cm))
radius = float(np.hypot(cm, cxy))
cx = cp + radius
cy = cp - radius
cos_phi = float(np.cos(phi))
sin_phi = float(np.sin(phi))
bx1 = bx * cos_phi + by * sin_phi
by1 = -bx * sin_phi + by * cos_phi
if abs(cx) < 1e-14 * abs(cy):
xc = 0.0
yc = -by1 / cy
degree = 1
elif abs(cy) < 1e-14 * abs(cx):
xc = -bx1 / cx
yc = 0.0
degree = 1
else:
xc = -bx1 / cx
yc = -by1 / cy
degree = 2
x_center = xc * cos_phi - yc * sin_phi
y_center = xc * sin_phi + yc * cos_phi
z_center = a + xc * bx1 + yc * by1 + xc * xc * cx + yc * yc * cy
if cx > cy:
cx, cy = cy, cx
phi += 0.5 * np.pi
phi = -phi
phi1 = _canonicalize_half_pi(phi)
phi2 = _canonicalize_half_pi(phi + 0.5 * np.pi)
return degree, float(cx), float(cy), phi1, phi2, float(x_center), float(y_center), float(z_center)
def _compute_curvature_results(
field: DataField,
mask: np.ndarray | None,
masking: str,
) -> dict[str, float] | None:
coeffs = _fit_quadratic_surface(np.asarray(field.data, dtype=np.float64), mask, masking)
if coeffs is None:
return None
xres = field.xres
yres = field.yres
xreal = float(field.xreal)
yreal = float(field.yreal)
qx = 2.0 / xreal * xres / max(xres - 1.0, 1.0)
qy = 2.0 / yreal * yres / max(yres - 1.0, 1.0)
q = float(np.sqrt(qx * qy))
mx = float(np.sqrt(qx / qy))
my = float(np.sqrt(qy / qx))
ccoeffs = np.array([
coeffs[0],
mx * coeffs[1],
my * coeffs[3],
mx * mx * coeffs[2],
coeffs[4],
my * my * coeffs[5],
], dtype=np.float64)
degree, kappa1, kappa2, phi1, phi2, xc, yc, zc = _curvature_at_apex(ccoeffs)
x_norm = xc * mx
y_norm = yc * my
zc = float(
coeffs[0]
+ coeffs[1] * x_norm
+ coeffs[2] * x_norm * x_norm
+ coeffs[3] * y_norm
+ coeffs[4] * x_norm * y_norm
+ coeffs[5] * y_norm * y_norm
)
r1 = float("inf") if abs(kappa1) <= 1e-14 else float(1.0 / (q * q * kappa1))
r2 = float("inf") if abs(kappa2) <= 1e-14 else float(1.0 / (q * q * kappa2))
x0 = float(xc / q + 0.5 * xreal + field.xoff)
y0 = float(yc / q + 0.5 * yreal + field.yoff)
print(f"debug: {x0}, {y0}, {r1}, {r2}")
return {
"degree": float(degree),
"x0": x0,
"y0": y0,
"z0": float(zc),
"r1": r1,
"r2": r2,
"phi1": float(phi1),
"phi2": float(phi2),
}
def _line_intersections(
x0: float,
y0: float,
phi: float,
x_min: float,
y_min: float,
width: float,
height: float,
) -> tuple[_Intersection, _Intersection] | None:
dx = float(np.cos(phi))
dy = float(np.sin(phi))
points: list[_Intersection] = []
eps = 1e-12
x_max = x_min + width
y_max = y_min + height
if abs(dx) > eps:
for x in (x_min, x_max):
t = (x - x0) / dx
y = y0 + t * dy
if y_min - eps <= y <= y_max + eps:
points.append(_Intersection(float(t), float(np.clip(x, x_min, x_max)), float(np.clip(y, y_min, y_max))))
if abs(dy) > eps:
for y in (y_min, y_max):
t = (y - y0) / dy
x = x0 + t * dx
if x_min - eps <= x <= x_max + eps:
points.append(_Intersection(float(t), float(np.clip(x, x_min, x_max)), float(np.clip(y, y_min, y_max))))
unique: list[_Intersection] = []
for point in sorted(points, key=lambda item: item.t):
if unique and abs(point.x - unique[-1].x) < 1e-9 and abs(point.y - unique[-1].y) < 1e-9:
continue
unique.append(point)
if len(unique) < 2:
return None
return unique[0], unique[-1]
def _profile_from_intersections(field: DataField, start: _Intersection, end: _Intersection) -> LineData:
x_start = start.x - field.xoff
y_start = start.y - field.yoff
x_end = end.x - field.xoff
y_end = end.y - field.yoff
px1 = x_start / max(field.xreal, 1e-30) * max(field.xres - 1, 0)
py1 = y_start / max(field.yreal, 1e-30) * max(field.yres - 1, 0)
px2 = x_end / max(field.xreal, 1e-30) * max(field.xres - 1, 0)
py2 = y_end / max(field.yreal, 1e-30) * max(field.yres - 1, 0)
n_samples = max(2, int(np.ceil(np.hypot(px2 - px1, py2 - py1))))
t = np.linspace(0.0, 1.0, n_samples, dtype=np.float64)
coords_y = py1 + t * (py2 - py1)
coords_x = px1 + t * (px2 - px1)
profile = map_coordinates(field.data, [coords_y, coords_x], order=1, mode="nearest")
axis = np.linspace(start.t, end.t, n_samples, dtype=np.float64)
return LineData(data=np.asarray(profile, dtype=np.float64), x_axis=axis, x_unit=field.si_unit_xy, y_unit=field.si_unit_z)
def _curvature_markup(
field: DataField,
center_x: float,
center_y: float,
intersections: list[tuple[_Intersection, _Intersection]],
) -> dict[str, object]:
shapes: list[dict[str, object]] = []
for start, end in intersections:
shapes.append({
"kind": "line",
"x1": (start.x - field.xoff) / max(field.xreal, 1e-30),
"y1": (start.y - field.yoff) / max(field.yreal, 1e-30),
"x2": (end.x - field.xoff) / max(field.xreal, 1e-30),
"y2": (end.y - field.yoff) / max(field.yreal, 1e-30),
"width": 3,
"color": _CURVATURE_COLOR,
})
if np.isfinite(center_x) and np.isfinite(center_y):
radius = 0.015
fx = (center_x - field.xoff) / max(field.xreal, 1e-30)
fy = (center_y - field.yoff) / max(field.yreal, 1e-30)
shapes.append({
"kind": "circle",
"x1": fx - radius,
"y1": fy - radius,
"x2": fx + radius,
"y2": fy + radius,
"width": 2,
"color": _CENTER_COLOR,
})
return {"kind": "markup", "shapes": shapes}
def _empty_profile(unit_xy: str, unit_z: str) -> LineData:
return LineData(data=np.zeros(0, dtype=np.float64), x_axis=np.zeros(0, dtype=np.float64), x_unit=unit_xy, y_unit=unit_z)
@register_node(display_name="Curvature")
class Curvature:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"masking": (["ignore", "include", "exclude"], {"default": "ignore"}),
},
"optional": {
"mask": ("IMAGE",),
},
}
OUTPUTS = (
('ANNOTATION_SOURCE', 'output'),
('RECORD_TABLE', 'measurements'),
('LINE', 'profile_x'),
('LINE', 'profile_y'),
)
FUNCTION = "process"
DESCRIPTION = (
"Fit a quadratic surface and report the overall principal curvature radii and directions, matching "
"Gwyddion's curvature feature. The output annotation marks the principal cross-sections and the node "
"also returns the two corresponding height profiles."
)
def process(
self,
field: DataField,
masking: str,
mask: np.ndarray | None = None,
) -> tuple:
require_compatible_xy_z_units(field, "Curvature")
mask_array = _normalize_mask(mask, field.data.shape)
results = _compute_curvature_results(field, mask_array, masking)
if results is None:
emit_warning("Curvature requires at least six usable pixels for the quadratic fit.")
table = RecordTable([])
emit_table(table)
emit_preview(encode_preview(render_datafield_preview(field, field.colormap)))
empty = _empty_profile(field.si_unit_xy, field.si_unit_z)
return (field.replace(), table, empty, empty)
intersections: list[tuple[_Intersection, _Intersection]] = []
warnings: list[str] = []
for angle_key in ("phi1", "phi2"):
hit = _line_intersections(
results["x0"],
results["y0"],
-results[angle_key],
field.xoff,
field.yoff,
field.xreal,
field.yreal,
)
if hit is None:
warnings.append("Principal axes are outside the image.")
else:
intersections.append(hit)
profiles = []
for pair in intersections[:2]:
profiles.append(_profile_from_intersections(field, pair[1], pair[0]))
while len(profiles) < 2:
profiles.append(_empty_profile(field.si_unit_xy, field.si_unit_z))
markup_spec = _curvature_markup(field, results["x0"], results["y0"], intersections)
output = field.replace(overlays=[*field.overlays, markup_spec])
table = RecordTable([
{"quantity": "Curvature radius 1", "value": results["r1"], "unit": field.si_unit_xy},
{"quantity": "Curvature radius 2", "value": results["r2"], "unit": field.si_unit_xy},
{"quantity": "Center x position", "value": results["x0"], "unit": field.si_unit_xy},
{"quantity": "Center y position", "value": results["y0"], "unit": field.si_unit_xy},
{"quantity": "Center value", "value": results["z0"], "unit": field.si_unit_z},
{"quantity": "Direction 1", "value": results["phi1"], "unit": "deg"},
{"quantity": "Direction 2", "value": results["phi2"], "unit": "deg"},
])
preview_base = render_datafield_preview(field, field.colormap)
panels = []
for p, title in zip(profiles, ["X Principal Axis", "Y Principal Axis"]):
if len(p.data) > 0:
panels.append({
"title": title,
"kind": "line_plot",
"line": p.data.tolist(),
"x_axis": p.x_axis.tolist(),
"x_unit": field.si_unit_xy,
})
panels.append({
"title": "Overview",
"kind": "image",
"image": encode_preview(_apply_markup_overlay(preview_base, field, markup_spec)),
})
emit_preview({"kind": "panels", "panels": panels})
# emit_table(table)
if warnings:
emit_warning(warnings[0])
return (output, table, profiles[0], profiles[1])