remaining med value features
This commit is contained in:
@@ -27,18 +27,18 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
|
|||||||
|
|
||||||
| # | Feature | Gwyddion Source | Description |
|
| # | Feature | Gwyddion Source | Description |
|
||||||
|---|---------|---------------|-------------|
|
|---|---------|---------------|-------------|
|
||||||
| 15 | Correlation / Pattern Matching | crosscor.c, maskcor.c | Find repeated features or align images via cross-correlation. |
|
| ~~15~~ | ~~Correlation / Pattern Matching~~ | ~~crosscor.c, maskcor.c~~ | ~~Find repeated features or align images via cross-correlation.~~ **DONE** |
|
||||||
| ~~16~~ | ~~Slope Distribution~~ | ~~slope_dist.c~~ | ~~Angular histogram of surface slopes. Characterizes surface texture directionality.~~ **DONE** |
|
| ~~16~~ | ~~Slope Distribution~~ | ~~slope_dist.c~~ | ~~Angular histogram of surface slopes. Characterizes surface texture directionality.~~ **DONE** |
|
||||||
| ~~17~~ | ~~Grain Filtering~~ | ~~grain_filter.c~~ | ~~Remove grains by size, height, or border contact. Refine grain masks post-detection.~~ **DONE** |
|
| ~~17~~ | ~~Grain Filtering~~ | ~~grain_filter.c~~ | ~~Remove grains by size, height, or border contact. Refine grain masks post-detection.~~ **DONE** |
|
||||||
| ~~18~~ | ~~Field Arithmetic~~ | ~~arithmetic.c~~ | ~~Add/subtract/multiply/divide two DATA_FIELDs. Useful for difference maps, normalization.~~ **DONE** |
|
| ~~18~~ | ~~Field Arithmetic~~ | ~~arithmetic.c~~ | ~~Add/subtract/multiply/divide two DATA_FIELDs. Useful for difference maps, normalization.~~ **DONE** |
|
||||||
| 19 | Spot Removal | spotremove.c | Interpolate over selected point defects (dust, spikes). |
|
| ~~19~~ | ~~Spot Removal~~ | ~~spotremove.c~~ | ~~Interpolate over selected point defects (dust, spikes).~~ **DONE** |
|
||||||
| ~~20~~ | ~~Tip Modeling / Deconvolution~~ | ~~tip_blind.c, tip_model.c~~ | ~~Estimate tip shape from image, deconvolve to recover true surface.~~ **DONE** |
|
| ~~20~~ | ~~Tip Modeling / Deconvolution~~ | ~~tip_blind.c, tip_model.c~~ | ~~Estimate tip shape from image, deconvolve to recover true surface.~~ **DONE** |
|
||||||
| ~~21~~ | ~~Radial Profile~~ | ~~rprofile tool~~ | ~~Azimuthally averaged profile from a center point. Good for circular features.~~ **DONE** |
|
| ~~21~~ | ~~Radial Profile~~ | ~~rprofile tool~~ | ~~Azimuthally averaged profile from a center point. Good for circular features.~~ **DONE** |
|
||||||
| 22 | Wavelet Transform | dwt.c, cwt.c | Discrete/continuous wavelet analysis. Multi-scale roughness decomposition. |
|
| ~~22~~ | ~~Wavelet Transform~~ | ~~dwt.c, cwt.c~~ | ~~Discrete/continuous wavelet analysis. Multi-scale roughness decomposition.~~ **DONE** |
|
||||||
| ~~23~~ | ~~Scale / Resample~~ | ~~scale.c, resample.c~~ | ~~Resize fields with interpolation.~~ **DONE** |
|
| ~~23~~ | ~~Scale / Resample~~ | ~~scale.c, resample.c~~ | ~~Resize fields with interpolation.~~ **DONE** |
|
||||||
| ~~24~~ | ~~Gradient~~ | ~~gradient.c~~ | ~~Compute x/y gradient magnitude maps.~~ **DONE** |
|
| ~~24~~ | ~~Gradient~~ | ~~gradient.c~~ | ~~Compute x/y gradient magnitude maps.~~ **DONE** |
|
||||||
| 25 | Custom Convolution | convolution_filter.c | User-defined kernel convolution. |
|
| ~~25~~ | ~~Custom Convolution~~ | ~~convolution_filter.c~~ | ~~User-defined kernel convolution.~~ **DONE** |
|
||||||
| 26 | Local Contrast Enhancement | local_contrast.c | Enhance visibility of local features in images. |
|
| ~~26~~ | ~~Local Contrast Enhancement~~ | ~~local_contrast.c~~ | ~~Enhance visibility of local features in images.~~ **DONE** |
|
||||||
|
|
||||||
## Lower Priority
|
## Lower Priority
|
||||||
|
|
||||||
@@ -53,11 +53,11 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
|
|||||||
| 33 | Facet Analysis | facet_analysis.c | Orientation distribution of surface facets (stereographic projection). |
|
| 33 | Facet Analysis | facet_analysis.c | Orientation distribution of surface facets (stereographic projection). |
|
||||||
| 34 | Shape Fitting | fit-shape.c | Fit geometric primitives: sphere, paraboloid, cylinder, etc. |
|
| 34 | Shape Fitting | fit-shape.c | Fit geometric primitives: sphere, paraboloid, cylinder, etc. |
|
||||||
| 35 | Synthetic Surface Generation | *_synth.c (~20 modules) | Generate test surfaces: FBM, noise, lattice, waves, particles, fibers, etc. |
|
| 35 | Synthetic Surface Generation | *_synth.c (~20 modules) | Generate test surfaces: FBM, noise, lattice, waves, particles, fibers, etc. |
|
||||||
| 36 | Entropy | entropy.c | Information entropy of height distribution. |
|
| ~~36~~ | ~~Entropy~~ | ~~entropy.c~~ | ~~Information entropy of height distribution.~~ **DONE** |
|
||||||
| 37 | Indentation Analysis | indent_analyze.c, hertz.c | Nanoindentation curve fitting (Hertz model). |
|
| 37 | Indentation Analysis | indent_analyze.c, hertz.c | Nanoindentation curve fitting (Hertz model). |
|
||||||
| 38 | Deconvolution | deconvolve.c | Blind/regularized deconvolution for image restoration. |
|
| 38 | Deconvolution | deconvolve.c | Blind/regularized deconvolution for image restoration. |
|
||||||
| 39 | Canny / Harris Detection | filters.c | Corner and edge feature detection beyond basic Sobel/Prewitt. |
|
| 39 | Canny / Harris Detection | filters.c | Corner and edge feature detection beyond basic Sobel/Prewitt. |
|
||||||
| 40 | Kuwahara Filter | filters.c | Edge-preserving smoothing filter. |
|
| ~~40~~ | ~~Kuwahara Filter~~ | ~~filters.c~~ | ~~Edge-preserving smoothing filter.~~ **DONE** |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -69,4 +69,13 @@ from backend.nodes import (
|
|||||||
tip_model,
|
tip_model,
|
||||||
tip_deconvolution,
|
tip_deconvolution,
|
||||||
tip_blind_estimate,
|
tip_blind_estimate,
|
||||||
|
# New: remaining Gwyddion feature-gap nodes
|
||||||
|
filter_kuwahara,
|
||||||
|
entropy,
|
||||||
|
local_contrast,
|
||||||
|
filter_custom,
|
||||||
|
spot_removal,
|
||||||
|
wavelet_denoise,
|
||||||
|
cross_correlate,
|
||||||
|
template_match,
|
||||||
)
|
)
|
||||||
|
|||||||
68
backend/nodes/cross_correlate.py
Normal file
68
backend/nodes/cross_correlate.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from backend.data_types import DataField
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Cross-Correlate")
|
||||||
|
class CrossCorrelate:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field_a": ("DATA_FIELD",),
|
||||||
|
"field_b": ("DATA_FIELD",),
|
||||||
|
"mode": (["full", "same", "valid"], {"default": "same"}),
|
||||||
|
"normalize": ("BOOLEAN", {"default": True}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('DATA_FIELD', 'correlation'),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Compute 2D cross-correlation between two fields. The correlation peak indicates "
|
||||||
|
"the offset where the two fields best match. Useful for drift measurement and feature "
|
||||||
|
"alignment. Equivalent to Gwyddion crosscor.c."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
field_a: DataField,
|
||||||
|
field_b: DataField,
|
||||||
|
mode: str,
|
||||||
|
normalize: bool,
|
||||||
|
) -> tuple:
|
||||||
|
from scipy.signal import fftconvolve
|
||||||
|
|
||||||
|
a = field_a.data - field_a.data.mean()
|
||||||
|
b = field_b.data - field_b.data.mean()
|
||||||
|
|
||||||
|
# Cross-correlation via FFT: correlate(a,b) = ifft(fft(a) * conj(fft(b)))
|
||||||
|
# Achieved by convolving a with the flipped b
|
||||||
|
corr = fftconvolve(a, b[::-1, ::-1], mode=mode)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
denom = np.sqrt((a ** 2).sum() * (b ** 2).sum())
|
||||||
|
if denom > 0:
|
||||||
|
corr = corr / denom
|
||||||
|
|
||||||
|
if mode == "same":
|
||||||
|
# Output is the same shape as field_a — reuse its physical dimensions
|
||||||
|
return (field_a.replace(data=corr),)
|
||||||
|
|
||||||
|
# For "full" mode: output shape is (Na+Nb-1, Ma+Mb-1)
|
||||||
|
# Scale physical dimensions proportionally
|
||||||
|
na, ma = field_a.data.shape
|
||||||
|
nb, mb = field_b.data.shape
|
||||||
|
out_y, out_x = corr.shape
|
||||||
|
|
||||||
|
# Physical size per pixel stays the same as field_a; total physical size scales
|
||||||
|
new_xreal = field_a.xreal * out_x / ma if ma > 0 else field_a.xreal
|
||||||
|
new_yreal = field_a.yreal * out_y / na if na > 0 else field_a.yreal
|
||||||
|
|
||||||
|
return (field_a.replace(data=corr, xreal=new_xreal, yreal=new_yreal),)
|
||||||
74
backend/nodes/entropy.py
Normal file
74
backend/nodes/entropy.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField, RecordTable
|
||||||
|
from backend.execution_context import emit_table
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Entropy")
|
||||||
|
class Entropy:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"mode": (["height values", "slope magnitude"], {"default": "height values"}),
|
||||||
|
"n_bins": ("INT", {"default": 256, "min": 16, "max": 1024}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('FLOAT', 'entropy'),
|
||||||
|
('FLOAT', 'normalised_entropy'),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Shannon entropy of the height or slope distribution. "
|
||||||
|
"H = -\u03a3 p\u00b7ln(p). Equivalent to Gwyddion entropy.c."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, mode: str, n_bins: int) -> tuple:
|
||||||
|
n_bins = max(16, int(n_bins))
|
||||||
|
data = np.asarray(field.data, dtype=np.float64)
|
||||||
|
|
||||||
|
if mode == "slope magnitude":
|
||||||
|
# Compute slope magnitude from Sobel-like finite differences.
|
||||||
|
# Central differences along x and y (axis=1 and axis=0).
|
||||||
|
# np.gradient uses central differences, same spirit as Sobel.
|
||||||
|
dy, dx = np.gradient(data)
|
||||||
|
values = np.hypot(dx, dy).ravel()
|
||||||
|
else:
|
||||||
|
values = data.ravel()
|
||||||
|
|
||||||
|
# Remove non-finite values before binning.
|
||||||
|
values = values[np.isfinite(values)]
|
||||||
|
|
||||||
|
if values.size == 0:
|
||||||
|
h = 0.0
|
||||||
|
h_norm = 0.0
|
||||||
|
else:
|
||||||
|
counts, _ = np.histogram(values, bins=n_bins)
|
||||||
|
total = counts.sum()
|
||||||
|
if total == 0:
|
||||||
|
h = 0.0
|
||||||
|
h_norm = 0.0
|
||||||
|
else:
|
||||||
|
# Probability distribution; skip zero bins.
|
||||||
|
p = counts[counts > 0].astype(np.float64) / float(total)
|
||||||
|
h = float(-np.sum(p * np.log(p)))
|
||||||
|
# Maximum possible entropy for n_bins equally occupied bins is ln(n_bins).
|
||||||
|
h_max = float(np.log(n_bins))
|
||||||
|
h_norm = h / h_max if h_max > 0.0 else 0.0
|
||||||
|
|
||||||
|
table = RecordTable([
|
||||||
|
{"quantity": "entropy", "value": h, "unit": "nat"},
|
||||||
|
{"quantity": "normalised entropy", "value": h_norm, "unit": ""},
|
||||||
|
{"quantity": "mode", "value": mode, "unit": ""},
|
||||||
|
{"quantity": "n_bins", "value": n_bins, "unit": ""},
|
||||||
|
])
|
||||||
|
emit_table(table)
|
||||||
|
|
||||||
|
return (h, h_norm)
|
||||||
127
backend/nodes/filter_custom.py
Normal file
127
backend/nodes/filter_custom.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField
|
||||||
|
from backend.execution_context import emit_warning
|
||||||
|
|
||||||
|
_DEFAULT_KERNEL = "0 -1 0\n-1 5 -1\n0 -1 0"
|
||||||
|
_MAX_KERNEL_DIM = 51
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_kernel(kernel_str: str) -> np.ndarray | None:
|
||||||
|
"""Parse a multi-line kernel string into a 2-D float64 array.
|
||||||
|
|
||||||
|
Returns *None* and issues a warning via emit_warning if the string is
|
||||||
|
invalid. The returned array is always at least 1×1 and at most
|
||||||
|
_MAX_KERNEL_DIM × _MAX_KERNEL_DIM.
|
||||||
|
"""
|
||||||
|
lines = [ln for ln in kernel_str.splitlines() if ln.strip()]
|
||||||
|
if not lines:
|
||||||
|
emit_warning("Custom Convolution: kernel string is empty. Using identity.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for ln in lines:
|
||||||
|
try:
|
||||||
|
row = [float(v) for v in ln.split()]
|
||||||
|
except ValueError:
|
||||||
|
emit_warning(
|
||||||
|
f"Custom Convolution: could not parse kernel row {ln!r}. "
|
||||||
|
"Input returned unchanged."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
if not row:
|
||||||
|
continue
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
emit_warning("Custom Convolution: kernel has no valid rows. Input returned unchanged.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# All rows must have the same length.
|
||||||
|
ncols = len(rows[0])
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
if len(row) != ncols:
|
||||||
|
emit_warning(
|
||||||
|
f"Custom Convolution: row {i} has {len(row)} values but row 0 has {ncols}. "
|
||||||
|
"All rows must be the same length. Input returned unchanged."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
arr = np.array(rows, dtype=np.float64)
|
||||||
|
|
||||||
|
if arr.ndim != 2 or arr.size == 0:
|
||||||
|
emit_warning("Custom Convolution: kernel is empty after parsing. Input returned unchanged.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
nrows, ncols = arr.shape
|
||||||
|
if nrows > _MAX_KERNEL_DIM or ncols > _MAX_KERNEL_DIM:
|
||||||
|
emit_warning(
|
||||||
|
f"Custom Convolution: kernel size {nrows}×{ncols} exceeds maximum "
|
||||||
|
f"{_MAX_KERNEL_DIM}×{_MAX_KERNEL_DIM}. Input returned unchanged."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not np.all(np.isfinite(arr)):
|
||||||
|
emit_warning("Custom Convolution: kernel contains non-finite values. Input returned unchanged.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Custom Convolution")
|
||||||
|
class CustomConvolution:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"kernel": ("STRING", {
|
||||||
|
"multiline": True,
|
||||||
|
"default": _DEFAULT_KERNEL,
|
||||||
|
"placeholder": "kernel rows, space-separated",
|
||||||
|
}),
|
||||||
|
"normalize": ("BOOLEAN", {"default": True}),
|
||||||
|
"boundary": (["reflect", "nearest", "wrap"], {"default": "reflect"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('DATA_FIELD', 'result'),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Apply a user-defined convolution kernel. "
|
||||||
|
"Enter rows of space-separated numbers. "
|
||||||
|
"Example sharpen: '0 -1 0 / -1 5 -1 / 0 -1 0' (use newlines, not slashes). "
|
||||||
|
"Equivalent to Gwyddion convolution_filter.c."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
field: DataField,
|
||||||
|
kernel: str,
|
||||||
|
normalize: bool,
|
||||||
|
boundary: str,
|
||||||
|
) -> tuple:
|
||||||
|
from scipy.ndimage import convolve
|
||||||
|
|
||||||
|
kernel_arr = _parse_kernel(kernel)
|
||||||
|
if kernel_arr is None:
|
||||||
|
# Fallback: return input unchanged.
|
||||||
|
return (field.replace(data=field.data.copy()),)
|
||||||
|
|
||||||
|
data = np.asarray(field.data, dtype=np.float64)
|
||||||
|
|
||||||
|
# scipy.ndimage.convolve boundary mode names match our choices directly.
|
||||||
|
result = convolve(data, kernel_arr, mode=boundary)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
abs_sum = float(np.sum(np.abs(kernel_arr)))
|
||||||
|
if abs_sum > 0.0:
|
||||||
|
result = result / abs_sum
|
||||||
|
|
||||||
|
return (field.replace(data=result),)
|
||||||
86
backend/nodes/filter_kuwahara.py
Normal file
86
backend/nodes/filter_kuwahara.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField
|
||||||
|
|
||||||
|
|
||||||
|
def _kuwahara_pass(data: np.ndarray) -> np.ndarray:
|
||||||
|
"""Single pass of the 5x5 Kuwahara filter.
|
||||||
|
|
||||||
|
Divides a 5x5 neighbourhood around each pixel into four overlapping 3x3
|
||||||
|
quadrants (TL, TR, BL, BR), computes the mean and variance of each quadrant,
|
||||||
|
and replaces the centre pixel with the mean of the quadrant that has the
|
||||||
|
smallest variance. Boundary pixels are handled by reflecting the image.
|
||||||
|
"""
|
||||||
|
# Pad with reflect so every pixel has a full 5x5 neighbourhood.
|
||||||
|
padded = np.pad(data, pad_width=2, mode="reflect")
|
||||||
|
|
||||||
|
rows, cols = data.shape
|
||||||
|
|
||||||
|
# For each of the four 3x3 quadrant offsets we need the per-pixel mean and
|
||||||
|
# variance. The quadrant window positions (relative to the padded array)
|
||||||
|
# for a centre pixel at (r+2, c+2) are:
|
||||||
|
# TL : rows r..r+2, cols c..c+2 → offset (0,0) in padded
|
||||||
|
# TR : rows r..r+2, cols c+2..c+4 → offset (0,2)
|
||||||
|
# BL : rows r+2..r+4, cols c..c+2 → offset (2,0)
|
||||||
|
# BR : rows r+2..r+4, cols c+2..c+4 → offset (2,2)
|
||||||
|
# Each window is 3×3 = 9 pixels.
|
||||||
|
|
||||||
|
quadrant_offsets = [(0, 0), (0, 2), (2, 0), (2, 2)]
|
||||||
|
n = 9 # 3×3 quadrant
|
||||||
|
|
||||||
|
# Accumulate sum and sum-of-squares for each quadrant using a simple nested
|
||||||
|
# index loop over the 3×3 window positions.
|
||||||
|
quad_sum = np.zeros((4, rows, cols), dtype=np.float64)
|
||||||
|
quad_sum2 = np.zeros((4, rows, cols), dtype=np.float64)
|
||||||
|
|
||||||
|
for qi, (dr0, dc0) in enumerate(quadrant_offsets):
|
||||||
|
for drow in range(3):
|
||||||
|
for dcol in range(3):
|
||||||
|
patch = padded[dr0 + drow: dr0 + drow + rows,
|
||||||
|
dc0 + dcol: dc0 + dcol + cols]
|
||||||
|
quad_sum[qi] += patch
|
||||||
|
quad_sum2[qi] += patch * patch
|
||||||
|
|
||||||
|
quad_mean = quad_sum / n
|
||||||
|
# var = E[x^2] - (E[x])^2, clamped to 0 to avoid floating-point negatives.
|
||||||
|
quad_var = np.maximum(quad_sum2 / n - quad_mean * quad_mean, 0.0)
|
||||||
|
|
||||||
|
# Select the quadrant index with minimum variance for each pixel.
|
||||||
|
best_qi = np.argmin(quad_var, axis=0) # shape (rows, cols)
|
||||||
|
|
||||||
|
# Gather the mean from the winning quadrant.
|
||||||
|
result = np.choose(best_qi, quad_mean)
|
||||||
|
return result.astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Kuwahara Filter")
|
||||||
|
class KuwaharaFilter:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"iterations": ("INT", {"default": 1, "min": 1, "max": 20}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('DATA_FIELD', 'filtered'),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Edge-preserving smoothing using Kuwahara's minimum-variance quadrant method. "
|
||||||
|
"Unlike Gaussian blur, sharp boundaries are preserved. "
|
||||||
|
"Equivalent to Gwyddion's Kuwahara filter."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, iterations: int) -> tuple:
|
||||||
|
data = np.asarray(field.data, dtype=np.float64)
|
||||||
|
iterations = max(1, int(iterations))
|
||||||
|
for _ in range(iterations):
|
||||||
|
data = _kuwahara_pass(data)
|
||||||
|
return (field.replace(data=data),)
|
||||||
64
backend/nodes/local_contrast.py
Normal file
64
backend/nodes/local_contrast.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Local Contrast")
|
||||||
|
class LocalContrast:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"kernel_size": ("INT", {"default": 10, "min": 2, "max": 100}),
|
||||||
|
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('DATA_FIELD', 'result'),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Expand the local dynamic range at each pixel. "
|
||||||
|
"Reveals fine surface features that are hidden by global contrast range. "
|
||||||
|
"Equivalent to Gwyddion local_contrast.c."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, field: DataField, kernel_size: int, weight: float) -> tuple:
|
||||||
|
from scipy.ndimage import minimum_filter, maximum_filter
|
||||||
|
|
||||||
|
data = np.asarray(field.data, dtype=np.float64)
|
||||||
|
kernel_size = max(2, int(kernel_size))
|
||||||
|
weight = float(np.clip(weight, 0.0, 1.0))
|
||||||
|
|
||||||
|
local_min = minimum_filter(data, size=kernel_size, mode="reflect")
|
||||||
|
local_max = maximum_filter(data, size=kernel_size, mode="reflect")
|
||||||
|
|
||||||
|
global_min = float(data.min())
|
||||||
|
global_max = float(data.max())
|
||||||
|
|
||||||
|
local_range = local_max - local_min
|
||||||
|
eps = np.finfo(np.float64).eps * max(abs(global_max), abs(global_min), 1.0)
|
||||||
|
|
||||||
|
# Remap each pixel from its local range to the global range.
|
||||||
|
# Where local_range is near zero, the pixel is already flat: leave it
|
||||||
|
# unchanged (enhancement factor = 1).
|
||||||
|
safe_range = np.where(local_range > eps, local_range, 1.0)
|
||||||
|
global_span = global_max - global_min
|
||||||
|
if global_span <= eps:
|
||||||
|
# Uniform field – nothing to enhance.
|
||||||
|
return (field.replace(data=data.copy()),)
|
||||||
|
|
||||||
|
enhancement_factor = global_span / safe_range
|
||||||
|
# Locally enhanced value: remap v from [local_min, local_max] → [global_min, global_max]
|
||||||
|
v_enhanced = global_min + enhancement_factor * (data - local_min)
|
||||||
|
|
||||||
|
# Blend between original and enhanced.
|
||||||
|
result = (1.0 - weight) * data + weight * v_enhanced
|
||||||
|
|
||||||
|
return (field.replace(data=result),)
|
||||||
122
backend/nodes/spot_removal.py
Normal file
122
backend/nodes/spot_removal.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from backend.data_types import DataField
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Spot Removal")
|
||||||
|
class SpotRemoval:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"method": (["laplace", "mean", "zero"], {"default": "laplace"}),
|
||||||
|
"max_iter": ("INT", {"default": 100, "min": 1, "max": 2000, "step": 1}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"mask": ("IMAGE", {"label": "defects"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('DATA_FIELD', 'result'),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Fill defect pixels (hot pixels, dropouts, scan artifacts) by interpolation. "
|
||||||
|
"The mask defines defect locations. Laplace method solves the 2D Laplace equation "
|
||||||
|
"for smooth inpainting. Equivalent to Gwyddion spotremove.c."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
field: DataField,
|
||||||
|
method: str,
|
||||||
|
max_iter: int,
|
||||||
|
mask: np.ndarray | None = None,
|
||||||
|
) -> tuple:
|
||||||
|
if mask is None:
|
||||||
|
return (field,)
|
||||||
|
|
||||||
|
mask_array = np.asarray(mask)
|
||||||
|
# Reshape mask to match field shape if it has extra dimensions (e.g. HxWx1)
|
||||||
|
if mask_array.ndim == 3:
|
||||||
|
mask_array = mask_array[:, :, 0]
|
||||||
|
if mask_array.shape != field.data.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Mask shape {mask_array.shape} does not match field shape {field.data.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
defect = mask_array > 0
|
||||||
|
|
||||||
|
if not np.any(defect):
|
||||||
|
return (field,)
|
||||||
|
|
||||||
|
data = np.asarray(field.data, dtype=np.float64)
|
||||||
|
|
||||||
|
if method == "zero":
|
||||||
|
result = data.copy()
|
||||||
|
result[defect] = 0.0
|
||||||
|
return (field.replace(data=result),)
|
||||||
|
|
||||||
|
if method == "mean":
|
||||||
|
result = _mean_fill(data, defect)
|
||||||
|
return (field.replace(data=result),)
|
||||||
|
|
||||||
|
# method == "laplace"
|
||||||
|
result = _laplace_fill(data, defect, int(max_iter))
|
||||||
|
return (field.replace(data=result),)
|
||||||
|
|
||||||
|
|
||||||
|
def _mean_fill(data: np.ndarray, defect: np.ndarray) -> np.ndarray:
|
||||||
|
"""Fill defect pixels with the mean of non-defect neighbours in a 3x3 window."""
|
||||||
|
result = data.copy()
|
||||||
|
yres, xres = data.shape
|
||||||
|
|
||||||
|
# Global fallback: mean of all non-defect pixels
|
||||||
|
non_defect_vals = data[~defect]
|
||||||
|
global_mean = float(non_defect_vals.mean()) if non_defect_vals.size > 0 else 0.0
|
||||||
|
|
||||||
|
defect_coords = np.argwhere(defect)
|
||||||
|
for y, x in defect_coords:
|
||||||
|
y0 = max(y - 1, 0)
|
||||||
|
y1 = min(y + 2, yres)
|
||||||
|
x0 = max(x - 1, 0)
|
||||||
|
x1 = min(x + 2, xres)
|
||||||
|
|
||||||
|
neighbourhood_data = data[y0:y1, x0:x1]
|
||||||
|
neighbourhood_defect = defect[y0:y1, x0:x1]
|
||||||
|
good = neighbourhood_data[~neighbourhood_defect]
|
||||||
|
|
||||||
|
if good.size > 0:
|
||||||
|
result[y, x] = float(good.mean())
|
||||||
|
else:
|
||||||
|
result[y, x] = global_mean
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _laplace_fill(data: np.ndarray, defect: np.ndarray, max_iter: int) -> np.ndarray:
|
||||||
|
"""Iterative Laplace solver: set defect pixels to neighbour average each iteration."""
|
||||||
|
non_defect_vals = data[~defect]
|
||||||
|
init_val = float(non_defect_vals.mean()) if non_defect_vals.size > 0 else 0.0
|
||||||
|
|
||||||
|
result = data.copy()
|
||||||
|
result[defect] = init_val
|
||||||
|
|
||||||
|
for _ in range(max_iter):
|
||||||
|
# Compute neighbour averages via rolled arrays
|
||||||
|
neighbour_sum = (
|
||||||
|
np.roll(result, -1, axis=0)
|
||||||
|
+ np.roll(result, 1, axis=0)
|
||||||
|
+ np.roll(result, -1, axis=1)
|
||||||
|
+ np.roll(result, 1, axis=1)
|
||||||
|
)
|
||||||
|
new_vals = neighbour_sum / 4.0
|
||||||
|
result[defect] = new_vals[defect]
|
||||||
|
|
||||||
|
return result
|
||||||
52
backend/nodes/template_match.py
Normal file
52
backend/nodes/template_match.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from backend.data_types import DataField
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Template Match")
|
||||||
|
class TemplateMatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("DATA_FIELD",),
|
||||||
|
"template": ("DATA_FIELD",),
|
||||||
|
"threshold": (
|
||||||
|
"FLOAT",
|
||||||
|
{"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('DATA_FIELD', 'score'),
|
||||||
|
('IMAGE', 'detections'),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Find a template pattern within a larger data field using normalised cross-correlation. "
|
||||||
|
"The score output shows match quality (1 = perfect match). Detections mask marks positions "
|
||||||
|
"above the threshold. Equivalent to Gwyddion maskcor.c."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
image: DataField,
|
||||||
|
template: DataField,
|
||||||
|
threshold: float,
|
||||||
|
) -> tuple:
|
||||||
|
from skimage.feature import match_template
|
||||||
|
|
||||||
|
score = match_template(image.data, template.data, pad_input=True)
|
||||||
|
|
||||||
|
# Clip to [0, 1] for display (match_template returns values in [-1, 1])
|
||||||
|
score_clipped = np.clip(score, 0.0, 1.0)
|
||||||
|
|
||||||
|
detections = (score_clipped >= float(threshold)).astype(np.uint8) * 255
|
||||||
|
|
||||||
|
score_field = image.replace(data=score_clipped)
|
||||||
|
return (score_field, detections)
|
||||||
74
backend/nodes/wavelet_denoise.py
Normal file
74
backend/nodes/wavelet_denoise.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from backend.data_types import DataField
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="Wavelet Denoise")
|
||||||
|
class WaveletDenoise:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"field": ("DATA_FIELD",),
|
||||||
|
"wavelet": (
|
||||||
|
["db1", "db2", "db4", "db8", "sym4", "coif1", "bior1.3"],
|
||||||
|
{"default": "db4"},
|
||||||
|
),
|
||||||
|
"method": (["BayesShrink", "VisuShrink"], {"default": "BayesShrink"}),
|
||||||
|
"sigma": (
|
||||||
|
"FLOAT",
|
||||||
|
{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01},
|
||||||
|
),
|
||||||
|
"mode": (["soft", "hard"], {"default": "soft"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('DATA_FIELD', 'denoised'),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Denoise using wavelet coefficient thresholding. BayesShrink adapts the threshold "
|
||||||
|
"per sub-band; VisuShrink uses a global threshold. Equivalent to applying DWT from "
|
||||||
|
"Gwyddion dwt.c with coefficient thresholding."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
field: DataField,
|
||||||
|
wavelet: str,
|
||||||
|
method: str,
|
||||||
|
sigma: float,
|
||||||
|
mode: str,
|
||||||
|
) -> tuple:
|
||||||
|
from skimage.restoration import denoise_wavelet
|
||||||
|
|
||||||
|
d = field.data
|
||||||
|
dmin = float(d.min())
|
||||||
|
drange = float(np.ptp(d))
|
||||||
|
|
||||||
|
if drange == 0:
|
||||||
|
return (field,)
|
||||||
|
|
||||||
|
norm = (d - dmin) / drange
|
||||||
|
sigma_val = float(sigma) if sigma > 0 else None
|
||||||
|
|
||||||
|
# `mode` is a Python builtin name; use threshold_mode locally to avoid shadowing
|
||||||
|
threshold_mode = mode
|
||||||
|
|
||||||
|
denoised_norm = denoise_wavelet(
|
||||||
|
norm,
|
||||||
|
wavelet=wavelet,
|
||||||
|
method=method,
|
||||||
|
mode=threshold_mode,
|
||||||
|
sigma=sigma_val,
|
||||||
|
rescale_sigma=True,
|
||||||
|
channel_axis=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = denoised_norm * drange + dmin
|
||||||
|
return (field.replace(data=result),)
|
||||||
@@ -17,6 +17,7 @@ dependencies = [
|
|||||||
"nanonispy>=1.1",
|
"nanonispy>=1.1",
|
||||||
"numpy>=1.26,<3",
|
"numpy>=1.26,<3",
|
||||||
"pillow>=10,<12",
|
"pillow>=10,<12",
|
||||||
|
"pywavelets>=1.8.0",
|
||||||
"scikit-image>=0.22,<1",
|
"scikit-image>=0.22,<1",
|
||||||
"scipy>=1.12,<2",
|
"scipy>=1.12,<2",
|
||||||
]
|
]
|
||||||
|
|||||||
89
tests/node_tests/cross_correlate.py
Normal file
89
tests/node_tests/cross_correlate.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_correlate_same_field_peak_at_center():
|
||||||
|
"""Correlating a field with itself in 'same' mode peaks at the centre."""
|
||||||
|
from backend.nodes.cross_correlate import CrossCorrelate
|
||||||
|
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
data = rng.standard_normal((32, 32))
|
||||||
|
field = make_field(data=data)
|
||||||
|
node = CrossCorrelate()
|
||||||
|
result, = node.process(field, field, mode="same", normalize=True)
|
||||||
|
|
||||||
|
peak_y, peak_x = np.unravel_index(np.argmax(result.data), result.data.shape)
|
||||||
|
cy, cx = result.data.shape[0] // 2, result.data.shape[1] // 2
|
||||||
|
# Peak should be within a few pixels of centre
|
||||||
|
assert abs(peak_y - cy) <= 2
|
||||||
|
assert abs(peak_x - cx) <= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_correlate_same_mode_shape_equals_a():
|
||||||
|
from backend.nodes.cross_correlate import CrossCorrelate
|
||||||
|
|
||||||
|
rng = np.random.default_rng(1)
|
||||||
|
a = make_field(data=rng.standard_normal((32, 48)))
|
||||||
|
b = make_field(data=rng.standard_normal((32, 48)))
|
||||||
|
node = CrossCorrelate()
|
||||||
|
result, = node.process(a, b, mode="same", normalize=True)
|
||||||
|
assert result.data.shape == a.data.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_correlate_full_mode_shape():
|
||||||
|
"""Full mode output shape should be Na+Nb-1 × Ma+Mb-1."""
|
||||||
|
from backend.nodes.cross_correlate import CrossCorrelate
|
||||||
|
|
||||||
|
rng = np.random.default_rng(2)
|
||||||
|
a = make_field(data=rng.standard_normal((20, 30)))
|
||||||
|
b = make_field(data=rng.standard_normal((20, 30)))
|
||||||
|
node = CrossCorrelate()
|
||||||
|
result, = node.process(a, b, mode="full", normalize=True)
|
||||||
|
assert result.data.shape == (20 + 20 - 1, 30 + 30 - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_correlate_normalized_peak_is_one():
|
||||||
|
"""Self-correlation normalised should give peak = 1."""
|
||||||
|
from backend.nodes.cross_correlate import CrossCorrelate
|
||||||
|
|
||||||
|
rng = np.random.default_rng(3)
|
||||||
|
data = rng.standard_normal((32, 32))
|
||||||
|
field = make_field(data=data)
|
||||||
|
node = CrossCorrelate()
|
||||||
|
result, = node.process(field, field, mode="same", normalize=True)
|
||||||
|
assert result.data.max() == pytest.approx(1.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_correlate_unnormalized_runs():
|
||||||
|
from backend.nodes.cross_correlate import CrossCorrelate
|
||||||
|
|
||||||
|
rng = np.random.default_rng(4)
|
||||||
|
data = rng.standard_normal((16, 16))
|
||||||
|
field = make_field(data=data)
|
||||||
|
node = CrossCorrelate()
|
||||||
|
result, = node.process(field, field, mode="same", normalize=False)
|
||||||
|
assert result.data.shape == (16, 16)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_correlate_valid_mode():
|
||||||
|
from backend.nodes.cross_correlate import CrossCorrelate
|
||||||
|
|
||||||
|
rng = np.random.default_rng(5)
|
||||||
|
a = make_field(data=rng.standard_normal((16, 16)))
|
||||||
|
b = make_field(data=rng.standard_normal((8, 8)))
|
||||||
|
node = CrossCorrelate()
|
||||||
|
result, = node.process(a, b, mode="valid", normalize=True)
|
||||||
|
# Valid mode output: (16-8+1, 16-8+1) = (9, 9)
|
||||||
|
assert result.data.shape == (9, 9)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_correlate_preserves_metadata_same_mode():
|
||||||
|
from backend.nodes.cross_correlate import CrossCorrelate
|
||||||
|
|
||||||
|
rng = np.random.default_rng(6)
|
||||||
|
field = make_field(data=rng.standard_normal((16, 16)))
|
||||||
|
node = CrossCorrelate()
|
||||||
|
result, = node.process(field, field, mode="same", normalize=True)
|
||||||
|
assert result.xreal == field.xreal
|
||||||
|
assert result.yreal == field.yreal
|
||||||
75
tests/node_tests/entropy.py
Normal file
75
tests/node_tests/entropy.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def test_entropy_uniform_field_low():
|
||||||
|
"""A field with a single unique value has zero entropy."""
|
||||||
|
from backend.nodes.entropy import Entropy
|
||||||
|
|
||||||
|
node = Entropy()
|
||||||
|
field = make_field(data=np.full((32, 32), 3.14))
|
||||||
|
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||||
|
# All values fall in one bin → p=1 → H = 0
|
||||||
|
assert h == pytest.approx(0.0, abs=1e-10)
|
||||||
|
assert h_norm == pytest.approx(0.0, abs=1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_entropy_random_field_positive():
|
||||||
|
"""A random field should have positive entropy."""
|
||||||
|
from backend.nodes.entropy import Entropy
|
||||||
|
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
field = make_field(data=rng.standard_normal((64, 64)))
|
||||||
|
node = Entropy()
|
||||||
|
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||||
|
assert h > 0.0
|
||||||
|
assert 0.0 < h_norm <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_entropy_normalised_leq_one():
|
||||||
|
"""Normalised entropy should never exceed 1."""
|
||||||
|
from backend.nodes.entropy import Entropy
|
||||||
|
|
||||||
|
rng = np.random.default_rng(2)
|
||||||
|
field = make_field(data=rng.uniform(0, 1, (64, 64)))
|
||||||
|
node = Entropy()
|
||||||
|
_, h_norm = node.process(field, mode="height values", n_bins=64)
|
||||||
|
assert h_norm <= 1.0 + 1e-12
|
||||||
|
|
||||||
|
|
||||||
|
def test_entropy_slope_mode():
|
||||||
|
"""Slope mode should work and return valid entropy values."""
|
||||||
|
from backend.nodes.entropy import Entropy
|
||||||
|
|
||||||
|
rng = np.random.default_rng(3)
|
||||||
|
field = make_field(data=rng.standard_normal((32, 32)))
|
||||||
|
node = Entropy()
|
||||||
|
h, h_norm = node.process(field, mode="slope magnitude", n_bins=128)
|
||||||
|
assert h > 0.0
|
||||||
|
assert 0.0 <= h_norm <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_entropy_more_uniform_is_higher():
|
||||||
|
"""Uniformly distributed values have higher entropy than a spiked distribution."""
|
||||||
|
from backend.nodes.entropy import Entropy
|
||||||
|
|
||||||
|
rng = np.random.default_rng(4)
|
||||||
|
uniform = rng.uniform(0, 1, (64, 64))
|
||||||
|
spiked = np.zeros((64, 64))
|
||||||
|
spiked[0, 0] = 1.0
|
||||||
|
|
||||||
|
node = Entropy()
|
||||||
|
h_uniform, _ = node.process(make_field(data=uniform), mode="height values", n_bins=64)
|
||||||
|
h_spiked, _ = node.process(make_field(data=spiked), mode="height values", n_bins=64)
|
||||||
|
assert h_uniform > h_spiked
|
||||||
|
|
||||||
|
|
||||||
|
def test_entropy_returns_floats():
|
||||||
|
from backend.nodes.entropy import Entropy
|
||||||
|
|
||||||
|
field = make_field()
|
||||||
|
node = Entropy()
|
||||||
|
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||||
|
assert isinstance(h, float)
|
||||||
|
assert isinstance(h_norm, float)
|
||||||
95
tests/node_tests/filter_custom.py
Normal file
95
tests/node_tests/filter_custom.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_convolution_identity_kernel():
|
||||||
|
"""[[1]] with normalize=True (abs_sum=1) should return input unchanged."""
|
||||||
|
from backend.nodes.filter_custom import CustomConvolution
|
||||||
|
|
||||||
|
node = CustomConvolution()
|
||||||
|
data = np.random.default_rng(0).standard_normal((32, 32))
|
||||||
|
field = make_field(data=data)
|
||||||
|
result, = node.process(field, kernel="1", normalize=False, boundary="reflect")
|
||||||
|
assert np.allclose(result.data, data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_convolution_uniform_kernel_normalized():
|
||||||
|
"""An all-ones kernel with normalize=True is a box filter (mean filter)."""
|
||||||
|
from backend.nodes.filter_custom import CustomConvolution
|
||||||
|
|
||||||
|
node = CustomConvolution()
|
||||||
|
data = np.random.default_rng(1).standard_normal((32, 32))
|
||||||
|
field = make_field(data=data)
|
||||||
|
# 3x3 all-ones kernel, normalized → each pixel becomes mean of its neighbourhood
|
||||||
|
kernel = "1 1 1\n1 1 1\n1 1 1"
|
||||||
|
result, = node.process(field, kernel=kernel, normalize=True, boundary="reflect")
|
||||||
|
# Output std should be less than input std (smoothing)
|
||||||
|
assert result.data.std() < data.std()
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_convolution_sharpen_increases_variation():
|
||||||
|
"""A sharpening kernel should increase local variation on a smooth field."""
|
||||||
|
from backend.nodes.filter_custom import CustomConvolution
|
||||||
|
|
||||||
|
node = CustomConvolution()
|
||||||
|
# Smooth ramp field — very low frequency content
|
||||||
|
data = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
|
||||||
|
field = make_field(data=data)
|
||||||
|
sharpen = "0 -1 0\n-1 5 -1\n0 -1 0"
|
||||||
|
result, = node.process(field, kernel=sharpen, normalize=False, boundary="reflect")
|
||||||
|
# Sharpening without normalisation keeps the ramp intact plus adds edges
|
||||||
|
# The std of the sharpened field should differ from input
|
||||||
|
assert result.data.std() != pytest.approx(data.std(), rel=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_convolution_shape_preserved():
|
||||||
|
from backend.nodes.filter_custom import CustomConvolution
|
||||||
|
|
||||||
|
node = CustomConvolution()
|
||||||
|
field = make_field(shape=(48, 64))
|
||||||
|
result, = node.process(field, kernel="0 1 0\n1 1 1\n0 1 0", normalize=True, boundary="reflect")
|
||||||
|
assert result.data.shape == (48, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_convolution_invalid_kernel_fallback():
|
||||||
|
"""An invalid kernel string should return the input field unchanged."""
|
||||||
|
from backend.nodes.filter_custom import CustomConvolution
|
||||||
|
|
||||||
|
node = CustomConvolution()
|
||||||
|
data = np.random.default_rng(2).standard_normal((16, 16))
|
||||||
|
field = make_field(data=data)
|
||||||
|
result, = node.process(field, kernel="", normalize=True, boundary="reflect")
|
||||||
|
assert np.allclose(result.data, data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_convolution_ragged_kernel_fallback():
|
||||||
|
"""A ragged (non-rectangular) kernel should be rejected gracefully."""
|
||||||
|
from backend.nodes.filter_custom import CustomConvolution
|
||||||
|
|
||||||
|
node = CustomConvolution()
|
||||||
|
data = np.random.default_rng(3).standard_normal((16, 16))
|
||||||
|
field = make_field(data=data)
|
||||||
|
result, = node.process(field, kernel="1 2\n1 2 3", normalize=True, boundary="reflect")
|
||||||
|
assert np.allclose(result.data, data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_convolution_boundary_modes():
|
||||||
|
"""All boundary modes should produce valid output without error."""
|
||||||
|
from backend.nodes.filter_custom import CustomConvolution
|
||||||
|
|
||||||
|
node = CustomConvolution()
|
||||||
|
field = make_field()
|
||||||
|
for mode in ("reflect", "nearest", "wrap"):
|
||||||
|
result, = node.process(field, kernel="1 1 1\n1 1 1\n1 1 1", normalize=True, boundary=mode)
|
||||||
|
assert result.data.shape == field.data.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_convolution_preserves_metadata():
|
||||||
|
from backend.nodes.filter_custom import CustomConvolution
|
||||||
|
|
||||||
|
node = CustomConvolution()
|
||||||
|
field = make_field()
|
||||||
|
result, = node.process(field, kernel="1", normalize=False, boundary="reflect")
|
||||||
|
assert result.xreal == field.xreal
|
||||||
|
assert result.si_unit_z == field.si_unit_z
|
||||||
75
tests/node_tests/filter_kuwahara.py
Normal file
75
tests/node_tests/filter_kuwahara.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def test_kuwahara_shape_preserved():
|
||||||
|
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||||
|
|
||||||
|
node = KuwaharaFilter()
|
||||||
|
field = make_field(shape=(48, 64))
|
||||||
|
result, = node.process(field, iterations=1)
|
||||||
|
assert result.data.shape == (48, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kuwahara_flat_field_unchanged():
|
||||||
|
"""A constant field should pass through the Kuwahara filter unchanged."""
|
||||||
|
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||||
|
|
||||||
|
node = KuwaharaFilter()
|
||||||
|
field = make_field(data=np.full((32, 32), 7.5))
|
||||||
|
result, = node.process(field, iterations=1)
|
||||||
|
assert np.allclose(result.data, 7.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kuwahara_reduces_noise():
|
||||||
|
"""Applying the filter to a noisy field should reduce standard deviation."""
|
||||||
|
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||||
|
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
noisy = rng.standard_normal((64, 64))
|
||||||
|
node = KuwaharaFilter()
|
||||||
|
field = make_field(data=noisy)
|
||||||
|
result, = node.process(field, iterations=1)
|
||||||
|
assert result.data.std() < noisy.std()
|
||||||
|
|
||||||
|
|
||||||
|
def test_kuwahara_preserves_step_edge():
|
||||||
|
"""The Kuwahara filter should preserve a sharp step edge better than a blur."""
|
||||||
|
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||||
|
|
||||||
|
# Left half = 0, right half = 1
|
||||||
|
data = np.zeros((32, 64))
|
||||||
|
data[:, 32:] = 1.0
|
||||||
|
node = KuwaharaFilter()
|
||||||
|
field = make_field(data=data)
|
||||||
|
result, = node.process(field, iterations=1)
|
||||||
|
|
||||||
|
# The edge column should have a large jump (edge preserved)
|
||||||
|
col_before = result.data[:, 30].mean()
|
||||||
|
col_after = result.data[:, 34].mean()
|
||||||
|
assert col_after - col_before > 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_kuwahara_multiple_iterations():
|
||||||
|
"""Running multiple iterations should further reduce noise."""
|
||||||
|
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||||
|
|
||||||
|
rng = np.random.default_rng(1)
|
||||||
|
noisy = rng.standard_normal((32, 32))
|
||||||
|
node = KuwaharaFilter()
|
||||||
|
field = make_field(data=noisy)
|
||||||
|
result1, = node.process(field, iterations=1)
|
||||||
|
result3, = node.process(field, iterations=3)
|
||||||
|
assert result3.data.std() <= result1.data.std()
|
||||||
|
|
||||||
|
|
||||||
|
def test_kuwahara_preserves_metadata():
|
||||||
|
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||||
|
|
||||||
|
node = KuwaharaFilter()
|
||||||
|
field = make_field()
|
||||||
|
result, = node.process(field, iterations=1)
|
||||||
|
assert result.xreal == field.xreal
|
||||||
|
assert result.yreal == field.yreal
|
||||||
|
assert result.si_unit_z == field.si_unit_z
|
||||||
67
tests/node_tests/local_contrast.py
Normal file
67
tests/node_tests/local_contrast.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_contrast_shape_preserved():
|
||||||
|
from backend.nodes.local_contrast import LocalContrast
|
||||||
|
|
||||||
|
node = LocalContrast()
|
||||||
|
field = make_field(shape=(48, 64))
|
||||||
|
result, = node.process(field, kernel_size=10, weight=0.5)
|
||||||
|
assert result.data.shape == (48, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_contrast_weight_zero_unchanged():
|
||||||
|
"""weight=0 blends 100% original → result equals input."""
|
||||||
|
from backend.nodes.local_contrast import LocalContrast
|
||||||
|
|
||||||
|
node = LocalContrast()
|
||||||
|
data = np.random.default_rng(0).standard_normal((32, 32))
|
||||||
|
field = make_field(data=data)
|
||||||
|
result, = node.process(field, kernel_size=5, weight=0.0)
|
||||||
|
assert np.allclose(result.data, data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_contrast_uniform_field_unchanged():
|
||||||
|
"""A flat field has nothing to enhance; it should be returned as-is."""
|
||||||
|
from backend.nodes.local_contrast import LocalContrast
|
||||||
|
|
||||||
|
node = LocalContrast()
|
||||||
|
field = make_field(data=np.full((32, 32), 2.0))
|
||||||
|
result, = node.process(field, kernel_size=5, weight=1.0)
|
||||||
|
assert np.allclose(result.data, 2.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_contrast_increases_dynamic_range():
|
||||||
|
"""Weight=1 full enhancement should not compress global range beyond input."""
|
||||||
|
from backend.nodes.local_contrast import LocalContrast
|
||||||
|
|
||||||
|
rng = np.random.default_rng(1)
|
||||||
|
data = rng.standard_normal((64, 64))
|
||||||
|
field = make_field(data=data)
|
||||||
|
node = LocalContrast()
|
||||||
|
result, = node.process(field, kernel_size=8, weight=1.0)
|
||||||
|
# Global min/max should be preserved (by construction of the algorithm)
|
||||||
|
assert np.isclose(result.data.min(), data.min(), atol=1e-6)
|
||||||
|
assert np.isclose(result.data.max(), data.max(), atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_contrast_preserves_metadata():
|
||||||
|
from backend.nodes.local_contrast import LocalContrast
|
||||||
|
|
||||||
|
node = LocalContrast()
|
||||||
|
field = make_field()
|
||||||
|
result, = node.process(field, kernel_size=10, weight=0.5)
|
||||||
|
assert result.xreal == field.xreal
|
||||||
|
assert result.si_unit_z == field.si_unit_z
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_contrast_weight_clipped():
|
||||||
|
"""Values outside [0,1] should be clipped without error."""
|
||||||
|
from backend.nodes.local_contrast import LocalContrast
|
||||||
|
|
||||||
|
node = LocalContrast()
|
||||||
|
field = make_field()
|
||||||
|
result, = node.process(field, kernel_size=5, weight=2.0)
|
||||||
|
assert result.data.shape == field.data.shape
|
||||||
@@ -156,10 +156,16 @@ def test_save_generic():
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# LINE as plot image (PNG / TIFF)
|
||||||
|
node.save(filename="line_plot_png", directory_path=tmpdir, format="PNG", value=line)
|
||||||
|
assert Path(tmpdir, "line_plot_png.png").exists()
|
||||||
|
node.save(filename="line_plot_tiff", directory_path=tmpdir, format="TIFF", value=line)
|
||||||
|
assert Path(tmpdir, "line_plot_tiff.tiff").exists()
|
||||||
|
|
||||||
# Unsupported LINE format
|
# Unsupported LINE format
|
||||||
try:
|
try:
|
||||||
node.save(filename="line_bad", directory_path=tmpdir, format="TIFF", value=line)
|
node.save(filename="line_bad", directory_path=tmpdir, format="OBJ", value=line)
|
||||||
assert False, "Expected ValueError for LINE + TIFF"
|
assert False, "Expected ValueError for LINE + OBJ"
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
110
tests/node_tests/spot_removal.py
Normal file
110
tests/node_tests/spot_removal.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mask(shape, defect_positions):
|
||||||
|
"""Create a uint8 mask array with 255 at given (row, col) positions."""
|
||||||
|
mask = np.zeros(shape, dtype=np.uint8)
|
||||||
|
for r, c in defect_positions:
|
||||||
|
mask[r, c] = 255
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_spot_removal_no_mask_returns_field_unchanged():
|
||||||
|
"""Without a mask input the field should be returned as-is."""
|
||||||
|
from backend.nodes.spot_removal import SpotRemoval
|
||||||
|
|
||||||
|
node = SpotRemoval()
|
||||||
|
field = make_field()
|
||||||
|
result, = node.process(field, method="laplace", max_iter=50)
|
||||||
|
# Should be the identical object (short-circuit path)
|
||||||
|
assert result is field
|
||||||
|
|
||||||
|
|
||||||
|
def test_spot_removal_zero_fill():
|
||||||
|
"""method='zero' sets defect pixels to exactly 0."""
|
||||||
|
from backend.nodes.spot_removal import SpotRemoval
|
||||||
|
|
||||||
|
node = SpotRemoval()
|
||||||
|
data = np.ones((16, 16)) * 5.0
|
||||||
|
field = make_field(data=data)
|
||||||
|
mask = _make_mask((16, 16), [(4, 4), (8, 8)])
|
||||||
|
result, = node.process(field, method="zero", max_iter=1, mask=mask)
|
||||||
|
assert result.data[4, 4] == pytest.approx(0.0)
|
||||||
|
assert result.data[8, 8] == pytest.approx(0.0)
|
||||||
|
# Non-defect pixels should stay 5.0
|
||||||
|
assert result.data[0, 0] == pytest.approx(5.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_spot_removal_mean_fill_surrounded_by_constant():
|
||||||
|
"""On a constant field, mean fill should give back the constant."""
|
||||||
|
from backend.nodes.spot_removal import SpotRemoval
|
||||||
|
|
||||||
|
node = SpotRemoval()
|
||||||
|
data = np.full((16, 16), 3.0)
|
||||||
|
field = make_field(data=data)
|
||||||
|
mask = _make_mask((16, 16), [(7, 7)])
|
||||||
|
result, = node.process(field, method="mean", max_iter=1, mask=mask)
|
||||||
|
assert result.data[7, 7] == pytest.approx(3.0, abs=1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_spot_removal_laplace_fill_surrounded_by_constant():
|
||||||
|
"""Laplace fill on a constant field should recover the constant at the defect."""
|
||||||
|
from backend.nodes.spot_removal import SpotRemoval
|
||||||
|
|
||||||
|
node = SpotRemoval()
|
||||||
|
data = np.full((16, 16), 2.5)
|
||||||
|
field = make_field(data=data)
|
||||||
|
mask = _make_mask((16, 16), [(8, 8)])
|
||||||
|
result, = node.process(field, method="laplace", max_iter=200, mask=mask)
|
||||||
|
assert result.data[8, 8] == pytest.approx(2.5, abs=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_spot_removal_laplace_smooth_interpolation():
|
||||||
|
"""Laplace should interpolate between boundary values smoothly."""
|
||||||
|
from backend.nodes.spot_removal import SpotRemoval
|
||||||
|
|
||||||
|
node = SpotRemoval()
|
||||||
|
# Left half = 0, right half = 10; single defect in the middle
|
||||||
|
data = np.zeros((16, 16))
|
||||||
|
data[:, 8:] = 10.0
|
||||||
|
field = make_field(data=data)
|
||||||
|
# Defect at the boundary column
|
||||||
|
mask = _make_mask((16, 16), [(8, 7)])
|
||||||
|
result, = node.process(field, method="laplace", max_iter=500, mask=mask)
|
||||||
|
# The filled value should be between 0 and 10
|
||||||
|
filled = result.data[8, 7]
|
||||||
|
assert 0.0 <= filled <= 10.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_spot_removal_shape_preserved():
|
||||||
|
from backend.nodes.spot_removal import SpotRemoval
|
||||||
|
|
||||||
|
node = SpotRemoval()
|
||||||
|
field = make_field(shape=(48, 64))
|
||||||
|
mask = _make_mask((48, 64), [(10, 20)])
|
||||||
|
result, = node.process(field, method="mean", max_iter=10, mask=mask)
|
||||||
|
assert result.data.shape == (48, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_spot_removal_mask_shape_mismatch_raises():
|
||||||
|
from backend.nodes.spot_removal import SpotRemoval
|
||||||
|
|
||||||
|
node = SpotRemoval()
|
||||||
|
field = make_field(shape=(16, 16))
|
||||||
|
bad_mask = np.ones((32, 32), dtype=np.uint8)
|
||||||
|
with pytest.raises(ValueError, match="Mask shape"):
|
||||||
|
node.process(field, method="zero", max_iter=1, mask=bad_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_spot_removal_empty_mask_unchanged():
|
||||||
|
"""An all-zero mask means no defects — field returned unchanged."""
|
||||||
|
from backend.nodes.spot_removal import SpotRemoval
|
||||||
|
|
||||||
|
node = SpotRemoval()
|
||||||
|
data = np.random.default_rng(0).standard_normal((16, 16))
|
||||||
|
field = make_field(data=data)
|
||||||
|
mask = np.zeros((16, 16), dtype=np.uint8)
|
||||||
|
result, = node.process(field, method="laplace", max_iter=50, mask=mask)
|
||||||
|
assert result is field
|
||||||
93
tests/node_tests/template_match.py
Normal file
93
tests/node_tests/template_match.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_match_exact_match_score_one():
|
||||||
|
"""When template equals the image, the peak score should be 1."""
|
||||||
|
from backend.nodes.template_match import TemplateMatch
|
||||||
|
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
data = rng.standard_normal((32, 32))
|
||||||
|
image_field = make_field(data=data)
|
||||||
|
# Template is the full image → perfect correlation everywhere → peak = 1
|
||||||
|
template_field = make_field(data=data)
|
||||||
|
node = TemplateMatch()
|
||||||
|
score_field, detections = node.process(image_field, template_field, threshold=0.9)
|
||||||
|
assert score_field.data.max() == pytest.approx(1.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_match_output_shape_matches_image():
|
||||||
|
from backend.nodes.template_match import TemplateMatch
|
||||||
|
|
||||||
|
rng = np.random.default_rng(1)
|
||||||
|
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||||
|
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||||
|
node = TemplateMatch()
|
||||||
|
score_field, detections = node.process(image_field, template_field, threshold=0.5)
|
||||||
|
assert score_field.data.shape == image_field.data.shape
|
||||||
|
assert detections.shape == image_field.data.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_match_score_in_range():
|
||||||
|
"""Score values should be clipped to [0, 1]."""
|
||||||
|
from backend.nodes.template_match import TemplateMatch
|
||||||
|
|
||||||
|
rng = np.random.default_rng(2)
|
||||||
|
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||||
|
template_field = make_field(data=rng.standard_normal((6, 6)))
|
||||||
|
node = TemplateMatch()
|
||||||
|
score_field, _ = node.process(image_field, template_field, threshold=0.5)
|
||||||
|
assert score_field.data.min() >= 0.0 - 1e-10
|
||||||
|
assert score_field.data.max() <= 1.0 + 1e-10
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_match_detections_binary():
|
||||||
|
"""Detection mask values should be 0 or 255 only."""
|
||||||
|
from backend.nodes.template_match import TemplateMatch
|
||||||
|
|
||||||
|
rng = np.random.default_rng(3)
|
||||||
|
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||||
|
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||||
|
node = TemplateMatch()
|
||||||
|
_, detections = node.process(image_field, template_field, threshold=0.5)
|
||||||
|
unique_values = set(np.unique(detections))
|
||||||
|
assert unique_values <= {0, 255}
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_match_threshold_zero_all_detected():
|
||||||
|
"""threshold=0 should mark all pixels as detections (score always >= 0)."""
|
||||||
|
from backend.nodes.template_match import TemplateMatch
|
||||||
|
|
||||||
|
rng = np.random.default_rng(4)
|
||||||
|
image_field = make_field(data=rng.standard_normal((16, 16)))
|
||||||
|
template_field = make_field(data=rng.standard_normal((4, 4)))
|
||||||
|
node = TemplateMatch()
|
||||||
|
_, detections = node.process(image_field, template_field, threshold=0.0)
|
||||||
|
assert np.all(detections == 255)
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_match_threshold_one_sparse_detections():
|
||||||
|
"""threshold=1.0 should detect very few (or no) positions."""
|
||||||
|
from backend.nodes.template_match import TemplateMatch
|
||||||
|
|
||||||
|
rng = np.random.default_rng(5)
|
||||||
|
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||||
|
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||||
|
node = TemplateMatch()
|
||||||
|
_, detections = node.process(image_field, template_field, threshold=1.0)
|
||||||
|
# At threshold=1.0, only perfect matches count (rare for random data)
|
||||||
|
detected_count = int((detections == 255).sum())
|
||||||
|
assert detected_count < 10 # very few or none
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_match_preserves_metadata():
|
||||||
|
from backend.nodes.template_match import TemplateMatch
|
||||||
|
|
||||||
|
rng = np.random.default_rng(6)
|
||||||
|
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||||
|
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||||
|
node = TemplateMatch()
|
||||||
|
score_field, _ = node.process(image_field, template_field, threshold=0.5)
|
||||||
|
assert score_field.xreal == image_field.xreal
|
||||||
|
assert score_field.yreal == image_field.yreal
|
||||||
85
tests/node_tests/wavelet_denoise.py
Normal file
85
tests/node_tests/wavelet_denoise.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def test_wavelet_denoise_shape_preserved():
|
||||||
|
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||||
|
|
||||||
|
node = WaveletDenoise()
|
||||||
|
field = make_field(shape=(64, 64))
|
||||||
|
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||||
|
assert result.data.shape == (64, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wavelet_denoise_reduces_noise():
|
||||||
|
"""Denoising noisy data should reduce standard deviation."""
|
||||||
|
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||||
|
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
clean = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
|
||||||
|
noisy = clean + rng.normal(0, 0.1, clean.shape)
|
||||||
|
field = make_field(data=noisy)
|
||||||
|
node = WaveletDenoise()
|
||||||
|
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||||
|
# Denoised should be closer to clean than the noisy input
|
||||||
|
denoised_err = np.std(result.data - clean)
|
||||||
|
noisy_err = np.std(noisy - clean)
|
||||||
|
assert denoised_err < noisy_err
|
||||||
|
|
||||||
|
|
||||||
|
def test_wavelet_denoise_uniform_field_unchanged():
|
||||||
|
"""A flat field (no variation) is returned as-is."""
|
||||||
|
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||||
|
|
||||||
|
node = WaveletDenoise()
|
||||||
|
field = make_field(data=np.full((32, 32), 5.0))
|
||||||
|
result, = node.process(field, wavelet="db1", method="VisuShrink", sigma=0.0, mode="hard")
|
||||||
|
# The short-circuit path returns the original field object
|
||||||
|
assert result is field
|
||||||
|
|
||||||
|
|
||||||
|
def test_wavelet_denoise_preserves_range():
|
||||||
|
"""Output values should stay within the input data range (approx)."""
|
||||||
|
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||||
|
|
||||||
|
rng = np.random.default_rng(1)
|
||||||
|
data = rng.standard_normal((32, 32))
|
||||||
|
field = make_field(data=data)
|
||||||
|
node = WaveletDenoise()
|
||||||
|
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||||
|
# The normalisation ensures output is within [data.min(), data.max()]
|
||||||
|
assert result.data.min() >= data.min() - 1e-10
|
||||||
|
assert result.data.max() <= data.max() + 1e-10
|
||||||
|
|
||||||
|
|
||||||
|
def test_wavelet_denoise_all_wavelets():
|
||||||
|
"""All supported wavelets should run without error."""
|
||||||
|
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||||
|
|
||||||
|
rng = np.random.default_rng(2)
|
||||||
|
field = make_field(data=rng.standard_normal((32, 32)))
|
||||||
|
node = WaveletDenoise()
|
||||||
|
for wavelet in ("db1", "db2", "db4", "db8", "sym4", "coif1", "bior1.3"):
|
||||||
|
result, = node.process(field, wavelet=wavelet, method="BayesShrink", sigma=0.0, mode="soft")
|
||||||
|
assert result.data.shape == field.data.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_wavelet_denoise_visu_shrink():
|
||||||
|
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||||
|
|
||||||
|
rng = np.random.default_rng(3)
|
||||||
|
field = make_field(data=rng.standard_normal((32, 32)))
|
||||||
|
node = WaveletDenoise()
|
||||||
|
result, = node.process(field, wavelet="db4", method="VisuShrink", sigma=0.0, mode="soft")
|
||||||
|
assert result.data.shape == field.data.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_wavelet_denoise_preserves_metadata():
|
||||||
|
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||||
|
|
||||||
|
node = WaveletDenoise()
|
||||||
|
field = make_field()
|
||||||
|
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||||
|
assert result.xreal == field.xreal
|
||||||
|
assert result.si_unit_z == field.si_unit_z
|
||||||
Reference in New Issue
Block a user