diff --git a/backend/execution.py b/backend/execution.py index 918f4d9..68a9e7e 100644 --- a/backend/execution.py +++ b/backend/execution.py @@ -219,10 +219,10 @@ class ExecutionEngine: ) -> None: """Wire up broadcast callbacks on display node classes.""" from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup - from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram + from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, Histogram from backend.nodes.modify import CropResizeField, RotateField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask - from backend.nodes.io import SaveImage, LoadFile, LoadDemo + from backend.nodes.io import SaveImage, Image, ImageDemo PreviewImage._broadcast_fn = on_preview ThresholdMask._broadcast_fn = on_preview @@ -235,26 +235,26 @@ class ExecutionEngine: ValueDisplay._broadcast_value_fn = on_value TableMath._broadcast_value_fn = on_value Stats._broadcast_value_fn = on_value - HeightHistogram._broadcast_overlay_fn = on_overlay + Histogram._broadcast_overlay_fn = on_overlay CrossSection._broadcast_overlay_fn = on_overlay LineCursors._broadcast_overlay_fn = on_overlay CropResizeField._broadcast_overlay_fn = on_overlay RotateField._broadcast_warning_fn = on_warning Markup._broadcast_overlay_fn = on_overlay - LoadFile._broadcast_warning_fn = on_warning - LoadDemo._broadcast_warning_fn = on_warning + Image._broadcast_warning_fn = on_warning + ImageDemo._broadcast_warning_fn = on_warning SaveImage._broadcast_warning_fn = on_warning def _set_node_id_on_display(self, cls: type, node_id: str) -> None: """Inform display nodes of their current node_id for WS tagging.""" from backend.nodes.display import PreviewImage, PrintTable, View3D, ValueDisplay, Markup - from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, HeightHistogram + from backend.nodes.analysis import CrossSection, LineCursors, TableMath, Stats, Histogram from backend.nodes.modify import CropResizeField, RotateField from backend.nodes.mask import ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask - from backend.nodes.io import LoadFile, LoadDemo, SaveImage - if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, HeightHistogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup, + from backend.nodes.io import Image, ImageDemo, SaveImage + if cls in (PreviewImage, PrintTable, View3D, ValueDisplay, TableMath, Stats, Histogram, CrossSection, LineCursors, CropResizeField, RotateField, Markup, ThresholdMask, MaskMorphology, MaskInvert, MaskCombine, DrawMask, - LoadFile, LoadDemo, SaveImage): + Image, ImageDemo, SaveImage): cls._current_node_id = node_id def _auto_preview( @@ -275,12 +275,12 @@ class ExecutionEngine: from backend.data_types import ( DataField, image_to_uint8, encode_preview, render_datafield_preview, ) - from backend.nodes.io import LoadFile, LoadDemo + from backend.nodes.io import Image, ImageDemo if getattr(cls, "_CUSTOM_PREVIEW", False): return - if cls in (LoadFile, LoadDemo) and on_preview: + if cls in (Image, ImageDemo) and on_preview: preview = self._render_load_node_preview(result, inputs or {}) if preview: on_preview(node_id, preview) diff --git a/backend/node_menu.py b/backend/node_menu.py new file mode 100644 index 0000000..71da1f2 --- /dev/null +++ b/backend/node_menu.py @@ -0,0 +1,98 @@ +""" +Central Add Node menu manifest. + +Edit MENU_LAYOUT to rearrange which nodes appear under each menu leaf and +their order within that leaf. Node classes not listed here fall back to their +class CATEGORY. +""" + +from __future__ import annotations + +from typing import Any + + +MENU_LAYOUT: dict[str, list[str]] = { + "Add": [ + "Image", + "ImageDemo", + "Folder", + "ColorMap", + "Number", + "RangeSlider", + "Coordinate", + "Font", + ], + "Output": [ + "PreviewImage", + "SaveImage", + "View3D", + "PrintTable", + "ValueDisplay", + ], + "Overlay": [ + "Markup", + "Annotations", + ], + "Modify": [ + "ColormapAdjust", + "CropResizeField", + "RotateField", + ], + "Filter": [ + "GaussianFilter", + "MedianFilter", + "EdgeDetect", + "FFTFilter1D", + "FFTFilter2D", + ], + "Frequency": [ + "FFT2D", + "InverseFFT2D", + ], + "Flatten": [ + "PlaneLevelField", + "PolyLevelField", + "FixZero", + ], + "Measure": [ + "Statistics", + "Histogram", + "LineCursors", + "CrossSection", + "Stats", + ], + "Mask": [ + "DrawMask", + "ThresholdMask", + "MaskMorphology", + "MaskInvert", + "MaskCombine", + ], + "Particles": [ + "ParticleAnalysis", + ], +} + + +_CATEGORY_ORDER = {category: index for index, category in enumerate(MENU_LAYOUT)} +_NODE_METADATA: dict[str, dict[str, Any]] = {} +for category, class_names in MENU_LAYOUT.items(): + for node_order, class_name in enumerate(class_names): + _NODE_METADATA[class_name] = { + "category": category, + "category_order": _CATEGORY_ORDER[category], + "menu_order": node_order, + } + + +def get_menu_metadata(class_name: str, fallback_category: str = "uncategorized") -> dict[str, Any]: + metadata = _NODE_METADATA.get(class_name) + if metadata is not None: + return dict(metadata) + + fallback_order = _CATEGORY_ORDER.get(fallback_category, len(_CATEGORY_ORDER)) + return { + "category": fallback_category, + "category_order": fallback_order, + "menu_order": 10_000, + } diff --git a/backend/node_registry.py b/backend/node_registry.py index 594d2e5..8b2a861 100644 --- a/backend/node_registry.py +++ b/backend/node_registry.py @@ -9,6 +9,8 @@ the execution engine and the /nodes REST endpoint. from __future__ import annotations from typing import Any +from backend.node_menu import get_menu_metadata + NODE_CLASS_MAPPINGS: dict[str, type] = {} NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {} @@ -37,11 +39,14 @@ def get_node_info(class_name: str) -> dict[str, Any]: """ cls = NODE_CLASS_MAPPINGS[class_name] input_types: dict = cls.INPUT_TYPES() + menu_metadata = get_menu_metadata(class_name, getattr(cls, "CATEGORY", "uncategorized")) return { "name": class_name, "display_name": NODE_DISPLAY_NAME_MAPPINGS.get(class_name, class_name), - "category": getattr(cls, "CATEGORY", "uncategorized"), + "category": menu_metadata["category"], + "category_order": menu_metadata["category_order"], + "menu_order": menu_metadata["menu_order"], "input": input_types, "input_order": {k: list(v.keys()) for k, v in input_types.items()}, "output": list(cls.RETURN_TYPES), diff --git a/backend/nodes/analysis.py b/backend/nodes/analysis.py index 6229226..1e0240f 100644 --- a/backend/nodes/analysis.py +++ b/backend/nodes/analysis.py @@ -2,8 +2,8 @@ Analysis nodes — statistics, histograms, FFT, cross sections. Gwyddion equivalents: - StatisticsNode → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h) - HeightHistogram → DH (height distribution), gwy_data_field_dh + Statistics → gwy_data_field_get_min/max/avg/rms (libprocess/stats.h) + Histogram → DH (height distribution), gwy_data_field_dh FFT2D → gwy_data_field_2dfft + gwy_data_field_2dpsdf CrossSection → gwy_data_field_get_profile (libprocess/datafield.c) """ @@ -16,11 +16,11 @@ from backend.data_types import DataField, MeasureTable, RecordTable, datafield_t # --------------------------------------------------------------------------- -# StatisticsNode +# Statistics # --------------------------------------------------------------------------- @register_node(display_name="Statistics") -class StatisticsNode: +class Statistics: @classmethod def INPUT_TYPES(cls): return { @@ -59,11 +59,11 @@ class StatisticsNode: # --------------------------------------------------------------------------- -# HeightHistogram +# Histogram # --------------------------------------------------------------------------- @register_node(display_name="Height Histogram") -class HeightHistogram: +class Histogram: @classmethod def INPUT_TYPES(cls): return { @@ -130,9 +130,9 @@ class HeightHistogram: yb = float(counts[idx_b]) if len(counts) else 0.0 count_unit = "count" if y_scale == "linear" else "log10(1+count)" - if HeightHistogram._broadcast_overlay_fn is not None: - HeightHistogram._broadcast_overlay_fn( - HeightHistogram._current_node_id, + if Histogram._broadcast_overlay_fn is not None: + Histogram._broadcast_overlay_fn( + Histogram._current_node_id, { "kind": "line_plot", "section_title": "Histogram", @@ -754,36 +754,6 @@ def _op_da(z): return float(np.mean(np.abs(np.diff(z)))) -@register_node(display_name="Line Math") -class LineMath: - """Compute a single scalar value from a LINE profile.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "line": ("LINE",), - "operation": (list(LINE_OPS.keys()),), - } - } - - RETURN_TYPES = ("MEASURE_TABLE",) - RETURN_NAMES = ("result",) - FUNCTION = "process" - CATEGORY = "analysis" - DESCRIPTION = ( - "Compute a single scalar measurement from a LINE profile. " - "Includes basic stats and Gwyddion-convention roughness parameters." - ) - - def process(self, line, operation: str) -> tuple: - z = np.asarray(line, dtype=np.float64).ravel() - fn, unit = LINE_OPS[operation] - value = fn(z) - table = MeasureTable([{"quantity": operation, "value": value, "unit": unit}]) - return (table,) - - # --------------------------------------------------------------------------- # TableMath — scalar measurement from a numeric record-table column # --------------------------------------------------------------------------- @@ -869,56 +839,6 @@ def _scalar_payload(value: float, unit: str = "") -> dict: return payload -@register_node(display_name="Table Math") -class TableMath: - """Compute a scalar reduction over one numeric column in a record table.""" - - _broadcast_value_fn = None - _current_node_id: str = "" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "table": ("RECORD_TABLE",), - "column": ("STRING", { - "default": "value", - "choices_from_table_input": "table", - }), - "operation": (list(TABLE_OPS.keys()),), - } - } - - RETURN_TYPES = ("FLOAT",) - RETURN_NAMES = ("value",) - FUNCTION = "process" - CATEGORY = "analysis" - DESCRIPTION = ( - "Compute a scalar reduction over one numeric record-table column. " - "Useful for max, min, avg, median, sum, range, std, variance, and count." - ) - - def process(self, table: list, column: str, operation: str) -> tuple: - if isinstance(table, MeasureTable): - raise ValueError("Table Math only accepts record tables, not measurement tables.") - if not isinstance(table, list) or not table: - raise ValueError("Table Math requires a non-empty record table input.") - - column_name = resolve_table_column_name(table, column) - values = extract_numeric_table_values(table, column_name) - if not values: - raise ValueError(f"Column '{column_name}' has no numeric values.") - - op = TABLE_OPS.get(operation) - if op is None: - raise ValueError(f"Unsupported table operation: {operation}") - - result = op(np.asarray(values, dtype=np.float64)) - if TableMath._broadcast_value_fn is not None: - TableMath._broadcast_value_fn(TableMath._current_node_id, result) - return (result,) - - def extract_numeric_table_values(table: list, column: str) -> list[float]: values = [] for row in table: diff --git a/backend/nodes/io.py b/backend/nodes/io.py index ca9b7d4..26ad6f4 100644 --- a/backend/nodes/io.py +++ b/backend/nodes/io.py @@ -125,11 +125,11 @@ def list_folder_paths(folderpath: str) -> list[dict]: # --------------------------------------------------------------------------- -# LoadFile (unified loader — replaces LoadImage + LoadSPM) +# Image (unified loader — replaces LoadImage + LoadSPM) # --------------------------------------------------------------------------- -@register_node(display_name="Load File") -class LoadFile: +@register_node(display_name="Image") +class Image: @classmethod def INPUT_TYPES(cls): return { @@ -185,8 +185,8 @@ class LoadFile: return (field,) def _send_warning(self, message: str): - fn = LoadFile._broadcast_warning_fn - nid = LoadFile._current_node_id + fn = Image._broadcast_warning_fn + nid = Image._current_node_id if fn and nid: fn(nid, message) @@ -353,7 +353,7 @@ class LoadFile: # --------------------------------------------------------------------------- -# LoadDemo +# ImageDemo # --------------------------------------------------------------------------- def _list_demo_files() -> list[str]: @@ -366,8 +366,8 @@ def _list_demo_files() -> list[str]: ) -@register_node(display_name="Load Demo File") -class LoadDemo: +@register_node(display_name="Image (Demo)") +class ImageDemo: @classmethod def INPUT_TYPES(cls): choices = _list_demo_files() or ["(no demo files found)"] @@ -388,7 +388,7 @@ class LoadDemo: DESCRIPTION = "Load a bundled demo file so you can try the app without providing your own data." def load(self, name: str = "", colormap: str = "viridis", colormap_map=None): - loader = LoadFile() + loader = Image() demo_path = DEMO_DIR / name if not demo_path.exists(): raise FileNotFoundError(f"Demo file not found: {name}") diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index bdb0f27..f2cd427 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -80,6 +80,23 @@ function sameStringArray(a = [], b = []) { return a.every((item, index) => item === b[index]); } +function compareMenuNodes(a, b) { + const orderA = Number.isFinite(a?.def?.menu_order) ? a.def.menu_order : Number.MAX_SAFE_INTEGER; + const orderB = Number.isFinite(b?.def?.menu_order) ? b.def.menu_order : Number.MAX_SAFE_INTEGER; + if (orderA !== orderB) return orderA - orderB; + + const nameA = (a?.def?.display_name || a?.className || '').toLowerCase(); + const nameB = (b?.def?.display_name || b?.className || '').toLowerCase(); + return nameA.localeCompare(nameB); +} + +function compareMenuCategories(a, b) { + const orderA = Number.isFinite(a?.order) ? a.order : Number.MAX_SAFE_INTEGER; + const orderB = Number.isFinite(b?.order) ? b.order : Number.MAX_SAFE_INTEGER; + if (orderA !== orderB) return orderA - orderB; + return String(a?.name || '').localeCompare(String(b?.name || '')); +} + function socketTypesCompatible(sourceType, targetType) { if (sourceType === targetType) return true; const accepted = SOCKET_COMPATIBILITY[targetType]; @@ -272,10 +289,25 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti } } const cat = def.category || 'uncategorized'; - if (!cats[cat]) cats[cat] = []; - cats[cat].push({ className, def }); + if (!cats[cat]) { + cats[cat] = { + name: cat, + order: Number.isFinite(def.category_order) ? def.category_order : Number.MAX_SAFE_INTEGER, + items: [], + }; + } + cats[cat].order = Math.min( + cats[cat].order, + Number.isFinite(def.category_order) ? def.category_order : Number.MAX_SAFE_INTEGER, + ); + cats[cat].items.push({ className, def }); } - return cats; + return Object.values(cats) + .map((category) => ({ + ...category, + items: [...category.items].sort(compareMenuNodes), + })) + .sort(compareMenuCategories); }, [nodeDefs, filterType, filterDirection]); // Flat filtered list for search @@ -283,8 +315,8 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti if (!search.trim()) return null; const q = search.toLowerCase(); const results = []; - for (const items of Object.values(categories)) { - for (const { className, def } of items) { + for (const category of categories) { + for (const { className, def } of category.items) { const name = (def.display_name || className).toLowerCase(); if (name.includes(q)) results.push({ className, def }); } @@ -341,7 +373,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti setOpenCat(cat); }, []); - if (Object.keys(categories).length === 0) { + if (categories.length === 0) { return (
e.stopPropagation()}>
No compatible nodes
@@ -349,7 +381,8 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti ); } - const catNames = Object.keys(categories).sort(); + const catNames = categories.map((category) => category.name); + const categoryMap = Object.fromEntries(categories.map((category) => [category.name, category.items])); return ( <> @@ -411,7 +444,7 @@ function ContextMenu({ x, y, nodeDefs, onAdd, onClose, filterType, filterDirecti
{/* Submenu rendered as a sibling, positioned at computed screen coords */} - {openCat && categories[openCat] && ( + {openCat && categoryMap[openCat] && (
- {categories[openCat].map(({ className, def }) => ( + {categoryMap[openCat].map(({ className, def }) => (
{ - if (node.data.className === 'LoadFile' || node.data.className === 'LoadDemo') { + if (node.data.className === 'Image' || node.data.className === 'ImageDemo') { refreshLoadNodeOutputs(node.id); } }); diff --git a/frontend/src/CustomNode.jsx b/frontend/src/CustomNode.jsx index 5c3336b..6a6bb36 100644 --- a/frontend/src/CustomNode.jsx +++ b/frontend/src/CustomNode.jsx @@ -868,10 +868,10 @@ function CustomNode({ id, data }) { if (data.className === 'Folder') { return getBasename(data.widgetValues?.folder); } - if (data.className === 'LoadFile') { + if (data.className === 'Image') { return getBasename(connectedPathInfo?.path || data.widgetValues?.filename); } - if (data.className === 'LoadDemo') { + if (data.className === 'ImageDemo') { return getBasename(data.widgetValues?.name); } return ''; diff --git a/frontend/src/executionGraph.js b/frontend/src/executionGraph.js index 693c917..06a0bc2 100644 --- a/frontend/src/executionGraph.js +++ b/frontend/src/executionGraph.js @@ -21,14 +21,14 @@ export function getConnectedNodeIds(edges) { } function isPreviewLoadNode(node) { - return ['LoadFile', 'LoadDemo'].includes(node?.data?.className); + return ['Image', 'ImageDemo'].includes(node?.data?.className); } function hasPreviewLoadSelection(node) { - if (node?.data?.className === 'LoadFile') { + if (node?.data?.className === 'Image') { return !!String(node.data?.widgetValues?.filename || '').trim(); } - if (node?.data?.className === 'LoadDemo') { + if (node?.data?.className === 'ImageDemo') { return !!String(node.data?.widgetValues?.name || '').trim(); } return false; diff --git a/frontend/tests/executionGraph.test.mjs b/frontend/tests/executionGraph.test.mjs index d5558d5..1c4c763 100644 --- a/frontend/tests/executionGraph.test.mjs +++ b/frontend/tests/executionGraph.test.mjs @@ -12,7 +12,7 @@ test('serializeExecutionGraph excludes isolated nodes from the backend prompt', { id: '1', data: { - className: 'LoadFile', + className: 'Image', definition: { input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, manual_trigger: false, @@ -34,7 +34,7 @@ test('serializeExecutionGraph excludes isolated nodes from the backend prompt', { id: '3', data: { - className: 'LoadFile', + className: 'Image', definition: { input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, manual_trigger: false, @@ -56,7 +56,7 @@ test('serializeExecutionGraph excludes isolated nodes from the backend prompt', assert.deepEqual(prompt, { '1': { - class_type: 'LoadFile', + class_type: 'Image', inputs: { filename: 'scan.gwy' }, }, '2': { @@ -72,7 +72,7 @@ test('serializeExecutionGraph includes isolated preview-load nodes alongside con { id: '1', data: { - className: 'LoadFile', + className: 'Image', definition: { input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, manual_trigger: false, @@ -94,7 +94,7 @@ test('serializeExecutionGraph includes isolated preview-load nodes alongside con { id: '3', data: { - className: 'LoadDemo', + className: 'ImageDemo', definition: { input: { required: { name: [['demo.npy'], {}] }, optional: {} }, manual_trigger: false, @@ -105,7 +105,7 @@ test('serializeExecutionGraph includes isolated preview-load nodes alongside con { id: '4', data: { - className: 'LoadFile', + className: 'Image', definition: { input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, manual_trigger: false, @@ -127,7 +127,7 @@ test('serializeExecutionGraph includes isolated preview-load nodes alongside con assert.deepEqual(prompt, { '1': { - class_type: 'LoadFile', + class_type: 'Image', inputs: { filename: 'first.gwy' }, }, '2': { @@ -135,19 +135,19 @@ test('serializeExecutionGraph includes isolated preview-load nodes alongside con inputs: { field: ['1', 0] }, }, '3': { - class_type: 'LoadDemo', + class_type: 'ImageDemo', inputs: { name: 'demo.npy' }, }, }); assert.equal('4' in prompt, false); }); -test('serializeExecutionGraph allows a singleton LoadFile graph so previews can run', () => { +test('serializeExecutionGraph allows a singleton Image graph so previews can run', () => { const nodes = [ { id: '1', data: { - className: 'LoadFile', + className: 'Image', definition: { input: { required: { filename: ['FILE_PICKER', {}] }, optional: {} }, manual_trigger: false, @@ -161,18 +161,18 @@ test('serializeExecutionGraph allows a singleton LoadFile graph so previews can assert.deepEqual(prompt, { '1': { - class_type: 'LoadFile', + class_type: 'Image', inputs: { filename: 'scan.gwy' }, }, }); }); -test('serializeExecutionGraph allows a singleton LoadDemo graph so previews can run', () => { +test('serializeExecutionGraph allows a singleton ImageDemo graph so previews can run', () => { const nodes = [ { id: '1', data: { - className: 'LoadDemo', + className: 'ImageDemo', definition: { input: { required: { name: [['demo.npy'], {}] }, optional: {} }, manual_trigger: false, @@ -186,7 +186,7 @@ test('serializeExecutionGraph allows a singleton LoadDemo graph so previews can assert.deepEqual(prompt, { '1': { - class_type: 'LoadDemo', + class_type: 'ImageDemo', inputs: { name: 'demo.npy' }, }, }); @@ -214,10 +214,10 @@ test('getAutoRunnableNodes ignores disconnected nodes when deciding what can aut test('getAutoRunnableNodes includes isolated preview-load nodes with selections', () => { const nodes = [ - { id: '1', data: { className: 'LoadFile', definition: {}, widgetValues: { filename: 'first.gwy' } } }, + { id: '1', data: { className: 'Image', definition: {}, widgetValues: { filename: 'first.gwy' } } }, { id: '2', data: { className: 'PreviewImage', definition: {}, widgetValues: {} } }, - { id: '3', data: { className: 'LoadDemo', definition: {}, widgetValues: { name: 'demo.npy' } } }, - { id: '4', data: { className: 'LoadFile', definition: {}, widgetValues: { filename: '' } } }, + { id: '3', data: { className: 'ImageDemo', definition: {}, widgetValues: { name: 'demo.npy' } } }, + { id: '4', data: { className: 'Image', definition: {}, widgetValues: { filename: '' } } }, ]; const edges = [ { @@ -233,12 +233,12 @@ test('getAutoRunnableNodes includes isolated preview-load nodes with selections' assert.deepEqual(runnable.map((node) => node.id), ['1', '2', '3']); }); -test('getAutoRunnableNodes allows a singleton LoadFile graph', () => { +test('getAutoRunnableNodes allows a singleton Image graph', () => { const nodes = [ { id: '1', data: { - className: 'LoadFile', + className: 'Image', definition: {}, widgetValues: { filename: 'scan.gwy' }, }, @@ -250,12 +250,12 @@ test('getAutoRunnableNodes allows a singleton LoadFile graph', () => { assert.deepEqual(runnable.map((node) => node.id), ['1']); }); -test('getAutoRunnableNodes allows a singleton LoadDemo graph', () => { +test('getAutoRunnableNodes allows a singleton ImageDemo graph', () => { const nodes = [ { id: '1', data: { - className: 'LoadDemo', + className: 'ImageDemo', definition: {}, widgetValues: { name: 'demo.npy' }, }, diff --git a/frontend/tests/workflowSerialization.test.mjs b/frontend/tests/workflowSerialization.test.mjs index f691437..0e0e154 100644 --- a/frontend/tests/workflowSerialization.test.mjs +++ b/frontend/tests/workflowSerialization.test.mjs @@ -103,7 +103,7 @@ test('hydrateWorkflowState clears shared path widgets while restoring saved dyna id: '12', position: { x: 40, y: 80 }, data: { - className: 'LoadFile', + className: 'Image', widgetValues: { filename: 'scan.ibw', colormap: 'viridis' }, output: ['DATA_FIELD', 'DATA_FIELD'], output_name: ['Height', 'Phase'], @@ -123,7 +123,7 @@ test('hydrateWorkflowState clears shared path widgets while restoring saved dyna }; const defs = { - LoadFile: { + Image: { category: 'io', input: { required: { filename: ['FILE_PICKER', {}], colormap: [['viridis', 'gray'], {}] } }, output: ['DATA_FIELD'], @@ -138,13 +138,13 @@ test('hydrateWorkflowState clears shared path widgets while restoring saved dyna assert.deepEqual(hydrated.edges, saved.edges); assert.equal(hydrated.nodes[0].type, 'custom'); assert.equal(hydrated.nodes[0].dragHandle, '.drag-handle'); - assert.equal(hydrated.nodes[0].data.label, 'LoadFile'); + assert.equal(hydrated.nodes[0].data.label, 'Image'); assert.equal(hydrated.nodes[0].data.previewImage, null); assert.equal(hydrated.nodes[0].data.widgetValues.filename, ''); assert.equal(hydrated.nodes[0].data.widgetValues.colormap, 'viridis'); assert.deepEqual(hydrated.nodes[0].data.definition.output, ['DATA_FIELD', 'DATA_FIELD']); assert.deepEqual(hydrated.nodes[0].data.definition.output_name, ['Height', 'Phase']); - assert.deepEqual(hydrated.nodes[0].data.definition.input, defs.LoadFile.input); + assert.deepEqual(hydrated.nodes[0].data.definition.input, defs.Image.input); }); test('serializeWorkflowState and hydrateWorkflowState clear path-like widgets but preserve other metadata', () => { @@ -153,8 +153,8 @@ test('serializeWorkflowState and hydrateWorkflowState clear path-like widgets bu id: '7', position: { x: 10, y: 20 }, data: { - label: 'Load File', - className: 'LoadFile', + label: 'Image', + className: 'Image', widgetValues: { filename: 'scan.gwy', colormap: 'gray' }, definition: { category: 'io', @@ -176,7 +176,7 @@ test('serializeWorkflowState and hydrateWorkflowState clear path-like widgets bu }, ]; const defs = { - LoadFile: { + Image: { category: 'io', input: { required: { filename: ['FILE_PICKER', {}], colormap: [['gray', 'viridis'], {}] } }, output: ['DATA_FIELD'], diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 5a7d463..dd0ab5c 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -476,9 +476,9 @@ def test_fix_zero(): # ========================================================================= def test_statistics(): - print("=== Test: StatisticsNode ===") - from backend.nodes.analysis import StatisticsNode - node = StatisticsNode() + print("=== Test: Statistics ===") + from backend.nodes.analysis import Statistics + node = Statistics() data = np.array([[1, 2], [3, 4]], dtype=np.float64) field = make_field(data=data) @@ -506,17 +506,17 @@ def test_statistics(): def test_height_histogram(): - print("=== Test: HeightHistogram ===") - from backend.nodes.analysis import HeightHistogram - node = HeightHistogram() + print("=== Test: Histogram ===") + from backend.nodes.analysis import Histogram + node = Histogram() # Uniform data should give a roughly flat histogram data = np.linspace(0, 1, 1000).reshape(25, 40) field = make_field(data=data) overlays = [] - HeightHistogram._broadcast_overlay_fn = lambda nid, data: overlays.append(data) - HeightHistogram._current_node_id = "test" + Histogram._broadcast_overlay_fn = lambda nid, data: overlays.append(data) + Histogram._current_node_id = "test" table, = node.process( field, @@ -549,7 +549,7 @@ def test_height_histogram(): measurements["B count"]["value"] - measurements["A count"]["value"], ) - HeightHistogram._broadcast_overlay_fn = None + Histogram._broadcast_overlay_fn = None print(" PASS\n") @@ -829,10 +829,10 @@ def test_particle_analysis(): # ========================================================================= def test_load_file(): - print("=== Test: LoadFile ===") - from backend.nodes.io import LoadFile + print("=== Test: Image ===") + from backend.nodes.io import Image from PIL import Image - node = LoadFile() + node = Image() with tempfile.TemporaryDirectory() as tmpdir: # Test loading a grayscale PNG → single DataField output @@ -1247,10 +1247,10 @@ def test_value_display(): # ========================================================================= def test_load_file_ibw(): - print("=== Test: LoadFile IBW multi-channel ===") - from backend.nodes.io import LoadFile + print("=== Test: Image IBW multi-channel ===") + from backend.nodes.io import Image - node = LoadFile() + node = Image() ibw_path = os.path.join(os.path.dirname(__file__), "..", "demo", "BR_New20012.ibw") ibw_path = os.path.abspath(ibw_path) if not os.path.exists(ibw_path): @@ -1283,10 +1283,10 @@ def test_load_file_ibw(): def test_load_file_npz(): - print("=== Test: LoadFile .npz ===") - from backend.nodes.io import LoadFile + print("=== Test: Image .npz ===") + from backend.nodes.io import Image - node = LoadFile() + node = Image() with tempfile.TemporaryDirectory() as tmpdir: data = np.random.default_rng(99).standard_normal((30, 40)) path = os.path.join(tmpdir, "test.npz") @@ -1300,10 +1300,10 @@ def test_load_file_npz(): def test_load_file_not_found(): - print("=== Test: LoadFile not found ===") - from backend.nodes.io import LoadFile + print("=== Test: Image not found ===") + from backend.nodes.io import Image - node = LoadFile() + node = Image() try: node.load(filename="/nonexistent/path/file.png") assert False, "Should have raised FileNotFoundError" @@ -1314,10 +1314,10 @@ def test_load_file_not_found(): def test_load_file_unsupported(): - print("=== Test: LoadFile unsupported format ===") - from backend.nodes.io import LoadFile + print("=== Test: Image unsupported format ===") + from backend.nodes.io import Image - node = LoadFile() + node = Image() with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "test.xyz") with open(path, "w") as f: @@ -1332,14 +1332,14 @@ def test_load_file_unsupported(): def test_load_file_warning(): - print("=== Test: LoadFile warning for uncalibrated data ===") - from backend.nodes.io import LoadFile + print("=== Test: Image warning for uncalibrated data ===") + from backend.nodes.io import Image from PIL import Image - node = LoadFile() + node = Image() warnings = [] - LoadFile._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) - LoadFile._current_node_id = "test" + Image._broadcast_warning_fn = lambda nid, msg: warnings.append(msg) + Image._current_node_id = "test" with tempfile.TemporaryDirectory() as tmpdir: arr = np.random.default_rng(10).integers(0, 256, (16, 16), dtype=np.uint8) @@ -1352,7 +1352,7 @@ def test_load_file_warning(): assert len(warnings) == 1 assert "Uncalibrated" in warnings[0] - LoadFile._broadcast_warning_fn = None + Image._broadcast_warning_fn = None print(" PASS\n") @@ -1428,14 +1428,14 @@ def test_list_channels(): # ========================================================================= -# I/O — LoadDemo +# I/O — ImageDemo # ========================================================================= def test_load_demo(): - print("=== Test: LoadDemo ===") - from backend.nodes.io import LoadDemo + print("=== Test: ImageDemo ===") + from backend.nodes.io import ImageDemo - node = LoadDemo() + node = ImageDemo() # Should be able to load a demo file by name result = node.load(name="nanoparticles.npy") @@ -1460,14 +1460,14 @@ def test_load_demo(): def test_load_demo_multi_layer_preview_payload(): - print("=== Test: LoadDemo multi-layer preview payload ===") + print("=== Test: ImageDemo multi-layer preview payload ===") from backend.execution import ExecutionEngine import backend.nodes # noqa: F401 previews = [] prompt = { "1": { - "class_type": "LoadDemo", + "class_type": "ImageDemo", "inputs": { "name": "whiskers.ibw", "colormap": "viridis",