Files
tono/backend/nodes/grain_visualization.py

241 lines
8.5 KiB
Python

"""Grain visualization — visualize grains as geometric shapes."""
from __future__ import annotations
import numpy as np
from scipy.ndimage import label, find_objects, distance_transform_edt
from backend.node_registry import register_node
from backend.data_types import DataField
from backend.nodes.helpers import mask_to_bool, bool_to_mask
def _grain_centroid(grain_mask: np.ndarray, slc: tuple[slice, slice]) -> tuple[float, float]:
"""Return (cy, cx) centroid of a grain within its bounding slice."""
ys, xs = np.where(grain_mask[slc])
cy = float(ys.mean()) + slc[0].start
cx = float(xs.mean()) + slc[1].start
return cy, cx
def _grain_inscribed_radius(grain_mask: np.ndarray, slc: tuple[slice, slice]) -> float:
"""Return the inscribed disc radius for a grain region."""
region = grain_mask[slc]
if not np.any(region):
return 0.0
dt = distance_transform_edt(region)
return float(dt.max())
def _grain_inertia(grain_mask: np.ndarray, slc: tuple[slice, slice]) -> tuple[float, float, float]:
"""Return (semi_major, semi_minor, angle_rad) from the inertia tensor."""
ys, xs = np.where(grain_mask[slc])
cy_local = ys.mean()
cx_local = xs.mean()
dy = ys - cy_local
dx = xs - cx_local
n = len(ys)
if n < 2:
return 1.0, 1.0, 0.0
Ixx = np.sum(dy * dy) / n
Iyy = np.sum(dx * dx) / n
Ixy = -np.sum(dx * dy) / n
# Eigenvalues of the 2x2 inertia tensor
mean_I = (Ixx + Iyy) / 2.0
diff_I = (Ixx - Iyy) / 2.0
discriminant = max(0.0, diff_I ** 2 + Ixy ** 2)
sqrt_disc = np.sqrt(discriminant)
lambda1 = mean_I + sqrt_disc
lambda2 = mean_I - sqrt_disc
# Semi-axes proportional to sqrt of eigenvalues, scaled by 2 for visual size
semi_major = 2.0 * np.sqrt(max(lambda1, 0.0))
semi_minor = 2.0 * np.sqrt(max(lambda2, 0.0))
# Angle of the major axis
angle = 0.5 * np.arctan2(2.0 * Ixy, Iyy - Ixx)
return float(semi_major), float(semi_minor), float(angle)
def _draw_circle_filled(canvas: np.ndarray, cy: float, cx: float, r: float) -> None:
h, w = canvas.shape
y_lo = max(0, int(cy - r - 1))
y_hi = min(h, int(cy + r + 2))
x_lo = max(0, int(cx - r - 1))
x_hi = min(w, int(cx + r + 2))
yy, xx = np.ogrid[y_lo:y_hi, x_lo:x_hi]
dist_sq = (yy - cy) ** 2 + (xx - cx) ** 2
canvas[y_lo:y_hi, x_lo:x_hi] |= (dist_sq <= r * r)
def _draw_circle_outline(canvas: np.ndarray, cy: float, cx: float, r: float, thickness: float = 1.5) -> None:
h, w = canvas.shape
y_lo = max(0, int(cy - r - thickness - 1))
y_hi = min(h, int(cy + r + thickness + 2))
x_lo = max(0, int(cx - r - thickness - 1))
x_hi = min(w, int(cx + r + thickness + 2))
yy, xx = np.ogrid[y_lo:y_hi, x_lo:x_hi]
dist = np.sqrt((yy - cy) ** 2 + (xx - cx) ** 2)
canvas[y_lo:y_hi, x_lo:x_hi] |= (np.abs(dist - r) < thickness)
def _draw_rect_filled(canvas: np.ndarray, y0: int, y1: int, x0: int, x1: int) -> None:
h, w = canvas.shape
y0c, y1c = max(0, y0), min(h, y1)
x0c, x1c = max(0, x0), min(w, x1)
canvas[y0c:y1c, x0c:x1c] = True
def _draw_rect_outline(canvas: np.ndarray, y0: int, y1: int, x0: int, x1: int, thickness: int = 1) -> None:
h, w = canvas.shape
y0c, y1c = max(0, y0), min(h, y1)
x0c, x1c = max(0, x0), min(w, x1)
# Top edge
canvas[y0c:min(h, y0c + thickness), x0c:x1c] = True
# Bottom edge
canvas[max(0, y1c - thickness):y1c, x0c:x1c] = True
# Left edge
canvas[y0c:y1c, x0c:min(w, x0c + thickness)] = True
# Right edge
canvas[y0c:y1c, max(0, x1c - thickness):x1c] = True
def _draw_cross(canvas: np.ndarray, cy: float, cx: float, arm: int = 3) -> None:
h, w = canvas.shape
iy, ix = int(round(cy)), int(round(cx))
for d in range(-arm, arm + 1):
if 0 <= iy + d < h and 0 <= ix < w:
canvas[iy + d, ix] = True
if 0 <= iy < h and 0 <= ix + d < w:
canvas[iy, ix + d] = True
def _draw_ellipse_filled(canvas: np.ndarray, cy: float, cx: float,
semi_major: float, semi_minor: float, angle: float) -> None:
h, w = canvas.shape
r_max = max(semi_major, semi_minor, 1.0)
y_lo = max(0, int(cy - r_max - 1))
y_hi = min(h, int(cy + r_max + 2))
x_lo = max(0, int(cx - r_max - 1))
x_hi = min(w, int(cx + r_max + 2))
yy, xx = np.ogrid[y_lo:y_hi, x_lo:x_hi]
cos_a, sin_a = np.cos(angle), np.sin(angle)
dy = yy - cy
dx = xx - cx
# Rotate into ellipse-aligned coordinates
u = cos_a * dx + sin_a * dy
v = -sin_a * dx + cos_a * dy
a = max(semi_major, 0.5)
b = max(semi_minor, 0.5)
canvas[y_lo:y_hi, x_lo:x_hi] |= ((u / a) ** 2 + (v / b) ** 2 <= 1.0)
def _draw_ellipse_outline(canvas: np.ndarray, cy: float, cx: float,
semi_major: float, semi_minor: float, angle: float,
thickness: float = 1.5) -> None:
h, w = canvas.shape
r_max = max(semi_major, semi_minor, 1.0)
y_lo = max(0, int(cy - r_max - thickness - 1))
y_hi = min(h, int(cy + r_max + thickness + 2))
x_lo = max(0, int(cx - r_max - thickness - 1))
x_hi = min(w, int(cx + r_max + thickness + 2))
yy, xx = np.ogrid[y_lo:y_hi, x_lo:x_hi]
cos_a, sin_a = np.cos(angle), np.sin(angle)
dy = yy - cy
dx = xx - cx
u = cos_a * dx + sin_a * dy
v = -sin_a * dx + cos_a * dy
a = max(semi_major, 0.5)
b = max(semi_minor, 0.5)
ellipse_val = (u / a) ** 2 + (v / b) ** 2
canvas[y_lo:y_hi, x_lo:x_hi] |= (np.abs(np.sqrt(ellipse_val) - 1.0) < thickness / max(a, b))
@register_node(display_name="Grain Visualization")
class GrainVisualization:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"mask": ("IMAGE",),
"style": (["inscribed_disc", "bounding_box", "centroid", "ellipse"], {"default": "inscribed_disc"}),
"fill": ("BOOLEAN", {"default": False}),
}
}
OUTPUTS = (
('IMAGE', 'result'),
('DATA_FIELD', 'labeled'),
)
FUNCTION = "process"
DESCRIPTION = (
"Visualize labeled grains as geometric shapes — inscribed discs, bounding boxes, "
"centroid markers, or fitted ellipses. Produces a mask image with the chosen shapes "
"and a labeled field where each grain has a unique integer value. "
"Equivalent to Gwyddion's grain selection visualization (grain_makesel)."
)
KEYWORDS = ("ellipse", "disc", "bounding box", "centroid", "inscribed", "label")
def process(self, field: DataField, mask: np.ndarray, style: str, fill: bool) -> tuple:
mask_bool = mask_to_bool(mask)
labels, n_grains = label(mask_bool.astype(np.int32))
slices = find_objects(labels)
h, w = mask_bool.shape[:2]
canvas = np.zeros((h, w), dtype=bool)
for gid in range(1, n_grains + 1):
slc = slices[gid - 1]
if slc is None:
continue
grain_mask = labels == gid
cy, cx = _grain_centroid(grain_mask, slc)
if style == "inscribed_disc":
r = _grain_inscribed_radius(grain_mask, slc)
if r < 0.5:
r = 0.5
if fill:
_draw_circle_filled(canvas, cy, cx, r)
else:
_draw_circle_outline(canvas, cy, cx, r)
elif style == "bounding_box":
y0, y1 = slc[0].start, slc[0].stop
x0, x1 = slc[1].start, slc[1].stop
if fill:
_draw_rect_filled(canvas, y0, y1, x0, x1)
else:
_draw_rect_outline(canvas, y0, y1, x0, x1)
elif style == "centroid":
arm = max(3, int(round(min(h, w) * 0.01)))
_draw_cross(canvas, cy, cx, arm)
elif style == "ellipse":
semi_major, semi_minor, angle = _grain_inertia(grain_mask, slc)
if semi_major < 0.5:
semi_major = 0.5
if semi_minor < 0.5:
semi_minor = 0.5
if fill:
_draw_ellipse_filled(canvas, cy, cx, semi_major, semi_minor, angle)
else:
_draw_ellipse_outline(canvas, cy, cx, semi_major, semi_minor, angle)
else:
raise ValueError(f"Unknown visualization style: {style!r}")
result = bool_to_mask(canvas)
labeled_field = field.replace(data=labels.astype(np.float64))
return (result, labeled_field)