diff --git a/GWYDDION_FEATURE_GAP.md b/GWYDDION_FEATURE_GAP.md index e56e29b..965d3db 100644 --- a/GWYDDION_FEATURE_GAP.md +++ b/GWYDDION_FEATURE_GAP.md @@ -10,17 +10,17 @@ Reference for future implementation. Grouped by value to typical SPM workflows. |---|---------|---------------|-------------| | ~~1~~ | ~~Line Correction~~ | ~~linecorrect.c, linematch.c~~ | ~~Row-by-row median/polynomial alignment. Essential for raw SPM data with scan-line artifacts.~~ **DONE** | | ~~2~~ | ~~Scar Removal~~ | ~~scars.c~~ | ~~Detect and interpolate scan-line defects (horizontal streaks).~~ **DONE** | -| 3 | Facet Leveling | facet-level.c | Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces. | +| ~~3~~ | ~~Facet Leveling~~ | ~~facet-level.c~~ | ~~Orient the dominant surface facet to horizontal. Better than plane level for terraced/stepped surfaces.~~ **DONE** | | ~~4~~ | ~~Morphological Mask Ops~~ | ~~mask_morph.c~~ | ~~Erode, dilate, open, close on grain masks. Needed to clean up thresholded masks.~~ **DONE** | | ~~5~~ | ~~1D FFT Filter~~ | ~~fft_filter_1d.c~~ | ~~Bandpass/lowpass/highpass filtering of LINE profiles.~~ **DONE** | | ~~6~~ | ~~2D FFT Filter~~ | ~~fft_filter_2d.c~~ | ~~Frequency-domain filtering of DATA_FIELDs (remove periodic noise, etc.).~~ **DONE** | | ~~7~~ | ~~Autocorrelation (ACF)~~ | ~~acf2d.c~~ | ~~2D autocorrelation function. Reveals periodic structures and correlation lengths.~~ **DONE** | | ~~8~~ | ~~PSDF~~ | ~~psdf2d.c~~ | ~~Radial/2D power spectral density function. Complementary to ACF for roughness characterization.~~ **DONE** | -| 9 | Fractal Dimension | fractal.c | Multiple methods: partitioning, cube counting, triangulation, PSDF, HHCF. Quantifies surface complexity. | -| 10 | Curvature | curvature.c | Local mean/Gaussian curvature maps. Useful for feature identification. | -| 11 | Grain Distance Transform | mask_edt.c | Euclidean distance from grain boundaries. Useful for spatial distribution analysis. | -| 12 | Watershed Segmentation | grain_wshed.c | Automatic grain detection without manual threshold. More robust than simple thresholding. | -| 13 | Rotate / Flip | rotate.c, basicops.c | Basic geometric transforms (90°, arbitrary angle, mirror). | +| ~~9~~ | ~~Fractal Dimension~~ | ~~fractal.c~~ | ~~Multiple methods: partitioning, cube counting, triangulation, PSDF, HHCF. Quantifies surface complexity.~~ **DONE** | +| ~~10~~ | ~~Curvature~~ | ~~curvature.c~~ | ~~Quadratic-surface curvature fit with principal radii/directions. Useful for apex and dome characterization.~~ **DONE** | +| ~~11~~ | ~~Grain Distance Transform~~ | ~~mask_edt.c~~ | ~~Euclidean distance from grain boundaries. Useful for spatial distribution analysis.~~ **DONE** | +| ~~12~~ | ~~Watershed Segmentation~~ | ~~grain_wshed.c~~ | ~~Automatic grain detection without manual threshold. More robust than simple thresholding.~~ **DONE** | +| ~~13~~ | ~~Rotate / Flip~~ | ~~rotate.c, basicops.c~~ | ~~Basic geometric transforms (90°, arbitrary angle, mirror).~~ **DONE** | | ~~14~~ | ~~Crop~~ | ~~crop.c~~ | ~~Extract sub-region of a field.~~ **DONE** | ## Medium Value @@ -70,7 +70,10 @@ For reference, these Gwyddion equivalents are already covered: | Load Image / Load SPM File | io | File import (gwy, sxm, ibw) | | Save Image | io | File export | | Coordinate | io | — | +| Rotate Field | modify | rotate.c | +| Flip Field | modify | basicops.c | | Plane Level | level | level.c | +| Facet Level | level | facet-level.c | | Polynomial Level | level | polylevel.c | | Fix Zero | level | level.c (fix_zero) | | Line Correction | level | linecorrect.c, linematch.c | @@ -81,6 +84,8 @@ For reference, these Gwyddion equivalents are already covered: | 2D FFT Filter | filters | fft_filter_2d.c (lowpass, highpass, bandpass, notch) | | Scar Removal | filters | scars.c | | Statistics | analysis | stats.c | +| Curvature | analysis | curvature.c | +| Fractal Dimension | analysis | fractal.c | | Height Histogram | analysis | linestats.c (dh) | | 2D FFT | analysis | fft.c | | Cross Section | analysis | profile tool | @@ -90,5 +95,7 @@ For reference, these Gwyddion equivalents are already covered: | Mask Morphology | mask | mask_morph.c (erode, dilate, open, close) | | Mask Invert | mask | — | | Mask Combine | mask | — (boolean AND, OR, XOR, subtract) | +| Grain Distance Transform | mask | mask_edt.c | +| Watershed Segmentation | particles | grain_wshed.c | | Particle Analysis | particles | grain_stat.c | | Preview / 3D View / Print Table | display | Presentation, 3D view | diff --git a/backend/node_menu.py b/backend/node_menu.py index 631a112..e718e73 100644 --- a/backend/node_menu.py +++ b/backend/node_menu.py @@ -57,13 +57,16 @@ MENU_LAYOUT: dict[str, list[str]] = { ], "Flatten": [ "PlaneLevelField", + "FacetLevelField", "PolyLevelField", "FixZero", "LineCorrection", ], "Measure": [ "CrossSection", + "Curvature", "Histogram", + "FractalDimension", "ACF", "Cursors", "Statistics", @@ -75,8 +78,10 @@ MENU_LAYOUT: dict[str, list[str]] = { "MaskMorphology", "MaskInvert", "MaskCombine", + "GrainDistanceTransform", ], "Particles": [ + "WatershedSegmentation", "ParticleAnalysis", ], } diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index de5f40d..1e1d1a7 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -23,6 +23,7 @@ from backend.nodes import ( flip_field, # Level plane_level_field, + facet_level_field, poly_level_field, fix_zero, line_correction, @@ -32,6 +33,7 @@ from backend.nodes import ( mask_morphology, mask_invert, mask_combine, + grain_distance_transform, # Correction scar_removal, # Display @@ -45,6 +47,8 @@ from backend.nodes import ( print_table, value_display, # Analysis + curvature, + fractal_dimension, statistics_node, histogram, acf, @@ -54,6 +58,7 @@ from backend.nodes import ( inverse_fft_2d, cross_section, stats, + watershed_segmentation, ) try: diff --git a/backend/nodes/curvature.py b/backend/nodes/curvature.py new file mode 100644 index 0000000..3c5aecf --- /dev/null +++ b/backend/nodes/curvature.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +from scipy.ndimage import map_coordinates + +from backend.data_types import ( + DataField, + LineData, + MeasureTable, + _apply_markup_overlay, + encode_preview, + render_datafield_preview, +) +from backend.execution_context import emit_preview, emit_table, emit_warning +from backend.node_registry import register_node +from backend.nodes.surface_common import require_compatible_xy_z_units + +_CURVATURE_COLOR = "#ff9800" +_CENTER_COLOR = "#8bd3ff" + + +@dataclass(frozen=True) +class _Intersection: + t: float + x: float + y: float + + +def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None: + if mask is None: + return None + + mask_array = np.asarray(mask) + if mask_array.shape[:2] != shape: + raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.") + return mask_array > 127 + + +def _canonicalize_half_pi(angle: float) -> float: + wrapped = (float(angle) + 0.5 * np.pi) % np.pi - 0.5 * np.pi + if wrapped <= -0.5 * np.pi + 1e-15: + wrapped += np.pi + return float(wrapped) + + +def _fit_quadratic_surface(data: np.ndarray, mask: np.ndarray | None, masking: str) -> np.ndarray | None: + yres, xres = data.shape + yy, xx = np.mgrid[0:yres, 0:xres] + + x = 2.0 * xx.astype(np.float64) / max(xres - 1, 1) - 1.0 + y = 2.0 * yy.astype(np.float64) / max(yres - 1, 1) - 1.0 + + valid = np.ones(data.shape, dtype=bool) + if mask is not None and masking != "ignore": + valid = mask if masking == "include" else ~mask + + if np.count_nonzero(valid) < 6: + return None + + design = np.column_stack([ + np.ones(int(np.count_nonzero(valid)), dtype=np.float64), + x[valid], + x[valid] ** 2, + y[valid], + x[valid] * y[valid], + y[valid] ** 2, + ]) + coeffs, _, _, _ = np.linalg.lstsq(design, np.asarray(data, dtype=np.float64)[valid], rcond=None) + return np.asarray(coeffs, dtype=np.float64) + + +def _curvature_at_apex(coeffs: np.ndarray) -> tuple[int, float, float, float, float, float, float, float]: + a, bx, by, cxx, cxy, cyy = [float(value) for value in coeffs] + + if abs(cxx) + abs(cxy) + abs(cyy) <= 1e-14 * (abs(bx) + abs(by)): + return 0, 0.0, 0.0, 0.0, float(0.5 * np.pi), 0.0, 0.0, a + + cm = cxx - cyy + cp = cxx + cyy + phi = 0.5 * float(np.arctan2(cxy, cm)) + radius = float(np.hypot(cm, cxy)) + cx = cp + radius + cy = cp - radius + cos_phi = float(np.cos(phi)) + sin_phi = float(np.sin(phi)) + bx1 = bx * cos_phi + by * sin_phi + by1 = -bx * sin_phi + by * cos_phi + + if abs(cx) < 1e-14 * abs(cy): + xc = 0.0 + yc = -by1 / cy + degree = 1 + elif abs(cy) < 1e-14 * abs(cx): + xc = -bx1 / cx + yc = 0.0 + degree = 1 + else: + xc = -bx1 / cx + yc = -by1 / cy + degree = 2 + + x_center = xc * cos_phi - yc * sin_phi + y_center = xc * sin_phi + yc * cos_phi + z_center = a + xc * bx1 + yc * by1 + xc * xc * cx + yc * yc * cy + + if cx > cy: + cx, cy = cy, cx + phi += 0.5 * np.pi + phi = -phi + + phi1 = _canonicalize_half_pi(phi) + phi2 = _canonicalize_half_pi(phi + 0.5 * np.pi) + return degree, float(cx), float(cy), phi1, phi2, float(x_center), float(y_center), float(z_center) + + +def _compute_curvature_results( + field: DataField, + mask: np.ndarray | None, + masking: str, +) -> dict[str, float] | None: + coeffs = _fit_quadratic_surface(np.asarray(field.data, dtype=np.float64), mask, masking) + if coeffs is None: + return None + + xres = field.xres + yres = field.yres + xreal = float(field.xreal) + yreal = float(field.yreal) + qx = 2.0 / xreal * xres / max(xres - 1.0, 1.0) + qy = 2.0 / yreal * yres / max(yres - 1.0, 1.0) + q = float(np.sqrt(qx * qy)) + mx = float(np.sqrt(qx / qy)) + my = float(np.sqrt(qy / qx)) + + ccoeffs = np.array([ + coeffs[0], + mx * coeffs[1], + my * coeffs[3], + mx * mx * coeffs[2], + coeffs[4], + my * my * coeffs[5], + ], dtype=np.float64) + degree, kappa1, kappa2, phi1, phi2, xc, yc, zc = _curvature_at_apex(ccoeffs) + x_norm = xc * mx + y_norm = yc * my + zc = float( + coeffs[0] + + coeffs[1] * x_norm + + coeffs[2] * x_norm * x_norm + + coeffs[3] * y_norm + + coeffs[4] * x_norm * y_norm + + coeffs[5] * y_norm * y_norm + ) + + r1 = float("inf") if abs(kappa1) <= 1e-14 else float(1.0 / (q * q * kappa1)) + r2 = float("inf") if abs(kappa2) <= 1e-14 else float(1.0 / (q * q * kappa2)) + x0 = float(xc / q + 0.5 * xreal + field.xoff) + y0 = float(yc / q + 0.5 * yreal + field.yoff) + + return { + "degree": float(degree), + "x0": x0, + "y0": y0, + "z0": float(zc), + "r1": r1, + "r2": r2, + "phi1": float(phi1), + "phi2": float(phi2), + } + + +def _line_intersections( + x0: float, + y0: float, + phi: float, + x_min: float, + y_min: float, + width: float, + height: float, +) -> tuple[_Intersection, _Intersection] | None: + dx = float(np.cos(phi)) + dy = float(np.sin(phi)) + points: list[_Intersection] = [] + eps = 1e-12 + x_max = x_min + width + y_max = y_min + height + + if abs(dx) > eps: + for x in (x_min, x_max): + t = (x - x0) / dx + y = y0 + t * dy + if y_min - eps <= y <= y_max + eps: + points.append(_Intersection(float(t), float(np.clip(x, x_min, x_max)), float(np.clip(y, y_min, y_max)))) + if abs(dy) > eps: + for y in (y_min, y_max): + t = (y - y0) / dy + x = x0 + t * dx + if x_min - eps <= x <= x_max + eps: + points.append(_Intersection(float(t), float(np.clip(x, x_min, x_max)), float(np.clip(y, y_min, y_max)))) + + unique: list[_Intersection] = [] + for point in sorted(points, key=lambda item: item.t): + if unique and abs(point.x - unique[-1].x) < 1e-9 and abs(point.y - unique[-1].y) < 1e-9: + continue + unique.append(point) + + if len(unique) < 2: + return None + return unique[0], unique[-1] + + +def _profile_from_intersections(field: DataField, start: _Intersection, end: _Intersection) -> LineData: + x_start = start.x - field.xoff + y_start = start.y - field.yoff + x_end = end.x - field.xoff + y_end = end.y - field.yoff + + px1 = x_start / max(field.xreal, 1e-30) * max(field.xres - 1, 0) + py1 = y_start / max(field.yreal, 1e-30) * max(field.yres - 1, 0) + px2 = x_end / max(field.xreal, 1e-30) * max(field.xres - 1, 0) + py2 = y_end / max(field.yreal, 1e-30) * max(field.yres - 1, 0) + n_samples = max(2, int(np.ceil(np.hypot(px2 - px1, py2 - py1)))) + + t = np.linspace(0.0, 1.0, n_samples, dtype=np.float64) + coords_y = py1 + t * (py2 - py1) + coords_x = px1 + t * (px2 - px1) + profile = map_coordinates(field.data, [coords_y, coords_x], order=1, mode="nearest") + + axis = np.linspace(start.t, end.t, n_samples, dtype=np.float64) + return LineData(data=np.asarray(profile, dtype=np.float64), x_axis=axis, x_unit=field.si_unit_xy, y_unit=field.si_unit_z) + + +def _curvature_markup( + field: DataField, + center_x: float, + center_y: float, + intersections: list[tuple[_Intersection, _Intersection]], +) -> dict[str, object]: + shapes: list[dict[str, object]] = [] + for start, end in intersections: + shapes.append({ + "kind": "line", + "x1": (start.x - field.xoff) / max(field.xreal, 1e-30), + "y1": (start.y - field.yoff) / max(field.yreal, 1e-30), + "x2": (end.x - field.xoff) / max(field.xreal, 1e-30), + "y2": (end.y - field.yoff) / max(field.yreal, 1e-30), + "width": 3, + "color": _CURVATURE_COLOR, + }) + + if np.isfinite(center_x) and np.isfinite(center_y): + radius = 0.015 + fx = (center_x - field.xoff) / max(field.xreal, 1e-30) + fy = (center_y - field.yoff) / max(field.yreal, 1e-30) + shapes.append({ + "kind": "circle", + "x1": fx - radius, + "y1": fy - radius, + "x2": fx + radius, + "y2": fy + radius, + "width": 2, + "color": _CENTER_COLOR, + }) + + return {"kind": "markup", "shapes": shapes} + + +def _empty_profile(unit_xy: str, unit_z: str) -> LineData: + return LineData(data=np.zeros(0, dtype=np.float64), x_axis=np.zeros(0, dtype=np.float64), x_unit=unit_xy, y_unit=unit_z) + + +@register_node(display_name="Curvature") +class Curvature: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "masking": (["ignore", "include", "exclude"], {"default": "ignore"}), + }, + "optional": { + "mask": ("IMAGE",), + }, + } + + RETURN_TYPES = ("ANNOTATION_SOURCE", "MEASURE_TABLE", "LINE", "LINE") + RETURN_NAMES = ("output", "measurements", "profile 1", "profile 2") + FUNCTION = "process" + + DESCRIPTION = ( + "Fit a quadratic surface and report the overall principal curvature radii and directions, matching " + "Gwyddion's curvature feature. The output annotation marks the principal cross-sections and the node " + "also returns the two corresponding height profiles." + ) + + def process( + self, + field: DataField, + masking: str, + mask: np.ndarray | None = None, + ) -> tuple: + require_compatible_xy_z_units(field, "Curvature") + mask_array = _normalize_mask(mask, field.data.shape) + results = _compute_curvature_results(field, mask_array, masking) + + if results is None: + emit_warning("Curvature requires at least six usable pixels for the quadratic fit.") + table = MeasureTable([]) + emit_table(table) + emit_preview(encode_preview(render_datafield_preview(field, field.colormap))) + empty = _empty_profile(field.si_unit_xy, field.si_unit_z) + return (field.replace(), table, empty, empty) + + intersections: list[tuple[_Intersection, _Intersection]] = [] + warnings: list[str] = [] + for angle_key in ("phi1", "phi2"): + hit = _line_intersections( + results["x0"], + results["y0"], + -results[angle_key], + field.xoff, + field.yoff, + field.xreal, + field.yreal, + ) + if hit is None: + warnings.append("Principal axes are outside the image.") + else: + intersections.append(hit) + + profiles = [] + for pair in intersections[:2]: + profiles.append(_profile_from_intersections(field, pair[0], pair[1])) + while len(profiles) < 2: + profiles.append(_empty_profile(field.si_unit_xy, field.si_unit_z)) + + markup_spec = _curvature_markup(field, results["x0"], results["y0"], intersections) + output = field.replace(overlays=[*field.overlays, markup_spec]) + + table = MeasureTable([ + {"quantity": "Center x position", "value": float(results["x0"]), "unit": field.si_unit_xy}, + {"quantity": "Center y position", "value": float(results["y0"]), "unit": field.si_unit_xy}, + {"quantity": "Center value", "value": float(results["z0"]), "unit": field.si_unit_z}, + {"quantity": "Curvature radius 1", "value": float(results["r1"]), "unit": field.si_unit_xy}, + {"quantity": "Curvature radius 2", "value": float(results["r2"]), "unit": field.si_unit_xy}, + {"quantity": "Direction 1", "value": float(np.degrees(results["phi1"])), "unit": "deg"}, + {"quantity": "Direction 2", "value": float(np.degrees(results["phi2"])), "unit": "deg"}, + ]) + + preview_base = render_datafield_preview(field, field.colormap) + emit_preview(encode_preview(_apply_markup_overlay(preview_base, field, markup_spec))) + emit_table(table) + if warnings: + emit_warning(warnings[0]) + + return (output, table, profiles[0], profiles[1]) diff --git a/backend/nodes/facet_level_field.py b/backend/nodes/facet_level_field.py new file mode 100644 index 0000000..421bc56 --- /dev/null +++ b/backend/nodes/facet_level_field.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import numpy as np + +from backend.data_types import DataField +from backend.node_registry import register_node +from backend.nodes.surface_common import require_compatible_xy_z_units + + +def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None: + if mask is None: + return None + + mask_array = np.asarray(mask) + if mask_array.shape[:2] != shape: + raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.") + return mask_array > 127 + + +def _facet_cell_mask(mask: np.ndarray | None, masking: str, shape: tuple[int, int]) -> np.ndarray: + yres, xres = shape + if yres < 2 or xres < 2: + return np.zeros((0, 0), dtype=bool) + + if mask is None or masking == "ignore": + return np.ones((yres - 1, xres - 1), dtype=bool) + + m00 = mask[:-1, :-1] + m01 = mask[:-1, 1:] + m10 = mask[1:, :-1] + m11 = mask[1:, 1:] + + if masking == "include": + return m00 & m01 & m10 & m11 + if masking == "exclude": + return ~(m00 | m01 | m10 | m11) + raise ValueError(f"Unknown masking mode: {masking}") + + +def _fit_facet_plane( + data: np.ndarray, + dx: float, + dy: float, + mask: np.ndarray | None, + masking: str, +) -> tuple[bool, float, float, float]: + yres, xres = data.shape + if yres < 2 or xres < 2: + return False, 0.0, 0.0, 0.0 + + dx = float(dx) if float(dx) > 0.0 else 1.0 + dy = float(dy) if float(dy) > 0.0 else 1.0 + + valid = _facet_cell_mask(mask, masking, data.shape) + nvalid = int(np.count_nonzero(valid)) + if nvalid < 4: + return False, 0.0, 0.0, 0.0 + + z00 = data[:-1, :-1] + z01 = data[:-1, 1:] + z10 = data[1:, :-1] + z11 = data[1:, 1:] + + vx = 0.5 * (z11 + z01 - z10 - z00) / dx + vy = 0.5 * (z10 + z11 - z00 - z01) / dy + mag2 = vx * vx + vy * vy + + sigma2 = float((1.0 / 20.0) * np.mean(mag2[valid])) + if not np.isfinite(sigma2) or sigma2 <= 0.0: + return True, 0.0, 0.0, 0.0 + + weights = np.exp(-mag2[valid] / sigma2) + sumvz = float(np.sum(weights)) + if not np.isfinite(sumvz) or sumvz <= 0.0: + return True, 0.0, 0.0, 0.0 + + pbx = float(np.sum(vx[valid] * weights) / sumvz * dx) + pby = float(np.sum(vy[valid] * weights) / sumvz * dy) + pa = float(-0.5 * (pbx * xres + pby * yres)) + return True, pa, pbx, pby + + +def _subtract_plane(data: np.ndarray, a: float, bx: float, by: float) -> np.ndarray: + yy, xx = np.mgrid[0:data.shape[0], 0:data.shape[1]] + return np.asarray(data, dtype=np.float64) - (float(a) + float(bx) * xx + float(by) * yy) + + +def _facet_level_data( + field: DataField, + mask: np.ndarray | None, + masking: str, + *, + max_iterations: int = 100, + eps: float = 1e-9, +) -> np.ndarray: + working = np.asarray(field.data, dtype=np.float64).copy() + + for _ in range(max(1, int(max_iterations))): + ok, a, bx, by = _fit_facet_plane(working, field.dx, field.dy, mask, masking) + if not ok: + return np.asarray(field.data, dtype=np.float64).copy() + + working = _subtract_plane(working, a, bx, by) + slope_x = float(bx) / (field.dx if field.dx > 0.0 else 1.0) + slope_y = float(by) / (field.dy if field.dy > 0.0 else 1.0) + if slope_x * slope_x + slope_y * slope_y < float(eps): + break + + return working + + +@register_node(display_name="Facet Level") +class FacetLevelField: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "masking": (["exclude", "include", "ignore"], {"default": "exclude"}), + }, + "optional": { + "mask": ("IMAGE",), + }, + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("leveled",) + FUNCTION = "process" + + DESCRIPTION = ( + "Level a field by iteratively finding the dominant local facet orientation and subtracting the " + "corresponding plane, matching Gwyddion's facet-level behaviour. Supports mask include/exclude " + "selection and expects topographic data with compatible XY and Z units." + ) + + def process( + self, + field: DataField, + masking: str, + mask: np.ndarray | None = None, + ) -> tuple: + require_compatible_xy_z_units(field, "Facet Level") + mask_array = _normalize_mask(mask, field.data.shape) + leveled = _facet_level_data(field, mask_array, masking, max_iterations=100) + return (field.replace(data=leveled),) diff --git a/backend/nodes/fractal_dimension.py b/backend/nodes/fractal_dimension.py new file mode 100644 index 0000000..7090238 --- /dev/null +++ b/backend/nodes/fractal_dimension.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +from scipy.ndimage import map_coordinates + +from backend.data_types import LineData, MeasureTable +from backend.execution_context import emit_overlay, emit_table, emit_warning +from backend.node_registry import register_node + + +_LOG_TINY = float(np.finfo(np.float64).tiny) + + +@dataclass(frozen=True) +class _FractalMethod: + display_name: str + x_label: str + y_label: str + + +_METHODS: dict[str, _FractalMethod] = { + "partitioning": _FractalMethod("Partitioning", "log h", "log S"), + "cube_counting": _FractalMethod("Cube counting", "log h", "log N"), + "triangulation": _FractalMethod("Triangulation", "log h", "log A"), + "psdf": _FractalMethod("Power spectrum", "log k", "log W"), + "hhcf": _FractalMethod("Structure function", "log h", "log H"), +} + + +def _clamp01(value: float) -> float: + return float(np.clip(value, 0.0, 1.0)) + + +def _resample_square(data: np.ndarray, size: int, interpolation: str) -> np.ndarray: + source = np.asarray(data, dtype=np.float64) + if source.shape == (size, size): + return source.copy() + + order_map = {"nearest": 0, "linear": 1, "cubic": 3} + if interpolation not in order_map: + raise ValueError(f"Unknown interpolation mode: {interpolation}") + + yres, xres = source.shape + yy, xx = np.meshgrid( + np.linspace(0.0, max(yres - 1, 0), size, dtype=np.float64), + np.linspace(0.0, max(xres - 1, 0), size, dtype=np.float64), + indexing="ij", + ) + return np.asarray( + map_coordinates( + source, + [yy, xx], + order=order_map[interpolation], + mode="nearest", + prefilter=order_map[interpolation] > 1, + ), + dtype=np.float64, + ) + + +def _safe_log(values: np.ndarray | float) -> np.ndarray | float: + return np.log(np.clip(values, _LOG_TINY, None)) + + +def _fit_line(x: np.ndarray, y: np.ndarray) -> tuple[float, float]: + coeffs = np.polyfit(np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64), 1) + return float(coeffs[0]), float(coeffs[1]) + + +def _row_level2(row: np.ndarray) -> np.ndarray: + values = np.asarray(row, dtype=np.float64) + if values.size <= 1: + return values.copy() + x = np.linspace(-1.0, 1.0, values.size, dtype=np.float64) + A = np.column_stack((np.ones_like(x), x)) + coeffs, _, _, _ = np.linalg.lstsq(A, values, rcond=None) + return values - (coeffs[0] + coeffs[1] * x) + + +def _hann_window(size: int) -> np.ndarray: + if size <= 0: + return np.ones(0, dtype=np.float64) + t = (np.arange(size, dtype=np.float64) + 0.5) / float(size) + return 0.5 - 0.5 * np.cos(2.0 * np.pi * t) + + +def _window_with_rms_compensation(values: np.ndarray, window: np.ndarray) -> np.ndarray: + row = np.asarray(values, dtype=np.float64) + rms = float(np.sqrt(np.mean(row * row))) + weighted = row * window + new_rms = float(np.sqrt(np.mean(weighted * weighted))) + if rms > 0.0 and new_rms > 0.0: + weighted *= rms / new_rms + return weighted + + +def _fractal_partitioning(data: np.ndarray, interpolation: str) -> tuple[np.ndarray, np.ndarray]: + xres = int(data.shape[1]) + dimexp = int(np.floor(np.log(float(max(xres, 2))) / np.log(2.0) + 0.5)) + if dimexp < 2: + return np.zeros(0, dtype=np.float64), np.zeros(0, dtype=np.float64) + + size = (1 << dimexp) + 1 + buffer = _resample_square(data, size, interpolation) + xvals = np.empty(dimexp - 1, dtype=np.float64) + yvals = np.empty(dimexp - 1, dtype=np.float64) + + for l in range(1, dimexp): + rp = 1 << l + nx = (size - 1) // rp - 1 + ny = (size - 1) // rp - 1 + accum = 0.0 + for i in range(nx): + for j in range(ny): + block = buffer[j * rp:j * rp + rp, i * rp:i * rp + rp] + rms = float(np.std(block, ddof=0)) + accum += rms * rms + xvals[l - 1] = np.log(float(rp)) + denom = max(nx * ny, 1) + yvals[l - 1] = float(_safe_log(accum / denom)) + + return xvals, yvals + + +def _fractal_cube_counting(data: np.ndarray, interpolation: str) -> tuple[np.ndarray, np.ndarray]: + xres = int(data.shape[1]) + dimexp = int(np.floor(np.log(float(max(xres, 2))) / np.log(2.0) + 0.5)) + if dimexp < 1: + return np.zeros(0, dtype=np.float64), np.zeros(0, dtype=np.float64) + + size = (1 << dimexp) + 1 + buffer = _resample_square(data, size, interpolation) + imin = float(np.min(buffer)) + height = float(np.max(buffer) - imin) + if not np.isfinite(height) or height <= 0.0: + height = _LOG_TINY + + xvals = np.empty(dimexp, dtype=np.float64) + yvals = np.empty(dimexp, dtype=np.float64) + + for l in range(dimexp): + rp = 1 << (l + 1) + rp2 = (1 << dimexp) // rp + a = max(height / rp, _LOG_TINY) + accum = 0.0 + for i in range(rp): + for j in range(rp): + block = buffer[j * rp2:j * rp2 + rp2 + 1, i * rp2:i * rp2 + rp2 + 1] - imin + maxv = float(np.max(block)) + minv = float(np.min(block)) + accum += rp - np.floor(minv / a) - np.floor((height - maxv) / a) + xvals[l] = float((l + 1 - dimexp) * np.log(2.0)) + yvals[l] = float(_safe_log(accum)) + + return xvals, yvals + + +def _fractal_triangulation(data: np.ndarray, interpolation: str) -> tuple[np.ndarray, np.ndarray]: + xres = int(data.shape[1]) + dimexp = int(np.floor(np.log(float(max(xres, 2))) / np.log(2.0) + 0.5)) + if dimexp < 0: + return np.zeros(0, dtype=np.float64), np.zeros(0, dtype=np.float64) + + size = (1 << dimexp) + 1 + buffer = _resample_square(data, size, interpolation) + height = float(np.max(buffer) - np.min(buffer)) + if not np.isfinite(height) or height <= 0.0: + height = _LOG_TINY + dil = float((1 << dimexp) / height) + dil *= dil + + xvals = np.empty(dimexp + 1, dtype=np.float64) + yvals = np.empty(dimexp + 1, dtype=np.float64) + + for l in range(dimexp + 1): + rp = 1 << l + rp2 = (1 << dimexp) // rp + accum = 0.0 + for i in range(rp): + for j in range(rp): + z1 = float(buffer[j * rp2, i * rp2]) + z2 = float(buffer[j * rp2, (i + 1) * rp2]) + z3 = float(buffer[(j + 1) * rp2, i * rp2]) + z4 = float(buffer[(j + 1) * rp2, (i + 1) * rp2]) + + a = float(np.sqrt(rp2 * rp2 + dil * (z1 - z2) * (z1 - z2))) + b = float(np.sqrt(rp2 * rp2 + dil * (z1 - z3) * (z1 - z3))) + c = float(np.sqrt(rp2 * rp2 + dil * (z3 - z4) * (z3 - z4))) + d = float(np.sqrt(rp2 * rp2 + dil * (z2 - z4) * (z2 - z4))) + e = float(np.sqrt(2.0 * rp2 * rp2 + dil * (z3 - z2) * (z3 - z2))) + + s1 = 0.5 * (a + b + e) + s2 = 0.5 * (c + d + e) + term1 = max(s1 * (s1 - a) * (s1 - b) * (s1 - e), 0.0) + term2 = max(s2 * (s2 - c) * (s2 - d) * (s2 - e), 0.0) + accum += np.sqrt(term1) + np.sqrt(term2) + xvals[l] = float((l - dimexp) * np.log(2.0)) + yvals[l] = float(_safe_log(accum)) + + return xvals, yvals + + +def _fractal_psdf(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + rows, width = data.shape + if width < 2 or rows < 1: + return np.zeros(0, dtype=np.float64), np.zeros(0, dtype=np.float64) + + window = _hann_window(width) + accum = np.zeros(width // 2 + 1, dtype=np.float64) + for row in np.asarray(data, dtype=np.float64): + leveled = _row_level2(row) + weighted = _window_with_rms_compensation(leveled, window) + spectrum = np.fft.rfft(weighted) + accum += np.abs(spectrum) ** 2 + accum /= float(rows) + + indices = np.arange(1, accum.size, dtype=np.float64) + return np.log(indices), _safe_log(accum[1:]) + + +def _fractal_hhcf(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + rows, width = data.shape + if width < 2 or rows < 1: + return np.zeros(0, dtype=np.float64), np.zeros(0, dtype=np.float64) + + accum = np.zeros(width, dtype=np.float64) + for row in np.asarray(data, dtype=np.float64): + leveled = _row_level2(row) + for lag in range(width): + if lag == 0: + accum[lag] += 0.0 + else: + diffs = leveled[lag:] - leveled[:-lag] + accum[lag] += float(np.mean(diffs * diffs)) if diffs.size else 0.0 + accum /= float(rows) + + outres = min(width - 1, (width + 5) // 10 + int(np.rint(np.sqrt(width)))) + indices = np.arange(1, outres + 1, dtype=np.float64) + return np.log(indices), _safe_log(accum[1:outres + 1]) + + +def _compute_method(field_data: np.ndarray, method: str, interpolation: str) -> tuple[np.ndarray, np.ndarray]: + if method == "partitioning": + return _fractal_partitioning(field_data, interpolation) + if method == "cube_counting": + return _fractal_cube_counting(field_data, interpolation) + if method == "triangulation": + return _fractal_triangulation(field_data, interpolation) + if method == "psdf": + return _fractal_psdf(field_data) + if method == "hhcf": + return _fractal_hhcf(field_data) + raise ValueError(f"Unknown fractal method: {method}") + + +def _dimension_from_slope(method: str, slope: float) -> float: + if method == "partitioning": + return 3.0 - slope / 2.0 + if method == "cube_counting": + return slope + if method == "triangulation": + return 2.0 + slope + if method == "psdf": + return 3.5 + slope / 2.0 + if method == "hhcf": + return 3.0 - slope / 2.0 + raise ValueError(f"Unknown fractal method: {method}") + + +def _select_fit_range(xvals: np.ndarray, x1: float, x2: float) -> tuple[np.ndarray, float, float]: + if xvals.size == 0: + return np.zeros(0, dtype=bool), 0.0, 0.0 + + xmin = float(np.min(xvals)) + xmax = float(np.max(xvals)) + if abs(float(x1) - float(x2)) < 1e-9: + return np.ones(xvals.size, dtype=bool), xmin, xmax + + lo_frac = min(float(x1), float(x2)) + hi_frac = max(float(x1), float(x2)) + lo = xmin + lo_frac * (xmax - xmin) + hi = xmin + hi_frac * (xmax - xmin) + mask = (xvals >= lo) & (xvals <= hi) + return mask, float(lo), float(hi) + + +@register_node(display_name="Fractal Dimension") +class FractalDimension: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "method": (list(_METHODS.keys()), {"default": "partitioning"}), + "interpolation": (["linear", "nearest", "cubic"], {"default": "linear"}), + "x1": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "x2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + "y2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "hidden": True}), + } + } + + RETURN_TYPES = ("FLOAT", "LINE", "MEASURE_TABLE") + RETURN_NAMES = ("dimension", "curve", "measurements") + FUNCTION = "process" + + DESCRIPTION = ( + "Calculate the surface fractal dimension using Gwyddion's partitioning, cube counting, triangulation, " + "power-spectrum, or HHCF methods. The in-node graph shows the log-log curve and lets you drag the fit range." + ) + + def process( + self, + field, + method: str, + interpolation: str, + x1: float, + y1: float, + x2: float, + y2: float, + ) -> tuple: + xvals, yvals = _compute_method(np.asarray(field.data, dtype=np.float64), method, interpolation) + finite = np.isfinite(xvals) & np.isfinite(yvals) + xvals = np.asarray(xvals[finite], dtype=np.float64) + yvals = np.asarray(yvals[finite], dtype=np.float64) + + line = LineData(data=yvals, x_axis=xvals, x_unit="", y_unit="") + + x1 = _clamp01(x1) + x2 = _clamp01(x2) + y1 = _clamp01(y1) + y2 = _clamp01(y2) + + fit_mask, fit_from, fit_to = _select_fit_range(xvals, x1, x2) + if np.count_nonzero(fit_mask) >= 2: + slope, intercept = _fit_line(xvals[fit_mask], yvals[fit_mask]) + dimension = _dimension_from_slope(method, slope) + else: + slope = float("nan") + intercept = float("nan") + dimension = float("nan") + emit_warning("Fractal fit range contains fewer than two usable points.") + + table = MeasureTable([ + {"quantity": "Dimension", "value": float(dimension), "unit": ""}, + {"quantity": "Fit slope", "value": float(slope), "unit": ""}, + {"quantity": "Fit intercept", "value": float(intercept), "unit": ""}, + {"quantity": "Fit from", "value": float(fit_from), "unit": ""}, + {"quantity": "Fit to", "value": float(fit_to), "unit": ""}, + ]) + + method_info = _METHODS[method] + emit_overlay({ + "kind": "line_plot", + "section_title": "Fractal Dimension", + "line": yvals.tolist(), + "x_axis": xvals.tolist(), + "x_label": method_info.x_label, + "y_label": method_info.y_label, + "x1": x1, + "x2": x2, + "y1": y1, + "y2": y2, + "a_locked": False, + "b_locked": False, + }) + emit_table(table) + + return (float(dimension), line, table) diff --git a/backend/nodes/grain_distance_transform.py b/backend/nodes/grain_distance_transform.py new file mode 100644 index 0000000..f4780c8 --- /dev/null +++ b/backend/nodes/grain_distance_transform.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from functools import lru_cache + +import numpy as np +from scipy.ndimage import binary_erosion, distance_transform_edt + +from backend.data_types import DataField +from backend.node_registry import register_node + + +def _normalize_mask(mask: np.ndarray) -> np.ndarray: + data = np.asarray(mask) + if data.ndim != 2: + raise ValueError("Grain Distance Transform requires a 2-D mask.") + return data > 127 + + +def _prepare_mask(binary: np.ndarray, from_border: bool) -> tuple[np.ndarray, tuple[slice, slice]]: + binary = np.asarray(binary, dtype=bool) + if from_border: + return binary, (slice(None), slice(None)) + + pad = max(binary.shape) + padded = np.pad(binary, pad, mode="constant", constant_values=True) + padded[0, :] = False + padded[-1, :] = False + padded[:, 0] = False + padded[:, -1] = False + return padded, (slice(pad, pad + binary.shape[0]), slice(pad, pad + binary.shape[1])) + + +@lru_cache(maxsize=32) +def _distance_structures() -> tuple[np.ndarray, np.ndarray]: + cross = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=bool) + square = np.ones((3, 3), dtype=bool) + cross.setflags(write=False) + square.setflags(write=False) + return cross, square + + +def _simple_distance_transform(binary: np.ndarray, distance_type: str, from_border: bool) -> np.ndarray: + work, crop = _prepare_mask(binary, from_border) + result = np.zeros(work.shape, dtype=np.float64) + current = work.copy() + cross, square = _distance_structures() + + if distance_type == "cityblock": + sequence = (cross,) + elif distance_type == "chess": + sequence = (square,) + elif distance_type == "octagonal48": + sequence = (cross, square) + elif distance_type == "octagonal84": + sequence = (square, cross) + else: + raise ValueError(f"Unsupported simple distance type: {distance_type}") + + step = 1.0 + iteration = 0 + while np.any(current): + structure = sequence[iteration % len(sequence)] + eroded = binary_erosion(current, structure=structure, border_value=0) + removed = current & ~eroded + result[removed] = step + current = eroded + step += 1.0 + iteration += 1 + + return result[crop] + + +def _euclidean_distance_transform(binary: np.ndarray, from_border: bool) -> np.ndarray: + if from_border: + work = np.pad(np.asarray(binary, dtype=bool), 1, mode="constant", constant_values=False) + return np.asarray(distance_transform_edt(work), dtype=np.float64)[1:-1, 1:-1] + + work, crop = _prepare_mask(binary, False) + return np.asarray(distance_transform_edt(work), dtype=np.float64)[crop] + + +def _distance_transform(binary: np.ndarray, distance_type: str, from_border: bool) -> np.ndarray: + if distance_type == "euclidean": + return _euclidean_distance_transform(binary, from_border) + if distance_type == "octagonal": + d48 = _simple_distance_transform(binary, "octagonal48", from_border) + d84 = _simple_distance_transform(binary, "octagonal84", from_border) + return 0.5 * (d48 + d84) + return _simple_distance_transform(binary, distance_type, from_border) + + +@register_node(display_name="Grain Distance Transform") +class GrainDistanceTransform: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "mask": ("IMAGE",), + "distance_type": (["euclidean", "cityblock", "chess", "octagonal48", "octagonal84", "octagonal"], {"default": "euclidean"}), + "output_type": (["interior", "exterior", "signed"], {"default": "interior"}), + "from_border": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("DATA_FIELD",) + RETURN_NAMES = ("distance",) + FUNCTION = "process" + + DESCRIPTION = ( + "Compute the mask distance transform using Gwyddion-style interior, exterior, or signed output. " + "Supports Euclidean, city-block, chessboard, and octagonal distance variants, with optional " + "image-boundary handling matching mask_edt." + ) + + def process( + self, + field: DataField, + mask: np.ndarray, + distance_type: str, + output_type: str, + from_border: bool, + ) -> tuple: + binary = _normalize_mask(mask) + + interior = _distance_transform(binary, distance_type, bool(from_border)) + interior *= binary + + if output_type == "interior": + distance = interior + else: + exterior_binary = ~binary + exterior = _distance_transform(exterior_binary, distance_type, bool(from_border)) + exterior *= exterior_binary + if output_type == "exterior": + distance = exterior + elif output_type == "signed": + distance = interior - exterior + else: + raise ValueError(f"Unsupported output type: {output_type}") + + scale = float(np.sqrt(field.dx * field.dy)) + result = field.replace( + data=np.asarray(distance, dtype=np.float64) * scale, + si_unit_z=field.si_unit_xy, + ) + return (result,) diff --git a/backend/nodes/surface_common.py b/backend/nodes/surface_common.py new file mode 100644 index 0000000..06ecffb --- /dev/null +++ b/backend/nodes/surface_common.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from backend.data_types import DataField + + +_LENGTH_UNITS = {"m", "km", "cm", "mm", "um", "µm", "nm", "pm", "fm"} + + +def unit_dimension_key(unit: str) -> str: + text = str(unit or "").strip().replace("µ", "u") + if not text: + return "" + if text in _LENGTH_UNITS: + return "length" + return text + + +def require_compatible_xy_z_units(field: DataField, node_name: str) -> None: + xy_key = unit_dimension_key(field.si_unit_xy) + z_key = unit_dimension_key(field.si_unit_z) + if xy_key and z_key and xy_key != z_key: + raise ValueError(f"{node_name} requires compatible XY and Z units, matching Gwyddion's topography-only behavior.") diff --git a/backend/nodes/watershed_segmentation.py b/backend/nodes/watershed_segmentation.py new file mode 100644 index 0000000..ecd301a --- /dev/null +++ b/backend/nodes/watershed_segmentation.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +from functools import lru_cache + +import numpy as np +from scipy.ndimage import label + +from backend.execution_context import emit_preview +from backend.data_types import DataField, encode_preview +from backend.node_registry import register_node +from backend.nodes.helpers import _mask_overlay + + +def _working_height(field: DataField, invert_height: bool) -> np.ndarray: + data = np.asarray(field.data, dtype=np.float64) + return -data if invert_height else data.copy() + + +def _next_indices(data: np.ndarray) -> np.ndarray: + yres, xres = data.shape + flat_idx = np.arange(yres * xres, dtype=np.int64).reshape(yres, xres) + + right_val = np.full_like(data, -np.inf, dtype=np.float64) + right_val[:, :-1] = data[:, 1:] + left_val = np.full_like(data, -np.inf, dtype=np.float64) + left_val[:, 1:] = data[:, :-1] + down_val = np.full_like(data, -np.inf, dtype=np.float64) + down_val[:-1, :] = data[1:, :] + up_val = np.full_like(data, -np.inf, dtype=np.float64) + up_val[1:, :] = data[:-1, :] + + right_idx = flat_idx.copy() + right_idx[:, :-1] = flat_idx[:, 1:] + left_idx = flat_idx.copy() + left_idx[:, 1:] = flat_idx[:, :-1] + down_idx = flat_idx.copy() + down_idx[:-1, :] = flat_idx[1:, :] + up_idx = flat_idx.copy() + up_idx[1:, :] = flat_idx[:-1, :] + + next_idx = flat_idx.copy() + local = ( + (data >= right_val) + & (data >= left_val) + & (data >= down_val) + & (data >= up_val) + ) + + right_mask = (~local) & (right_val >= data) & (right_val >= left_val) & (right_val >= down_val) & (right_val >= up_val) + next_idx[right_mask] = right_idx[right_mask] + + unresolved = (~local) & (~right_mask) + left_mask = unresolved & (left_val >= data) & (left_val >= right_val) & (left_val >= down_val) & (left_val >= up_val) + next_idx[left_mask] = left_idx[left_mask] + + unresolved &= ~left_mask + down_mask = unresolved & (down_val >= data) & (down_val >= right_val) & (down_val >= left_val) & (down_val >= up_val) + next_idx[down_mask] = down_idx[down_mask] + + unresolved &= ~down_mask + next_idx[unresolved] = up_idx[unresolved] + return next_idx.ravel() + + +def _terminal_indices(data: np.ndarray) -> np.ndarray: + terminals = _next_indices(np.asarray(data, dtype=np.float64)) + while True: + jumped = terminals[terminals] + if np.array_equal(jumped, terminals): + return terminals + terminals = jumped + + +@lru_cache(maxsize=32) +def _source_order(shape: tuple[int, int]) -> np.ndarray: + yres, xres = shape + if yres < 3 or xres < 3: + return np.zeros(0, dtype=np.int64) + rows, cols = np.mgrid[1:yres - 1, 1:xres - 1] + order = (rows.ravel(order="F") * xres + cols.ravel(order="F")).astype(np.int64) + order.setflags(write=False) + return order + + +def _location_step(data: np.ndarray, water: np.ndarray, dropsize: float) -> None: + terminals = _terminal_indices(data) + ordered_sources = _source_order(data.shape) + counts = np.bincount(terminals[ordered_sources], minlength=data.size).astype(np.float64) + water += counts.reshape(data.shape) + data -= dropsize * counts.reshape(data.shape) + + +def _seed_labels(water: np.ndarray, threshold: int) -> np.ndarray: + structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.int8) + labeled, ngrains = label(water > 0.0, structure=structure) + if ngrains <= 0: + return np.zeros_like(labeled, dtype=np.int32) + + sizes = np.bincount(labeled.ravel(), minlength=ngrains + 1) + seeds = np.zeros_like(labeled, dtype=np.int32) + next_label = 1 + flat_water = water.ravel() + flat_labeled = labeled.ravel() + + for grain_id in range(1, ngrains + 1): + if int(sizes[grain_id]) <= int(threshold): + continue + indices = np.flatnonzero(flat_labeled == grain_id) + if indices.size == 0: + continue + peak_index = int(indices[np.argmax(flat_water[indices])]) + seeds.ravel()[peak_index] = next_label + next_label += 1 + + return seeds + + +def _process_mask(labels: np.ndarray, row: int, col: int) -> None: + yres, xres = labels.shape + if col == 0 or row == 0 or col == xres - 1 or row == yres - 1: + labels[row, col] = -1 + return + + if labels[row, col] != 0: + return + + left = int(labels[row, col - 1]) + up = int(labels[row - 1, col]) + right = int(labels[row, col + 1]) + down = int(labels[row + 1, col]) + + if abs(left) + abs(up) + abs(right) + abs(down) == 0: + return + + value = 0 + boundary = False + for candidate in (left, up, right, down): + if value > 0 and candidate > 0 and candidate != value: + boundary = True + break + if candidate > 0: + value = candidate + + labels[row, col] = -1 if boundary else value + + +def _watershed_step( + data: np.ndarray, + water: np.ndarray, + labels: np.ndarray, + seeds: np.ndarray, + dropsize: float, +) -> None: + labels[seeds > 0] = seeds[seeds > 0] + + terminals = _terminal_indices(data) + ordered_sources = _source_order(data.shape) + ordered_terminals = terminals[ordered_sources] + xres = data.shape[1] + + for term in ordered_terminals: + row = int(term // xres) + col = int(term % xres) + _process_mask(labels, row, col) + + counts = np.bincount(ordered_terminals, minlength=data.size).astype(np.float64) + water += counts.reshape(data.shape) + data -= dropsize * counts.reshape(data.shape) + + +def _mark_boundaries(labels: np.ndarray) -> np.ndarray: + result = labels.copy() + if result.shape[0] < 3 or result.shape[1] < 3: + return result + + interior = result[1:-1, 1:-1] + right = result[1:-1, 2:] + down = result[2:, 1:-1] + interior[(interior != right) | (interior != down)] = 0 + return result + + +def _combine_masks(result_mask: np.ndarray, existing_mask: np.ndarray | None, combine_mode: str) -> np.ndarray: + if existing_mask is None or combine_mode == "replace": + return result_mask + + existing = np.asarray(existing_mask) > 127 + current = np.asarray(result_mask, dtype=bool) + if existing.shape != current.shape: + raise ValueError("Existing mask must have the same shape as the watershed output.") + + if combine_mode == "union": + merged = current | existing + elif combine_mode == "intersection": + merged = current & existing + else: + raise ValueError(f"Unsupported combine mode: {combine_mode}") + + return merged.astype(np.uint8) * 255 + + +@register_node(display_name="Watershed Segmentation") +class WatershedSegmentation: + _CUSTOM_PREVIEW = True + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "field": ("DATA_FIELD",), + "invert_height": ("BOOLEAN", {"default": False}), + "locate_steps": ("INT", {"default": 10, "min": 1, "max": 200, "step": 1}), + "locate_threshold": ("INT", {"default": 10, "min": 0, "max": 100000, "step": 1}), + "locate_drop_size": ("FLOAT", {"default": 0.1, "min": 0.0001, "max": 1.0, "step": 0.01}), + "watershed_steps": ("INT", {"default": 20, "min": 1, "max": 2000, "step": 1}), + "watershed_drop_size": ("FLOAT", {"default": 0.1, "min": 0.0001, "max": 1.0, "step": 0.01}), + "combine_mode": (["replace", "union", "intersection"], {"default": "replace"}), + }, + "optional": { + "mask": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + + DESCRIPTION = ( + "Segment a height field into grains using the two-stage Gwyddion watershed workflow: " + "drop-based seed location followed by watershed growth. Supports hill or valley detection " + "and optional union/intersection with an existing mask." + ) + + def process( + self, + field: DataField, + invert_height: bool, + locate_steps: int, + locate_threshold: int, + locate_drop_size: float, + watershed_steps: int, + watershed_drop_size: float, + combine_mode: str, + mask: np.ndarray | None = None, + ) -> tuple: + working = _working_height(field, bool(invert_height)) + water = np.zeros_like(working, dtype=np.float64) + + q = float((np.max(working) - np.min(working)) / 50.0) + locate_drop = float(locate_drop_size) * q + watershed_drop = float(watershed_drop_size) * q + + locate_field = working.copy() + for _ in range(int(locate_steps)): + _location_step(locate_field, water, locate_drop) + + seeds = _seed_labels(water, int(locate_threshold)) + labels = np.zeros_like(seeds, dtype=np.int32) + watershed_field = working.copy() + for _ in range(int(watershed_steps)): + _watershed_step(watershed_field, water, labels, seeds, watershed_drop) + + labels = _mark_boundaries(labels) + result_mask = (labels > 0).astype(np.uint8) * 255 + result_mask = _combine_masks(result_mask, mask, combine_mode) + + emit_preview(encode_preview(_mask_overlay(field, result_mask))) + return (result_mask,) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index e86e527..6980828 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -526,6 +526,81 @@ def test_plane_level(): print(" PASS\n") +def test_facet_level(): + print("=== Test: FacetLevelField ===") + from backend.node_registry import get_node_info + from backend.nodes.facet_level_field import FacetLevelField + from backend.nodes.plane_level_field import PlaneLevelField + + def fit_pixel_plane(data: np.ndarray, region: np.ndarray) -> tuple[float, float, float]: + yy, xx = np.mgrid[0:data.shape[0], 0:data.shape[1]] + A = np.column_stack([ + np.ones(int(np.count_nonzero(region)), dtype=np.float64), + xx[region].astype(np.float64), + yy[region].astype(np.float64), + ]) + coeffs, _, _, _ = np.linalg.lstsq(A, data[region].ravel().astype(np.float64), rcond=None) + return float(coeffs[0]), float(coeffs[1]), float(coeffs[2]) + + node = FacetLevelField() + plane_node = PlaneLevelField() + assert get_node_info("FacetLevelField")["category"] == "Flatten" + + N = 96 + yy, xx = np.mgrid[0:N, 0:N] + base = 0.055 * xx + 0.028 * yy + terraces = np.zeros((N, N), dtype=np.float64) + terraces[:, 54:] += 6.0 + terraces[18:70, 68:88] += 3.5 + field = make_field(data=base + terraces) + + plane_leveled, = plane_node.process(field) + facet_leveled, = node.process(field, masking="ignore") + + left_region = xx < 48 + right_region = (xx > 60) & ~((yy >= 18) & (yy < 70) & (xx >= 68) & (xx < 88)) + _, plane_left_bx, plane_left_by = fit_pixel_plane(plane_leveled.data, left_region) + _, plane_right_bx, plane_right_by = fit_pixel_plane(plane_leveled.data, right_region) + _, facet_left_bx, facet_left_by = fit_pixel_plane(facet_leveled.data, left_region) + _, facet_right_bx, facet_right_by = fit_pixel_plane(facet_leveled.data, right_region) + plane_slope = float(max(np.hypot(plane_left_bx, plane_left_by), np.hypot(plane_right_bx, plane_right_by))) + facet_slope = float(max(np.hypot(facet_left_bx, facet_left_by), np.hypot(facet_right_bx, facet_right_by))) + assert facet_slope < plane_slope * 1e-6 + + mask = np.zeros((N, N), dtype=np.uint8) + mask[24:72, 24:72] = 255 + base_only = 0.035 * xx + 0.014 * yy + masked_facet = 5.0 - 0.065 * xx + 0.045 * yy + competing = np.where(mask > 0, masked_facet, base_only) + competing_field = make_field(data=competing) + + excluded, = node.process(competing_field, masking="exclude", mask=mask) + included, = node.process(competing_field, masking="include", mask=mask) + + outer_region = (mask == 0) & (xx > 4) & (xx < N - 4) & (yy > 4) & (yy < N - 4) + inner_region = (mask > 0) & (xx > 28) & (xx < 68) & (yy > 28) & (yy < 68) + _, excl_outer_bx, excl_outer_by = fit_pixel_plane(excluded.data, outer_region) + _, excl_inner_bx, excl_inner_by = fit_pixel_plane(excluded.data, inner_region) + _, incl_outer_bx, incl_outer_by = fit_pixel_plane(included.data, outer_region) + _, incl_inner_bx, incl_inner_by = fit_pixel_plane(included.data, inner_region) + + excl_outer_slope = float(np.hypot(excl_outer_bx, excl_outer_by)) + excl_inner_slope = float(np.hypot(excl_inner_bx, excl_inner_by)) + incl_outer_slope = float(np.hypot(incl_outer_bx, incl_outer_by)) + incl_inner_slope = float(np.hypot(incl_inner_bx, incl_inner_by)) + assert excl_outer_slope < incl_outer_slope * 0.2 + assert incl_inner_slope < excl_inner_slope * 0.2 + + bad_units = DataField(data=np.zeros((16, 16), dtype=np.float64), xreal=1e-6, yreal=1e-6, si_unit_xy="nm", si_unit_z="V") + try: + node.process(bad_units, masking="ignore") + except ValueError as exc: + assert "compatible XY and Z units" in str(exc) + else: + assert False, "Facet level should reject incompatible XY/Z units." + print(" PASS\n") + + def test_poly_level(): print("=== Test: PolyLevelField ===") from backend.nodes.poly_level_field import PolyLevelField @@ -571,6 +646,78 @@ def test_fix_zero(): print(" PASS\n") +def test_curvature(): + print("=== Test: Curvature ===") + from backend.node_registry import get_node_info + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.curvature import Curvature + + node = Curvature() + assert get_node_info("Curvature")["category"] == "Measure" + + xres, yres = 121, 101 + xreal, yreal = 8.0e-6, 6.0e-6 + xoff, yoff = 1.0e-6, -0.5e-6 + x = np.linspace(xoff, xoff + xreal, xres, dtype=np.float64) + y = np.linspace(yoff, yoff + yreal, yres, dtype=np.float64) + yy, xx = np.meshgrid(y, x, indexing="ij") + + x0 = xoff + 0.45 * xreal + y0 = yoff + 0.60 * yreal + rx = 1.2e-6 + ry = 2.4e-6 + z0 = 3.0e-9 + data = z0 + (xx - x0) ** 2 / (2.0 * rx) + (yy - y0) ** 2 / (2.0 * ry) + field = DataField(data=data, xreal=xreal, yreal=yreal, xoff=xoff, yoff=yoff, si_unit_xy="m", si_unit_z="m") + + previews = [] + tables = [] + with execution_callbacks(preview=lambda nid, uri: previews.append(uri), table=lambda nid, rows: tables.append(rows)), active_node("test"): + output, table, profile1, profile2 = node.process(field, masking="ignore") + + rows = {row["quantity"]: row for row in table} + recovered_radii = sorted([rows["Curvature radius 1"]["value"], rows["Curvature radius 2"]["value"]]) + expected_radii = sorted([rx, ry]) + assert len(previews) == 1 + assert previews[0].startswith("data:image/png;base64,") + assert len(tables) == 1 + assert abs(rows["Center x position"]["value"] - x0) < xreal * 0.02 + assert abs(rows["Center y position"]["value"] - y0) < yreal * 0.02 + assert abs(rows["Center value"]["value"] - z0) < 5e-11 + assert np.allclose(recovered_radii, expected_radii, rtol=0.08, atol=5e-8) + assert output.overlays[-1]["kind"] == "markup" + assert len(output.overlays[-1]["shapes"]) == 3 + assert isinstance(profile1, LineData) + assert isinstance(profile2, LineData) + assert profile1.x_unit == field.si_unit_xy + assert profile1.y_unit == field.si_unit_z + assert profile2.x_unit == field.si_unit_xy + assert profile2.y_unit == field.si_unit_z + assert len(profile1) > 10 + assert len(profile2) > 10 + + mask = np.zeros((yres, xres), dtype=np.uint8) + mask[:, :xres // 2] = 255 + left = 1.0e-9 + (xx - (xoff + 0.25 * xreal)) ** 2 / (2.0 * 0.9e-6) + (yy - (yoff + 0.5 * yreal)) ** 2 / (2.0 * 1.8e-6) + right = 2.0e-9 + (xx - (xoff + 0.75 * xreal)) ** 2 / (2.0 * 1.6e-6) + (yy - (yoff + 0.5 * yreal)) ** 2 / (2.0 * 3.2e-6) + split_field = DataField(data=np.where(mask > 0, left, right), xreal=xreal, yreal=yreal, xoff=xoff, yoff=yoff, si_unit_xy="m", si_unit_z="m") + _, include_table, _, _ = node.process(split_field, masking="include", mask=mask) + _, exclude_table, _, _ = node.process(split_field, masking="exclude", mask=mask) + include_radii = sorted([row["value"] for row in include_table if row["quantity"].startswith("Curvature radius")]) + exclude_radii = sorted([row["value"] for row in exclude_table if row["quantity"].startswith("Curvature radius")]) + assert np.allclose(include_radii, [0.9e-6, 1.8e-6], rtol=0.12, atol=5e-8) + assert np.allclose(exclude_radii, [1.6e-6, 3.2e-6], rtol=0.12, atol=5e-8) + + bad_units = DataField(data=np.zeros((16, 16), dtype=np.float64), xreal=1e-6, yreal=1e-6, si_unit_xy="nm", si_unit_z="V") + try: + node.process(bad_units, masking="ignore") + except ValueError as exc: + assert "compatible XY and Z units" in str(exc) + else: + assert False, "Curvature should reject incompatible XY/Z units." + print(" PASS\n") + + def test_line_correction(): print("=== Test: LineCorrection ===") from backend.node_registry import get_node_info @@ -866,6 +1013,80 @@ def test_height_histogram(): print(" PASS\n") +def test_fractal_dimension(): + print("=== Test: FractalDimension ===") + from backend.node_registry import get_node_info + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.fractal_dimension import FractalDimension + + node = FractalDimension() + assert get_node_info("FractalDimension")["category"] == "Measure" + + N = 129 + yy, xx = np.mgrid[0:N, 0:N] / (N - 1) + data = 0.25 * xx + 0.12 * yy + 0.03 * np.sin(6.0 * np.pi * xx) + 0.02 * np.cos(4.0 * np.pi * yy) + field = make_field(data=data, xreal=4.0e-6, yreal=4.0e-6) + + overlays = [] + tables = [] + with execution_callbacks(overlay=lambda nid, payload: overlays.append(payload), table=lambda nid, rows: tables.append(rows)), active_node("test"): + dimension, curve, table = node.process( + field, + method="partitioning", + interpolation="linear", + x1=0.0, + y1=0.5, + x2=1.0, + y2=0.5, + ) + + assert np.isfinite(dimension) + assert 1.5 < dimension < 2.5 + assert isinstance(curve, LineData) + assert len(curve) > 3 + assert curve.x_axis is not None + assert np.all(np.diff(curve.x_axis) > 0.0) + assert len(overlays) == 1 + assert overlays[0]["kind"] == "line_plot" + assert len(tables) == 1 + assert table[0]["quantity"] == "Dimension" + + methods = ["partitioning", "cube_counting", "triangulation", "psdf", "hhcf"] + for method in methods: + dim, line, measurements = node.process( + field, + method=method, + interpolation="linear", + x1=0.0, + y1=0.5, + x2=1.0, + y2=0.5, + ) + assert np.isfinite(dim), f"{method} should produce a finite fractal dimension" + if method == "psdf": + assert -1.0 < dim < 3.2 + else: + assert 1.2 < dim < 3.2 + assert isinstance(line, LineData) + assert len(line) >= 2 + assert measurements[0]["quantity"] == "Dimension" + + narrowed_dim, _, narrowed_table = node.process( + field, + method="partitioning", + interpolation="linear", + x1=0.15, + y1=0.5, + x2=0.55, + y2=0.5, + ) + assert np.isfinite(narrowed_dim) + fit_from = next(row["value"] for row in narrowed_table if row["quantity"] == "Fit from") + fit_to = next(row["value"] for row in narrowed_table if row["quantity"] == "Fit to") + assert fit_to > fit_from + print(" PASS\n") + + def test_cross_section(): print("=== Test: CrossSection ===") from backend.nodes.cross_section import CrossSection @@ -1167,6 +1388,93 @@ def test_particle_analysis(): print(" PASS\n") +def test_grain_distance_transform(): + print("=== Test: GrainDistanceTransform ===") + from backend.nodes.grain_distance_transform import GrainDistanceTransform + + node = GrainDistanceTransform() + field = make_field(data=np.zeros((7, 7), dtype=np.float64), xreal=7.0, yreal=7.0) + mask = np.zeros((7, 7), dtype=np.uint8) + mask[2:5, 2:5] = 255 + + interior, = node.process(field, mask, distance_type="euclidean", output_type="interior", from_border=True) + assert interior.data.shape == field.data.shape + assert interior.si_unit_z == field.si_unit_xy + assert np.isclose(interior.data[3, 3], 2.0) + assert np.isclose(interior.data[2, 2], 1.0) + assert np.isclose(interior.data[0, 0], 0.0) + + exterior, = node.process(field, mask, distance_type="cityblock", output_type="exterior", from_border=True) + assert np.isclose(exterior.data[1, 1], 2.0) + assert np.isclose(exterior.data[2, 1], 1.0) + assert np.isclose(exterior.data[3, 3], 0.0) + + signed, = node.process(field, mask, distance_type="chess", output_type="signed", from_border=True) + assert signed.data[3, 3] > 0.0 + assert signed.data[0, 0] < 0.0 + + edge_field = make_field(data=np.zeros((5, 5), dtype=np.float64), xreal=5.0, yreal=5.0) + edge_mask = np.zeros((5, 5), dtype=np.uint8) + edge_mask[:, :2] = 255 + from_edge, = node.process(edge_field, edge_mask, distance_type="euclidean", output_type="interior", from_border=True) + not_from_edge, = node.process(edge_field, edge_mask, distance_type="euclidean", output_type="interior", from_border=False) + assert not_from_edge.data[2, 0] > from_edge.data[2, 0] + print(" PASS\n") + + +def test_watershed_segmentation(): + print("=== Test: WatershedSegmentation ===") + from scipy.ndimage import label + from backend.execution_context import active_node, execution_callbacks + from backend.nodes.watershed_segmentation import WatershedSegmentation + + node = WatershedSegmentation() + y, x = np.mgrid[-1:1:64j, -1:1:64j] + data = ( + 2.0 * np.exp(-((x + 0.45) ** 2 + y**2) / 0.05) + + 2.0 * np.exp(-((x - 0.45) ** 2 + y**2) / 0.05) + - 0.3 * np.exp(-(x**2 + y**2) / 0.12) + ) + field = make_field(data=data, xreal=2.0e-6, yreal=2.0e-6) + + previews = [] + with execution_callbacks(preview=lambda nid, uri: previews.append(uri)), active_node("test"): + mask, = node.process( + field, + invert_height=False, + locate_steps=10, + locate_threshold=8, + locate_drop_size=0.1, + watershed_steps=20, + watershed_drop_size=0.1, + combine_mode="replace", + ) + assert mask.dtype == np.uint8 + assert mask.shape == field.data.shape + assert len(previews) == 1 + assert previews[0].startswith("data:image/png;base64,") + + _, ngrains = label(mask > 127) + assert ngrains >= 2 + + seed_mask = np.zeros_like(mask) + seed_mask[:, :32] = 255 + intersected, = node.process( + field, + invert_height=False, + locate_steps=10, + locate_threshold=8, + locate_drop_size=0.1, + watershed_steps=20, + watershed_drop_size=0.1, + combine_mode="intersection", + mask=seed_mask, + ) + assert np.count_nonzero(intersected) < np.count_nonzero(mask) + assert np.all(intersected[:, 40:] == 0) + print(" PASS\n") + + # ========================================================================= # I/O # ========================================================================= @@ -2814,6 +3122,7 @@ if __name__ == "__main__": test_plane_level() test_poly_level() test_fix_zero() + test_curvature() test_line_correction() test_scar_removal() test_angle_measure() @@ -2821,6 +3130,7 @@ if __name__ == "__main__": # Analysis test_statistics() test_height_histogram() + test_fractal_dimension() test_cross_section() test_line_cursors() test_fft2d()