diff --git a/backend/execution.py b/backend/execution.py index e058664..7f6a44f 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -188,7 +188,7 @@ class ExecutionEngine: from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram from backend.nodes.modify import CropResizeField - from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine + from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask from backend.nodes.io import SaveImage, LoadFile PreviewImage._broadcast_fn = on_preview @@ -196,6 +196,7 @@ class ExecutionEngine: MaskMorphology._broadcast_fn = on_preview MaskInvert._broadcast_fn = on_preview MaskCombine._broadcast_fn = on_preview + DrawMask._broadcast_overlay_fn = on_overlay View3D._broadcast_mesh_fn = on_mesh PrintTable._broadcast_table_fn = on_table ValueDisplay._broadcast_value_fn = on_value @@ -213,10 +214,10 @@ class ExecutionEngine: from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram from backend.nodes.modify import CropResizeField - from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine + from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask from backend.nodes.io import LoadFile, SaveImage if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, - ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, + ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask, LoadFile, SaveImage): cls._current_node_id = node_id diff --git a/backend/nodes/mask.py b/backend/nodes/mask.py index 9203488..4222af3 100644 --- a/backend/nodes/mask.py +++ b/backend/nodes/mask.py @@ -9,6 +9,7 @@ Gwyddion equivalents: """ from __future__ import annotations +import json import numpy as np from backend.node_registry import register_node from backend.data_types import DataField, datafield_to_uint8, encode_preview @@ -29,6 +30,157 @@ def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray: return np.clip(overlay, 0, 255).astype(np.uint8) +def _clamp_fraction(value) -> float: + try: + numeric = float(value) + except (TypeError, ValueError): + return 0.0 + return max(0.0, min(1.0, numeric)) + + +def _parse_mask_strokes(mask_paths) -> list[dict]: + if isinstance(mask_paths, list): + raw_strokes = mask_paths + elif isinstance(mask_paths, str) and mask_paths.strip(): + try: + parsed = json.loads(mask_paths) + except json.JSONDecodeError: + return [] + raw_strokes = parsed if isinstance(parsed, list) else [] + else: + return [] + + strokes = [] + for stroke in raw_strokes: + if not isinstance(stroke, dict): + continue + raw_points = stroke.get("points") + if not isinstance(raw_points, list): + continue + + points = [] + for point in raw_points: + if not isinstance(point, dict): + continue + if "x" not in point or "y" not in point: + continue + points.append({ + "x": _clamp_fraction(point.get("x")), + "y": _clamp_fraction(point.get("y")), + }) + + if not points: + continue + + try: + size = max(1, int(round(float(stroke.get("size", 1))))) + except (TypeError, ValueError): + size = 1 + + strokes.append({ + "size": size, + "points": points, + }) + + return strokes + + +def _rasterize_mask(width: int, height: int, strokes: list[dict], default_pen_size: int) -> np.ndarray: + from PIL import Image, ImageDraw + + width = max(1, int(width)) + height = max(1, int(height)) + default_pen_size = max(1, int(default_pen_size)) + + mask_image = Image.new("L", (width, height), 0) + draw = ImageDraw.Draw(mask_image) + + for stroke in strokes: + points = stroke.get("points") or [] + if not points: + continue + + size = stroke.get("size", default_pen_size) + try: + size = max(1, int(round(float(size)))) + except (TypeError, ValueError): + size = default_pen_size + + pixel_points = [] + for point in points: + px = int(round(_clamp_fraction(point.get("x")) * (width - 1))) + py = int(round(_clamp_fraction(point.get("y")) * (height - 1))) + pixel_points.append((px, py)) + + radius = max(0.5, size / 2.0) + + if len(pixel_points) == 1: + x, y = pixel_points[0] + draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255) + continue + + draw.line(pixel_points, fill=255, width=size) + for x, y in pixel_points: + draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255) + + return np.asarray(mask_image, dtype=np.uint8) + + +# --------------------------------------------------------------------------- +# DrawMask +# --------------------------------------------------------------------------- + +@register_node(display_name="Draw Mask") +class DrawMask: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "pen_size": ("INT", {"default": 12, "min": 1, "max": 128, "step": 1}), + "invert": ("BOOLEAN", {"default": False}), + "clear_mask": ("BUTTON", {"label": "Clear Mask", "set_widgets": {"mask_paths": "[]"}}), + "mask_paths": ("STRING", {"default": "[]", "hidden": True}), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + CATEGORY = "mask" + DESCRIPTION = ( + "Paint a binary mask directly over an image preview. " + "Pen size controls newly drawn strokes, the overlay lets you clear the mask, " + "and invert flips the final binary output." + ) + + _broadcast_overlay_fn = None + _current_node_id: str = "" + + def process(self, field: DataField, pen_size: int, invert: bool, mask_paths: str) -> tuple: + strokes = _parse_mask_strokes(mask_paths) + mask = _rasterize_mask(field.xres, field.yres, strokes, pen_size) + if invert: + mask = np.where(mask > 127, np.uint8(0), np.uint8(255)) + + if DrawMask._broadcast_overlay_fn is not None: + DrawMask._broadcast_overlay_fn( + DrawMask._current_node_id, + { + "kind": "mask_paint", + "section_title": "Mask", + "image": encode_preview(datafield_to_uint8(field, "gray")), + "image_width": field.xres, + "image_height": field.yres, + "invert": bool(invert), + }, + ) + + return (mask,) + + # --------------------------------------------------------------------------- # ThresholdMask # --------------------------------------------------------------------------- diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 89a2bea..b803089 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -530,7 +530,12 @@ function Flow() { updateNodeData(msg.data.node_id, { meshData: msg.data.mesh }); break; case 'overlay': - updateNodeData(msg.data.node_id, { overlay: msg.data.overlay }); + updateNodeData( + msg.data.node_id, + msg.data.overlay?.kind === 'mask_paint' + ? { overlay: msg.data.overlay, previewImage: null } + : { overlay: msg.data.overlay }, + ); break; case 'node_warning': updateNodeData(msg.data.node_id, { warning: msg.data.message }); diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index ed2f0d4..1752dc0 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -5,6 +5,7 @@ import LinePlotOverlay from './LinePlotOverlay'; const SurfaceView = lazy(() => import('./SurfaceView')); const CrossSectionOverlay = lazy(() => import('./CrossSectionOverlay')); const CropBoxOverlay = lazy(() => import('./CropBoxOverlay')); +const MaskPaintOverlay = lazy(() => import('./MaskPaintOverlay')); // ── Constants ───────────────────────────────────────────────────────── @@ -525,8 +526,12 @@ function CustomNode({ id, data }) { const catColor = CAT_COLORS[def.category] || '#333'; const maxIORows = Math.max(dataInputs.length, outputs.length); const hasInteractiveLineOverlay = data.overlay?.kind === 'line_plot' && hiddenWidgets.has('x1'); + const hasInteractiveOverlay = !!data.overlay && (hiddenWidgets.has('x1') || data.overlay.kind === 'mask_paint'); + const hidePreviewForInteractiveMask = data.overlay?.kind === 'mask_paint'; const overlayTitle = data.overlay?.section_title - || (data.overlay?.kind === 'crop_box' + || (data.overlay?.kind === 'mask_paint' + ? 'Mask' + : data.overlay?.kind === 'crop_box' ? 'Crop' : data.overlay?.kind === 'line_plot' ? 'Line Plot' @@ -641,7 +646,9 @@ function CustomNode({ id, data }) { )} {/* Collapsible preview image */} - {data.previewImage && !(hasInteractiveLineOverlay && typeof data.previewImage === 'object' && data.previewImage.kind === 'line_plot') && ( + {data.previewImage + && !hidePreviewForInteractiveMask + && !(hasInteractiveLineOverlay && typeof data.previewImage === 'object' && data.previewImage.kind === 'line_plot') && ( Loading...}> {data.overlay.kind === 'line_plot' ? ( @@ -687,6 +694,16 @@ function CustomNode({ id, data }) { nodeId={id} onWidgetChange={ctx.onWidgetChange} /> + ) : data.overlay.kind === 'mask_paint' ? ( + ) : ( { + if (!point || typeof point !== 'object') return null; + return { + x: Number(clampFraction(point.x).toFixed(4)), + y: Number(clampFraction(point.y).toFixed(4)), + }; + }) + .filter(Boolean); + + if (points.length === 0) return null; + return { size, points }; +} + +function parseMaskPaths(maskPaths, fallbackPenSize) { + if (Array.isArray(maskPaths)) { + return maskPaths.map((stroke) => sanitizeStroke(stroke, fallbackPenSize)).filter(Boolean); + } + if (typeof maskPaths !== 'string' || !maskPaths.trim()) return []; + + try { + const parsed = JSON.parse(maskPaths); + if (!Array.isArray(parsed)) return []; + return parsed.map((stroke) => sanitizeStroke(stroke, fallbackPenSize)).filter(Boolean); + } catch { + return []; + } +} + +function drawStroke(ctx, stroke, width, height, imageWidth, imageHeight, styles = {}) { + if (!stroke || !Array.isArray(stroke.points) || stroke.points.length === 0) return; + + const scaleX = imageWidth > 0 ? width / imageWidth : 1; + const scaleY = imageHeight > 0 ? height / imageHeight : 1; + const brushScale = Math.max(0.5, Math.min(scaleX, scaleY)); + const lineWidth = Math.max(1, stroke.size * brushScale); + + ctx.save(); + ctx.lineCap = 'round'; + ctx.lineJoin = 'round'; + ctx.strokeStyle = styles.strokeStyle || '#ffffff'; + ctx.fillStyle = styles.fillStyle || '#ffffff'; + ctx.lineWidth = lineWidth; + + const points = stroke.points.map((point) => ({ + x: clampFraction(point.x) * width, + y: clampFraction(point.y) * height, + })); + + if (points.length === 1) { + const radius = lineWidth / 2; + ctx.beginPath(); + ctx.arc(points[0].x, points[0].y, radius, 0, Math.PI * 2); + ctx.fill(); + ctx.restore(); + return; + } + + ctx.beginPath(); + ctx.moveTo(points[0].x, points[0].y); + for (let i = 1; i < points.length; i += 1) { + ctx.lineTo(points[i].x, points[i].y); + } + ctx.stroke(); + + for (const point of points) { + ctx.beginPath(); + ctx.arc(point.x, point.y, lineWidth / 2, 0, Math.PI * 2); + ctx.fill(); + } + + ctx.restore(); +} + +export default function MaskPaintOverlay({ + image, + imageWidth, + imageHeight, + penSize, + maskPaths, + nodeId, + onWidgetChange, +}) { + const containerRef = useRef(null); + const canvasRef = useRef(null); + const strokesRef = useRef([]); + const draftStrokeRef = useRef(null); + const [strokes, setStrokes] = useState(() => parseMaskPaths(maskPaths, penSize)); + const [draftStroke, setDraftStroke] = useState(null); + const [drawing, setDrawing] = useState(false); + const [cursorPoint, setCursorPoint] = useState(null); + + useEffect(() => { + const parsed = parseMaskPaths(maskPaths, penSize); + strokesRef.current = parsed; + setStrokes(parsed); + setDraftStroke(null); + setDrawing(false); + }, [maskPaths, penSize]); + + useEffect(() => { + strokesRef.current = strokes; + }, [strokes]); + + useEffect(() => { + draftStrokeRef.current = draftStroke; + }, [draftStroke]); + + const redrawCanvas = useCallback((committedStrokes, activeStroke) => { + const canvas = canvasRef.current; + const container = containerRef.current; + if (!canvas || !container) return; + + const rect = container.getBoundingClientRect(); + const cssWidth = Math.max(1, Math.round(rect.width)); + const cssHeight = Math.max(1, Math.round(rect.height)); + const dpr = Math.max(1, window.devicePixelRatio || 1); + + if (canvas.width !== Math.round(cssWidth * dpr) || canvas.height !== Math.round(cssHeight * dpr)) { + canvas.width = Math.round(cssWidth * dpr); + canvas.height = Math.round(cssHeight * dpr); + canvas.style.width = `${cssWidth}px`; + canvas.style.height = `${cssHeight}px`; + } + + const ctx = canvas.getContext('2d'); + if (!ctx) return; + + ctx.setTransform(1, 0, 0, 1, 0, 0); + ctx.clearRect(0, 0, canvas.width, canvas.height); + + const maskCanvas = document.createElement('canvas'); + maskCanvas.width = canvas.width; + maskCanvas.height = canvas.height; + const maskCtx = maskCanvas.getContext('2d'); + if (!maskCtx) return; + + maskCtx.setTransform(1, 0, 0, 1, 0, 0); + maskCtx.clearRect(0, 0, maskCanvas.width, maskCanvas.height); + maskCtx.scale(dpr, dpr); + + const drawMaskStroke = (stroke) => drawStroke( + maskCtx, + stroke, + cssWidth, + cssHeight, + imageWidth, + imageHeight, + { strokeStyle: '#ffffff', fillStyle: '#ffffff' }, + ); + + for (const stroke of committedStrokes) { + drawMaskStroke(stroke); + } + if (activeStroke) { + drawMaskStroke(activeStroke); + } + + ctx.drawImage(maskCanvas, 0, 0); + ctx.globalCompositeOperation = 'source-in'; + ctx.fillStyle = 'rgba(255, 59, 59, 0.16)'; + ctx.fillRect(0, 0, canvas.width, canvas.height); + ctx.globalCompositeOperation = 'source-over'; + }, [imageHeight, imageWidth]); + + useEffect(() => { + redrawCanvas(strokes, draftStroke); + }, [draftStroke, redrawCanvas, strokes]); + + useEffect(() => { + const container = containerRef.current; + if (!container || typeof ResizeObserver === 'undefined') return undefined; + + const observer = new ResizeObserver(() => { + redrawCanvas(strokesRef.current, draftStroke); + }); + observer.observe(container); + return () => observer.disconnect(); + }, [draftStroke, redrawCanvas]); + + const getPoint = useCallback((event) => { + const rect = containerRef.current?.getBoundingClientRect(); + if (!rect) return null; + return { + x: clampFraction((event.clientX - rect.left) / rect.width), + y: clampFraction((event.clientY - rect.top) / rect.height), + }; + }, []); + + const getBrushDisplaySize = useCallback(() => { + const rect = containerRef.current?.getBoundingClientRect(); + if (!rect) return Math.max(1, Math.round(Number(penSize) || 1)); + const scaleX = imageWidth > 0 ? rect.width / imageWidth : 1; + const scaleY = imageHeight > 0 ? rect.height / imageHeight : 1; + const brushScale = Math.max(0.5, Math.min(scaleX, scaleY)); + return Math.max(1, (Math.max(1, Math.round(Number(penSize) || 1)) * brushScale)); + }, [imageHeight, imageWidth, penSize]); + + const appendPoint = useCallback((stroke, point) => { + if (!stroke || !point) return stroke; + const lastPoint = stroke.points[stroke.points.length - 1]; + if (lastPoint && Math.abs(lastPoint.x - point.x) < 0.001 && Math.abs(lastPoint.y - point.y) < 0.001) { + return stroke; + } + return { + ...stroke, + points: [...stroke.points, point], + }; + }, []); + + const commitStroke = useCallback((stroke) => { + const normalizedStroke = sanitizeStroke(stroke, penSize); + setDraftStroke(null); + setDrawing(false); + if (!normalizedStroke || !nodeId || !onWidgetChange) return; + + const nextStrokes = [...strokesRef.current, normalizedStroke]; + strokesRef.current = nextStrokes; + setStrokes(nextStrokes); + onWidgetChange(nodeId, 'mask_paths', JSON.stringify(nextStrokes)); + }, [nodeId, onWidgetChange, penSize]); + + const handlePointerDown = useCallback((event) => { + if (event.target.closest('button')) return; + const point = getPoint(event); + if (!point) return; + + event.preventDefault(); + event.stopPropagation(); + event.currentTarget.setPointerCapture(event.pointerId); + setCursorPoint(point); + setDrawing(true); + setDraftStroke({ + size: Math.max(1, Math.round(Number(penSize) || 1)), + points: [point], + }); + }, [getPoint, penSize]); + + const handlePointerMove = useCallback((event) => { + const point = getPoint(event); + if (!point) return; + setCursorPoint(point); + if (!drawing) return; + + setDraftStroke((current) => appendPoint(current, point)); + }, [appendPoint, drawing, getPoint]); + + const handlePointerUp = useCallback(() => { + if (!drawing) return; + commitStroke(draftStrokeRef.current); + }, [commitStroke, drawing]); + + const handlePointerLeave = useCallback(() => { + if (!drawing) { + setCursorPoint(null); + } + }, [drawing]); + + return ( +
+ mask source redrawCanvas(strokesRef.current, draftStroke)} + /> + + {cursorPoint && ( +
+ )} +
+ ); +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index 1592822..d68b4ae 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -568,6 +568,47 @@ html, body, #root { opacity: 0.9; } +.mask-paint-overlay { + position: relative; + overflow: hidden; + user-select: none; + touch-action: none; + background: #0f172a; + border: 1px solid #334155; + border-radius: 6px; + cursor: crosshair; +} + +.mask-paint-overlay-drawing { + cursor: crosshair; +} + +.mask-paint-image { + width: 100%; + display: block; +} + +.mask-paint-canvas { + position: absolute; + inset: 0; + width: 100%; + height: 100%; + pointer-events: none; +} + +.mask-paint-cursor { + position: absolute; + border: 1.5px solid rgba(255, 255, 255, 0.95); + border-radius: 50%; + background: rgba(255, 255, 255, 0.08); + box-shadow: + 0 0 0 1px rgba(239, 68, 68, 0.85), + 0 0 10px rgba(15, 23, 42, 0.35); + transform: translate(-50%, -50%); + pointer-events: none; + z-index: 2; +} + /* ── 3D surface view ──────────────────────────────────────────────── */ .surface-view-container { width: 100%; diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 821bb28..909eb78 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -4,6 +4,7 @@ Tests for all argonode backend nodes (excluding FFT2D which has its own test fil Run from project root: .venv/bin/python -m tests.test_nodes """ +import json import sys import os import tempfile @@ -708,6 +709,53 @@ def test_mask_combine(): print(" PASS\n") +def test_draw_mask(): + print("=== Test: DrawMask ===") + from backend.nodes.mask import DrawMask + node = DrawMask() + + field = make_field(data=np.zeros((32, 32), dtype=np.float64)) + overlays = [] + DrawMask._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + DrawMask._current_node_id = "test" + + mask_paths = [ + { + "size": 5, + "points": [ + {"x": 0.2, "y": 0.5}, + {"x": 0.8, "y": 0.5}, + ], + } + ] + + mask, = node.process(field, pen_size=2, invert=False, mask_paths=json.dumps(mask_paths)) + assert mask.dtype == np.uint8 + assert mask.shape == (32, 32) + assert mask[16, 16] == 255 + assert mask[14, 16] == 255 + assert mask[0, 0] == 0 + + assert len(overlays) == 1 + assert overlays[0]["kind"] == "mask_paint" + assert overlays[0]["section_title"] == "Mask" + assert overlays[0]["image"].startswith("data:image/png;base64,") + assert overlays[0]["image_width"] == field.xres + assert overlays[0]["image_height"] == field.yres + assert overlays[0]["invert"] is False + + inverted, = node.process(field, pen_size=2, invert=True, mask_paths=json.dumps(mask_paths)) + assert inverted[16, 16] == 0 + assert inverted[0, 0] == 255 + assert overlays[-1]["invert"] is True + + cleared, = node.process(field, pen_size=12, invert=False, mask_paths="[]") + assert np.count_nonzero(cleared) == 0 + + DrawMask._broadcast_overlay_fn = None + print(" PASS\n") + + def test_particle_analysis(): print("=== Test: ParticleAnalysis ===") from backend.nodes.grains import ParticleAnalysis @@ -1563,6 +1611,7 @@ if __name__ == "__main__": test_mask_morphology() test_mask_invert() test_mask_combine() + test_draw_mask() # Grains test_particle_analysis()