rename grains to particle, add colormap adjust, table math

This commit is contained in:
2026-03-24 23:48:03 -07:00
parent edfdead4c1
commit 44de72d31b
12 changed files with 512 additions and 109 deletions

View File

@@ -29,7 +29,7 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
|---|---------|---------------|-------------| |---|---------|---------------|-------------|
| 15 | Correlation / Pattern Matching | crosscor.c, maskcor.c | Find repeated features or align images via cross-correlation. | | 15 | Correlation / Pattern Matching | crosscor.c, maskcor.c | Find repeated features or align images via cross-correlation. |
| 16 | Slope Distribution | slope_dist.c | Angular histogram of surface slopes. Characterizes surface texture directionality. | | 16 | Slope Distribution | slope_dist.c | Angular histogram of surface slopes. Characterizes surface texture directionality. |
| 17 | Grain Filtering | grain_filter.c | Remove grains by size, height, or border contact. Refine grain masks post-detection. | | 17 | Grain Filtering | grain_filter.c | Remove particles by size, height, or border contact. Refine grain masks post-detection. |
| 18 | Field Arithmetic | arithmetic.c | Add/subtract/multiply/divide two DATA_FIELDs. Useful for difference maps, normalization. | | 18 | Field Arithmetic | arithmetic.c | Add/subtract/multiply/divide two DATA_FIELDs. Useful for difference maps, normalization. |
| 19 | Spot Removal | spotremove.c | Interpolate over selected point defects (dust, spikes). | | 19 | Spot Removal | spotremove.c | Interpolate over selected point defects (dust, spikes). |
| 20 | Tip Modeling / Deconvolution | tip_blind.c, tip_model.c | Estimate tip shape from image, deconvolve to recover true surface. | | 20 | Tip Modeling / Deconvolution | tip_blind.c, tip_model.c | Estimate tip shape from image, deconvolve to recover true surface. |
@@ -88,5 +88,5 @@ For reference, these Gwyddion equivalents are already covered:
| Mask Morphology | mask | mask_morph.c (erode, dilate, open, close) | | Mask Morphology | mask | mask_morph.c (erode, dilate, open, close) |
| Mask Invert | mask | — | | Mask Invert | mask | — |
| Mask Combine | mask | — (boolean AND, OR, XOR, subtract) | | Mask Combine | mask | — (boolean AND, OR, XOR, subtract) |
| Particle Analysis | grains | grain_stat.c | | Particle Analysis | particles | grain_stat.c |
| Preview / 3D View / Print Table | display | Presentation, 3D view | | Preview / 3D View / Print Table | display | Presentation, 3D view |

View File

@@ -32,6 +32,8 @@ class DataField:
si_unit_z: str = "m" si_unit_z: str = "m"
domain: str = "spatial" # "spatial" or "frequency" domain: str = "spatial" # "spatial" or "frequency"
colormap: str = "viridis" colormap: str = "viridis"
display_offset: float = 0.0
display_scale: float = 1.0
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.data = np.asarray(self.data, dtype=np.float64) self.data = np.asarray(self.data, dtype=np.float64)
@@ -53,6 +55,8 @@ class DataField:
si_unit_z=self.si_unit_z, si_unit_z=self.si_unit_z,
domain=self.domain, domain=self.domain,
colormap=self.colormap, colormap=self.colormap,
display_offset=self.display_offset,
display_scale=self.display_scale,
) )
def replace(self, **kwargs) -> "DataField": def replace(self, **kwargs) -> "DataField":
@@ -69,6 +73,8 @@ class DataField:
"si_unit_z": self.si_unit_z, "si_unit_z": self.si_unit_z,
"domain": self.domain, "domain": self.domain,
"colormap": self.colormap, "colormap": self.colormap,
"display_offset": self.display_offset,
"display_scale": self.display_scale,
} }
base.update(kwargs) base.update(kwargs)
return DataField(**base) return DataField(**base)
@@ -88,20 +94,51 @@ class DataField:
# Utility helpers shared across nodes # Utility helpers shared across nodes
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def normalize_for_colormap(
data: np.ndarray,
*,
offset: float = 0.0,
scale: float = 1.0,
data_min: float | None = None,
data_max: float | None = None,
) -> np.ndarray:
"""
Normalize an array to [0, 1] for colormap lookup, then apply a display window.
offset/scale operate in normalized data coordinates:
output = clip((base_norm - offset) / scale, 0, 1)
So offset=0, scale=1 maps the full data range 1:1 into the colormap.
"""
data = np.asarray(data, dtype=np.float64)
dmin = float(data.min()) if data_min is None else float(data_min)
dmax = float(data.max()) if data_max is None else float(data_max)
if dmax > dmin:
base_norm = (data - dmin) / (dmax - dmin)
else:
base_norm = np.zeros_like(data)
offset = float(offset)
scale = float(scale)
if not np.isfinite(offset):
offset = 0.0
if not np.isfinite(scale) or scale <= 0.0:
scale = 1.0
return np.clip((base_norm - offset) / scale, 0.0, 1.0)
def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray: def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray:
""" """
Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap. Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap.
Returns shape (H, W, 3) uint8. Returns shape (H, W, 3) uint8.
""" """
import matplotlib.cm as cm import matplotlib.cm as cm
import matplotlib.colors as mcolors normalized = normalize_for_colormap(
df.data,
data = df.data offset=df.display_offset,
dmin, dmax = data.min(), data.max() scale=df.display_scale,
if dmax > dmin: )
normalized = (data - dmin) / (dmax - dmin)
else:
normalized = np.zeros_like(data)
cmap = cm.get_cmap(colormap) cmap = cm.get_cmap(colormap)
rgba = cmap(normalized) # (H, W, 4) float [0,1] rgba = cmap(normalized) # (H, W, 4) float [0,1]

View File

@@ -1,2 +1,7 @@
# Import all node modules to trigger @register_node decorators. # Import all node modules to trigger @register_node decorators.
from . import io, filters, modify, level, analysis, grains, mask, display from . import io, filters, modify, level, analysis, mask, display
try:
from . import particle
except ImportError:
from . import grains

View File

@@ -10,6 +10,7 @@ Gwyddion equivalents:
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from typing import Callable
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview from backend.data_types import DataField, datafield_to_uint8, encode_preview
@@ -562,3 +563,103 @@ class LineMath:
value = fn(z) value = fn(z)
table = [{"quantity": operation, "value": value, "unit": unit}] table = [{"quantity": operation, "value": value, "unit": unit}]
return (table,) return (table,)
# ---------------------------------------------------------------------------
# TableMath — scalar measurement from a numeric TABLE column
# ---------------------------------------------------------------------------
TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = {
"min": lambda values: float(np.min(values)),
"max": lambda values: float(np.max(values)),
"avg": lambda values: float(np.mean(values)),
"mean": lambda values: float(np.mean(values)),
"median": lambda values: float(np.median(values)),
"sum": lambda values: float(np.sum(values)),
"range": lambda values: float(np.max(values) - np.min(values)),
"std": lambda values: float(np.std(values)),
"variance": lambda values: float(np.var(values)),
"count": lambda values: float(len(values)),
}
@register_node(display_name="Table Math")
class TableMath:
"""Compute a scalar reduction over one numeric column in a TABLE."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"table": ("TABLE",),
"column": ("STRING", {"default": "value"}),
"operation": (list(TABLE_OPS.keys()),),
}
}
RETURN_TYPES = ("FLOAT",)
RETURN_NAMES = ("value",)
FUNCTION = "process"
CATEGORY = "analysis"
DESCRIPTION = (
"Compute a scalar reduction over one numeric TABLE column. "
"Useful for max, min, avg, median, sum, range, std, variance, and count."
)
def process(self, table: list, column: str, operation: str) -> tuple:
if not isinstance(table, list) or not table:
raise ValueError("Table Math requires a non-empty TABLE input.")
column_name = self._resolve_column_name(table, column)
values = self._extract_numeric_values(table, column_name)
if not values:
raise ValueError(f"Column '{column_name}' has no numeric values.")
op = TABLE_OPS.get(operation)
if op is None:
raise ValueError(f"Unsupported table operation: {operation}")
return (op(np.asarray(values, dtype=np.float64)),)
def _resolve_column_name(self, table: list, column: str) -> str:
requested = str(column or "").strip()
if requested:
return requested
if self._extract_numeric_values(table, "value"):
return "value"
numeric_columns = []
seen = set()
for row in table:
if not isinstance(row, dict):
continue
for key in row.keys():
if key in seen:
continue
seen.add(key)
if self._extract_numeric_values(table, key):
numeric_columns.append(key)
if len(numeric_columns) == 1:
return numeric_columns[0]
if not numeric_columns:
raise ValueError("Table Math could not find any numeric columns in the input table.")
raise ValueError(
"Table Math found multiple numeric columns; set the column name explicitly."
)
def _extract_numeric_values(self, table: list, column: str) -> list[float]:
values = []
for row in table:
if not isinstance(row, dict) or column not in row:
continue
value = row[column]
if isinstance(value, bool):
continue
try:
numeric = float(value)
except (TypeError, ValueError):
continue
if np.isfinite(numeric):
values.append(numeric)
return values

View File

@@ -9,7 +9,9 @@ before execution begins.
from __future__ import annotations from __future__ import annotations
import numpy as np import numpy as np
from backend.node_registry import register_node from backend.node_registry import register_node
from backend.data_types import DataField, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview from backend.data_types import (
DataField, COLORMAPS, datafield_to_uint8, image_to_uint8, encode_preview, normalize_for_colormap,
)
@register_node(display_name="Preview") @register_node(display_name="Preview")
@@ -113,10 +115,13 @@ class View3D:
# Normalize for colormap # Normalize for colormap
zmin, zmax = float(z.min()), float(z.max()) zmin, zmax = float(z.min()), float(z.max())
if zmax > zmin: z_norm = normalize_for_colormap(
z_norm = (z - zmin) / (zmax - zmin) z,
else: offset=field.display_offset,
z_norm = np.zeros_like(z) scale=field.display_scale,
data_min=float(field.data.min()),
data_max=float(field.data.max()),
)
cmap_name = field.colormap if colormap == "auto" else colormap cmap_name = field.colormap if colormap == "auto" else colormap
cmap = cm.get_cmap(cmap_name) cmap = cm.get_cmap(cmap_name)

View File

@@ -10,6 +10,39 @@ from backend.node_registry import register_node
from backend.data_types import DataField, datafield_to_uint8, encode_preview from backend.data_types import DataField, datafield_to_uint8, encode_preview
# ---------------------------------------------------------------------------
# ColormapAdjust
# ---------------------------------------------------------------------------
@register_node(display_name="Colormap Adjust")
class ColormapAdjust:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"offset": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}),
"scale": ("FLOAT", {"default": 1.0, "min": 0.05, "max": 4.0, "step": 0.01}),
"auto": ("BUTTON", {"label": "Auto", "set_widgets": {"offset": 0.0, "scale": 1.0}}),
}
}
RETURN_TYPES = ("DATA_FIELD",)
RETURN_NAMES = ("field",)
FUNCTION = "process"
CATEGORY = "modify"
DESCRIPTION = (
"Adjust how a DATA_FIELD maps into its colormap without changing the underlying data. "
"offset and scale operate in normalized display coordinates; Auto resets to the full data range."
)
def process(self, field: DataField, offset: float, scale: float) -> tuple:
scale = float(scale)
if not np.isfinite(scale) or scale <= 0.0:
raise ValueError("Scale must be a positive number.")
return (field.replace(display_offset=float(offset), display_scale=scale),)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# CropResizeField # CropResizeField
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -2,7 +2,7 @@
Particle detection nodes. Particle detection nodes.
Gwyddion equivalents: Gwyddion equivalents:
ParticleAnalysis gwy_data_field_grains_get_values (grains-values.c) ParticleAnalysis gwy_data_field_particles_get_values (particles-values.c)
""" """
from __future__ import annotations from __future__ import annotations
@@ -30,11 +30,11 @@ class ParticleAnalysis:
RETURN_TYPES = ("TABLE",) RETURN_TYPES = ("TABLE",)
RETURN_NAMES = ("particle_stats",) RETURN_NAMES = ("particle_stats",)
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "grains" CATEGORY = "particles"
DESCRIPTION = ( DESCRIPTION = (
"Label connected particle regions in a binary mask and compute per-particle " "Label connected particle regions in a binary mask and compute per-particle "
"statistics: area, equivalent diameter, mean/max height, bounding box. " "statistics: area, equivalent diameter, mean/max height, bounding box. "
"Equivalent to gwy_data_field_grains_get_values." "Equivalent to gwy_data_field_particles_get_values."
) )
def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple: def process(self, field: DataField, mask: np.ndarray, min_size: int) -> tuple:

View File

@@ -178,6 +178,7 @@ function serializeGraph(nodes, edges, { excludeManualTrigger = false } = {}) {
for (const [name, spec] of Object.entries(required)) { for (const [name, spec] of Object.entries(required)) {
const [type] = Array.isArray(spec) ? spec : [spec]; const [type] = Array.isArray(spec) ? spec : [spec];
if (DATA_TYPES.has(type)) continue; // socket, handled via edges if (DATA_TYPES.has(type)) continue; // socket, handled via edges
if (type === 'BUTTON') continue; // UI-only widget, not a backend input
if (widgetValues[name] !== undefined) { if (widgetValues[name] !== undefined) {
inputs[name] = widgetValues[name]; inputs[name] = widgetValues[name];
} }
@@ -604,6 +605,7 @@ function Flow() {
for (const [name, spec] of Object.entries(required)) { for (const [name, spec] of Object.entries(required)) {
const [type, opts] = Array.isArray(spec) ? spec : [spec, {}]; const [type, opts] = Array.isArray(spec) ? spec : [spec, {}];
if (DATA_TYPES.has(type)) continue; if (DATA_TYPES.has(type)) continue;
if (type === 'BUTTON') continue;
if (Array.isArray(type)) { if (Array.isArray(type)) {
widgetValues[name] = type[0]; // combo default = first option widgetValues[name] = type[0]; // combo default = first option
} else { } else {
@@ -1026,7 +1028,7 @@ function Flow() {
const cat = n.data?.definition?.category; const cat = n.data?.definition?.category;
const colors = { const colors = {
io: '#37474f', filters: '#1a237e', level: '#1b5e20', io: '#37474f', filters: '#1a237e', level: '#1b5e20',
analysis: '#4a148c', grains: '#bf360c', display: '#212121', analysis: '#4a148c', particles: '#bf360c', display: '#212121',
}; };
return colors[cat] || '#333'; return colors[cat] || '#333';
}} }}

View File

@@ -26,7 +26,7 @@ const CAT_COLORS = {
modify: '#0f766e', modify: '#0f766e',
level: '#1b5e20', level: '#1b5e20',
analysis: '#4a148c', analysis: '#4a148c',
grains: '#bf360c', particles:'#bf360c',
display: '#212121', display: '#212121',
}; };
@@ -171,6 +171,69 @@ function CollapsibleSection({ title, defaultOpen, children }) {
); );
} }
function getTableColumns(rows) {
const columns = [];
for (const row of rows) {
if (!row || typeof row !== 'object') continue;
for (const key of Object.keys(row)) {
if (!columns.includes(key)) columns.push(key);
}
}
return columns;
}
function formatTableCell(value) {
if (value == null) return '';
if (typeof value === 'number') {
if (!Number.isFinite(value)) return String(value);
const abs = Math.abs(value);
if (Number.isInteger(value) && abs < 1e6) return String(value);
if ((abs > 0 && abs < 1e-3) || abs >= 1e4) return value.toExponential(3);
return value.toFixed(4).replace(/\.?0+$/, '');
}
if (Array.isArray(value)) return value.join(', ');
return String(value);
}
function NodeTable({ rows }) {
const columns = getTableColumns(rows);
if (columns.length === 0) return null;
return (
<div className="node-table-wrap">
<div className="node-table-scroll">
<table className="node-table-grid">
<thead>
<tr>
{columns.map((column) => (
<th key={column} scope="col">{column}</th>
))}
</tr>
</thead>
<tbody>
{rows.map((row, rowIndex) => (
<tr key={row.id ?? row.quantity ?? rowIndex}>
{columns.map((column) => {
const value = row?.[column];
return (
<td
key={`${rowIndex}-${column}`}
className={typeof value === 'number' ? 'node-table-num' : ''}
title={formatTableCell(value)}
>
{formatTableCell(value)}
</td>
);
})}
</tr>
))}
</tbody>
</table>
</div>
</div>
);
}
// ── CustomNode component ────────────────────────────────────────────── // ── CustomNode component ──────────────────────────────────────────────
function CustomNode({ id, data }) { function CustomNode({ id, data }) {
@@ -411,21 +474,7 @@ function CustomNode({ id, data }) {
{/* Collapsible table data */} {/* Collapsible table data */}
{data.tableRows && data.tableRows.length > 0 && ( {data.tableRows && data.tableRows.length > 0 && (
<CollapsibleSection title="Table" defaultOpen={true}> <CollapsibleSection title="Table" defaultOpen={true}>
<div className="node-table"> <NodeTable rows={data.tableRows} />
{data.tableRows.map((row, i) => {
let line;
if (row.quantity !== undefined) {
const val = typeof row.value === 'number' ? row.value.toExponential(3) : row.value;
line = `${row.quantity}: ${val} ${row.unit || ''}`;
} else {
line = Object.entries(row)
.slice(0, 3)
.map(([k, v]) => `${k}: ${typeof v === 'number' ? v.toExponential(2) : v}`)
.join(' ');
}
return <div key={i} className="table-line">{line}</div>;
})}
</div>
</CollapsibleSection> </CollapsibleSection>
)} )}
</div> </div>
@@ -480,6 +529,26 @@ function WidgetControl({ widget, nodeId, value, widgetValues, onChange, openFile
); );
} }
if (type === 'BUTTON') {
const updates = opts?.set_widgets && typeof opts.set_widgets === 'object'
? Object.entries(opts.set_widgets)
: [];
return (
<button
className="nodrag widget-button"
type="button"
onClick={() => {
for (const [targetName, targetValue] of updates) {
onChange(nodeId, targetName, targetValue);
}
}}
>
{opts?.label || name}
</button>
);
}
if (type === 'FLOAT') { if (type === 'FLOAT') {
if (opts?.slider) { if (opts?.slider) {
const rawMin = opts?.min_widget ? widgetValues?.[opts.min_widget] : opts?.min; const rawMin = opts?.min_widget ? widgetValues?.[opts.min_widget] : opts?.min;

View File

@@ -227,6 +227,23 @@ html, body, #root {
accent-color: #3a7abf; accent-color: #3a7abf;
} }
.widget-button {
flex: 1;
min-width: 0;
background: #0f3460;
color: #e0e0e0;
border: 1px solid #334155;
border-radius: 3px;
padding: 4px 8px;
font-size: 11px;
cursor: pointer;
}
.widget-button:hover {
background: #1a4a8a;
border-color: #3a7abf;
}
.slider-control { .slider-control {
display: flex; display: flex;
align-items: center; align-items: center;
@@ -496,18 +513,56 @@ html, body, #root {
} }
/* ── Node table ────────────────────────────────────────────────────── */ /* ── Node table ────────────────────────────────────────────────────── */
.node-table { .node-table-wrap {
padding: 4px 10px; padding: 4px 10px 8px;
}
.node-table-scroll {
max-height: 220px;
overflow: auto;
border: 1px solid #334155;
border-radius: 6px;
background: #0f172a;
}
.node-table-grid {
width: 100%;
border-collapse: collapse;
font-family: "SF Mono", "Fira Code", monospace; font-family: "SF Mono", "Fira Code", monospace;
font-size: 10px; font-size: 10px;
color: #cbd5e1; color: #cbd5e1;
} }
.table-line { .node-table-grid th,
.node-table-grid td {
padding: 6px 8px;
border-bottom: 1px solid rgba(51, 65, 85, 0.75);
white-space: nowrap; white-space: nowrap;
overflow: hidden; text-align: left;
text-overflow: ellipsis; vertical-align: top;
line-height: 1.5; }
.node-table-grid thead th {
position: sticky;
top: 0;
z-index: 1;
background: #16213e;
color: #94a3b8;
font-size: 9px;
letter-spacing: 0.04em;
text-transform: uppercase;
}
.node-table-grid tbody tr:nth-child(even) {
background: rgba(30, 41, 59, 0.38);
}
.node-table-grid tbody tr:last-child td {
border-bottom: none;
}
.node-table-num {
text-align: right !important;
} }
/* ── Node resize handles ───────────────────────────────────────────── */ /* ── Node resize handles ───────────────────────────────────────────── */

View File

@@ -1,12 +1,12 @@
""" """
Thorough tests for the grain/particle analysis pipeline: Thorough tests for the particles/particle analysis pipeline:
ThresholdMask → GrainAnalysis ThresholdMask → GrainAnalysis
Covers synthetic geometry (known answers), the demo nanoparticles image, Covers synthetic geometry (known answers), the demo nanoparticles image,
edge cases, and physical-unit correctness. edge cases, and physical-unit correctness.
Run from project root: Run from project root:
.venv/bin/python -m tests.test_grains .venv/bin/python -m tests.test_particles
""" """
import sys import sys
@@ -28,7 +28,7 @@ def make_field(data, xreal=1e-6, yreal=1e-6):
def test_threshold_otsu_bimodal(): def test_threshold_otsu_bimodal():
"""Otsu on a clean bimodal image should separate the two populations.""" """Otsu on a clean bimodal image should separate the two populations."""
print("=== Test: Otsu on bimodal image ===") print("=== Test: Otsu on bimodal image ===")
from backend.nodes.grains import ThresholdMask from backend.nodes.particle import ThresholdMask
node = ThresholdMask() node = ThresholdMask()
data = np.zeros((128, 128)) data = np.zeros((128, 128))
@@ -50,7 +50,7 @@ def test_threshold_otsu_bimodal():
def test_threshold_relative_range(): def test_threshold_relative_range():
"""Relative threshold at 0.5 should be the midpoint of [min, max].""" """Relative threshold at 0.5 should be the midpoint of [min, max]."""
print("=== Test: Relative threshold at midpoint ===") print("=== Test: Relative threshold at midpoint ===")
from backend.nodes.grains import ThresholdMask from backend.nodes.particle import ThresholdMask
node = ThresholdMask() node = ThresholdMask()
data = np.full((64, 64), 2.0) data = np.full((64, 64), 2.0)
@@ -68,7 +68,7 @@ def test_threshold_relative_range():
def test_threshold_empty_mask(): def test_threshold_empty_mask():
"""Very high absolute threshold on low data should produce an empty mask.""" """Very high absolute threshold on low data should produce an empty mask."""
print("=== Test: Empty mask from high threshold ===") print("=== Test: Empty mask from high threshold ===")
from backend.nodes.grains import ThresholdMask from backend.nodes.particle import ThresholdMask
node = ThresholdMask() node = ThresholdMask()
data = np.ones((64, 64)) data = np.ones((64, 64))
@@ -82,7 +82,7 @@ def test_threshold_empty_mask():
def test_threshold_full_mask(): def test_threshold_full_mask():
"""Very low absolute threshold should produce an all-white mask.""" """Very low absolute threshold should produce an all-white mask."""
print("=== Test: Full mask from low threshold ===") print("=== Test: Full mask from low threshold ===")
from backend.nodes.grains import ThresholdMask from backend.nodes.particle import ThresholdMask
node = ThresholdMask() node = ThresholdMask()
data = np.ones((64, 64)) * 5.0 data = np.ones((64, 64)) * 5.0
@@ -100,7 +100,7 @@ def test_threshold_full_mask():
def test_single_circle_area(): def test_single_circle_area():
"""A single filled circle — verify pixel count and physical area.""" """A single filled circle — verify pixel count and physical area."""
print("=== Test: Single circle area ===") print("=== Test: Single circle area ===")
from backend.nodes.grains import GrainAnalysis from backend.nodes.particle import GrainAnalysis
node = GrainAnalysis() node = GrainAnalysis()
N = 200 N = 200
@@ -118,35 +118,35 @@ def test_single_circle_area():
field = make_field(data, xreal=XREAL, yreal=XREAL) field = make_field(data, xreal=XREAL, yreal=XREAL)
table, = node.process(field, mask=mask, min_size=1) table, = node.process(field, mask=mask, min_size=1)
assert len(table) == 1, f"Expected 1 grain, got {len(table)}" assert len(table) == 1, f"Expected 1 particles, got {len(table)}"
grain = table[0] particles = table[0]
# Pixel area of a discrete circle: should be close to π r² # Pixel area of a discrete circle: should be close to π r²
expected_px = np.pi * r ** 2 expected_px = np.pi * r ** 2
assert abs(grain["area_px"] - expected_px) / expected_px < 0.02, \ assert abs(particles["area_px"] - expected_px) / expected_px < 0.02, \
f"area_px={grain['area_px']}, expected≈{expected_px:.0f}" f"area_px={particles['area_px']}, expected≈{expected_px:.0f}"
# Physical area # Physical area
pixel_area = (XREAL / N) ** 2 pixel_area = (XREAL / N) ** 2
expected_m2 = grain["area_px"] * pixel_area expected_m2 = particles["area_px"] * pixel_area
assert abs(grain["area_m2"] - expected_m2) < 1e-20, \ assert abs(particles["area_m2"] - expected_m2) < 1e-20, \
f"area_m2 mismatch: {grain['area_m2']} vs {expected_m2}" f"area_m2 mismatch: {particles['area_m2']} vs {expected_m2}"
# Equivalent diameter should be close to 2r in physical units # Equivalent diameter should be close to 2r in physical units
expected_diam = 2 * r * (XREAL / N) expected_diam = 2 * r * (XREAL / N)
assert abs(grain["equiv_diam_m"] - expected_diam) / expected_diam < 0.02, \ assert abs(particles["equiv_diam_m"] - expected_diam) / expected_diam < 0.02, \
f"equiv_diam={grain['equiv_diam_m']:.3e}, expected≈{expected_diam:.3e}" f"equiv_diam={particles['equiv_diam_m']:.3e}, expected≈{expected_diam:.3e}"
# Heights # Heights
assert abs(grain["mean_height"] - 5.0) < 1e-10 assert abs(particles["mean_height"] - 5.0) < 1e-10
assert abs(grain["max_height"] - 5.0) < 1e-10 assert abs(particles["max_height"] - 5.0) < 1e-10
print(" PASS\n") print(" PASS\n")
def test_multiple_grains_separation(): def test_multiple_particles_separation():
"""Three well-separated grains of different sizes — check each is reported.""" """Three well-separated particles of different sizes — check each is reported."""
print("=== Test: Multiple grain separation ===") print("=== Test: Multiple particles separation ===")
from backend.nodes.grains import GrainAnalysis from backend.nodes.particle import GrainAnalysis
node = GrainAnalysis() node = GrainAnalysis()
N = 128 N = 128
@@ -168,7 +168,7 @@ def test_multiple_grains_separation():
field = make_field(data) field = make_field(data)
table, = node.process(field, mask=mask, min_size=1) table, = node.process(field, mask=mask, min_size=1)
assert len(table) == 3, f"Expected 3 grains, got {len(table)}" assert len(table) == 3, f"Expected 3 particles, got {len(table)}"
table.sort(key=lambda r: r["area_px"], reverse=True) table.sort(key=lambda r: r["area_px"], reverse=True)
assert table[0]["area_px"] == 400 # 20×20 assert table[0]["area_px"] == 400 # 20×20
@@ -182,24 +182,24 @@ def test_multiple_grains_separation():
def test_min_size_filtering(): def test_min_size_filtering():
"""min_size should exclude grains smaller than the threshold.""" """min_size should exclude particles smaller than the threshold."""
print("=== Test: min_size filtering ===") print("=== Test: min_size filtering ===")
from backend.nodes.grains import GrainAnalysis from backend.nodes.particle import GrainAnalysis
node = GrainAnalysis() node = GrainAnalysis()
N = 64 N = 64
data = np.zeros((N, N)) data = np.zeros((N, N))
mask = np.zeros((N, N), dtype=np.uint8) mask = np.zeros((N, N), dtype=np.uint8)
# Large grain: 15×15 = 225 px # Large particles: 15×15 = 225 px
data[5:20, 5:20] = 1.0 data[5:20, 5:20] = 1.0
mask[5:20, 5:20] = 255 mask[5:20, 5:20] = 255
# Medium grain: 8×8 = 64 px # Medium particles: 8×8 = 64 px
data[30:38, 30:38] = 1.0 data[30:38, 30:38] = 1.0
mask[30:38, 30:38] = 255 mask[30:38, 30:38] = 255
# Tiny grain: 3×3 = 9 px # Tiny particles: 3×3 = 9 px
data[50:53, 50:53] = 1.0 data[50:53, 50:53] = 1.0
mask[50:53, 50:53] = 255 mask[50:53, 50:53] = 255
@@ -224,16 +224,16 @@ def test_min_size_filtering():
print(" PASS\n") print(" PASS\n")
def test_grain_bounding_box(): def test_particles_bounding_box():
"""Bounding box should match the grain extents.""" """Bounding box should match the particles extents."""
print("=== Test: Grain bounding box ===") print("=== Test: Grain bounding box ===")
from backend.nodes.grains import GrainAnalysis from backend.nodes.particle import GrainAnalysis
node = GrainAnalysis() node = GrainAnalysis()
N = 64 N = 64
data = np.zeros((N, N)) data = np.zeros((N, N))
mask = np.zeros((N, N), dtype=np.uint8) mask = np.zeros((N, N), dtype=np.uint8)
# Place a grain at rows 20:35, cols 10:45 # Place a particles at rows 20:35, cols 10:45
data[20:35, 10:45] = 2.0 data[20:35, 10:45] = 2.0
mask[20:35, 10:45] = 255 mask[20:35, 10:45] = 255
@@ -247,10 +247,10 @@ def test_grain_bounding_box():
print(" PASS\n") print(" PASS\n")
def test_empty_mask_produces_no_grains(): def test_empty_mask_produces_no_particles():
"""An all-zero mask should yield zero grains.""" """An all-zero mask should yield zero particles."""
print("=== Test: Empty mask → no grains ===") print("=== Test: Empty mask → no particles ===")
from backend.nodes.grains import GrainAnalysis from backend.nodes.particle import GrainAnalysis
node = GrainAnalysis() node = GrainAnalysis()
field = make_field(np.ones((64, 64))) field = make_field(np.ones((64, 64)))
@@ -261,10 +261,10 @@ def test_empty_mask_produces_no_grains():
print(" PASS\n") print(" PASS\n")
def test_grain_at_image_edge(): def test_particles_at_image_edge():
"""A grain touching the image border should still be detected.""" """A particles touching the image border should still be detected."""
print("=== Test: Grain at image edge ===") print("=== Test: Grain at image edge ===")
from backend.nodes.grains import GrainAnalysis from backend.nodes.particle import GrainAnalysis
node = GrainAnalysis() node = GrainAnalysis()
N = 64 N = 64
@@ -282,11 +282,11 @@ def test_grain_at_image_edge():
print(" PASS\n") print(" PASS\n")
def test_adjacent_grains_connectivity(): def test_adjacent_particles_connectivity():
"""Two diagonally-touching blocks should be separate grains """Two diagonally-touching blocks should be separate particles
(scipy.ndimage.label uses 4-connectivity by default).""" (scipy.ndimage.label uses 4-connectivity by default)."""
print("=== Test: Diagonal adjacency → separate grains ===") print("=== Test: Diagonal adjacency → separate particles ===")
from backend.nodes.grains import GrainAnalysis from backend.nodes.particle import GrainAnalysis
node = GrainAnalysis() node = GrainAnalysis()
N = 32 N = 32
@@ -305,7 +305,7 @@ def test_adjacent_grains_connectivity():
table, = node.process(field, mask=mask, min_size=1) table, = node.process(field, mask=mask, min_size=1)
# Default label() uses structure that connects diagonals? Let's verify. # Default label() uses structure that connects diagonals? Let's verify.
# scipy.ndimage.label default is cross-shaped (no diagonals) for 2D # scipy.ndimage.label default is cross-shaped (no diagonals) for 2D
assert len(table) == 2, f"Expected 2 separate grains, got {len(table)}" assert len(table) == 2, f"Expected 2 separate particles, got {len(table)}"
print(" PASS\n") print(" PASS\n")
@@ -316,7 +316,7 @@ def test_adjacent_grains_connectivity():
def test_pipeline_synthetic(): def test_pipeline_synthetic():
"""Full pipeline on a synthetic image with known geometry.""" """Full pipeline on a synthetic image with known geometry."""
print("=== Test: Full pipeline on synthetic particles ===") print("=== Test: Full pipeline on synthetic particles ===")
from backend.nodes.grains import ThresholdMask, GrainAnalysis from backend.nodes.particle import ThresholdMask, GrainAnalysis
N = 200 N = 200
XREAL = 10e-6 # 10 µm XREAL = 10e-6 # 10 µm
@@ -349,20 +349,20 @@ def test_pipeline_synthetic():
# Particles are well above noise, so mask should capture all 5 # Particles are well above noise, so mask should capture all 5
assert mask.max() == 255, "No particles detected" assert mask.max() == 255, "No particles detected"
# Step 2: grain analysis # Step 2: particles analysis
ga = GrainAnalysis() ga = GrainAnalysis()
table, = ga.process(field, mask=mask, min_size=5) table, = ga.process(field, mask=mask, min_size=5)
assert len(table) == 5, f"Expected 5 grains, got {len(table)}" assert len(table) == 5, f"Expected 5 particles, got {len(table)}"
# Verify that detected areas are in the right ballpark # Verify that detected areas are in the right ballpark
table.sort(key=lambda r: r["area_px"], reverse=True) table.sort(key=lambda r: r["area_px"], reverse=True)
expected_areas = sorted([np.pi * r ** 2 for _, _, r, _ in specs], reverse=True) expected_areas = sorted([np.pi * r ** 2 for _, _, r, _ in specs], reverse=True)
for grain, expected_px in zip(table, expected_areas): for particles, expected_px in zip(table, expected_areas):
ratio = grain["area_px"] / expected_px ratio = particles["area_px"] / expected_px
assert 0.85 < ratio < 1.15, \ assert 0.85 < ratio < 1.15, \
f"grain area_px={grain['area_px']}, expected≈{expected_px:.0f}, ratio={ratio:.2f}" f"particles area_px={particles['area_px']}, expected≈{expected_px:.0f}, ratio={ratio:.2f}"
print(" PASS\n") print(" PASS\n")
@@ -371,7 +371,7 @@ def test_pipeline_demo_image():
"""Run the full pipeline on the bundled demo nanoparticles image.""" """Run the full pipeline on the bundled demo nanoparticles image."""
print("=== Test: Full pipeline on demo nanoparticles.npy ===") print("=== Test: Full pipeline on demo nanoparticles.npy ===")
from pathlib import Path from pathlib import Path
from backend.nodes.grains import ThresholdMask, GrainAnalysis from backend.nodes.particle import ThresholdMask, GrainAnalysis
from backend.runtime_paths import demo_dir from backend.runtime_paths import demo_dir
npy_path = demo_dir() / "nanoparticles.npy" npy_path = demo_dir() / "nanoparticles.npy"
@@ -398,16 +398,16 @@ def test_pipeline_demo_image():
ga = GrainAnalysis() ga = GrainAnalysis()
table, = ga.process(field, mask=mask, min_size=20) table, = ga.process(field, mask=mask, min_size=20)
assert len(table) > 0, "No grains detected" assert len(table) > 0, "No particles detected"
print(f" Found {len(table)} grains (min_size=20)") print(f" Found {len(table)} particles (min_size=20)")
# Sanity checks on grain properties # Sanity checks on particles properties
for grain in table: for particles in table:
assert grain["area_px"] >= 20 assert particles["area_px"] >= 20
assert grain["area_m2"] > 0 assert particles["area_m2"] > 0
assert grain["equiv_diam_m"] > 0 assert particles["equiv_diam_m"] > 0
assert grain["max_height"] >= grain["mean_height"] assert particles["max_height"] >= particles["mean_height"]
assert grain["mean_height"] > 0 assert particles["mean_height"] > 0
# Physical size sanity: equivalent diameters should be in the nmµm range # Physical size sanity: equivalent diameters should be in the nmµm range
diams_nm = [g["equiv_diam_m"] * 1e9 for g in table] diams_nm = [g["equiv_diam_m"] * 1e9 for g in table]
@@ -431,15 +431,15 @@ if __name__ == "__main__":
# GrainAnalysis # GrainAnalysis
test_single_circle_area() test_single_circle_area()
test_multiple_grains_separation() test_multiple_particles_separation()
test_min_size_filtering() test_min_size_filtering()
test_grain_bounding_box() test_particles_bounding_box()
test_empty_mask_produces_no_grains() test_empty_mask_produces_no_particles()
test_grain_at_image_edge() test_particles_at_image_edge()
test_adjacent_grains_connectivity() test_adjacent_particles_connectivity()
# End-to-end pipeline # End-to-end pipeline
test_pipeline_synthetic() test_pipeline_synthetic()
test_pipeline_demo_image() test_pipeline_demo_image()
print("All grain tests passed!") print("All particles tests passed!")

View File

@@ -10,7 +10,7 @@ import tempfile
import numpy as np import numpy as np
sys.path.insert(0, ".") sys.path.insert(0, ".")
from backend.data_types import DataField from backend.data_types import DataField, datafield_to_uint8
def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6): def make_field(data=None, shape=(64, 64), xreal=1e-6, yreal=1e-6):
@@ -223,6 +223,47 @@ def test_rotate_field():
print(" PASS\n") print(" PASS\n")
def test_colormap_adjust():
print("=== Test: ColormapAdjust ===")
from backend.nodes.modify import ColormapAdjust
node = ColormapAdjust()
field = DataField(
data=np.array([[0.0, 0.25, 0.5, 0.75, 1.0]], dtype=np.float64),
xreal=5.0,
yreal=1.0,
colormap="gray",
)
adjusted, = node.process(field, offset=0.25, scale=0.5)
assert np.array_equal(adjusted.data, field.data)
assert adjusted.display_offset == 0.25
assert adjusted.display_scale == 0.5
assert adjusted.colormap == field.colormap
rgb = datafield_to_uint8(adjusted, "gray")
intensities = rgb[0, :, 0]
assert intensities[0] == 0
assert intensities[1] == 0
assert 110 <= intensities[2] <= 145
assert intensities[3] == 255
assert intensities[4] == 255
auto_like, = node.process(field, offset=0.0, scale=1.0)
auto_rgb = datafield_to_uint8(auto_like, "gray")
auto_intensities = auto_rgb[0, :, 0]
assert auto_intensities[0] == 0
assert auto_intensities[-1] == 255
try:
node.process(field, offset=0.0, scale=0.0)
raise AssertionError("Expected non-positive scale to raise ValueError")
except ValueError:
pass
print(" PASS\n")
def test_edge_detect(): def test_edge_detect():
print("=== Test: EdgeDetect ===") print("=== Test: EdgeDetect ===")
from backend.nodes.filters import EdgeDetect from backend.nodes.filters import EdgeDetect
@@ -1263,6 +1304,59 @@ def test_line_math():
print(" PASS\n") print(" PASS\n")
# =========================================================================
# Analysis — TableMath
# =========================================================================
def test_table_math():
print("=== Test: TableMath ===")
from backend.nodes.analysis import TableMath
node = TableMath()
table = [
{"label": "a", "value": 1.0, "other": 10},
{"label": "b", "value": 5.0, "other": 20},
{"label": "c", "value": "3.0", "other": 30},
{"label": "d", "value": "bad", "other": 40},
]
result, = node.process(table, column="value", operation="max")
assert result == 5.0
result, = node.process(table, column="value", operation="min")
assert result == 1.0
result, = node.process(table, column="value", operation="avg")
assert np.isclose(result, 3.0)
result, = node.process(table, column="value", operation="median")
assert np.isclose(result, 3.0)
result, = node.process(table, column="other", operation="sum")
assert result == 100.0
result, = node.process(table, column="other", operation="count")
assert result == 4.0
# Blank column name should fall back to the common "value" column.
result, = node.process(table, column="", operation="range")
assert result == 4.0
try:
node.process(table, column="missing", operation="max")
raise AssertionError("Expected missing numeric column to raise ValueError")
except ValueError:
pass
try:
node.process([{"label": "only text"}], column="label", operation="max")
raise AssertionError("Expected non-numeric column to raise ValueError")
except ValueError:
pass
print(" PASS\n")
# ========================================================================= # =========================================================================
# Display — View3D # Display — View3D
# ========================================================================= # =========================================================================
@@ -1322,6 +1416,7 @@ if __name__ == "__main__":
test_median_filter() test_median_filter()
test_crop_resize_field() test_crop_resize_field()
test_rotate_field() test_rotate_field()
test_colormap_adjust()
test_edge_detect() test_edge_detect()
test_fft_filter_1d() test_fft_filter_1d()
test_fft_filter_2d() test_fft_filter_2d()
@@ -1338,6 +1433,7 @@ if __name__ == "__main__":
test_line_cursors() test_line_cursors()
test_fft2d() test_fft2d()
test_line_math() test_line_math()
test_table_math()
# Mask # Mask
test_threshold_mask() test_threshold_mask()