add draw mask node

This commit is contained in:
matei jordache
2026-03-25 15:44:09 -07:00
parent bce11590c7
commit ca59bac478
7 changed files with 576 additions and 7 deletions

View File

@@ -188,7 +188,7 @@ class ExecutionEngine:
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
from backend.nodes.modify import CropResizeField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import SaveImage, LoadFile
PreviewImage._broadcast_fn = on_preview
@@ -196,6 +196,7 @@ class ExecutionEngine:
MaskMorphology._broadcast_fn = on_preview
MaskInvert._broadcast_fn = on_preview
MaskCombine._broadcast_fn = on_preview
DrawMask._broadcast_overlay_fn = on_overlay
View3D._broadcast_mesh_fn = on_mesh
PrintTable._broadcast_table_fn = on_table
ValueDisplay._broadcast_value_fn = on_value
@@ -213,10 +214,10 @@ class ExecutionEngine:
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
from backend.nodes.modify import CropResizeField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import LoadFile, SaveImage
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
LoadFile, SaveImage):
cls._current_node_id = node_id

View File

@@ -9,6 +9,7 @@ Gwyddion equivalents:
"""
from __future__ import annotations
import json
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
@@ -29,6 +30,157 @@ def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
return np.clip(overlay, 0, 255).astype(np.uint8)
def _clamp_fraction(value) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, numeric))
def _parse_mask_strokes(mask_paths) -> list[dict]:
if isinstance(mask_paths, list):
raw_strokes = mask_paths
elif isinstance(mask_paths, str) and mask_paths.strip():
try:
parsed = json.loads(mask_paths)
except json.JSONDecodeError:
return []
raw_strokes = parsed if isinstance(parsed, list) else []
else:
return []
strokes = []
for stroke in raw_strokes:
if not isinstance(stroke, dict):
continue
raw_points = stroke.get("points")
if not isinstance(raw_points, list):
continue
points = []
for point in raw_points:
if not isinstance(point, dict):
continue
if "x" not in point or "y" not in point:
continue
points.append({
"x": _clamp_fraction(point.get("x")),
"y": _clamp_fraction(point.get("y")),
})
if not points:
continue
try:
size = max(1, int(round(float(stroke.get("size", 1)))))
except (TypeError, ValueError):
size = 1
strokes.append({
"size": size,
"points": points,
})
return strokes
def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray:
from PIL import Image, ImageDraw
width = max(1, int(width))
height = max(1, int(height))
default_pen_size = max(1, int(default_pen_size))
mask_image = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
for stroke in strokes:
points = stroke.get("points") or []
if not points:
continue
size = stroke.get("size", default_pen_size)
try:
size = max(1, int(round(float(size))))
except (TypeError, ValueError):
size = default_pen_size
pixel_points = []
for point in points:
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
pixel_points.append((px, py))
radius = max(0.5, size / 2.0)
if len(pixel_points) == 1:
x, y = pixel_points[0]
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
continue
draw.line(pixel_points, fill=255, width=size)
for x, y in pixel_points:
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
return np.asarray(mask_image, dtype=np.uint8)
# ---------------------------------------------------------------------------
# DrawMask
# ---------------------------------------------------------------------------
@register_node(display_name="Draw Mask")
class DrawMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}),
"invert": ("BOOLEAN", {"default": False}),
"clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}),
"mask_paths": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "mask"
DESCRIPTION = (
"Paint a binary mask directly over an image preview. "
"Pen size controls newly drawn strokes, the overlay lets you clear the mask, "
"and invert flips the final binary output."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple:
strokes = _parse_mask_strokes(mask_paths)
mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size)
if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None:
DrawMask._broadcast_overlay_fn(
DrawMask._current_node_id,
{
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
},
)
return (mask,)
# ---------------------------------------------------------------------------
# ThresholdMask
# ---------------------------------------------------------------------------