diff --git a/backend/nodes/straighten_path.py b/backend/nodes/straighten_path.py index 33ed766..801dcae 100644 --- a/backend/nodes/straighten_path.py +++ b/backend/nodes/straighten_path.py @@ -3,10 +3,12 @@ from __future__ import annotations import numpy as np +from scipy.interpolate import CubicSpline from scipy.ndimage import map_coordinates from backend.node_registry import register_node -from backend.data_types import DataField +from backend.data_types import DataField, LineData, datafield_to_uint8, encode_preview +from backend.execution_context import emit_overlay @register_node(display_name="Straighten Path") @@ -16,8 +18,8 @@ class StraightenPath: return { "required": { "field": ("DATA_FIELD",), - "points_x": ("STRING", {"default": "0.25, 0.5, 0.75"}), - "points_y": ("STRING", {"default": "0.5, 0.3, 0.5"}), + "points_x": ("STRING", {"default": "0.25, 0.5, 0.75", "hidden": True}), + "points_y": ("STRING", {"default": "0.5, 0.3, 0.5", "hidden": True}), "thickness": ("INT", {"default": 1, "min": 1, "max": 100, "step": 1}), "n_samples": ("INT", {"default": 256, "min": 10, "max": 2048, "step": 1}), } @@ -25,14 +27,15 @@ class StraightenPath: OUTPUTS = ( ('DATA_FIELD', 'straightened'), + ('LINE', 'profile'), ) FUNCTION = "process" DESCRIPTION = ( "Extract a cross-section along an arbitrary curved path defined by " - "control points. Points are given as fractional coordinates (0-1). " - "The path is interpolated with cubic splines, and data is sampled " - "along it with configurable thickness. " + "control points. The path is a natural cubic spline through the " + "points. Drag the points on the preview to reshape the path; the " + "shaded band shows the sampling thickness. " ) KEYWORDS = ("unbend", "unroll", "spline", "curved profile", "extract path") @@ -42,36 +45,46 @@ class StraightenPath: data = np.asarray(field.data, dtype=np.float64) yres, xres = data.shape - # Parse control points - px = [float(v.strip()) * (xres - 1) for v in points_x.split(",") if v.strip()] - py = [float(v.strip()) * (yres - 1) for v in points_y.split(",") if v.strip()] + fx = [float(v.strip()) for v in points_x.split(",") if v.strip()] + fy = [float(v.strip()) for v in points_y.split(",") if v.strip()] + n_pts = min(len(fx), len(fy)) + fx, fy = fx[:n_pts], fy[:n_pts] - if len(px) < 2 or len(py) < 2: - # Need at least 2 points - return (field,) + emit_overlay({ + "kind": "straighten_path", + "section_title": "Path", + "image": encode_preview(datafield_to_uint8(field, field.colormap)), + "points": [{"x": float(fx[i]), "y": float(fy[i])} for i in range(n_pts)], + "thickness": int(thickness), + "xres": int(xres), + "yres": int(yres), + }) - n_pts = min(len(px), len(py)) - px, py = px[:n_pts], py[:n_pts] + if n_pts < 2: + empty_line = LineData( + data=np.zeros(0, dtype=np.float64), + x_axis=np.zeros(0, dtype=np.float64), + x_unit=field.si_unit_xy, + y_unit=field.si_unit_z, + ) + return (field, empty_line) + + px = [f * (xres - 1) for f in fx] + py = [f * (yres - 1) for f in fy] - # Parameterize path and interpolate t_ctrl = np.linspace(0, 1, n_pts) t_sample = np.linspace(0, 1, n_samples) - - # Simple cubic interpolation via numpy - if n_pts >= 4: - from numpy.polynomial.polynomial import Polynomial - cx = np.interp(t_sample, t_ctrl, px) - cy = np.interp(t_sample, t_ctrl, py) + if n_pts >= 3: + cx = CubicSpline(t_ctrl, px, bc_type="natural")(t_sample) + cy = CubicSpline(t_ctrl, py, bc_type="natural")(t_sample) else: cx = np.interp(t_sample, t_ctrl, px) cy = np.interp(t_sample, t_ctrl, py) - # Sample along path with thickness if thickness <= 1: values = map_coordinates(data, [cy, cx], order=1, mode='nearest') result = values.reshape(1, -1) else: - # Compute normals dcx = np.gradient(cx) dcy = np.gradient(cy) length = np.sqrt(dcx**2 + dcy**2) @@ -86,12 +99,22 @@ class StraightenPath: sy = cy + off * ny result[i] = map_coordinates(data, [sy, sx], order=1, mode='nearest') - # Physical dimensions total_length = 0.0 for i in range(1, len(cx)): dx_phys = (cx[i] - cx[i - 1]) * field.dx dy_phys = (cy[i] - cy[i - 1]) * field.dy total_length += np.sqrt(dx_phys**2 + dy_phys**2) - return (field.replace(data=result, xreal=total_length, - yreal=thickness * max(field.dx, field.dy)),) + center_values = map_coordinates(data, [cy, cx], order=1, mode='nearest') + profile = LineData( + data=center_values, + x_axis=np.linspace(0.0, total_length, n_samples), + x_unit=field.si_unit_xy, + y_unit=field.si_unit_z, + ) + + straightened = field.replace( + data=result, xreal=total_length, + yreal=thickness * max(field.dx, field.dy), + ) + return (straightened, profile) diff --git a/docs/nodes/Straighten Path.md b/docs/nodes/Straighten Path.md index 1449b35..1853df8 100644 --- a/docs/nodes/Straighten Path.md +++ b/docs/nodes/Straighten Path.md @@ -13,20 +13,25 @@ Extract a cross-section along an arbitrary curved path defined by control points | Name | Type | Description | |------|------|-------------| | straightened | DATA_FIELD | Straightened cross-section; width = n_samples, height = thickness | +| profile | LINE | 1-pixel-wide profile sampled along the centerline of the path | ## Controls | Name | Type | Default | Description | |------|------|---------|-------------| -| points_x | STRING | "0.25, 0.5, 0.75" | Comma-separated fractional x-coordinates of control points (0.0-1.0) | -| points_y | STRING | "0.5, 0.3, 0.5" | Comma-separated fractional y-coordinates of control points (0.0-1.0) | | thickness | INT | 1 | Width of the sampled strip perpendicular to the path, in pixels (1-100) | | n_samples | INT | 256 | Number of sample points along the path (10-2048) | +## Interactive preview + +The node renders the input field with the control points and a smooth curve through them. Drag any point to reshape the path. Double-click anywhere on the image to add a new point at that location. Shift-click a point to delete it (a minimum of two points is kept). The shaded band along the curve previews the sampling thickness. + +The straightened result is shown in the regular preview section below. + ## Notes - Control points are specified as fractions of the image dimensions (0 = left/top edge, 1 = right/bottom edge). At least 2 points are required. -- Points are connected by linear interpolation; the path is sampled at n_samples evenly spaced positions. +- With 3 or more points, the path is a natural cubic spline (C² continuous) passing through each control point, matching the smooth curve drawn on the preview. With exactly 2 points the path is a straight line. - When thickness > 1, samples are taken along the local normal direction at each path position, producing a 2D strip rather than a single line. - The output xreal equals the physical path length (computed from pixel spacing), and yreal equals thickness times the pixel size. - Bilinear interpolation (order=1) is used with nearest-edge boundary handling. diff --git a/frontend/src/CustomNode.tsx b/frontend/src/CustomNode.tsx index 5471e32..e07710f 100644 --- a/frontend/src/CustomNode.tsx +++ b/frontend/src/CustomNode.tsx @@ -13,6 +13,7 @@ const MarkupOverlay = lazy(() => import('./MarkupOverlay')); const AngleMeasureOverlay = lazy(() => import('./AngleMeasureOverlay')); const ThresholdHistogram = lazy(() => import('./ThresholdHistogram')); const RadialProfileOverlay = lazy(() => import('./RadialProfileOverlay')); +const StraightenPathOverlay = lazy(() => import('./StraightenPathOverlay')); import TextNoteNode from './TextNoteNode'; @@ -1196,6 +1197,7 @@ function CustomNode({ id, data }: { id: string; data: NodeData }) { || data.overlay.kind === 'markup' || data.overlay.kind === 'threshold_histogram' || data.overlay.kind === 'radial_profile' + || data.overlay.kind === 'straighten_path' ); const hidePreviewForInteractiveMask = data.overlay?.kind === 'mask_paint' || data.overlay?.kind === 'markup'; const overlayTitle = data.overlay?.section_title @@ -1213,6 +1215,8 @@ function CustomNode({ id, data }: { id: string; data: NodeData }) { ? 'Line Plot' : data.overlay?.kind === 'radial_profile' ? 'Radial Profile' + : data.overlay?.kind === 'straighten_path' + ? 'Path' : 'Cross Section'); const headerMeta = (() => { if (data.className === 'Folder') { @@ -1558,6 +1562,16 @@ function CustomNode({ id, data }: { id: string; data: NodeData }) { nodeId={id} onWidgetChange={ctx!.onWidgetChange} /> + ) : data.overlay!.kind === 'straighten_path' ? ( + } + thickness={(data.widgetValues.thickness ?? data.overlay!.thickness ?? 1) as number} + xres={(data.overlay!.xres ?? 1) as number} + yres={(data.overlay!.yres ?? 1) as number} + nodeId={id} + onWidgetChange={ctx!.onWidgetChange} + /> ) : data.overlay!.kind === 'angle_measure' ? ( void; +} + +const round3 = (v: number) => parseFloat(v.toFixed(3)); + +function pointsToStrings(points: Point[]) { + return { + points_x: points.map(p => round3(p.x)).join(', '), + points_y: points.map(p => round3(p.y)).join(', '), + }; +} + +// Solve a 1-D natural cubic spline (matches scipy.interpolate.CubicSpline with +// bc_type="natural") and return a function that evaluates it at any t. +function naturalCubicSpline(t: number[], y: number[]): (tq: number) => number { + const n = t.length; + if (n < 2) return () => y[0] ?? 0; + if (n === 2) { + return (tq) => y[0] + (y[1] - y[0]) * (tq - t[0]) / (t[1] - t[0]); + } + const h = new Array(n - 1); + for (let i = 0; i < n - 1; i++) h[i] = t[i + 1] - t[i]; + + // Tridiagonal system for second derivatives M[1..n-2] (M[0] = M[n-1] = 0). + const a = new Array(n).fill(0); + const b = new Array(n).fill(0); + const c = new Array(n).fill(0); + const d = new Array(n).fill(0); + for (let i = 1; i < n - 1; i++) { + a[i] = h[i - 1]; + b[i] = 2 * (h[i - 1] + h[i]); + c[i] = h[i]; + d[i] = 6 * ((y[i + 1] - y[i]) / h[i] - (y[i] - y[i - 1]) / h[i - 1]); + } + for (let i = 2; i < n - 1; i++) { + const w = a[i] / b[i - 1]; + b[i] -= w * c[i - 1]; + d[i] -= w * d[i - 1]; + } + const M = new Array(n).fill(0); + if (n >= 3) { + M[n - 2] = d[n - 2] / b[n - 2]; + for (let i = n - 3; i >= 1; i--) { + M[i] = (d[i] - c[i] * M[i + 1]) / b[i]; + } + } + + return (tq) => { + let i = 0; + while (i < n - 2 && tq > t[i + 1]) i++; + const dx = h[i]; + const A = (t[i + 1] - tq) / dx; + const B = (tq - t[i]) / dx; + return A * y[i] + B * y[i + 1] + + ((A ** 3 - A) * M[i] + (B ** 3 - B) * M[i + 1]) * (dx * dx) / 6; + }; +} + +const CURVE_SAMPLES_PER_SEGMENT = 24; + +function buildCurvePath(points: Point[]): string { + if (points.length === 0) return ''; + if (points.length === 1) return `M ${points[0].x * 100} ${points[0].y * 100}`; + if (points.length === 2) { + return `M ${points[0].x * 100} ${points[0].y * 100} L ${points[1].x * 100} ${points[1].y * 100}`; + } + const n = points.length; + const t = Array.from({ length: n }, (_, i) => i / (n - 1)); + const xs = points.map(p => p.x); + const ys = points.map(p => p.y); + const fx = naturalCubicSpline(t, xs); + const fy = naturalCubicSpline(t, ys); + + const total = (n - 1) * CURVE_SAMPLES_PER_SEGMENT; + const segs: string[] = [`M ${points[0].x * 100} ${points[0].y * 100}`]; + for (let i = 1; i <= total; i++) { + const tq = i / total; + segs.push(`L ${fx(tq) * 100} ${fy(tq) * 100}`); + } + return segs.join(' '); +} + +function pointsKey(points: Point[]) { + return points.map(p => `${round3(p.x)},${round3(p.y)}`).join('|'); +} + +export default function StraightenPathOverlay({ + image, points, thickness, xres, yres, + nodeId, onWidgetChange, +}: StraightenPathOverlayProps) { + const containerRef = useRef(null); + const [draft, setDraft] = useState(null); + const draggingRef = useRef(null); + const pendingCommitRef = useRef(null); + + useEffect(() => { + if (pendingCommitRef.current !== null + && pointsKey(points) === pendingCommitRef.current) { + pendingCommitRef.current = null; + setDraft(null); + } + }, [points]); + + const commit = useCallback((next: Point[]) => { + pendingCommitRef.current = pointsKey(next); + const { points_x, points_y } = pointsToStrings(next); + onWidgetChange(nodeId, 'points_x', points_x); + onWidgetChange(nodeId, 'points_y', points_y); + }, [nodeId, onWidgetChange]); + + const displayPoints = draft ?? points; + + const onPointerDownMarker = useCallback((idx: number) => (e: React.PointerEvent) => { + e.stopPropagation(); + e.preventDefault(); + if (e.shiftKey && displayPoints.length > 2) { + const next = displayPoints.filter((_, i) => i !== idx); + setDraft(next); + commit(next); + return; + } + (e.target as HTMLElement).setPointerCapture(e.pointerId); + draggingRef.current = idx; + setDraft(displayPoints); + }, [displayPoints, commit]); + + const onPointerMove = useCallback((e: React.PointerEvent) => { + const idx = draggingRef.current; + if (idx === null || !containerRef.current) return; + const { fx, fy } = pointerToFraction(e, containerRef.current); + setDraft(prev => { + const base = prev ?? points; + return base.map((p, i) => i === idx + ? { x: clampFraction(fx), y: clampFraction(fy) } + : p); + }); + }, [points]); + + const onPointerUp = useCallback(() => { + if (draggingRef.current !== null && draft) { + commit(draft); + } + draggingRef.current = null; + }, [draft, commit]); + + const onDoubleClick = useCallback((e: React.MouseEvent) => { + if (!containerRef.current) return; + const rect = containerRef.current.getBoundingClientRect(); + const fx = clampFraction((e.clientX - rect.left) / rect.width); + const fy = clampFraction((e.clientY - rect.top) / rect.height); + const next = [...displayPoints, { x: fx, y: fy }]; + setDraft(next); + commit(next); + }, [displayPoints, commit]); + + const curveD = buildCurvePath(displayPoints); + + const refRes = Math.max(xres, yres) || 1; + const bandWidthPct = (thickness / refRes) * 100; + + return ( +
+ field + + + {curveD && bandWidthPct > 0 && ( + + )} + {curveD && ( + + )} + + + {displayPoints.map((p, i) => ( +
+ ))} +
+ ); +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index 9772764..27d0f67 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -1689,6 +1689,53 @@ html, body, #root { border-radius: 2px; } +/* ── Straighten Path overlay ──────────────────────────────────────── */ +.straighten-overlay { + position: relative; + user-select: none; + touch-action: none; + overflow: hidden; + cursor: crosshair; +} +.straighten-image { + width: 100%; + display: block; +} +.straighten-svg { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + pointer-events: none; +} +.straighten-band { + stroke: var(--accent-lighter); + opacity: 0.25; +} +.straighten-curve { + stroke: var(--marker); + stroke-width: 1.4; + vector-effect: non-scaling-stroke; +} +.straighten-marker { + position: absolute; + width: 12px; + height: 12px; + border-radius: 50%; + background: var(--marker); + border: 1px solid var(--marker-border); + transform: translate(-50%, -50%); + cursor: grab; + box-shadow: 0 0 4px var(--marker-shadow); + z-index: 1; +} +.straighten-marker:active { + cursor: grabbing; + background: var(--marker-active); + transform: translate(-50%, -50%) scale(1.2); +} + .angle-overlay { --angle-line-color: #ff9800; --angle-arc-color: rgb(255, 166, 77); @@ -1944,7 +1991,8 @@ html, body, #root { .is-panning .crop-overlay, .is-panning .mask-paint-overlay, .is-panning .markup-overlay, -.is-panning .radial-overlay { +.is-panning .radial-overlay, +.is-panning .straighten-overlay { pointer-events: none; } diff --git a/frontend/src/types.ts b/frontend/src/types.ts index a8de79b..d14e3af 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -75,6 +75,10 @@ export interface OverlayData { square?: boolean; a_locked?: boolean; b_locked?: boolean; + points?: Array<{ x: number; y: number }>; + thickness?: number; + xres?: number; + yres?: number; section_title?: string; line?: number[]; shape?: string; diff --git a/frontend/src/workflowCapture.ts b/frontend/src/workflowCapture.ts index 3002cfa..5c40bd0 100644 --- a/frontend/src/workflowCapture.ts +++ b/frontend/src/workflowCapture.ts @@ -12,6 +12,7 @@ export const OVERLAY_CAPTURE_SELECTORS = [ '.markup-overlay', // MarkupOverlay '.angle-overlay', // AngleMeasureOverlay '.radial-overlay', // RadialProfileOverlay + '.straighten-overlay', // StraightenPathOverlay ]; function encodeBase64(bytes: Uint8Array) { diff --git a/tests/node_tests/straighten_path.py b/tests/node_tests/straighten_path.py index 1be386c..bddc566 100644 --- a/tests/node_tests/straighten_path.py +++ b/tests/node_tests/straighten_path.py @@ -8,10 +8,11 @@ def test_basic_extraction(): node = StraightenPath() field = make_field(shape=(64, 64)) - (result,) = node.process(field, points_x="0.25, 0.5, 0.75", - points_y="0.5, 0.3, 0.5", - thickness=1, n_samples=256) + result, profile = node.process(field, points_x="0.25, 0.5, 0.75", + points_y="0.5, 0.3, 0.5", + thickness=1, n_samples=256) assert result.data.shape[1] == 256, f"Output width should be n_samples=256, got {result.data.shape[1]}" + assert profile.data.shape == (256,) def test_thickness(): @@ -19,10 +20,14 @@ def test_thickness(): node = StraightenPath() field = make_field(shape=(64, 64)) - (result,) = node.process(field, points_x="0.2, 0.8", - points_y="0.5, 0.5", - thickness=5, n_samples=100) + result, profile = node.process(field, points_x="0.2, 0.8", + points_y="0.5, 0.5", + thickness=5, n_samples=100) assert result.data.shape[0] == 5, f"Output height should be thickness=5, got {result.data.shape[0]}" + # Profile is the 1-pixel-wide centerline regardless of thickness. + assert profile.data.shape == (100,) + # For a horizontal line, the centerline equals the middle row of the strip. + assert np.allclose(profile.data, result.data[2]) def test_single_point_returns_input(): @@ -30,8 +35,33 @@ def test_single_point_returns_input(): node = StraightenPath() field = make_field(shape=(64, 64)) - (result,) = node.process(field, points_x="0.5", - points_y="0.5", - thickness=1, n_samples=100) - # With only 1 point, node returns the original field unchanged + result, profile = node.process(field, points_x="0.5", + points_y="0.5", + thickness=1, n_samples=100) + # With only 1 point, node returns the original field unchanged + empty profile. assert np.array_equal(result.data, field.data) + assert profile.data.shape == (0,) + + +def test_emits_overlay_with_points_and_thickness(): + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.straighten_path import StraightenPath + + node = StraightenPath() + field = make_field(shape=(64, 64)) + + overlays = [] + with execution_callbacks(overlay=lambda nid, d: overlays.append(d)), active_node("test"): + node.process(field, points_x="0.25, 0.5, 0.75", + points_y="0.5, 0.3, 0.5", + thickness=4, n_samples=128) + + assert len(overlays) == 1 + ov = overlays[0] + assert ov["kind"] == "straighten_path" + assert ov["section_title"] == "Path" + assert ov["image"].startswith("data:image/png;base64,") + assert ov["thickness"] == 4 + assert ov["xres"] == 64 and ov["yres"] == 64 + assert [p["x"] for p in ov["points"]] == [0.25, 0.5, 0.75] + assert [p["y"] for p in ov["points"]] == [0.5, 0.3, 0.5]