add remaining high value features
This commit is contained in:
268
backend/nodes/watershed_segmentation.py
Normal file
268
backend/nodes/watershed_segmentation.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage import label
|
||||
|
||||
from backend.execution_context import emit_preview
|
||||
from backend.data_types import DataField, encode_preview
|
||||
from backend.node_registry import register_node
|
||||
from backend.nodes.helpers import _mask_overlay
|
||||
|
||||
|
||||
def _working_height(field: DataField, invert_height: bool) -> np.ndarray:
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
return -data if invert_height else data.copy()
|
||||
|
||||
|
||||
def _next_indices(data: np.ndarray) -> np.ndarray:
|
||||
yres, xres = data.shape
|
||||
flat_idx = np.arange(yres * xres, dtype=np.int64).reshape(yres, xres)
|
||||
|
||||
right_val = np.full_like(data, -np.inf, dtype=np.float64)
|
||||
right_val[:, :-1] = data[:, 1:]
|
||||
left_val = np.full_like(data, -np.inf, dtype=np.float64)
|
||||
left_val[:, 1:] = data[:, :-1]
|
||||
down_val = np.full_like(data, -np.inf, dtype=np.float64)
|
||||
down_val[:-1, :] = data[1:, :]
|
||||
up_val = np.full_like(data, -np.inf, dtype=np.float64)
|
||||
up_val[1:, :] = data[:-1, :]
|
||||
|
||||
right_idx = flat_idx.copy()
|
||||
right_idx[:, :-1] = flat_idx[:, 1:]
|
||||
left_idx = flat_idx.copy()
|
||||
left_idx[:, 1:] = flat_idx[:, :-1]
|
||||
down_idx = flat_idx.copy()
|
||||
down_idx[:-1, :] = flat_idx[1:, :]
|
||||
up_idx = flat_idx.copy()
|
||||
up_idx[1:, :] = flat_idx[:-1, :]
|
||||
|
||||
next_idx = flat_idx.copy()
|
||||
local = (
|
||||
(data >= right_val)
|
||||
& (data >= left_val)
|
||||
& (data >= down_val)
|
||||
& (data >= up_val)
|
||||
)
|
||||
|
||||
right_mask = (~local) & (right_val >= data) & (right_val >= left_val) & (right_val >= down_val) & (right_val >= up_val)
|
||||
next_idx[right_mask] = right_idx[right_mask]
|
||||
|
||||
unresolved = (~local) & (~right_mask)
|
||||
left_mask = unresolved & (left_val >= data) & (left_val >= right_val) & (left_val >= down_val) & (left_val >= up_val)
|
||||
next_idx[left_mask] = left_idx[left_mask]
|
||||
|
||||
unresolved &= ~left_mask
|
||||
down_mask = unresolved & (down_val >= data) & (down_val >= right_val) & (down_val >= left_val) & (down_val >= up_val)
|
||||
next_idx[down_mask] = down_idx[down_mask]
|
||||
|
||||
unresolved &= ~down_mask
|
||||
next_idx[unresolved] = up_idx[unresolved]
|
||||
return next_idx.ravel()
|
||||
|
||||
|
||||
def _terminal_indices(data: np.ndarray) -> np.ndarray:
|
||||
terminals = _next_indices(np.asarray(data, dtype=np.float64))
|
||||
while True:
|
||||
jumped = terminals[terminals]
|
||||
if np.array_equal(jumped, terminals):
|
||||
return terminals
|
||||
terminals = jumped
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _source_order(shape: tuple[int, int]) -> np.ndarray:
|
||||
yres, xres = shape
|
||||
if yres < 3 or xres < 3:
|
||||
return np.zeros(0, dtype=np.int64)
|
||||
rows, cols = np.mgrid[1:yres - 1, 1:xres - 1]
|
||||
order = (rows.ravel(order="F") * xres + cols.ravel(order="F")).astype(np.int64)
|
||||
order.setflags(write=False)
|
||||
return order
|
||||
|
||||
|
||||
def _location_step(data: np.ndarray, water: np.ndarray, dropsize: float) -> None:
|
||||
terminals = _terminal_indices(data)
|
||||
ordered_sources = _source_order(data.shape)
|
||||
counts = np.bincount(terminals[ordered_sources], minlength=data.size).astype(np.float64)
|
||||
water += counts.reshape(data.shape)
|
||||
data -= dropsize * counts.reshape(data.shape)
|
||||
|
||||
|
||||
def _seed_labels(water: np.ndarray, threshold: int) -> np.ndarray:
|
||||
structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.int8)
|
||||
labeled, ngrains = label(water > 0.0, structure=structure)
|
||||
if ngrains <= 0:
|
||||
return np.zeros_like(labeled, dtype=np.int32)
|
||||
|
||||
sizes = np.bincount(labeled.ravel(), minlength=ngrains + 1)
|
||||
seeds = np.zeros_like(labeled, dtype=np.int32)
|
||||
next_label = 1
|
||||
flat_water = water.ravel()
|
||||
flat_labeled = labeled.ravel()
|
||||
|
||||
for grain_id in range(1, ngrains + 1):
|
||||
if int(sizes[grain_id]) <= int(threshold):
|
||||
continue
|
||||
indices = np.flatnonzero(flat_labeled == grain_id)
|
||||
if indices.size == 0:
|
||||
continue
|
||||
peak_index = int(indices[np.argmax(flat_water[indices])])
|
||||
seeds.ravel()[peak_index] = next_label
|
||||
next_label += 1
|
||||
|
||||
return seeds
|
||||
|
||||
|
||||
def _process_mask(labels: np.ndarray, row: int, col: int) -> None:
|
||||
yres, xres = labels.shape
|
||||
if col == 0 or row == 0 or col == xres - 1 or row == yres - 1:
|
||||
labels[row, col] = -1
|
||||
return
|
||||
|
||||
if labels[row, col] != 0:
|
||||
return
|
||||
|
||||
left = int(labels[row, col - 1])
|
||||
up = int(labels[row - 1, col])
|
||||
right = int(labels[row, col + 1])
|
||||
down = int(labels[row + 1, col])
|
||||
|
||||
if abs(left) + abs(up) + abs(right) + abs(down) == 0:
|
||||
return
|
||||
|
||||
value = 0
|
||||
boundary = False
|
||||
for candidate in (left, up, right, down):
|
||||
if value > 0 and candidate > 0 and candidate != value:
|
||||
boundary = True
|
||||
break
|
||||
if candidate > 0:
|
||||
value = candidate
|
||||
|
||||
labels[row, col] = -1 if boundary else value
|
||||
|
||||
|
||||
def _watershed_step(
|
||||
data: np.ndarray,
|
||||
water: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
seeds: np.ndarray,
|
||||
dropsize: float,
|
||||
) -> None:
|
||||
labels[seeds > 0] = seeds[seeds > 0]
|
||||
|
||||
terminals = _terminal_indices(data)
|
||||
ordered_sources = _source_order(data.shape)
|
||||
ordered_terminals = terminals[ordered_sources]
|
||||
xres = data.shape[1]
|
||||
|
||||
for term in ordered_terminals:
|
||||
row = int(term // xres)
|
||||
col = int(term % xres)
|
||||
_process_mask(labels, row, col)
|
||||
|
||||
counts = np.bincount(ordered_terminals, minlength=data.size).astype(np.float64)
|
||||
water += counts.reshape(data.shape)
|
||||
data -= dropsize * counts.reshape(data.shape)
|
||||
|
||||
|
||||
def _mark_boundaries(labels: np.ndarray) -> np.ndarray:
|
||||
result = labels.copy()
|
||||
if result.shape[0] < 3 or result.shape[1] < 3:
|
||||
return result
|
||||
|
||||
interior = result[1:-1, 1:-1]
|
||||
right = result[1:-1, 2:]
|
||||
down = result[2:, 1:-1]
|
||||
interior[(interior != right) | (interior != down)] = 0
|
||||
return result
|
||||
|
||||
|
||||
def _combine_masks(result_mask: np.ndarray, existing_mask: np.ndarray | None, combine_mode: str) -> np.ndarray:
|
||||
if existing_mask is None or combine_mode == "replace":
|
||||
return result_mask
|
||||
|
||||
existing = np.asarray(existing_mask) > 127
|
||||
current = np.asarray(result_mask, dtype=bool)
|
||||
if existing.shape != current.shape:
|
||||
raise ValueError("Existing mask must have the same shape as the watershed output.")
|
||||
|
||||
if combine_mode == "union":
|
||||
merged = current | existing
|
||||
elif combine_mode == "intersection":
|
||||
merged = current & existing
|
||||
else:
|
||||
raise ValueError(f"Unsupported combine mode: {combine_mode}")
|
||||
|
||||
return merged.astype(np.uint8) * 255
|
||||
|
||||
|
||||
@register_node(display_name="Watershed Segmentation")
|
||||
class WatershedSegmentation:
|
||||
_CUSTOM_PREVIEW = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"invert_height": ("BOOLEAN", {"default": False}),
|
||||
"locate_steps": ("INT", {"default": 10, "min": 1, "max": 200, "step": 1}),
|
||||
"locate_threshold": ("INT", {"default": 10, "min": 0, "max": 100000, "step": 1}),
|
||||
"locate_drop_size": ("FLOAT", {"default": 0.1, "min": 0.0001, "max": 1.0, "step": 0.01}),
|
||||
"watershed_steps": ("INT", {"default": 20, "min": 1, "max": 2000, "step": 1}),
|
||||
"watershed_drop_size": ("FLOAT", {"default": 0.1, "min": 0.0001, "max": 1.0, "step": 0.01}),
|
||||
"combine_mode": (["replace", "union", "intersection"], {"default": "replace"}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Segment a height field into grains using the two-stage Gwyddion watershed workflow: "
|
||||
"drop-based seed location followed by watershed growth. Supports hill or valley detection "
|
||||
"and optional union/intersection with an existing mask."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
invert_height: bool,
|
||||
locate_steps: int,
|
||||
locate_threshold: int,
|
||||
locate_drop_size: float,
|
||||
watershed_steps: int,
|
||||
watershed_drop_size: float,
|
||||
combine_mode: str,
|
||||
mask: np.ndarray | None = None,
|
||||
) -> tuple:
|
||||
working = _working_height(field, bool(invert_height))
|
||||
water = np.zeros_like(working, dtype=np.float64)
|
||||
|
||||
q = float((np.max(working) - np.min(working)) / 50.0)
|
||||
locate_drop = float(locate_drop_size) * q
|
||||
watershed_drop = float(watershed_drop_size) * q
|
||||
|
||||
locate_field = working.copy()
|
||||
for _ in range(int(locate_steps)):
|
||||
_location_step(locate_field, water, locate_drop)
|
||||
|
||||
seeds = _seed_labels(water, int(locate_threshold))
|
||||
labels = np.zeros_like(seeds, dtype=np.int32)
|
||||
watershed_field = working.copy()
|
||||
for _ in range(int(watershed_steps)):
|
||||
_watershed_step(watershed_field, water, labels, seeds, watershed_drop)
|
||||
|
||||
labels = _mark_boundaries(labels)
|
||||
result_mask = (labels > 0).astype(np.uint8) * 255
|
||||
result_mask = _combine_masks(result_mask, mask, combine_mode)
|
||||
|
||||
emit_preview(encode_preview(_mask_overlay(field, result_mask)))
|
||||
return (result_mask,)
|
||||
Reference in New Issue
Block a user