add draw mask node

This commit is contained in:
matei jordache
2026-03-25 15:44:09 -07:00
parent bce11590c7
commit ca59bac478
7 changed files with 576 additions and 7 deletions

View File

@@ -9,6 +9,7 @@ Gwyddion equivalents:
"""
from __future__ import annotations
import json
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
@@ -29,6 +30,157 @@ def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
return np.clip(overlay, 0, 255).astype(np.uint8)
def _clamp_fraction(value) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, numeric))
def _parse_mask_strokes(mask_paths) -> list[dict]:
if isinstance(mask_paths, list):
raw_strokes = mask_paths
elif isinstance(mask_paths, str) and mask_paths.strip():
try:
parsed = json.loads(mask_paths)
except json.JSONDecodeError:
return []
raw_strokes = parsed if isinstance(parsed, list) else []
else:
return []
strokes = []
for stroke in raw_strokes:
if not isinstance(stroke, dict):
continue
raw_points = stroke.get("points")
if not isinstance(raw_points, list):
continue
points = []
for point in raw_points:
if not isinstance(point, dict):
continue
if "x" not in point or "y" not in point:
continue
points.append({
"x": _clamp_fraction(point.get("x")),
"y": _clamp_fraction(point.get("y")),
})
if not points:
continue
try:
size = max(1, int(round(float(stroke.get("size", 1)))))
except (TypeError, ValueError):
size = 1
strokes.append({
"size": size,
"points": points,
})
return strokes
def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray:
from PIL import Image, ImageDraw
width = max(1, int(width))
height = max(1, int(height))
default_pen_size = max(1, int(default_pen_size))
mask_image = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
for stroke in strokes:
points = stroke.get("points") or []
if not points:
continue
size = stroke.get("size", default_pen_size)
try:
size = max(1, int(round(float(size))))
except (TypeError, ValueError):
size = default_pen_size
pixel_points = []
for point in points:
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
pixel_points.append((px, py))
radius = max(0.5, size / 2.0)
if len(pixel_points) == 1:
x, y = pixel_points[0]
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
continue
draw.line(pixel_points, fill=255, width=size)
for x, y in pixel_points:
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
return np.asarray(mask_image, dtype=np.uint8)
# ---------------------------------------------------------------------------
# DrawMask
# ---------------------------------------------------------------------------
@register_node(display_name="Draw Mask")
class DrawMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}),
"invert": ("BOOLEAN", {"default": False}),
"clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}),
"mask_paths": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "mask"
DESCRIPTION = (
"Paint a binary mask directly over an image preview. "
"Pen size controls newly drawn strokes, the overlay lets you clear the mask, "
"and invert flips the final binary output."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple:
strokes = _parse_mask_strokes(mask_paths)
mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size)
if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None:
DrawMask._broadcast_overlay_fn(
DrawMask._current_node_id,
{
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
},
)
return (mask,)
# ---------------------------------------------------------------------------
# ThresholdMask
# ---------------------------------------------------------------------------