Files
tono/backend/nodes/grain_distance_transform.py

148 lines
5.1 KiB
Python

from __future__ import annotations
from functools import lru_cache
import numpy as np
from scipy.ndimage import binary_erosion, distance_transform_edt
from backend.data_types import DataField
from backend.node_registry import register_node
def _normalize_mask(mask: np.ndarray) -> np.ndarray:
data = np.asarray(mask)
if data.ndim != 2:
raise ValueError("Grain Distance Transform requires a 2-D mask.")
return data > 127
def _prepare_mask(binary: np.ndarray, from_border: bool) -> tuple[np.ndarray, tuple[slice, slice]]:
binary = np.asarray(binary, dtype=bool)
if from_border:
return binary, (slice(None), slice(None))
pad = max(binary.shape)
padded = np.pad(binary, pad, mode="constant", constant_values=True)
padded[0, :] = False
padded[-1, :] = False
padded[:, 0] = False
padded[:, -1] = False
return padded, (slice(pad, pad + binary.shape[0]), slice(pad, pad + binary.shape[1]))
@lru_cache(maxsize=32)
def _distance_structures() -> tuple[np.ndarray, np.ndarray]:
cross = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=bool)
square = np.ones((3, 3), dtype=bool)
cross.setflags(write=False)
square.setflags(write=False)
return cross, square
def _simple_distance_transform(binary: np.ndarray, distance_type: str, from_border: bool) -> np.ndarray:
work, crop = _prepare_mask(binary, from_border)
result = np.zeros(work.shape, dtype=np.float64)
current = work.copy()
cross, square = _distance_structures()
if distance_type == "cityblock":
sequence = (cross,)
elif distance_type == "chess":
sequence = (square,)
elif distance_type == "octagonal48":
sequence = (cross, square)
elif distance_type == "octagonal84":
sequence = (square, cross)
else:
raise ValueError(f"Unsupported simple distance type: {distance_type}")
step = 1.0
iteration = 0
while np.any(current):
structure = sequence[iteration % len(sequence)]
eroded = binary_erosion(current, structure=structure, border_value=0)
removed = current & ~eroded
result[removed] = step
current = eroded
step += 1.0
iteration += 1
return result[crop]
def _euclidean_distance_transform(binary: np.ndarray, from_border: bool) -> np.ndarray:
if from_border:
work = np.pad(np.asarray(binary, dtype=bool), 1, mode="constant", constant_values=False)
return np.asarray(distance_transform_edt(work), dtype=np.float64)[1:-1, 1:-1]
work, crop = _prepare_mask(binary, False)
return np.asarray(distance_transform_edt(work), dtype=np.float64)[crop]
def _distance_transform(binary: np.ndarray, distance_type: str, from_border: bool) -> np.ndarray:
if distance_type == "euclidean":
return _euclidean_distance_transform(binary, from_border)
if distance_type == "octagonal":
d48 = _simple_distance_transform(binary, "octagonal48", from_border)
d84 = _simple_distance_transform(binary, "octagonal84", from_border)
return 0.5 * (d48 + d84)
return _simple_distance_transform(binary, distance_type, from_border)
@register_node(display_name="Grain Distance Transform")
class GrainDistanceTransform:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"mask": ("IMAGE",),
"distance_type": (["euclidean", "cityblock", "chess", "octagonal48", "octagonal84", "octagonal"], {"default": "euclidean"}),
"output_type": (["interior", "exterior", "signed"], {"default": "interior"}),
"from_border": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("distance",)
FUNCTION = "process"
DESCRIPTION = (
"Compute the mask distance transform using Gwyddion-style interior, exterior, or signed output. "
"Supports Euclidean, city-block, chessboard, and octagonal distance variants, with optional "
"image-boundary handling matching mask_edt."
)
def process(
self,
field: DataField,
mask: np.ndarray,
distance_type: str,
output_type: str,
from_border: bool,
) -> tuple:
binary = _normalize_mask(mask)
interior = _distance_transform(binary, distance_type, bool(from_border))
interior *= binary
if output_type == "interior":
distance = interior
else:
exterior_binary = ~binary
exterior = _distance_transform(exterior_binary, distance_type, bool(from_border))
exterior *= exterior_binary
if output_type == "exterior":
distance = exterior
elif output_type == "signed":
distance = interior - exterior
else:
raise ValueError(f"Unsupported output type: {output_type}")
scale = float(np.sqrt(field.dx * field.dy))
result = field.replace(
data=np.asarray(distance, dtype=np.float64) * scale,
si_unit_z=field.si_unit_xy,
)
return (result,)