""" Mask operation nodes — creation, morphology, and boolean combination. Gwyddion equivalents: ThresholdMask → threshold.c / otsu_threshold.c MaskMorphology → mask_morph.c (erode, dilate, open, close) MaskInvert → (bitwise NOT on mask) MaskCombine → (boolean ops between two masks) """ from __future__ import annotations import json import numpy as np from backend.node_registry import register_node from backend.data_types import DataField, datafield_to_uint8, encode_preview def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray: """Render greyscale base image with red shadow on masked (255) pixels. Returns (H, W, 3) uint8 array. """ grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8 overlay = grey.astype(np.float64) mask_bool = mask == 255 alpha = 0.45 overlay[mask_bool, 0] = overlay[mask_bool, 0] * (1 - alpha) + 255 * alpha overlay[mask_bool, 1] = overlay[mask_bool, 1] * (1 - alpha) overlay[mask_bool, 2] = overlay[mask_bool, 2] * (1 - alpha) return np.clip(overlay, 0, 255).astype(np.uint8) def _clamp_fraction(value) -> float: try: numeric = float(value) except (TypeError, ValueError): return 0.0 return max(0.0, min(1.0, numeric)) def _parse_mask_strokes(mask_paths) -> list[dict]: if isinstance(mask_paths, list): raw_strokes = mask_paths elif isinstance(mask_paths, str) and mask_paths.strip(): try: parsed = json.loads(mask_paths) except json.JSONDecodeError: return [] raw_strokes = parsed if isinstance(parsed, list) else [] else: return [] strokes = [] for stroke in raw_strokes: if not isinstance(stroke, dict): continue raw_points = stroke.get("points") if not isinstance(raw_points, list): continue points = [] for point in raw_points: if not isinstance(point, dict): continue if "x" not in point or "y" not in point: continue points.append({ "x": _clamp_fraction(point.get("x")), "y": _clamp_fraction(point.get("y")), }) if not points: continue try: size = max(1, int(round(float(stroke.get("size", 1))))) except (TypeError, ValueError): size = 1 strokes.append({ "size": size, "points": points, }) return strokes def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray: from PIL import Image, ImageDraw width = max(1, int(width)) height = max(1, int(height)) default_pen_size = max(1, int(default_pen_size)) mask_image = Image.new("L", (width, height), 0) draw = ImageDraw.Draw(mask_image) for stroke in strokes: points = stroke.get("points") or [] if not points: continue size = stroke.get("size", default_pen_size) try: size = max(1, int(round(float(size)))) except (TypeError, ValueError): size = default_pen_size pixel_points = [] for point in points: px = int(round(_clamp_fraction(point.get("x")) * (width - 1))) py = int(round(_clamp_fraction(point.get("y")) * (height - 1))) pixel_points.append((px, py)) radius = max(0.5, size / 2.0) if len(pixel_points) == 1: x, y = pixel_points[0] draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255) continue draw.line(pixel_points, fill=255, width=size) for x, y in pixel_points: draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255) return np.asarray(mask_image, dtype=np.uint8) # --------------------------------------------------------------------------- # DrawMask # --------------------------------------------------------------------------- @register_node(display_name="Draw Mask") class DrawMask: _CUSTOM_PREVIEW = True @classmethod def INPUT_TYPES(cls): return { "required": { "field": ("DATA_FIELD",), "pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}), "invert": ("BOOLEAN", {"default": False}), "clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}), "mask_paths": ("STRING", {"default": "[]", "hidden": True}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("mask",) FUNCTION = "process" CATEGORY = "mask" DESCRIPTION = ( "Paint a binary mask directly over an image preview. " "Pen size controls newly drawn strokes, the overlay lets you clear the mask, " "and invert flips the final binary output." ) _broadcast_overlay_fn = None _current_node_id: str = "" def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple: strokes = _parse_mask_strokes(mask_paths) mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size) if invert: mask = np.where(mask > 127, np.uint8(0), np.uint8(255)) if DrawMask._broadcast_overlay_fn is not None: DrawMask._broadcast_overlay_fn( DrawMask._current_node_id, { "kind": "mask_paint", "section_title": "Mask", "image": encode_preview(datafield_to_uint8(field, "gray")), "image_width": field.xres, "image_height": field.yres, "invert": bool(invert), }, ) return (mask,) # --------------------------------------------------------------------------- # ThresholdMask # --------------------------------------------------------------------------- @register_node(display_name="Threshold Mask") class ThresholdMask: _CUSTOM_PREVIEW = True @classmethod def INPUT_TYPES(cls): return { "required": { "field": ("DATA_FIELD",), "method": (["otsu", "absolute", "relative"],), "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}), "direction": (["above", "below"],), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("mask",) FUNCTION = "process" CATEGORY = "mask" DESCRIPTION = ( "Create a binary mask by thresholding data. " "Otsu automatically finds the optimal threshold. " "Equivalent to Gwyddion's threshold and otsu_threshold modules." ) _broadcast_fn = None _current_node_id: str = "" def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple: data = field.data if method == "otsu": from skimage.filters import threshold_otsu t = threshold_otsu(data) elif method == "absolute": t = float(threshold) elif method == "relative": # threshold is a fraction [0, 1] of the data range dmin, dmax = data.min(), data.max() t = dmin + float(threshold) * (dmax - dmin) else: raise ValueError(f"Unknown threshold method: {method}") if direction == "above": mask = (data >= t).astype(np.uint8) * 255 else: mask = (data < t).astype(np.uint8) * 255 if ThresholdMask._broadcast_fn is not None: overlay = _mask_overlay(field, mask) ThresholdMask._broadcast_fn( ThresholdMask._current_node_id, encode_preview(overlay), ) return (mask,) # --------------------------------------------------------------------------- # MaskMorphology # --------------------------------------------------------------------------- @register_node(display_name="Mask Morphology") class MaskMorphology: """Morphological operations on binary masks. Equivalent to Gwyddion's mask_morph.c (erode, dilate, open, close). """ _CUSTOM_PREVIEW = True @classmethod def INPUT_TYPES(cls): return { "required": { "mask": ("IMAGE",), "operation": (["dilate", "erode", "open", "close"],), "radius": ("INT", {"default": 1, "min": 1, "max": 50, "step": 1}), "shape": (["disk", "square"],), }, "optional": { "field": ("DATA_FIELD",), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("mask",) FUNCTION = "process" CATEGORY = "mask" DESCRIPTION = ( "Apply morphological operations to a binary mask. " "Dilate expands regions, erode shrinks them, " "open (erode then dilate) removes small spots, " "close (dilate then erode) fills small holes. " "Equivalent to Gwyddion mask_morph." ) _broadcast_fn = None _current_node_id: str = "" def process(self, mask: np.ndarray, operation: str, radius: int, shape: str, field: DataField | None = None) -> tuple: from scipy.ndimage import binary_dilation, binary_erosion binary = mask > 127 if shape == "disk": y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1] struct = (x * x + y * y) <= radius * radius else: size = 2 * radius + 1 struct = np.ones((size, size), dtype=bool) if operation == "dilate": result = binary_dilation(binary, structure=struct) elif operation == "erode": result = binary_erosion(binary, structure=struct) elif operation == "open": result = binary_dilation( binary_erosion(binary, structure=struct), structure=struct, ) elif operation == "close": result = binary_erosion( binary_dilation(binary, structure=struct), structure=struct, ) else: raise ValueError(f"Unknown morphological operation: {operation}") out = result.astype(np.uint8) * 255 if field is not None and MaskMorphology._broadcast_fn is not None: overlay = _mask_overlay(field, out) MaskMorphology._broadcast_fn( MaskMorphology._current_node_id, encode_preview(overlay), ) return (out,) # --------------------------------------------------------------------------- # MaskInvert # --------------------------------------------------------------------------- @register_node(display_name="Mask Invert") class MaskInvert: _CUSTOM_PREVIEW = True @classmethod def INPUT_TYPES(cls): return { "required": { "mask": ("IMAGE",), }, "optional": { "field": ("DATA_FIELD",), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("mask",) FUNCTION = "process" CATEGORY = "mask" DESCRIPTION = "Invert a binary mask — swap masked and unmasked regions." _broadcast_fn = None _current_node_id: str = "" def process(self, mask: np.ndarray, field: DataField | None = None) -> tuple: out = np.where(mask > 127, np.uint8(0), np.uint8(255)) if field is not None and MaskInvert._broadcast_fn is not None: overlay = _mask_overlay(field, out) MaskInvert._broadcast_fn( MaskInvert._current_node_id, encode_preview(overlay), ) return (out,) # --------------------------------------------------------------------------- # MaskCombine # --------------------------------------------------------------------------- @register_node(display_name="Mask Combine") class MaskCombine: _CUSTOM_PREVIEW = True @classmethod def INPUT_TYPES(cls): return { "required": { "mask_a": ("IMAGE",), "mask_b": ("IMAGE",), "operation": (["and", "or", "xor", "subtract"],), }, "optional": { "field": ("DATA_FIELD",), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("mask",) FUNCTION = "process" CATEGORY = "mask" DESCRIPTION = ( "Combine two binary masks with a boolean operation. " "AND keeps overlap, OR merges, XOR keeps non-overlapping regions, " "subtract removes mask_b from mask_a." ) _broadcast_fn = None _current_node_id: str = "" def process(self, mask_a: np.ndarray, mask_b: np.ndarray, operation: str, field: DataField | None = None) -> tuple: a = mask_a > 127 b = mask_b > 127 if operation == "and": result = a & b elif operation == "or": result = a | b elif operation == "xor": result = a ^ b elif operation == "subtract": result = a & ~b else: raise ValueError(f"Unknown mask operation: {operation}") out = result.astype(np.uint8) * 255 if field is not None and MaskCombine._broadcast_fn is not None: overlay = _mask_overlay(field, out) MaskCombine._broadcast_fn( MaskCombine._current_node_id, encode_preview(overlay), ) return (out,)