269 lines
9.0 KiB
Python
269 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
|
|
import numpy as np
|
|
from scipy.ndimage import label
|
|
|
|
from backend.execution_context import emit_preview
|
|
from backend.data_types import DataField, encode_preview
|
|
from backend.node_registry import register_node
|
|
from backend.nodes.helpers import _mask_overlay
|
|
|
|
|
|
def _working_height(field: DataField, invert_height: bool) -> np.ndarray:
|
|
data = np.asarray(field.data, dtype=np.float64)
|
|
return -data if invert_height else data.copy()
|
|
|
|
|
|
def _next_indices(data: np.ndarray) -> np.ndarray:
|
|
yres, xres = data.shape
|
|
flat_idx = np.arange(yres * xres, dtype=np.int64).reshape(yres, xres)
|
|
|
|
right_val = np.full_like(data, -np.inf, dtype=np.float64)
|
|
right_val[:, :-1] = data[:, 1:]
|
|
left_val = np.full_like(data, -np.inf, dtype=np.float64)
|
|
left_val[:, 1:] = data[:, :-1]
|
|
down_val = np.full_like(data, -np.inf, dtype=np.float64)
|
|
down_val[:-1, :] = data[1:, :]
|
|
up_val = np.full_like(data, -np.inf, dtype=np.float64)
|
|
up_val[1:, :] = data[:-1, :]
|
|
|
|
right_idx = flat_idx.copy()
|
|
right_idx[:, :-1] = flat_idx[:, 1:]
|
|
left_idx = flat_idx.copy()
|
|
left_idx[:, 1:] = flat_idx[:, :-1]
|
|
down_idx = flat_idx.copy()
|
|
down_idx[:-1, :] = flat_idx[1:, :]
|
|
up_idx = flat_idx.copy()
|
|
up_idx[1:, :] = flat_idx[:-1, :]
|
|
|
|
next_idx = flat_idx.copy()
|
|
local = (
|
|
(data >= right_val)
|
|
& (data >= left_val)
|
|
& (data >= down_val)
|
|
& (data >= up_val)
|
|
)
|
|
|
|
right_mask = (~local) & (right_val >= data) & (right_val >= left_val) & (right_val >= down_val) & (right_val >= up_val)
|
|
next_idx[right_mask] = right_idx[right_mask]
|
|
|
|
unresolved = (~local) & (~right_mask)
|
|
left_mask = unresolved & (left_val >= data) & (left_val >= right_val) & (left_val >= down_val) & (left_val >= up_val)
|
|
next_idx[left_mask] = left_idx[left_mask]
|
|
|
|
unresolved &= ~left_mask
|
|
down_mask = unresolved & (down_val >= data) & (down_val >= right_val) & (down_val >= left_val) & (down_val >= up_val)
|
|
next_idx[down_mask] = down_idx[down_mask]
|
|
|
|
unresolved &= ~down_mask
|
|
next_idx[unresolved] = up_idx[unresolved]
|
|
return next_idx.ravel()
|
|
|
|
|
|
def _terminal_indices(data: np.ndarray) -> np.ndarray:
|
|
terminals = _next_indices(np.asarray(data, dtype=np.float64))
|
|
while True:
|
|
jumped = terminals[terminals]
|
|
if np.array_equal(jumped, terminals):
|
|
return terminals
|
|
terminals = jumped
|
|
|
|
|
|
@lru_cache(maxsize=32)
|
|
def _source_order(shape: tuple[int, int]) -> np.ndarray:
|
|
yres, xres = shape
|
|
if yres < 3 or xres < 3:
|
|
return np.zeros(0, dtype=np.int64)
|
|
rows, cols = np.mgrid[1:yres - 1, 1:xres - 1]
|
|
order = (rows.ravel(order="F") * xres + cols.ravel(order="F")).astype(np.int64)
|
|
order.setflags(write=False)
|
|
return order
|
|
|
|
|
|
def _location_step(data: np.ndarray, water: np.ndarray, dropsize: float) -> None:
|
|
terminals = _terminal_indices(data)
|
|
ordered_sources = _source_order(data.shape)
|
|
counts = np.bincount(terminals[ordered_sources], minlength=data.size).astype(np.float64)
|
|
water += counts.reshape(data.shape)
|
|
data -= dropsize * counts.reshape(data.shape)
|
|
|
|
|
|
def _seed_labels(water: np.ndarray, threshold: int) -> np.ndarray:
|
|
structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.int8)
|
|
labeled, ngrains = label(water > 0.0, structure=structure)
|
|
if ngrains <= 0:
|
|
return np.zeros_like(labeled, dtype=np.int32)
|
|
|
|
sizes = np.bincount(labeled.ravel(), minlength=ngrains + 1)
|
|
seeds = np.zeros_like(labeled, dtype=np.int32)
|
|
next_label = 1
|
|
flat_water = water.ravel()
|
|
flat_labeled = labeled.ravel()
|
|
|
|
for grain_id in range(1, ngrains + 1):
|
|
if int(sizes[grain_id]) <= int(threshold):
|
|
continue
|
|
indices = np.flatnonzero(flat_labeled == grain_id)
|
|
if indices.size == 0:
|
|
continue
|
|
peak_index = int(indices[np.argmax(flat_water[indices])])
|
|
seeds.ravel()[peak_index] = next_label
|
|
next_label += 1
|
|
|
|
return seeds
|
|
|
|
|
|
def _process_mask(labels: np.ndarray, row: int, col: int) -> None:
|
|
yres, xres = labels.shape
|
|
if col == 0 or row == 0 or col == xres - 1 or row == yres - 1:
|
|
labels[row, col] = -1
|
|
return
|
|
|
|
if labels[row, col] != 0:
|
|
return
|
|
|
|
left = int(labels[row, col - 1])
|
|
up = int(labels[row - 1, col])
|
|
right = int(labels[row, col + 1])
|
|
down = int(labels[row + 1, col])
|
|
|
|
if abs(left) + abs(up) + abs(right) + abs(down) == 0:
|
|
return
|
|
|
|
value = 0
|
|
boundary = False
|
|
for candidate in (left, up, right, down):
|
|
if value > 0 and candidate > 0 and candidate != value:
|
|
boundary = True
|
|
break
|
|
if candidate > 0:
|
|
value = candidate
|
|
|
|
labels[row, col] = -1 if boundary else value
|
|
|
|
|
|
def _watershed_step(
|
|
data: np.ndarray,
|
|
water: np.ndarray,
|
|
labels: np.ndarray,
|
|
seeds: np.ndarray,
|
|
dropsize: float,
|
|
) -> None:
|
|
labels[seeds > 0] = seeds[seeds > 0]
|
|
|
|
terminals = _terminal_indices(data)
|
|
ordered_sources = _source_order(data.shape)
|
|
ordered_terminals = terminals[ordered_sources]
|
|
xres = data.shape[1]
|
|
|
|
for term in ordered_terminals:
|
|
row = int(term // xres)
|
|
col = int(term % xres)
|
|
_process_mask(labels, row, col)
|
|
|
|
counts = np.bincount(ordered_terminals, minlength=data.size).astype(np.float64)
|
|
water += counts.reshape(data.shape)
|
|
data -= dropsize * counts.reshape(data.shape)
|
|
|
|
|
|
def _mark_boundaries(labels: np.ndarray) -> np.ndarray:
|
|
result = labels.copy()
|
|
if result.shape[0] < 3 or result.shape[1] < 3:
|
|
return result
|
|
|
|
interior = result[1:-1, 1:-1]
|
|
right = result[1:-1, 2:]
|
|
down = result[2:, 1:-1]
|
|
interior[(interior != right) | (interior != down)] = 0
|
|
return result
|
|
|
|
|
|
def _combine_masks(result_mask: np.ndarray, existing_mask: np.ndarray | None, combine_mode: str) -> np.ndarray:
|
|
if existing_mask is None or combine_mode == "replace":
|
|
return result_mask
|
|
|
|
existing = np.asarray(existing_mask) > 127
|
|
current = np.asarray(result_mask, dtype=bool)
|
|
if existing.shape != current.shape:
|
|
raise ValueError("Existing mask must have the same shape as the watershed output.")
|
|
|
|
if combine_mode == "union":
|
|
merged = current | existing
|
|
elif combine_mode == "intersection":
|
|
merged = current & existing
|
|
else:
|
|
raise ValueError(f"Unsupported combine mode: {combine_mode}")
|
|
|
|
return merged.astype(np.uint8) * 255
|
|
|
|
|
|
@register_node(display_name="Watershed Segmentation")
|
|
class WatershedSegmentation:
|
|
_CUSTOM_PREVIEW = True
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"field": ("DATA_FIELD",),
|
|
"invert_height": ("BOOLEAN", {"default": False}),
|
|
"locate_steps": ("INT", {"default": 10, "min": 1, "max": 200, "step": 1}),
|
|
"locate_threshold": ("INT", {"default": 10, "min": 0, "max": 100000, "step": 1}),
|
|
"locate_drop_size": ("FLOAT", {"default": 0.1, "min": 0.0001, "max": 1.0, "step": 0.01}),
|
|
"watershed_steps": ("INT", {"default": 20, "min": 1, "max": 2000, "step": 1}),
|
|
"watershed_drop_size": ("FLOAT", {"default": 0.1, "min": 0.0001, "max": 1.0, "step": 0.01}),
|
|
"combine_mode": (["replace", "union", "intersection"], {"default": "replace"}),
|
|
},
|
|
"optional": {
|
|
"mask": ("IMAGE",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("mask",)
|
|
FUNCTION = "process"
|
|
|
|
DESCRIPTION = (
|
|
"Segment a height field into grains using the two-stage Gwyddion watershed workflow: "
|
|
"drop-based seed location followed by watershed growth. Supports hill or valley detection "
|
|
"and optional union/intersection with an existing mask."
|
|
)
|
|
|
|
def process(
|
|
self,
|
|
field: DataField,
|
|
invert_height: bool,
|
|
locate_steps: int,
|
|
locate_threshold: int,
|
|
locate_drop_size: float,
|
|
watershed_steps: int,
|
|
watershed_drop_size: float,
|
|
combine_mode: str,
|
|
mask: np.ndarray | None = None,
|
|
) -> tuple:
|
|
working = _working_height(field, bool(invert_height))
|
|
water = np.zeros_like(working, dtype=np.float64)
|
|
|
|
q = float((np.max(working) - np.min(working)) / 50.0)
|
|
locate_drop = float(locate_drop_size) * q
|
|
watershed_drop = float(watershed_drop_size) * q
|
|
|
|
locate_field = working.copy()
|
|
for _ in range(int(locate_steps)):
|
|
_location_step(locate_field, water, locate_drop)
|
|
|
|
seeds = _seed_labels(water, int(locate_threshold))
|
|
labels = np.zeros_like(seeds, dtype=np.int32)
|
|
watershed_field = working.copy()
|
|
for _ in range(int(watershed_steps)):
|
|
_watershed_step(watershed_field, water, labels, seeds, watershed_drop)
|
|
|
|
labels = _mark_boundaries(labels)
|
|
result_mask = (labels > 0).astype(np.uint8) * 255
|
|
result_mask = _combine_masks(result_mask, mask, combine_mode)
|
|
|
|
emit_preview(encode_preview(_mask_overlay(field, result_mask)))
|
|
return (result_mask,)
|