add draw mask node
This commit is contained in:
@@ -188,7 +188,7 @@ class ExecutionEngine:
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
|
||||
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
|
||||
from backend.nodes.modify import CropResizeField
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
|
||||
from backend.nodes.io import SaveImage, LoadFile
|
||||
|
||||
PreviewImage._broadcast_fn = on_preview
|
||||
@@ -196,6 +196,7 @@ class ExecutionEngine:
|
||||
MaskMorphology._broadcast_fn = on_preview
|
||||
MaskInvert._broadcast_fn = on_preview
|
||||
MaskCombine._broadcast_fn = on_preview
|
||||
DrawMask._broadcast_overlay_fn = on_overlay
|
||||
View3D._broadcast_mesh_fn = on_mesh
|
||||
PrintTable._broadcast_table_fn = on_table
|
||||
ValueDisplay._broadcast_value_fn = on_value
|
||||
@@ -213,10 +214,10 @@ class ExecutionEngine:
|
||||
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
|
||||
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
|
||||
from backend.nodes.modify import CropResizeField
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
|
||||
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
|
||||
from backend.nodes.io import LoadFile, SaveImage
|
||||
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField,
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine,
|
||||
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
|
||||
LoadFile, SaveImage):
|
||||
cls._current_node_id = node_id
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ Gwyddion equivalents:
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, datafield_to_uint8, encode_preview
|
||||
@@ -29,6 +30,157 @@ def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
|
||||
return np.clip(overlay, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def _clamp_fraction(value) -> float:
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
return max(0.0, min(1.0, numeric))
|
||||
|
||||
|
||||
def _parse_mask_strokes(mask_paths) -> list[dict]:
|
||||
if isinstance(mask_paths, list):
|
||||
raw_strokes = mask_paths
|
||||
elif isinstance(mask_paths, str) and mask_paths.strip():
|
||||
try:
|
||||
parsed = json.loads(mask_paths)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
raw_strokes = parsed if isinstance(parsed, list) else []
|
||||
else:
|
||||
return []
|
||||
|
||||
strokes = []
|
||||
for stroke in raw_strokes:
|
||||
if not isinstance(stroke, dict):
|
||||
continue
|
||||
raw_points = stroke.get("points")
|
||||
if not isinstance(raw_points, list):
|
||||
continue
|
||||
|
||||
points = []
|
||||
for point in raw_points:
|
||||
if not isinstance(point, dict):
|
||||
continue
|
||||
if "x" not in point or "y" not in point:
|
||||
continue
|
||||
points.append({
|
||||
"x": _clamp_fraction(point.get("x")),
|
||||
"y": _clamp_fraction(point.get("y")),
|
||||
})
|
||||
|
||||
if not points:
|
||||
continue
|
||||
|
||||
try:
|
||||
size = max(1, int(round(float(stroke.get("size", 1)))))
|
||||
except (TypeError, ValueError):
|
||||
size = 1
|
||||
|
||||
strokes.append({
|
||||
"size": size,
|
||||
"points": points,
|
||||
})
|
||||
|
||||
return strokes
|
||||
|
||||
|
||||
def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray:
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
width = max(1, int(width))
|
||||
height = max(1, int(height))
|
||||
default_pen_size = max(1, int(default_pen_size))
|
||||
|
||||
mask_image = Image.new("L", (width, height), 0)
|
||||
draw = ImageDraw.Draw(mask_image)
|
||||
|
||||
for stroke in strokes:
|
||||
points = stroke.get("points") or []
|
||||
if not points:
|
||||
continue
|
||||
|
||||
size = stroke.get("size", default_pen_size)
|
||||
try:
|
||||
size = max(1, int(round(float(size))))
|
||||
except (TypeError, ValueError):
|
||||
size = default_pen_size
|
||||
|
||||
pixel_points = []
|
||||
for point in points:
|
||||
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
|
||||
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
|
||||
pixel_points.append((px, py))
|
||||
|
||||
radius = max(0.5, size / 2.0)
|
||||
|
||||
if len(pixel_points) == 1:
|
||||
x, y = pixel_points[0]
|
||||
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
|
||||
continue
|
||||
|
||||
draw.line(pixel_points, fill=255, width=size)
|
||||
for x, y in pixel_points:
|
||||
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
|
||||
|
||||
return np.asarray(mask_image, dtype=np.uint8)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DrawMask
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Draw Mask")
|
||||
class DrawMask:
|
||||
_CUSTOM_PREVIEW = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}),
|
||||
"invert": ("BOOLEAN", {"default": False}),
|
||||
"clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}),
|
||||
"mask_paths": ("STRING", {"default": "[]", "hidden": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "mask"
|
||||
DESCRIPTION = (
|
||||
"Paint a binary mask directly over an image preview. "
|
||||
"Pen size controls newly drawn strokes, the overlay lets you clear the mask, "
|
||||
"and invert flips the final binary output."
|
||||
)
|
||||
|
||||
_broadcast_overlay_fn = None
|
||||
_current_node_id: str = ""
|
||||
|
||||
def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple:
|
||||
strokes = _parse_mask_strokes(mask_paths)
|
||||
mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size)
|
||||
if invert:
|
||||
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
|
||||
|
||||
if DrawMask._broadcast_overlay_fn is not None:
|
||||
DrawMask._broadcast_overlay_fn(
|
||||
DrawMask._current_node_id,
|
||||
{
|
||||
"kind": "mask_paint",
|
||||
"section_title": "Mask",
|
||||
"image": encode_preview(datafield_to_uint8(field, "gray")),
|
||||
"image_width": field.xres,
|
||||
"image_height": field.yres,
|
||||
"invert": bool(invert),
|
||||
},
|
||||
)
|
||||
|
||||
return (mask,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ThresholdMask
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user