62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
from __future__ import annotations
|
|
import numpy as np
|
|
from backend.node_registry import register_node
|
|
from backend.data_types import DataField, encode_preview
|
|
from backend.nodes.helpers import _mask_overlay
|
|
|
|
|
|
@register_node(display_name="Threshold Mask")
|
|
class ThresholdMask:
|
|
_CUSTOM_PREVIEW = True
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"field": ("DATA_FIELD",),
|
|
"method": (["otsu", "absolute", "relative"],),
|
|
"threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}),
|
|
"direction": (["above", "below"],),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("mask",)
|
|
FUNCTION = "process"
|
|
|
|
DESCRIPTION = (
|
|
"Create a binary mask by thresholding data. "
|
|
"Otsu automatically finds the optimal threshold. "
|
|
"Equivalent to Gwyddion's threshold and otsu_threshold modules."
|
|
)
|
|
|
|
_broadcast_fn = None
|
|
_current_node_id: str = ""
|
|
|
|
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
|
|
data = field.data
|
|
|
|
if method == "otsu":
|
|
from skimage.filters import threshold_otsu
|
|
t = threshold_otsu(data)
|
|
elif method == "absolute":
|
|
t = float(threshold)
|
|
elif method == "relative":
|
|
dmin, dmax = data.min(), data.max()
|
|
t = dmin + float(threshold) * (dmax - dmin)
|
|
else:
|
|
raise ValueError(f"Unknown threshold method: {method}")
|
|
|
|
if direction == "above":
|
|
mask = (data >= t).astype(np.uint8) * 255
|
|
else:
|
|
mask = (data < t).astype(np.uint8) * 255
|
|
|
|
if ThresholdMask._broadcast_fn is not None:
|
|
overlay = _mask_overlay(field, mask)
|
|
ThresholdMask._broadcast_fn(
|
|
ThresholdMask._current_node_id, encode_preview(overlay),
|
|
)
|
|
|
|
return (mask,)
|