diff --git a/backend/execution.py b/backend/execution.py index 918f4d9..68a9e7e 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -219,10 +219,10 @@ class ExecutionEngine: ) -> None: """Wire up broadcast callbacks on display node classes.""" from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup - from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram + from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, Histogram from backend.nodes.modify import CropResizeField, RotateField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask - from backend.nodes.io import SaveImage, LoadFile, LoadDemo + from backend.nodes.io import SaveImage, Image, ImageDemo PreviewImage._broadcast_fn = on_preview ThresholdMask._broadcast_fn = on_preview @@ -235,26 +235,26 @@ class ExecutionEngine: ValueDisplay._broadcast_value_fn = on_value TableMath._broadcast_value_fn = on_value Stats._broadcast_value_fn = on_value - HeightHistogram._broadcast_overlay_fn = on_overlay + Histogram._broadcast_overlay_fn = on_overlay CrossSection._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay CropResizeField._broadcast_overlay_fn = on_overlay RotateField._broadcast_warning_fn = on_warning Markup._broadcast_overlay_fn = on_overlay - LoadFile._broadcast_warning_fn = on_warning - LoadDemo._broadcast_warning_fn = on_warning + Image._broadcast_warning_fn = on_warning + ImageDemo._broadcast_warning_fn = on_warning SaveImage._broadcast_warning_fn = on_warning def _set_node_id_on_display(self, cls: type, node_id: str) -> None: """Inform display nodes of their current node_id for WS tagging.""" from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup - from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram + from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, Histogram from backend.nodes.modify import CropResizeField, RotateField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask - from backend.nodes.io import LoadFile, LoadDemo, SaveImage - if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup, + from backend.nodes.io import Image, ImageDemo, SaveImage + if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, Histogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup, ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask, - LoadFile, LoadDemo, SaveImage): + Image, ImageDemo, SaveImage): cls._current_node_id = node_id def _auto_preview( @@ -275,12 +275,12 @@ class ExecutionEngine: from backend.data_types import ( DataField, image_to_uint8, encode_preview, render_datafield_preview, ) - from backend.nodes.io import LoadFile, LoadDemo + from backend.nodes.io import Image, ImageDemo if getattr(cls, "_CUSTOM_PREVIEW", False): return - if cls in (LoadFile, LoadDemo) and on_preview: + if cls in (Image, ImageDemo) and on_preview: preview = self._render_load_node_preview(result, inputs or {}) if preview: on_preview(node_id, preview) diff --git a/backend/node_menu.py b/backend/node_menu.py new file mode 100644 index 0000000..71da1f2 --- /dev/null +++ b/backend/node_menu.py @@ -0,0 +1,98 @@ +""" +Central Add Node menu manifest. + +Edit MENU_LAYOUT to rearrange which nodes appear under each menu leaf and +their order within that leaf. Node classes not listed here fall back to their +class CATEGORY. +""" + +from __future__ import annotations + +from typing import Any + + +MENU_LAYOUT: dict[str, list[str]] = { + "Add": [ + "Image", + "ImageDemo", + "Folder", + "ColorMap", + "Number", + "RangeSlider", + "Coordinate", + "Font", + ], + "Output": [ + "PreviewImage", + "SaveImage", + "View3D", + "PrintTable", + "ValueDisplay", + ], + "Overlay": [ + "Markup", + "Annotations", + ], + "Modify": [ + "ColormapAdjust", + "CropResizeField", + "RotateField", + ], + "Filter": [ + "GaussianFilter", + "MedianFilter", + "EdgeDetect", + "FFTFilter1D", + "FFTFilter2D", + ], + "Frequency": [ + "FFT2D", + "InverseFFT2D", + ], + "Flatten": [ + "PlaneLevelField", + "PolyLevelField", + "FixZero", + ], + "Measure": [ + "Statistics", + "Histogram", + "LineCursors", + "CrossSection", + "Stats", + ], + "Mask": [ + "DrawMask", + "ThresholdMask", + "MaskMorphology", + "MaskInvert", + "MaskCombine", + ], + "Particles": [ + "ParticleAnalysis", + ], +} + + +_CATEGORY_ORDER = {category: index for index, category in enumerate(MENU_LAYOUT)} +_NODE_METADATA: dict[str, dict[str, Any]] = {} +for category, class_names in MENU_LAYOUT.items(): + for node_order, class_name in enumerate(class_names): + _NODE_METADATA[class_name] = { + "category": category, + "category_order": _CATEGORY_ORDER[category], + "menu_order": node_order, + } + + +def get_menu_metadata(class_name: str, fallback_category: str = "uncategorized") -> dict[str, Any]: + metadata = _NODE_METADATA.get(class_name) + if metadata is not None: + return dict(metadata) + + fallback_order = _CATEGORY_ORDER.get(fallback_category, len(_CATEGORY_ORDER)) + return { + "category": fallback_category, + "category_order": fallback_order, + "menu_order": 10_000, + } diff --git a/backend/node_registry.py b/backend/node_registry.py index 594d2e5..8b2a861 100644 --- a/backend/node_registry.py +++ b/backend/node_registry.py @@ -9,6 +9,8 @@ the execution engine and the /nodes REST endpoint. from __future__ import annotations from typing import Any +from backend.node_menu import get_menu_metadata + NODE_CLASS_MAPPINGS: dict[str, type] = {} NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {} @@ -37,11 +39,14 @@ def get_node_info(class_name: str) -> dict[str, Any]: """ cls = NODE_CLASS_MAPPINGS[class_name] input_types: dict = cls.INPUT_TYPES() + menu_metadata = get_menu_metadata(class_name, getattr(cls, "CATEGORY", "uncategorized")) return { "name": class_name, "display_name": NODE_DISPLAY_NAME_MAPPINGS.get(class_name, class_name), - "category": getattr(cls, "CATEGORY", "uncategorized"), + "category": menu_metadata["category"], + "category_order": menu_metadata["category_order"], + "menu_order": menu_metadata["menu_order"], "input": input_types, "input_order": {k: list(v.keys()) for k, v in input_types.items()}, "output": list(cls.RETURN_TYPES), diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index 6229226..1e0240f 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -2,8 +2,8 @@ Analysis nodes — statistics, histograms, FFT, cross sections. Gwyddion equivalents: - StatisticsNode → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h) - HeightHistogram → DH (height distribution), gwy_data_field_dh + Statistics → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h) + Histogram → DH (height distribution), gwy_data_field_dh FFT2D → gwy_data_field_2dfft + gwy_data_field_2dpsdf CrossSection → gwy_data_field_get_profile (libprocess/datafield.c) """ @@ -16,11 +16,11 @@ from backend.data_types import DataField, MeasureTable, RecordTable, datafield_t # --------------------------------------------------------------------------- -# StatisticsNode +# Statistics # --------------------------------------------------------------------------- @register_node(display_name="Statistics") -class StatisticsNode: +class Statistics: @classmethod def INPUT_TYPES(cls): return { @@ -59,11 +59,11 @@ class StatisticsNode: # --------------------------------------------------------------------------- -# HeightHistogram +# Histogram # --------------------------------------------------------------------------- @register_node(display_name="Height Histogram") -class HeightHistogram: +class Histogram: @classmethod def INPUT_TYPES(cls): return { @@ -130,9 +130,9 @@ class HeightHistogram: yb = float(counts[idx_b]) if len(counts) else 0.0 count_unit = "count" if y_scale == "linear" else "log10(1+count)" - if HeightHistogram._broadcast_overlay_fn is not None: - HeightHistogram._broadcast_overlay_fn( - HeightHistogram._current_node_id, + if Histogram._broadcast_overlay_fn is not None: + Histogram._broadcast_overlay_fn( + Histogram._current_node_id, { "kind": "line_plot", "section_title": "Histogram", @@ -754,36 +754,6 @@ def _op_da(z): return float(np.mean(np.abs(np.diff(z)))) -@register_node(display_name="Line Math") -class LineMath: - """Compute a single scalar value from a LINE profile.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "line": ("LINE",), - "operation": (list(LINE_OPS.keys()),), - } - } - - RETURN_TYPES = ("MEASURE_TABLE",) - RETURN_NAMES = ("result",) - FUNCTION = "process" - CATEGORY = "analysis" - DESCRIPTION = ( - "Compute a single scalar measurement from a LINE profile. " - "Includes basic stats and Gwyddion-convention roughness parameters." - ) - - def process(self, line, operation: str) -> tuple: - z = np.asarray(line, dtype=np.float64).ravel() - fn, unit = LINE_OPS[operation] - value = fn(z) - table = MeasureTable([{"quantity": operation, "value": value, "unit": unit}]) - return (table,) - - # --------------------------------------------------------------------------- # TableMath — scalar measurement from a numeric record-table column # --------------------------------------------------------------------------- @@ -869,56 +839,6 @@ def _scalar_payload(value: float, unit: str = "") -> dict: return payload -@register_node(display_name="Table Math") -class TableMath: - """Compute a scalar reduction over one numeric column in a record table.""" - - _broadcast_value_fn = None - _current_node_id: str = "" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "table": ("RECORD_TABLE",), - "column": ("STRING", { - "default": "value", - "choices_from_table_input": "table", - }), - "operation": (list(TABLE_OPS.keys()),), - } - } - - RETURN_TYPES = ("FLOAT",) - RETURN_NAMES = ("value",) - FUNCTION = "process" - CATEGORY = "analysis" - DESCRIPTION = ( - "Compute a scalar reduction over one numeric record-table column. " - "Useful for max, min, avg, median, sum, range, std, variance, and count." - ) - - def process(self, table: list, column: str, operation: str) -> tuple: - if isinstance(table, MeasureTable): - raise ValueError("Table Math only accepts record tables, not measurement tables.") - if not isinstance(table, list) or not table: - raise ValueError("Table Math requires a non-empty record table input.") - - column_name = resolve_table_column_name(table, column) - values = extract_numeric_table_values(table, column_name) - if not values: - raise ValueError(f"Column '{column_name}' has no numeric values.") - - op = TABLE_OPS.get(operation) - if op is None: - raise ValueError(f"Unsupported table operation: {operation}") - - result = op(np.asarray(values, dtype=np.float64)) - if TableMath._broadcast_value_fn is not None: - TableMath._broadcast_value_fn(TableMath._current_node_id, result) - return (result,) - - def extract_numeric_table_values(table: list, column: str) -> list[float]: values = [] for row in table: diff --git a/backend/nodes/io.py b/backend/nodes/io.py index ca9b7d4..26ad6f4 100644 --- a/backend/nodes/io.py +++ b/backend/nodes/io.py @@ -125,11 +125,11 @@ def list_folder_paths(folderpath: str) -> list[dict]: # --------------------------------------------------------------------------- -# LoadFile (unified loader — replaces LoadImage + LoadSPM) +# Image (unified loader — replaces LoadImage + LoadSPM) # --------------------------------------------------------------------------- -@register_node(display_name="Load File") -class LoadFile: +@register_node(display_name="Image") +class Image: @classmethod def INPUT_TYPES(cls): return { @@ -185,8 +185,8 @@ class LoadFile: return (field,) def _send_warning(self, message: str): - fn = LoadFile._broadcast_warning_fn - nid = LoadFile._current_node_id + fn = Image._broadcast_warning_fn + nid = Image._current_node_id if fn and nid: fn(nid, message) @@ -353,7 +353,7 @@ class LoadFile: # --------------------------------------------------------------------------- -# LoadDemo +# ImageDemo # --------------------------------------------------------------------------- def _list_demo_files() -> list[str]: @@ -366,8 +366,8 @@ def _list_demo_files() -> list[str]: ) -@register_node(display_name="Load Demo File") -class LoadDemo: +@register_node(display_name="Image (Demo)") +class ImageDemo: @classmethod def INPUT_TYPES(cls): choices = _list_demo_files() or ["(no demo files found)"] @@ -388,7 +388,7 @@ class LoadDemo: DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data." def load(self, name: str = "", colormap: str = "viridis", colormap_map=None): - loader = LoadFile() + loader = Image() demo_path = DEMO_DIR / name if not demo_path.exists(): raise FileNotFoundError(f"Demo file not found: {name}") diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index bdb0f27..f2cd427 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -80,6 +80,23 @@ function sameStringArray(a = [], b = []) { return a.every((item, index) => item === b[index]); } +function compareMenuNodes(a, b) { + const orderA = Number.isFinite(a?.def?.menu_order) ? a.def.menu_order : Number.MAX_SAFE_INTEGER; + const orderB = Number.isFinite(b?.def?.menu_order) ? b.def.menu_order : Number.MAX_SAFE_INTEGER; + if (orderA !== orderB) return orderA - orderB; + + const nameA = (a?.def?.display_name || a?.className || '').toLowerCase(); + const nameB = (b?.def?.display_name || b?.className || '').toLowerCase(); + return nameA.localeCompare(nameB); +} + +function compareMenuCategories(a, b) { + const orderA = Number.isFinite(a?.order) ? a.order : Number.MAX_SAFE_INTEGER; + const orderB = Number.isFinite(b?.order) ? b.order : Number.MAX_SAFE_INTEGER; + if (orderA !== orderB) return orderA - orderB; + return String(a?.name || '').localeCompare(String(b?.name || '')); +} + function socketTypesCompatible(sourceType, targetType) { if (sourceType === targetType) return true; const accepted = SOCKET_COMPATIBILITY[targetType]; @@ -272,10 +289,25 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti } } const cat = def.category || 'uncategorized'; - if (!cats[cat]) cats[cat] = []; - cats[cat].push({ className, def }); + if (!cats[cat]) { + cats[cat] = { + name: cat, + order: Number.isFinite(def.category_order) ? def.category_order : Number.MAX_SAFE_INTEGER, + items: [], + }; + } + cats[cat].order = Math.min( + cats[cat].order, + Number.isFinite(def.category_order) ? def.category_order : Number.MAX_SAFE_INTEGER, + ); + cats[cat].items.push({ className, def }); } - return cats; + return Object.values(cats) + .map((category) => ({ + ...category, + items: [...category.items].sort(compareMenuNodes), + })) + .sort(compareMenuCategories); }, [nodeDefs, filterType, filterDirection]); // Flat filtered list for search @@ -283,8 +315,8 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti if (!search.trim()) return null; const q = search.toLowerCase(); const results = []; - for (const items of Object.values(categories)) { - for (const { className, def } of items) { + for (const category of categories) { + for (const { className, def } of category.items) { const name = (def.display_name || className).toLowerCase(); if (name.includes(q)) results.push({ className, def }); } @@ -341,7 +373,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti setOpenCat(cat); }, []); - if (Object.keys(categories).length === 0) { + if (categories.length === 0) { return (