add draw mask node

This commit is contained in:
matei jordache
2026-03-25 15:44:09 -07:00
parent bce11590c7
commit ca59bac478
7 changed files with 576 additions and 7 deletions

View File

@@ -188,7 +188,7 @@ class ExecutionEngine:
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
from backend.nodes.modify import CropResizeField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import SaveImage, LoadFile
PreviewImage._broadcast_fn = on_preview
@@ -196,6 +196,7 @@ class ExecutionEngine:
MaskMorphology._broadcast_fn = on_preview
MaskInvert._broadcast_fn = on_preview
MaskCombine._broadcast_fn = on_preview
DrawMask._broadcast_overlay_fn = on_overlay
View3D._broadcast_mesh_fn = on_mesh
PrintTable._broadcast_table_fn = on_table
ValueDisplay._broadcast_value_fn = on_value
@@ -213,10 +214,10 @@ class ExecutionEngine:
from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay
from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram
from backend.nodes.modify import CropResizeField
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine
from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask
from backend.nodes.io import LoadFile, SaveImage
if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine,
ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask,
LoadFile, SaveImage):
cls._current_node_id = node_id

View File

@@ -9,6 +9,7 @@ Gwyddion equivalents:
"""
from __future__ import annotations
import json
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview
@@ -29,6 +30,157 @@ def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray:
return np.clip(overlay, 0, 255).astype(np.uint8)
def _clamp_fraction(value) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, numeric))
def _parse_mask_strokes(mask_paths) -> list[dict]:
if isinstance(mask_paths, list):
raw_strokes = mask_paths
elif isinstance(mask_paths, str) and mask_paths.strip():
try:
parsed = json.loads(mask_paths)
except json.JSONDecodeError:
return []
raw_strokes = parsed if isinstance(parsed, list) else []
else:
return []
strokes = []
for stroke in raw_strokes:
if not isinstance(stroke, dict):
continue
raw_points = stroke.get("points")
if not isinstance(raw_points, list):
continue
points = []
for point in raw_points:
if not isinstance(point, dict):
continue
if "x" not in point or "y" not in point:
continue
points.append({
"x": _clamp_fraction(point.get("x")),
"y": _clamp_fraction(point.get("y")),
})
if not points:
continue
try:
size = max(1, int(round(float(stroke.get("size", 1)))))
except (TypeError, ValueError):
size = 1
strokes.append({
"size": size,
"points": points,
})
return strokes
def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray:
from PIL import Image, ImageDraw
width = max(1, int(width))
height = max(1, int(height))
default_pen_size = max(1, int(default_pen_size))
mask_image = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
for stroke in strokes:
points = stroke.get("points") or []
if not points:
continue
size = stroke.get("size", default_pen_size)
try:
size = max(1, int(round(float(size))))
except (TypeError, ValueError):
size = default_pen_size
pixel_points = []
for point in points:
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
pixel_points.append((px, py))
radius = max(0.5, size / 2.0)
if len(pixel_points) == 1:
x, y = pixel_points[0]
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
continue
draw.line(pixel_points, fill=255, width=size)
for x, y in pixel_points:
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
return np.asarray(mask_image, dtype=np.uint8)
# ---------------------------------------------------------------------------
# DrawMask
# ---------------------------------------------------------------------------
@register_node(display_name="Draw Mask")
class DrawMask:
_CUSTOM_PREVIEW = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}),
"invert": ("BOOLEAN", {"default": False}),
"clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}),
"mask_paths": ("STRING", {"default": "[]", "hidden": True}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "mask"
DESCRIPTION = (
"Paint a binary mask directly over an image preview. "
"Pen size controls newly drawn strokes, the overlay lets you clear the mask, "
"and invert flips the final binary output."
)
_broadcast_overlay_fn = None
_current_node_id: str = ""
def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple:
strokes = _parse_mask_strokes(mask_paths)
mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size)
if invert:
mask = np.where(mask > 127, np.uint8(0), np.uint8(255))
if DrawMask._broadcast_overlay_fn is not None:
DrawMask._broadcast_overlay_fn(
DrawMask._current_node_id,
{
"kind": "mask_paint",
"section_title": "Mask",
"image": encode_preview(datafield_to_uint8(field, "gray")),
"image_width": field.xres,
"image_height": field.yres,
"invert": bool(invert),
},
)
return (mask,)
# ---------------------------------------------------------------------------
# ThresholdMask
# ---------------------------------------------------------------------------

View File

@@ -530,7 +530,12 @@ function Flow() {
updateNodeData(msg.data.node_id, { meshData: msg.data.mesh });
break;
case 'overlay':
updateNodeData(msg.data.node_id, { overlay: msg.data.overlay });
updateNodeData(
msg.data.node_id,
msg.data.overlay?.kind === 'mask_paint'
? { overlay: msg.data.overlay, previewImage: null }
: { overlay: msg.data.overlay },
);
break;
case 'node_warning':
updateNodeData(msg.data.node_id, { warning: msg.data.message });

View File

@@ -5,6 +5,7 @@ import LinePlotOverlay from './LinePlotOverlay';
const SurfaceView = lazy(() => import('./SurfaceView'));
const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay'));
const CropBoxOverlay = lazy(() => import('./CropBoxOverlay'));
const MaskPaintOverlay = lazy(() => import('./MaskPaintOverlay'));
// ── Constants ─────────────────────────────────────────────────────────
@@ -525,8 +526,12 @@ function CustomNode({ id, data }) {
const catColor = CAT_COLORS[def.category] || '#333';
const maxIORows = Math.max(dataInputs.length, outputs.length);
const hasInteractiveLineOverlay = data.overlay?.kind === 'line_plot' && hiddenWidgets.has('x1');
const hasInteractiveOverlay = !!data.overlay && (hiddenWidgets.has('x1') || data.overlay.kind === 'mask_paint');
const hidePreviewForInteractiveMask = data.overlay?.kind === 'mask_paint';
const overlayTitle = data.overlay?.section_title
|| (data.overlay?.kind === 'crop_box'
|| (data.overlay?.kind === 'mask_paint'
? 'Mask'
: data.overlay?.kind === 'crop_box'
? 'Crop'
: data.overlay?.kind === 'line_plot'
? 'Line Plot'
@@ -641,7 +646,9 @@ function CustomNode({ id, data }) {
)}
{/* Collapsible preview image */}
{data.previewImage && !(hasInteractiveLineOverlay && typeof data.previewImage === 'object' && data.previewImage.kind === 'line_plot') && (
{data.previewImage
&& !hidePreviewForInteractiveMask
&& !(hasInteractiveLineOverlay && typeof data.previewImage === 'object' && data.previewImage.kind === 'line_plot') && (
<CollapsibleSection title="Preview" defaultOpen={true}>
<PreviewBoundary
resetKey={typeof data.previewImage === 'string' ? data.previewImage : JSON.stringify({
@@ -662,7 +669,7 @@ function CustomNode({ id, data }) {
)}
{/* Interactive cross-section overlay */}
{data.overlay && hiddenWidgets.has('x1') && (
{hasInteractiveOverlay && (
<CollapsibleSection title={overlayTitle} defaultOpen={true}>
<Suspense fallback={<div className="node-preview" style={{color:'#64748b',padding:4}}>Loading...</div>}>
{data.overlay.kind === 'line_plot' ? (
@@ -687,6 +694,16 @@ function CustomNode({ id, data }) {
nodeId={id}
onWidgetChange={ctx.onWidgetChange}
/>
) : data.overlay.kind === 'mask_paint' ? (
<MaskPaintOverlay
image={data.overlay.image}
imageWidth={data.overlay.image_width}
imageHeight={data.overlay.image_height}
penSize={data.widgetValues.pen_size}
maskPaths={data.widgetValues.mask_paths}
nodeId={id}
onWidgetChange={ctx.onWidgetChange}
/>
) : (
<CrossSectionOverlay
image={data.overlay.image}

View File

@@ -0,0 +1,304 @@
import React, { useEffect, useRef, useState, useCallback } from 'react';
function clampFraction(value) {
const numeric = Number(value);
if (!Number.isFinite(numeric)) return 0;
return Math.max(0, Math.min(1, numeric));
}
function sanitizeStroke(stroke, fallbackPenSize) {
if (!stroke || typeof stroke !== 'object' || !Array.isArray(stroke.points) || stroke.points.length === 0) {
return null;
}
const size = Math.max(1, Math.round(Number(stroke.size) || fallbackPenSize || 1));
const points = stroke.points
.map((point) => {
if (!point || typeof point !== 'object') return null;
return {
x: Number(clampFraction(point.x).toFixed(4)),
y: Number(clampFraction(point.y).toFixed(4)),
};
})
.filter(Boolean);
if (points.length === 0) return null;
return { size, points };
}
function parseMaskPaths(maskPaths, fallbackPenSize) {
if (Array.isArray(maskPaths)) {
return maskPaths.map((stroke) => sanitizeStroke(stroke, fallbackPenSize)).filter(Boolean);
}
if (typeof maskPaths !== 'string' || !maskPaths.trim()) return [];
try {
const parsed = JSON.parse(maskPaths);
if (!Array.isArray(parsed)) return [];
return parsed.map((stroke) => sanitizeStroke(stroke, fallbackPenSize)).filter(Boolean);
} catch {
return [];
}
}
function drawStroke(ctx, stroke, width, height, imageWidth, imageHeight, styles = {}) {
if (!stroke || !Array.isArray(stroke.points) || stroke.points.length === 0) return;
const scaleX = imageWidth > 0 ? width / imageWidth : 1;
const scaleY = imageHeight > 0 ? height / imageHeight : 1;
const brushScale = Math.max(0.5, Math.min(scaleX, scaleY));
const lineWidth = Math.max(1, stroke.size * brushScale);
ctx.save();
ctx.lineCap = 'round';
ctx.lineJoin = 'round';
ctx.strokeStyle = styles.strokeStyle || '#ffffff';
ctx.fillStyle = styles.fillStyle || '#ffffff';
ctx.lineWidth = lineWidth;
const points = stroke.points.map((point) => ({
x: clampFraction(point.x) * width,
y: clampFraction(point.y) * height,
}));
if (points.length === 1) {
const radius = lineWidth / 2;
ctx.beginPath();
ctx.arc(points[0].x, points[0].y, radius, 0, Math.PI * 2);
ctx.fill();
ctx.restore();
return;
}
ctx.beginPath();
ctx.moveTo(points[0].x, points[0].y);
for (let i = 1; i < points.length; i += 1) {
ctx.lineTo(points[i].x, points[i].y);
}
ctx.stroke();
for (const point of points) {
ctx.beginPath();
ctx.arc(point.x, point.y, lineWidth / 2, 0, Math.PI * 2);
ctx.fill();
}
ctx.restore();
}
export default function MaskPaintOverlay({
image,
imageWidth,
imageHeight,
penSize,
maskPaths,
nodeId,
onWidgetChange,
}) {
const containerRef = useRef(null);
const canvasRef = useRef(null);
const strokesRef = useRef([]);
const draftStrokeRef = useRef(null);
const [strokes, setStrokes] = useState(() => parseMaskPaths(maskPaths, penSize));
const [draftStroke, setDraftStroke] = useState(null);
const [drawing, setDrawing] = useState(false);
const [cursorPoint, setCursorPoint] = useState(null);
useEffect(() => {
const parsed = parseMaskPaths(maskPaths, penSize);
strokesRef.current = parsed;
setStrokes(parsed);
setDraftStroke(null);
setDrawing(false);
}, [maskPaths, penSize]);
useEffect(() => {
strokesRef.current = strokes;
}, [strokes]);
useEffect(() => {
draftStrokeRef.current = draftStroke;
}, [draftStroke]);
const redrawCanvas = useCallback((committedStrokes, activeStroke) => {
const canvas = canvasRef.current;
const container = containerRef.current;
if (!canvas || !container) return;
const rect = container.getBoundingClientRect();
const cssWidth = Math.max(1, Math.round(rect.width));
const cssHeight = Math.max(1, Math.round(rect.height));
const dpr = Math.max(1, window.devicePixelRatio || 1);
if (canvas.width !== Math.round(cssWidth * dpr) || canvas.height !== Math.round(cssHeight * dpr)) {
canvas.width = Math.round(cssWidth * dpr);
canvas.height = Math.round(cssHeight * dpr);
canvas.style.width = `${cssWidth}px`;
canvas.style.height = `${cssHeight}px`;
}
const ctx = canvas.getContext('2d');
if (!ctx) return;
ctx.setTransform(1, 0, 0, 1, 0, 0);
ctx.clearRect(0, 0, canvas.width, canvas.height);
const maskCanvas = document.createElement('canvas');
maskCanvas.width = canvas.width;
maskCanvas.height = canvas.height;
const maskCtx = maskCanvas.getContext('2d');
if (!maskCtx) return;
maskCtx.setTransform(1, 0, 0, 1, 0, 0);
maskCtx.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
maskCtx.scale(dpr, dpr);
const drawMaskStroke = (stroke) => drawStroke(
maskCtx,
stroke,
cssWidth,
cssHeight,
imageWidth,
imageHeight,
{ strokeStyle: '#ffffff', fillStyle: '#ffffff' },
);
for (const stroke of committedStrokes) {
drawMaskStroke(stroke);
}
if (activeStroke) {
drawMaskStroke(activeStroke);
}
ctx.drawImage(maskCanvas, 0, 0);
ctx.globalCompositeOperation = 'source-in';
ctx.fillStyle = 'rgba(255, 59, 59, 0.16)';
ctx.fillRect(0, 0, canvas.width, canvas.height);
ctx.globalCompositeOperation = 'source-over';
}, [imageHeight, imageWidth]);
useEffect(() => {
redrawCanvas(strokes, draftStroke);
}, [draftStroke, redrawCanvas, strokes]);
useEffect(() => {
const container = containerRef.current;
if (!container || typeof ResizeObserver === 'undefined') return undefined;
const observer = new ResizeObserver(() => {
redrawCanvas(strokesRef.current, draftStroke);
});
observer.observe(container);
return () => observer.disconnect();
}, [draftStroke, redrawCanvas]);
const getPoint = useCallback((event) => {
const rect = containerRef.current?.getBoundingClientRect();
if (!rect) return null;
return {
x: clampFraction((event.clientX - rect.left) / rect.width),
y: clampFraction((event.clientY - rect.top) / rect.height),
};
}, []);
const getBrushDisplaySize = useCallback(() => {
const rect = containerRef.current?.getBoundingClientRect();
if (!rect) return Math.max(1, Math.round(Number(penSize) || 1));
const scaleX = imageWidth > 0 ? rect.width / imageWidth : 1;
const scaleY = imageHeight > 0 ? rect.height / imageHeight : 1;
const brushScale = Math.max(0.5, Math.min(scaleX, scaleY));
return Math.max(1, (Math.max(1, Math.round(Number(penSize) || 1)) * brushScale));
}, [imageHeight, imageWidth, penSize]);
const appendPoint = useCallback((stroke, point) => {
if (!stroke || !point) return stroke;
const lastPoint = stroke.points[stroke.points.length - 1];
if (lastPoint && Math.abs(lastPoint.x - point.x) < 0.001 && Math.abs(lastPoint.y - point.y) < 0.001) {
return stroke;
}
return {
...stroke,
points: [...stroke.points, point],
};
}, []);
const commitStroke = useCallback((stroke) => {
const normalizedStroke = sanitizeStroke(stroke, penSize);
setDraftStroke(null);
setDrawing(false);
if (!normalizedStroke || !nodeId || !onWidgetChange) return;
const nextStrokes = [...strokesRef.current, normalizedStroke];
strokesRef.current = nextStrokes;
setStrokes(nextStrokes);
onWidgetChange(nodeId, 'mask_paths', JSON.stringify(nextStrokes));
}, [nodeId, onWidgetChange, penSize]);
const handlePointerDown = useCallback((event) => {
if (event.target.closest('button')) return;
const point = getPoint(event);
if (!point) return;
event.preventDefault();
event.stopPropagation();
event.currentTarget.setPointerCapture(event.pointerId);
setCursorPoint(point);
setDrawing(true);
setDraftStroke({
size: Math.max(1, Math.round(Number(penSize) || 1)),
points: [point],
});
}, [getPoint, penSize]);
const handlePointerMove = useCallback((event) => {
const point = getPoint(event);
if (!point) return;
setCursorPoint(point);
if (!drawing) return;
setDraftStroke((current) => appendPoint(current, point));
}, [appendPoint, drawing, getPoint]);
const handlePointerUp = useCallback(() => {
if (!drawing) return;
commitStroke(draftStrokeRef.current);
}, [commitStroke, drawing]);
const handlePointerLeave = useCallback(() => {
if (!drawing) {
setCursorPoint(null);
}
}, [drawing]);
return (
<div
ref={containerRef}
className={`nodrag nowheel mask-paint-overlay${drawing ? ' mask-paint-overlay-drawing' : ''}`}
onPointerDown={handlePointerDown}
onPointerMove={handlePointerMove}
onPointerUp={handlePointerUp}
onLostPointerCapture={handlePointerUp}
onPointerLeave={handlePointerLeave}
>
<img
src={image}
alt="mask source"
draggable={false}
className="mask-paint-image"
onLoad={() => redrawCanvas(strokesRef.current, draftStroke)}
/>
<canvas ref={canvasRef} className="mask-paint-canvas" />
{cursorPoint && (
<div
className="mask-paint-cursor"
style={{
left: `${cursorPoint.x * 100}%`,
top: `${cursorPoint.y * 100}%`,
width: `${getBrushDisplaySize()}px`,
height: `${getBrushDisplaySize()}px`,
}}
/>
)}
</div>
);
}

View File

@@ -568,6 +568,47 @@ html, body, #root {
opacity: 0.9;
}
.mask-paint-overlay {
position: relative;
overflow: hidden;
user-select: none;
touch-action: none;
background: #0f172a;
border: 1px solid #334155;
border-radius: 6px;
cursor: crosshair;
}
.mask-paint-overlay-drawing {
cursor: crosshair;
}
.mask-paint-image {
width: 100%;
display: block;
}
.mask-paint-canvas {
position: absolute;
inset: 0;
width: 100%;
height: 100%;
pointer-events: none;
}
.mask-paint-cursor {
position: absolute;
border: 1.5px solid rgba(255, 255, 255, 0.95);
border-radius: 50%;
background: rgba(255, 255, 255, 0.08);
box-shadow:
0 0 0 1px rgba(239, 68, 68, 0.85),
0 0 10px rgba(15, 23, 42, 0.35);
transform: translate(-50%, -50%);
pointer-events: none;
z-index: 2;
}
/* ── 3D surface view ──────────────────────────────────────────────── */
.surface-view-container {
width: 100%;

View File

@@ -4,6 +4,7 @@ Tests for all argonode backend nodes (excluding FFT2D which has its own test fil
Run from project root:
.venv/bin/python -m tests.test_nodes
"""
import json
import sys
import os
import tempfile
@@ -708,6 +709,53 @@ def test_mask_combine():
print(" PASS\n")
def test_draw_mask():
print("=== Test: DrawMask ===")
from backend.nodes.mask import DrawMask
node = DrawMask()
field = make_field(data=np.zeros((32, 32), dtype=np.float64))
overlays = []
DrawMask._broadcast_overlay_fn = lambda nid, data: overlays.append(data)
DrawMask._current_node_id = "test"
mask_paths = [
{
"size": 5,
"points": [
{"x": 0.2, "y": 0.5},
{"x": 0.8, "y": 0.5},
],
}
]
mask, = node.process(field, pen_size=2, invert=False, mask_paths=json.dumps(mask_paths))
assert mask.dtype == np.uint8
assert mask.shape == (32, 32)
assert mask[16, 16] == 255
assert mask[14, 16] == 255
assert mask[0, 0] == 0
assert len(overlays) == 1
assert overlays[0]["kind"] == "mask_paint"
assert overlays[0]["section_title"] == "Mask"
assert overlays[0]["image"].startswith("data:image/png;base64,")
assert overlays[0]["image_width"] == field.xres
assert overlays[0]["image_height"] == field.yres
assert overlays[0]["invert"] is False
inverted, = node.process(field, pen_size=2, invert=True, mask_paths=json.dumps(mask_paths))
assert inverted[16, 16] == 0
assert inverted[0, 0] == 255
assert overlays[-1]["invert"] is True
cleared, = node.process(field, pen_size=12, invert=False, mask_paths="[]")
assert np.count_nonzero(cleared) == 0
DrawMask._broadcast_overlay_fn = None
print(" PASS\n")
def test_particle_analysis():
print("=== Test: ParticleAnalysis ===")
from backend.nodes.grains import ParticleAnalysis
@@ -1563,6 +1611,7 @@ if __name__ == "__main__":
test_mask_morphology()
test_mask_invert()
test_mask_combine()
test_draw_mask()
# Grains
test_particle_analysis()