diff --git a/backend/execution.py b/backend/execution.py index 7fbcf12..ba79fca 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -585,13 +585,16 @@ class ExecutionEngine: plt.close(fig) fallback_image = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" - return { + result_dict = { "kind": "line_plot", "line": y.tolist(), "x_axis": x.tolist(), "interactive": False, "fallback_image": fallback_image, } + if y_meta is not None and y_meta.x_unit: + result_dict["x_unit"] = y_meta.x_unit + return result_dict except Exception: return None diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index 0773ba0..79fcc0e 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -1,63 +1,60 @@ # Import all node modules to trigger @register_node decorators. from backend.nodes import ( # IO + colormap, + crop_resize, + fft_2d_invert, + filter_fft_1d, + filter_fft_2d, + filter_gaussian, + filter_median, + flip, + font, image, - ibw_note, image_demo, folder, coordinate, coordinate_pair, + level_facet, + level_plane, + level_poly, + mask_draw, + mask_threshold, + note, number, range_slider, + rotate, save, - save_image, - # Filters - gaussian_filter, - median_filter, edge_detect, - fft_filter_1d, - fft_filter_2d, - # Modify colormap_adjust, - crop_resize_field, - rotate_field, - flip_field, - # Level - plane_level_field, - facet_level_field, - poly_level_field, fix_zero, line_correction, # Mask - draw_mask, - threshold_mask, mask_morphology, mask_invert, mask_operations, grain_distance_transform, + save_layers, # Correction scar_removal, # Display - color_map, - font_node, annotations, angle_measure, markup, preview_image, + statistics, view_3d, print_table, value_display, # Analysis curvature, fractal_dimension, - statistics_node, histogram, acf_2d, acf_1d, cursors, fft_2d, psdf, - inverse_fft_2d, cross_section, stats, watershed_segmentation, diff --git a/backend/nodes/color_map.py b/backend/nodes/colormap.py similarity index 100% rename from backend/nodes/color_map.py rename to backend/nodes/colormap.py diff --git a/backend/nodes/crop_resize_field.py b/backend/nodes/crop_resize.py similarity index 100% rename from backend/nodes/crop_resize_field.py rename to backend/nodes/crop_resize.py diff --git a/backend/nodes/cursors.py b/backend/nodes/cursors.py index 38e891e..b84b7a3 100644 --- a/backend/nodes/cursors.py +++ b/backend/nodes/cursors.py @@ -98,6 +98,7 @@ class Cursors: "section_title": "Cursors", "line": y.tolist(), "x_axis": x.tolist(), + "x_unit": x_unit, "x1": x1, "x2": x2, "y1": float(y1), diff --git a/backend/nodes/fft_1d.py b/backend/nodes/fft_1d.py index 7a29a14..f927ef8 100644 --- a/backend/nodes/fft_1d.py +++ b/backend/nodes/fft_1d.py @@ -21,7 +21,7 @@ class FFT1D: OUTPUTS = ( ("LINE", "frequency_plot"), - ('RECORD_TABLE', 'measurement'), + ('RECORD_TABLE', 'max'), ) FUNCTION = "process" diff --git a/backend/nodes/fft_2d.py b/backend/nodes/fft_2d.py index 7df03bc..796f3c2 100644 --- a/backend/nodes/fft_2d.py +++ b/backend/nodes/fft_2d.py @@ -9,7 +9,7 @@ from backend.nodes.spectral_common import ( ) -@register_node(display_name="2D FFT") +@register_node(display_name="FFT 2D") class FFT2D: @classmethod def INPUT_TYPES(cls): diff --git a/backend/nodes/inverse_fft_2d.py b/backend/nodes/fft_2d_invert.py similarity index 100% rename from backend/nodes/inverse_fft_2d.py rename to backend/nodes/fft_2d_invert.py diff --git a/backend/nodes/fft_filter_1d.py b/backend/nodes/filter_fft_1d.py similarity index 97% rename from backend/nodes/fft_filter_1d.py rename to backend/nodes/filter_fft_1d.py index 202afbf..1ed0754 100644 --- a/backend/nodes/fft_filter_1d.py +++ b/backend/nodes/filter_fft_1d.py @@ -5,7 +5,7 @@ from backend.data_types import LineData from backend.nodes.helpers import _cached_1d_transfer -@register_node(display_name="1D FFT Filter") +@register_node(display_name="FFT Filter 1D") class FFTFilter1D: """Bandpass / lowpass / highpass / notch filtering of 1-D line profiles. diff --git a/backend/nodes/fft_filter_2d.py b/backend/nodes/filter_fft_2d.py similarity index 97% rename from backend/nodes/fft_filter_2d.py rename to backend/nodes/filter_fft_2d.py index 509c182..336b2d2 100644 --- a/backend/nodes/fft_filter_2d.py +++ b/backend/nodes/filter_fft_2d.py @@ -5,7 +5,7 @@ from backend.data_types import DataField from backend.nodes.helpers import _cached_2d_transfer -@register_node(display_name="2D FFT Filter") +@register_node(display_name="FFT Filter 2D") class FFTFilter2D: """Frequency-domain filtering of 2-D data fields (images). diff --git a/backend/nodes/gaussian_filter.py b/backend/nodes/filter_gaussian.py similarity index 100% rename from backend/nodes/gaussian_filter.py rename to backend/nodes/filter_gaussian.py diff --git a/backend/nodes/median_filter.py b/backend/nodes/filter_median.py similarity index 100% rename from backend/nodes/median_filter.py rename to backend/nodes/filter_median.py diff --git a/backend/nodes/flip_field.py b/backend/nodes/flip.py similarity index 100% rename from backend/nodes/flip_field.py rename to backend/nodes/flip.py diff --git a/backend/nodes/font_node.py b/backend/nodes/font.py similarity index 100% rename from backend/nodes/font_node.py rename to backend/nodes/font.py diff --git a/backend/nodes/histogram.py b/backend/nodes/histogram.py index 4f610c4..1e7baf6 100644 --- a/backend/nodes/histogram.py +++ b/backend/nodes/histogram.py @@ -80,6 +80,7 @@ class Histogram: "section_title": "Histogram", "line": counts.tolist(), "x_axis": bin_centers.astype(np.float64).tolist(), + "x_unit": field.si_unit_z, "x1": float(np.clip(x1, 0.0, 1.0)), "x2": float(np.clip(x2, 0.0, 1.0)), "y1": float(y1), diff --git a/backend/nodes/facet_level_field.py b/backend/nodes/level_facet.py similarity index 100% rename from backend/nodes/facet_level_field.py rename to backend/nodes/level_facet.py diff --git a/backend/nodes/plane_level_field.py b/backend/nodes/level_plane.py similarity index 100% rename from backend/nodes/plane_level_field.py rename to backend/nodes/level_plane.py diff --git a/backend/nodes/poly_level_field.py b/backend/nodes/level_poly.py similarity index 100% rename from backend/nodes/poly_level_field.py rename to backend/nodes/level_poly.py diff --git a/backend/nodes/draw_mask.py b/backend/nodes/mask_draw.py similarity index 100% rename from backend/nodes/draw_mask.py rename to backend/nodes/mask_draw.py diff --git a/backend/nodes/threshold_mask.py b/backend/nodes/mask_threshold.py similarity index 53% rename from backend/nodes/threshold_mask.py rename to backend/nodes/mask_threshold.py index 8b407fa..47f94a4 100644 --- a/backend/nodes/threshold_mask.py +++ b/backend/nodes/mask_threshold.py @@ -1,8 +1,8 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node -from backend.execution_context import emit_preview -from backend.data_types import DataField, encode_preview +from backend.execution_context import emit_preview, emit_overlay +from backend.data_types import DataField, encode_preview, RecordTable from backend.nodes.helpers import _mask_overlay @@ -15,14 +15,15 @@ class ThresholdMask: return { "required": { "field": ("DATA_FIELD",), - "method": (["otsu", "absolute", "relative"],), - "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001}), + "method": (["absolute", "relative", "otsu"],), + "threshold": ("FLOAT", {"default": 0.0, "min": -1e9, "max": 1e9, "step": 0.001, "socket_only": True}), "direction": (["above", "below"],), } } OUTPUTS = ( ('IMAGE', 'mask'), + ('RECORD_TABLE', 'threshold'), ) FUNCTION = "process" @@ -38,6 +39,12 @@ class ThresholdMask: def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple: data = field.data + raw_counts, bin_edges = np.histogram(data.ravel(), bins=256) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + counts = raw_counts.astype(np.float64) + xmin = float(bin_centers[0]) if len(bin_centers) else 0.0 + xmax = float(bin_centers[-1]) if len(bin_centers) else 1.0 + if method == "otsu": from skimage.filters import threshold_otsu t = threshold_otsu(data) @@ -49,12 +56,31 @@ class ThresholdMask: else: raise ValueError(f"Unknown threshold method: {method}") + span = xmax - xmin if xmax != xmin else 1.0 + threshold_frac = float(np.clip((t - xmin) / span, 0.0, 1.0)) + + emit_overlay({ + "kind": "threshold_histogram", + "section_title": "Histogram", + "line": counts.tolist(), + "x_axis": bin_centers.tolist(), + "x_unit": field.si_unit_z, + "threshold_frac": threshold_frac, + "x_min": xmin, + "x_max": xmax, + "method": method, + "locked": method == "otsu", + }) + if direction == "above": mask = (data >= t).astype(np.uint8) * 255 else: mask = (data < t).astype(np.uint8) * 255 - overlay = _mask_overlay(field, mask) - emit_preview(encode_preview(overlay)) + emit_preview(encode_preview(_mask_overlay(field, mask))) - return (mask,) + table = RecordTable([ + {"quantity": "threshold", "value": threshold, "unit": field.si_unit_xy}, + ]) + + return (mask, table) diff --git a/backend/nodes/ibw_note.py b/backend/nodes/note.py similarity index 100% rename from backend/nodes/ibw_note.py rename to backend/nodes/note.py diff --git a/backend/nodes/rotate_field.py b/backend/nodes/rotate.py similarity index 100% rename from backend/nodes/rotate_field.py rename to backend/nodes/rotate.py diff --git a/backend/nodes/save_image.py b/backend/nodes/save_layers.py similarity index 100% rename from backend/nodes/save_image.py rename to backend/nodes/save_layers.py diff --git a/backend/nodes/statistics_node.py b/backend/nodes/statistics.py similarity index 100% rename from backend/nodes/statistics_node.py rename to backend/nodes/statistics.py diff --git a/demo b/demo index 124b84c..0e24a1e 160000 --- a/demo +++ b/demo @@ -1 +1 @@ -Subproject commit 124b84ca7c79895dd9937e1d5ec553b29a9d5552 +Subproject commit 0e24a1eb540283bea7a087bec41b4de411e4d657 diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 40e961f..3ac0880 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -8,6 +8,7 @@ const CropBoxOverlay = lazy(() => import('./CropBoxOverlay')); const MaskPaintOverlay = lazy(() => import('./MaskPaintOverlay')); const MarkupOverlay = lazy(() => import('./MarkupOverlay')); const AngleMeasureOverlay = lazy(() => import('./AngleMeasureOverlay')); +const ThresholdHistogram = lazy(() => import('./ThresholdHistogram')); import { getSpecTypeAndOptions, isDataSocketSpec, SOCKET_WIDGET_TYPES, TYPE_COLORS, CAT_COLORS, @@ -971,6 +972,9 @@ function CustomNode({ id, data }) { visibleInputNames.add(name); } else if (opts?.hidden) { hiddenWidgets.add(name); + } else if (opts?.socket_only) { + dataInputs.push({ name, type, label: formatUiLabel(opts?.label || name) }); + visibleInputNames.add(name); } else { widgets.push({ name, type, opts: opts || {}, socketType: SOCKET_WIDGET_TYPES.has(type) ? type : null }); } @@ -1079,6 +1083,7 @@ function CustomNode({ id, data }) { hiddenWidgets.has('x1') || data.overlay.kind === 'mask_paint' || data.overlay.kind === 'markup' + || data.overlay.kind === 'threshold_histogram' ); const hidePreviewForInteractiveMask = data.overlay?.kind === 'mask_paint' || data.overlay?.kind === 'markup'; const overlayTitle = data.overlay?.section_title @@ -1286,6 +1291,21 @@ function CustomNode({ id, data }) { )} + {/* Threshold histogram — rendered before preview so it sits above the mask image */} + {data.overlay?.kind === 'threshold_histogram' && ( + + Loading...}> + + + + )} + {/* Collapsible preview image */} {data.previewImage && !hidePreviewForInteractiveMask @@ -1313,7 +1333,7 @@ function CustomNode({ id, data }) { )} {/* Interactive cross-section overlay */} - {hasInteractiveOverlay && ( + {hasInteractiveOverlay && data.overlay?.kind !== 'threshold_histogram' && ( Loading...}> {data.overlay.kind === 'line_plot' ? ( @@ -1371,6 +1391,13 @@ function CustomNode({ id, data }) { nodeId={id} onWidgetChange={ctx.onWidgetChange} /> + ) : data.overlay.kind === 'threshold_histogram' ? ( + ) : data.overlay.kind === 'angle_measure' ? ( - {formatTick(tick)} + {formatTick(tick / xScale)} ); })} + {xUnitLabel && ( + {xUnitLabel} + )} {yTicks.map((tick) => { const y = scaleY(tick); diff --git a/frontend/src/SurfaceView.jsx b/frontend/src/SurfaceView.jsx index 7527fd6..537eaf2 100644 --- a/frontend/src/SurfaceView.jsx +++ b/frontend/src/SurfaceView.jsx @@ -462,9 +462,59 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal } }, [applyCameraState, decode, meshData, scheduleViewportSync, updateDiagnostics]); - // Prevent scroll events from propagating to React Flow - const onWheel = useCallback((e) => { - e.stopPropagation(); + // Gesture-aware wheel handling: only capture scroll when it started inside the view. + // Uses capture phase to disable OrbitControls zoom before it fires when gesture started outside. + useEffect(() => { + const container = containerRef.current; + if (!container) return; + + const onEnter = () => { + isInsideRef.current = true; + pointerEnteredAtRef.current = Date.now(); + }; + const onLeave = () => { + isInsideRef.current = false; + }; + + // Capture phase: fires before OrbitControls on renderer.domElement + const onWheelCapture = () => { + const now = Date.now(); + const msSinceLastWheel = now - lastWheelAtRef.current; + const msSinceEnter = now - pointerEnteredAtRef.current; + lastWheelAtRef.current = now; + + if (msSinceLastWheel > 300) { + gestureStartedInsideRef.current = isInsideRef.current && msSinceEnter > 100; + } + + // Gesture started outside — disable OrbitControls zoom so it doesn't intercept + if (!gestureStartedInsideRef.current && threeRef.current) { + threeRef.current.controls.enableZoom = false; + } + }; + + // Bubble phase: fires after OrbitControls has already run (or skipped due to enableZoom=false) + const onWheelBubble = (e) => { + if (threeRef.current) { + threeRef.current.controls.enableZoom = true; + } + if (gestureStartedInsideRef.current) { + e.stopPropagation(); // prevent React Flow from panning when interacting with the 3D view + } + // else: let event propagate to React Flow so canvas panning continues + }; + + container.addEventListener('wheel', onWheelCapture, { capture: true, passive: true }); + container.addEventListener('wheel', onWheelBubble, { passive: false }); + container.addEventListener('pointerenter', onEnter); + container.addEventListener('pointerleave', onLeave); + + return () => { + container.removeEventListener('wheel', onWheelCapture, { capture: true }); + container.removeEventListener('wheel', onWheelBubble); + container.removeEventListener('pointerenter', onEnter); + container.removeEventListener('pointerleave', onLeave); + }; }, []); const onContextMenu = useCallback((e) => { @@ -476,8 +526,7 @@ export default function SurfaceView({ meshData, nodeId, widgetValues, runtimeVal
{showDiagnostics ? ( diff --git a/frontend/src/ThresholdHistogram.jsx b/frontend/src/ThresholdHistogram.jsx new file mode 100644 index 0000000..29e254b --- /dev/null +++ b/frontend/src/ThresholdHistogram.jsx @@ -0,0 +1,220 @@ +import React, { useEffect, useRef, useState, useCallback } from 'react'; +import { getAxisScale } from './valueFormatting'; + +const ASPECT_RATIO = 3.2 / 2.2; +const MARGINS = { top: 18, right: 16, bottom: 34, left: 56 }; + +function clamp(v, min, max) { return Math.max(min, Math.min(max, v)); } +function round4(v) { return parseFloat(v.toFixed(4)); } +function trimZeros(t) { return t.replace(/(?:\.0+|(\.\d+?)0+)$/, '$1'); } + +function formatTick(value) { + const abs = Math.abs(value); + if (abs === 0) return '0'; + if (abs >= 1e4 || abs < 1e-3) return value.toExponential(1).replace('e+', 'e'); + if (abs >= 100) return trimZeros(value.toFixed(0)); + if (abs >= 10) return trimZeros(value.toFixed(1)); + if (abs >= 1) return trimZeros(value.toFixed(2)); + return trimZeros(value.toFixed(3)); +} + +function makeTicks(min, max, count = 5) { + if (!Number.isFinite(min) || !Number.isFinite(max) || min === max) return [min]; + return Array.from({ length: count }, (_, i) => min + (max - min) * i / (count - 1)); +} + +function getExtent(values, fallbackMin = 0, fallbackMax = 1) { + if (!Array.isArray(values) || !values.length) return [fallbackMin, fallbackMax]; + let min = Infinity, max = -Infinity; + for (const v of values) { if (Number.isFinite(v)) { if (v < min) min = v; if (v > max) max = v; } } + return (Number.isFinite(min) && Number.isFinite(max)) ? [min, max] : [fallbackMin, fallbackMax]; +} + +export default function ThresholdHistogram({ overlay, threshold, thresholdConnected, nodeId, onWidgetChange }) { + const containerRef = useRef(null); + const [dragging, setDragging] = useState(false); + const [size, setSize] = useState({ width: 0 }); + + useEffect(() => { + if (!containerRef.current) return undefined; + const update = () => { + if (!containerRef.current) return; + setSize({ width: Math.max(1, Math.round(containerRef.current.clientWidth || 320)) }); + }; + update(); + if (typeof ResizeObserver === 'function') { + const ro = new ResizeObserver((entries) => { + const e = entries[0]; + if (e) setSize({ width: Math.max(1, Math.round(e.contentRect.width)) }); + }); + ro.observe(containerRef.current); + return () => ro.disconnect(); + } + window.addEventListener('resize', update); + return () => window.removeEventListener('resize', update); + }, []); + + const xValues = Array.isArray(overlay?.x_axis) && overlay.x_axis.length === overlay.line?.length + ? overlay.x_axis : overlay?.line?.map((_, i) => i) || []; + const yValues = Array.isArray(overlay?.line) ? overlay.line : []; + const method = overlay?.method ?? 'absolute'; + const locked = (overlay?.locked ?? false) || !!thresholdConnected; + const xMin = overlay?.x_min ?? 0; + const xMax = overlay?.x_max ?? 1; + + const width = size.width || 320; + const height = Math.round(width / ASPECT_RATIO); + const plotLeft = MARGINS.left; + const plotTop = MARGINS.top; + const plotWidth = Math.max(1, width - MARGINS.left - MARGINS.right); + const plotHeight = Math.max(1, height - MARGINS.top - MARGINS.bottom); + + const [xExtMin, xExtMax] = getExtent(xValues, 0, 1); + const [yMinRaw, yMaxRaw] = getExtent(yValues, 0, 1); + const yPad = yMinRaw === yMaxRaw ? 1 : (yMaxRaw - yMinRaw) * 0.08; + const yMin = yMinRaw - yPad; + const yMax = yMaxRaw + yPad; + + const scaleX = useCallback((v) => { + if (xExtMax === xExtMin) return plotLeft + plotWidth / 2; + return plotLeft + (v - xExtMin) / (xExtMax - xExtMin) * plotWidth; + }, [plotLeft, plotWidth, xExtMin, xExtMax]); + + const scaleY = useCallback((v) => { + if (yMax === yMin) return plotTop + plotHeight / 2; + return plotTop + (1 - (v - yMin) / (yMax - yMin)) * plotHeight; + }, [plotTop, plotHeight, yMin, yMax]); + + // Compute marker x-fraction from current threshold widget value + const markerFrac = (() => { + if (locked) return clamp(overlay?.threshold_frac ?? 0.5, 0, 1); + const t = threshold ?? 0; + if (method === 'relative') return clamp(t, 0, 1); + return (xMax === xMin) ? 0.5 : clamp((t - xMin) / (xMax - xMin), 0, 1); + })(); + + const markerX = plotLeft + markerFrac * plotWidth; + + // Snap marker circle to histogram line height + const markerY = (() => { + if (!xValues.length || !yValues.length) return plotTop + plotHeight / 2; + const targetX = xExtMin + markerFrac * (xExtMax - xExtMin); + let best = 0, bestDist = Infinity; + for (let i = 0; i < xValues.length; i++) { + const d = Math.abs(xValues[i] - targetX); + if (d < bestDist) { bestDist = d; best = i; } + } + return scaleY(yValues[best]); + })(); + + const handleDrag = useCallback((e) => { + if (!onWidgetChange || !nodeId || locked || !containerRef.current) return; + const rect = containerRef.current.getBoundingClientRect(); + const frac = clamp((e.clientX - rect.left - plotLeft) / plotWidth, 0, 1); + // Relative threshold is a 0-1 fraction — round to 4 dp is fine. + // Absolute threshold is in SI units (could be nm/m scale) — keep full float precision. + const newThreshold = method === 'relative' + ? round4(frac) + : xMin + frac * (xMax - xMin); + onWidgetChange(nodeId, 'threshold', newThreshold); + }, [onWidgetChange, nodeId, locked, plotLeft, plotWidth, method, xMin, xMax]); + + const onPointerDown = useCallback((e) => { + if (locked) return; + e.preventDefault(); + e.stopPropagation(); + e.currentTarget.setPointerCapture(e.pointerId); + setDragging(true); + }, [locked]); + + const onPointerMove = useCallback((e) => { + if (dragging) handleDrag(e); + }, [dragging, handleDrag]); + + const onPointerUp = useCallback(() => setDragging(false), []); + + const path = yValues.map((y, i) => `${i === 0 ? 'M' : 'L'} ${scaleX(xValues[i])} ${scaleY(y)}`).join(' '); + const xTickCount = Math.max(2, Math.min(5, Math.floor(plotWidth / 70))); + const yTickCount = Math.max(2, Math.min(5, Math.floor(plotHeight / 40))); + const xTicks = makeTicks(xExtMin, xExtMax, xTickCount); + const yTicks = makeTicks(yMin, yMax, yTickCount); + const xRepresentative = Math.max(Math.abs(xExtMin), Math.abs(xExtMax)); + const { scale: xScale, unitLabel: xUnitLabel } = getAxisScale(xRepresentative, overlay?.x_unit); + const plotStroke = clamp(plotWidth / 240, 1.4, 2.6); + const gridStroke = clamp(plotWidth / 900, 0.6, 1.1); + const cursorStroke = clamp(plotWidth / 220, 1.4, 2.2); + const markerRadius = clamp(plotWidth / 42, 5.5, 9); + const markerLabelSize = clamp(plotWidth / 34, 8, 11); + + return ( +
+ {locked && ( +
+ {thresholdConnected ? 'Locked — driven by socket' : 'Locked — Otsu auto-threshold'} +
+ )} + + + + {xTicks.map((tick) => { + const x = scaleX(tick); + return ( + + + {formatTick(tick / xScale)} + + ); + })} + {xUnitLabel && ( + {xUnitLabel} + )} + + {yTicks.map((tick) => { + const y = scaleY(tick); + return ( + + + {formatTick(tick)} + + ); + })} + + + + + {/* Threshold marker line */} + + + {/* Threshold marker circle */} + + + + T + + + +
+ ); +} diff --git a/frontend/src/valueFormatting.js b/frontend/src/valueFormatting.js index e1356e8..2c84aec 100644 --- a/frontend/src/valueFormatting.js +++ b/frontend/src/valueFormatting.js @@ -104,6 +104,21 @@ function choosePrefixExponent(value, power) { return candidates.reduce((best, candidate) => (candidate.absScaled > best.absScaled ? candidate : best)); } +/** + * Given a representative axis value and a unit string, returns the scale factor + * and prefixed unit label to use for a whole axis. + * All tick values should be divided by `scale` before display, and `unitLabel` shown once. + */ +export function getAxisScale(representativeValue, unit) { + if (!unit || typeof representativeValue !== 'number' || !Number.isFinite(representativeValue) || representativeValue === 0) { + return { scale: 1, unitLabel: unit || '' }; + } + const { valueText, unitText } = applySIPrefix(representativeValue, unit); + const scaled = parseFloat(valueText); + if (!Number.isFinite(scaled) || scaled === 0) return { scale: 1, unitLabel: unit }; + return { scale: representativeValue / scaled, unitLabel: unitText }; +} + export function applySIPrefix(value, unit) { const formattedUnit = formatDisplayUnit(unit); if (typeof value !== 'number' || !Number.isFinite(value)) { diff --git a/tests/test_fft.py b/tests/test_fft.py index c69d593..406a8e6 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -10,7 +10,7 @@ import numpy as np sys.path.insert(0, ".") from backend.data_types import DataField from backend.nodes.fft_2d import FFT2D -from backend.nodes.inverse_fft_2d import InverseFFT2D +from backend.nodes.fft_2d_invert import InverseFFT2D def make_field(data, xreal=1e-6, yreal=1e-6): diff --git a/tests/test_grains.py b/tests/test_grains.py index 8aca46f..39c3c75 100644 --- a/tests/test_grains.py +++ b/tests/test_grains.py @@ -28,7 +28,7 @@ def make_field(data, xreal=1e-6, yreal=1e-6): def test_threshold_otsu_bimodal(): """Otsu on a clean bimodal image should separate the two populations.""" print("=== Test: Otsu on bimodal image ===") - from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_threshold import ThresholdMask node = ThresholdMask() data = np.zeros((128, 128)) @@ -50,7 +50,7 @@ def test_threshold_otsu_bimodal(): def test_threshold_relative_range(): """Relative threshold at 0.5 should be the midpoint of [min, max].""" print("=== Test: Relative threshold at midpoint ===") - from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_threshold import ThresholdMask node = ThresholdMask() data = np.full((64, 64), 2.0) @@ -68,7 +68,7 @@ def test_threshold_relative_range(): def test_threshold_empty_mask(): """Very high absolute threshold on low data should produce an empty mask.""" print("=== Test: Empty mask from high threshold ===") - from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_threshold import ThresholdMask node = ThresholdMask() data = np.ones((64, 64)) @@ -82,7 +82,7 @@ def test_threshold_empty_mask(): def test_threshold_full_mask(): """Very low absolute threshold should produce an all-white mask.""" print("=== Test: Full mask from low threshold ===") - from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_threshold import ThresholdMask node = ThresholdMask() data = np.ones((64, 64)) * 5.0 @@ -316,7 +316,7 @@ def test_adjacent_grains_connectivity(): def test_pipeline_synthetic(): """Full pipeline on a synthetic image with known geometry.""" print("=== Test: Full pipeline on synthetic grains ===") - from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_threshold import ThresholdMask from backend.nodes.grain_analysis import GrainAnalysis N = 200 @@ -372,7 +372,7 @@ def test_pipeline_demo_image(): """Run the full pipeline on the bundled demo nanoparticles image.""" print("=== Test: Full pipeline on demo nanoparticles.npy ===") from pathlib import Path - from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_threshold import ThresholdMask from backend.nodes.grain_analysis import GrainAnalysis from backend.runtime_paths import demo_dir diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 945258a..e48e4a0 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -28,7 +28,7 @@ def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6): def test_gaussian_filter(): print("=== Test: GaussianFilter ===") - from backend.nodes.gaussian_filter import GaussianFilter + from backend.nodes.filter_gaussian import GaussianFilter node = GaussianFilter() field = make_field() @@ -46,7 +46,7 @@ def test_gaussian_filter(): def test_median_filter(): print("=== Test: MedianFilter ===") - from backend.nodes.median_filter import MedianFilter + from backend.nodes.filter_median import MedianFilter node = MedianFilter() # Median filter should remove salt-and-pepper noise @@ -68,7 +68,7 @@ def test_median_filter(): def test_crop_resize_field(): print("=== Test: CropResizeField ===") - from backend.nodes.crop_resize_field import CropResizeField + from backend.nodes.crop_resize import CropResizeField node = CropResizeField() data = np.arange(32, dtype=np.float64).reshape(4, 8) @@ -167,7 +167,7 @@ def test_crop_resize_field(): def test_rotate_field(): print("=== Test: RotateField ===") - from backend.nodes.rotate_field import RotateField + from backend.nodes.rotate import RotateField node = RotateField() data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64) @@ -230,7 +230,7 @@ def test_rotate_field(): def test_rotate_field_overlay_warning(): print("=== Test: RotateField overlay warning ===") - from backend.nodes.rotate_field import RotateField + from backend.nodes.rotate import RotateField node = RotateField() warnings = [] @@ -258,7 +258,7 @@ def test_rotate_field_overlay_warning(): def test_flip_field(): print("=== Test: FlipField ===") - from backend.nodes.flip_field import FlipField + from backend.nodes.flip import FlipField from backend.node_registry import get_node_info node = FlipField() @@ -420,7 +420,7 @@ def test_edge_detect(): def test_fft_filter_1d(): print("=== Test: FFTFilter1D ===") - from backend.nodes.fft_filter_1d import FFTFilter1D + from backend.nodes.filter_fft_1d import FFTFilter1D node = FFTFilter1D() # Signal: low-frequency sine + high-frequency sine @@ -464,7 +464,7 @@ def test_fft_filter_1d(): def test_fft_filter_2d(): print("=== Test: FFTFilter2D ===") - from backend.nodes.fft_filter_2d import FFTFilter2D + from backend.nodes.filter_fft_2d import FFTFilter2D node = FFTFilter2D() N = 128 @@ -506,7 +506,7 @@ def test_fft_filter_2d(): def test_plane_level(): print("=== Test: PlaneLevelField ===") - from backend.nodes.plane_level_field import PlaneLevelField + from backend.nodes.level_plane import PlaneLevelField node = PlaneLevelField() # Create a tilted plane + small signal @@ -554,8 +554,8 @@ def test_plane_level(): def test_facet_level(): print("=== Test: FacetLevelField ===") from backend.node_registry import get_node_info - from backend.nodes.facet_level_field import FacetLevelField - from backend.nodes.plane_level_field import PlaneLevelField + from backend.nodes.level_facet import FacetLevelField + from backend.nodes.level_plane import PlaneLevelField def fit_pixel_plane(data: np.ndarray, region: np.ndarray) -> tuple[float, float, float]: yy, xx = np.mgrid[0:data.shape[0], 0:data.shape[1]] @@ -628,7 +628,7 @@ def test_facet_level(): def test_poly_level(): print("=== Test: PolyLevelField ===") - from backend.nodes.poly_level_field import PolyLevelField + from backend.nodes.level_poly import PolyLevelField node = PolyLevelField() N = 64 @@ -966,7 +966,7 @@ def test_angle_measure(): def test_statistics(): print("=== Test: Statistics ===") - from backend.nodes.statistics_node import Statistics + from backend.nodes.statistics import Statistics node = Statistics() data = np.array([[1, 2], [3, 4]], dtype=np.float64) @@ -1194,7 +1194,7 @@ def test_cross_section(): def test_threshold_mask(): print("=== Test: ThresholdMask ===") - from backend.nodes.threshold_mask import ThresholdMask + from backend.nodes.mask_threshold import ThresholdMask node = ThresholdMask() # Clear bimodal data: left half = 0, right half = 1 @@ -1346,7 +1346,7 @@ def test_mask_operations(): def test_draw_mask(): print("=== Test: DrawMask ===") - from backend.nodes.draw_mask import DrawMask + from backend.nodes.mask_draw import DrawMask node = DrawMask() field = make_field(data=np.zeros((32, 32), dtype=np.float64)) @@ -1582,7 +1582,7 @@ def test_load_file(): def test_save_image(): print("=== Test: SaveImage (Save Layers) ===") - from backend.nodes.save_image import SaveImage + from backend.nodes.save_layers import SaveImage import tifffile node = SaveImage() input_types = SaveImage.INPUT_TYPES() @@ -1686,7 +1686,7 @@ def test_save_image(): def test_color_map_node(): print("=== Test: ColorMap ===") - from backend.nodes.color_map import ColorMap + from backend.nodes.colormap import ColorMap node = ColorMap() @@ -1712,7 +1712,7 @@ def test_color_map_node(): def test_font_node(): print("=== Test: Font ===") - from backend.nodes.font_node import Font + from backend.nodes.font import Font from backend.data_types import CUSTOM_FILE_FONT, SYSTEM_DEFAULT_FONT node = Font() @@ -1796,7 +1796,7 @@ def test_preview_image(): def test_annotations(): print("=== Test: Annotations ===") from backend.nodes.annotations import Annotations - from backend.nodes.font_node import Font + from backend.nodes.font import Font from backend.data_types import ImageData from backend.execution_context import active_node, execution_callbacks