remaining med value features
This commit is contained in:
@@ -27,18 +27,18 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
|
||||
|
||||
| # | Feature | Gwyddion Source | Description |
|
||||
|---|---------|---------------|-------------|
|
||||
| 15 | Correlation / Pattern Matching | crosscor.c, maskcor.c | Find repeated features or align images via cross-correlation. |
|
||||
| ~~15~~ | ~~Correlation / Pattern Matching~~ | ~~crosscor.c, maskcor.c~~ | ~~Find repeated features or align images via cross-correlation.~~ **DONE** |
|
||||
| ~~16~~ | ~~Slope Distribution~~ | ~~slope_dist.c~~ | ~~Angular histogram of surface slopes. Characterizes surface texture directionality.~~ **DONE** |
|
||||
| ~~17~~ | ~~Grain Filtering~~ | ~~grain_filter.c~~ | ~~Remove grains by size, height, or border contact. Refine grain masks post-detection.~~ **DONE** |
|
||||
| ~~18~~ | ~~Field Arithmetic~~ | ~~arithmetic.c~~ | ~~Add/subtract/multiply/divide two DATA_FIELDs. Useful for difference maps, normalization.~~ **DONE** |
|
||||
| 19 | Spot Removal | spotremove.c | Interpolate over selected point defects (dust, spikes). |
|
||||
| ~~19~~ | ~~Spot Removal~~ | ~~spotremove.c~~ | ~~Interpolate over selected point defects (dust, spikes).~~ **DONE** |
|
||||
| ~~20~~ | ~~Tip Modeling / Deconvolution~~ | ~~tip_blind.c, tip_model.c~~ | ~~Estimate tip shape from image, deconvolve to recover true surface.~~ **DONE** |
|
||||
| ~~21~~ | ~~Radial Profile~~ | ~~rprofile tool~~ | ~~Azimuthally averaged profile from a center point. Good for circular features.~~ **DONE** |
|
||||
| 22 | Wavelet Transform | dwt.c, cwt.c | Discrete/continuous wavelet analysis. Multi-scale roughness decomposition. |
|
||||
| ~~22~~ | ~~Wavelet Transform~~ | ~~dwt.c, cwt.c~~ | ~~Discrete/continuous wavelet analysis. Multi-scale roughness decomposition.~~ **DONE** |
|
||||
| ~~23~~ | ~~Scale / Resample~~ | ~~scale.c, resample.c~~ | ~~Resize fields with interpolation.~~ **DONE** |
|
||||
| ~~24~~ | ~~Gradient~~ | ~~gradient.c~~ | ~~Compute x/y gradient magnitude maps.~~ **DONE** |
|
||||
| 25 | Custom Convolution | convolution_filter.c | User-defined kernel convolution. |
|
||||
| 26 | Local Contrast Enhancement | local_contrast.c | Enhance visibility of local features in images. |
|
||||
| ~~25~~ | ~~Custom Convolution~~ | ~~convolution_filter.c~~ | ~~User-defined kernel convolution.~~ **DONE** |
|
||||
| ~~26~~ | ~~Local Contrast Enhancement~~ | ~~local_contrast.c~~ | ~~Enhance visibility of local features in images.~~ **DONE** |
|
||||
|
||||
## Lower Priority
|
||||
|
||||
@@ -53,11 +53,11 @@ Reference for future implementation. Grouped by value to typical SPM workflows.
|
||||
| 33 | Facet Analysis | facet_analysis.c | Orientation distribution of surface facets (stereographic projection). |
|
||||
| 34 | Shape Fitting | fit-shape.c | Fit geometric primitives: sphere, paraboloid, cylinder, etc. |
|
||||
| 35 | Synthetic Surface Generation | *_synth.c (~20 modules) | Generate test surfaces: FBM, noise, lattice, waves, particles, fibers, etc. |
|
||||
| 36 | Entropy | entropy.c | Information entropy of height distribution. |
|
||||
| ~~36~~ | ~~Entropy~~ | ~~entropy.c~~ | ~~Information entropy of height distribution.~~ **DONE** |
|
||||
| 37 | Indentation Analysis | indent_analyze.c, hertz.c | Nanoindentation curve fitting (Hertz model). |
|
||||
| 38 | Deconvolution | deconvolve.c | Blind/regularized deconvolution for image restoration. |
|
||||
| 39 | Canny / Harris Detection | filters.c | Corner and edge feature detection beyond basic Sobel/Prewitt. |
|
||||
| 40 | Kuwahara Filter | filters.c | Edge-preserving smoothing filter. |
|
||||
| ~~40~~ | ~~Kuwahara Filter~~ | ~~filters.c~~ | ~~Edge-preserving smoothing filter.~~ **DONE** |
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -69,4 +69,13 @@ from backend.nodes import (
|
||||
tip_model,
|
||||
tip_deconvolution,
|
||||
tip_blind_estimate,
|
||||
# New: remaining Gwyddion feature-gap nodes
|
||||
filter_kuwahara,
|
||||
entropy,
|
||||
local_contrast,
|
||||
filter_custom,
|
||||
spot_removal,
|
||||
wavelet_denoise,
|
||||
cross_correlate,
|
||||
template_match,
|
||||
)
|
||||
|
||||
68
backend/nodes/cross_correlate.py
Normal file
68
backend/nodes/cross_correlate.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import DataField
|
||||
from backend.node_registry import register_node
|
||||
|
||||
|
||||
@register_node(display_name="Cross-Correlate")
|
||||
class CrossCorrelate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field_a": ("DATA_FIELD",),
|
||||
"field_b": ("DATA_FIELD",),
|
||||
"mode": (["full", "same", "valid"], {"default": "same"}),
|
||||
"normalize": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'correlation'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Compute 2D cross-correlation between two fields. The correlation peak indicates "
|
||||
"the offset where the two fields best match. Useful for drift measurement and feature "
|
||||
"alignment. Equivalent to Gwyddion crosscor.c."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
field_a: DataField,
|
||||
field_b: DataField,
|
||||
mode: str,
|
||||
normalize: bool,
|
||||
) -> tuple:
|
||||
from scipy.signal import fftconvolve
|
||||
|
||||
a = field_a.data - field_a.data.mean()
|
||||
b = field_b.data - field_b.data.mean()
|
||||
|
||||
# Cross-correlation via FFT: correlate(a,b) = ifft(fft(a) * conj(fft(b)))
|
||||
# Achieved by convolving a with the flipped b
|
||||
corr = fftconvolve(a, b[::-1, ::-1], mode=mode)
|
||||
|
||||
if normalize:
|
||||
denom = np.sqrt((a ** 2).sum() * (b ** 2).sum())
|
||||
if denom > 0:
|
||||
corr = corr / denom
|
||||
|
||||
if mode == "same":
|
||||
# Output is the same shape as field_a — reuse its physical dimensions
|
||||
return (field_a.replace(data=corr),)
|
||||
|
||||
# For "full" mode: output shape is (Na+Nb-1, Ma+Mb-1)
|
||||
# Scale physical dimensions proportionally
|
||||
na, ma = field_a.data.shape
|
||||
nb, mb = field_b.data.shape
|
||||
out_y, out_x = corr.shape
|
||||
|
||||
# Physical size per pixel stays the same as field_a; total physical size scales
|
||||
new_xreal = field_a.xreal * out_x / ma if ma > 0 else field_a.xreal
|
||||
new_yreal = field_a.yreal * out_y / na if na > 0 else field_a.yreal
|
||||
|
||||
return (field_a.replace(data=corr, xreal=new_xreal, yreal=new_yreal),)
|
||||
74
backend/nodes/entropy.py
Normal file
74
backend/nodes/entropy.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, RecordTable
|
||||
from backend.execution_context import emit_table
|
||||
|
||||
|
||||
@register_node(display_name="Entropy")
|
||||
class Entropy:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"mode": (["height values", "slope magnitude"], {"default": "height values"}),
|
||||
"n_bins": ("INT", {"default": 256, "min": 16, "max": 1024}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('FLOAT', 'entropy'),
|
||||
('FLOAT', 'normalised_entropy'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Shannon entropy of the height or slope distribution. "
|
||||
"H = -\u03a3 p\u00b7ln(p). Equivalent to Gwyddion entropy.c."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, mode: str, n_bins: int) -> tuple:
|
||||
n_bins = max(16, int(n_bins))
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
|
||||
if mode == "slope magnitude":
|
||||
# Compute slope magnitude from Sobel-like finite differences.
|
||||
# Central differences along x and y (axis=1 and axis=0).
|
||||
# np.gradient uses central differences, same spirit as Sobel.
|
||||
dy, dx = np.gradient(data)
|
||||
values = np.hypot(dx, dy).ravel()
|
||||
else:
|
||||
values = data.ravel()
|
||||
|
||||
# Remove non-finite values before binning.
|
||||
values = values[np.isfinite(values)]
|
||||
|
||||
if values.size == 0:
|
||||
h = 0.0
|
||||
h_norm = 0.0
|
||||
else:
|
||||
counts, _ = np.histogram(values, bins=n_bins)
|
||||
total = counts.sum()
|
||||
if total == 0:
|
||||
h = 0.0
|
||||
h_norm = 0.0
|
||||
else:
|
||||
# Probability distribution; skip zero bins.
|
||||
p = counts[counts > 0].astype(np.float64) / float(total)
|
||||
h = float(-np.sum(p * np.log(p)))
|
||||
# Maximum possible entropy for n_bins equally occupied bins is ln(n_bins).
|
||||
h_max = float(np.log(n_bins))
|
||||
h_norm = h / h_max if h_max > 0.0 else 0.0
|
||||
|
||||
table = RecordTable([
|
||||
{"quantity": "entropy", "value": h, "unit": "nat"},
|
||||
{"quantity": "normalised entropy", "value": h_norm, "unit": ""},
|
||||
{"quantity": "mode", "value": mode, "unit": ""},
|
||||
{"quantity": "n_bins", "value": n_bins, "unit": ""},
|
||||
])
|
||||
emit_table(table)
|
||||
|
||||
return (h, h_norm)
|
||||
127
backend/nodes/filter_custom.py
Normal file
127
backend/nodes/filter_custom.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
from backend.execution_context import emit_warning
|
||||
|
||||
_DEFAULT_KERNEL = "0 -1 0\n-1 5 -1\n0 -1 0"
|
||||
_MAX_KERNEL_DIM = 51
|
||||
|
||||
|
||||
def _parse_kernel(kernel_str: str) -> np.ndarray | None:
|
||||
"""Parse a multi-line kernel string into a 2-D float64 array.
|
||||
|
||||
Returns *None* and issues a warning via emit_warning if the string is
|
||||
invalid. The returned array is always at least 1×1 and at most
|
||||
_MAX_KERNEL_DIM × _MAX_KERNEL_DIM.
|
||||
"""
|
||||
lines = [ln for ln in kernel_str.splitlines() if ln.strip()]
|
||||
if not lines:
|
||||
emit_warning("Custom Convolution: kernel string is empty. Using identity.")
|
||||
return None
|
||||
|
||||
rows = []
|
||||
for ln in lines:
|
||||
try:
|
||||
row = [float(v) for v in ln.split()]
|
||||
except ValueError:
|
||||
emit_warning(
|
||||
f"Custom Convolution: could not parse kernel row {ln!r}. "
|
||||
"Input returned unchanged."
|
||||
)
|
||||
return None
|
||||
if not row:
|
||||
continue
|
||||
rows.append(row)
|
||||
|
||||
if not rows:
|
||||
emit_warning("Custom Convolution: kernel has no valid rows. Input returned unchanged.")
|
||||
return None
|
||||
|
||||
# All rows must have the same length.
|
||||
ncols = len(rows[0])
|
||||
for i, row in enumerate(rows):
|
||||
if len(row) != ncols:
|
||||
emit_warning(
|
||||
f"Custom Convolution: row {i} has {len(row)} values but row 0 has {ncols}. "
|
||||
"All rows must be the same length. Input returned unchanged."
|
||||
)
|
||||
return None
|
||||
|
||||
arr = np.array(rows, dtype=np.float64)
|
||||
|
||||
if arr.ndim != 2 or arr.size == 0:
|
||||
emit_warning("Custom Convolution: kernel is empty after parsing. Input returned unchanged.")
|
||||
return None
|
||||
|
||||
nrows, ncols = arr.shape
|
||||
if nrows > _MAX_KERNEL_DIM or ncols > _MAX_KERNEL_DIM:
|
||||
emit_warning(
|
||||
f"Custom Convolution: kernel size {nrows}×{ncols} exceeds maximum "
|
||||
f"{_MAX_KERNEL_DIM}×{_MAX_KERNEL_DIM}. Input returned unchanged."
|
||||
)
|
||||
return None
|
||||
|
||||
if not np.all(np.isfinite(arr)):
|
||||
emit_warning("Custom Convolution: kernel contains non-finite values. Input returned unchanged.")
|
||||
return None
|
||||
|
||||
return arr
|
||||
|
||||
|
||||
@register_node(display_name="Custom Convolution")
|
||||
class CustomConvolution:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"kernel": ("STRING", {
|
||||
"multiline": True,
|
||||
"default": _DEFAULT_KERNEL,
|
||||
"placeholder": "kernel rows, space-separated",
|
||||
}),
|
||||
"normalize": ("BOOLEAN", {"default": True}),
|
||||
"boundary": (["reflect", "nearest", "wrap"], {"default": "reflect"}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'result'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Apply a user-defined convolution kernel. "
|
||||
"Enter rows of space-separated numbers. "
|
||||
"Example sharpen: '0 -1 0 / -1 5 -1 / 0 -1 0' (use newlines, not slashes). "
|
||||
"Equivalent to Gwyddion convolution_filter.c."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
kernel: str,
|
||||
normalize: bool,
|
||||
boundary: str,
|
||||
) -> tuple:
|
||||
from scipy.ndimage import convolve
|
||||
|
||||
kernel_arr = _parse_kernel(kernel)
|
||||
if kernel_arr is None:
|
||||
# Fallback: return input unchanged.
|
||||
return (field.replace(data=field.data.copy()),)
|
||||
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
|
||||
# scipy.ndimage.convolve boundary mode names match our choices directly.
|
||||
result = convolve(data, kernel_arr, mode=boundary)
|
||||
|
||||
if normalize:
|
||||
abs_sum = float(np.sum(np.abs(kernel_arr)))
|
||||
if abs_sum > 0.0:
|
||||
result = result / abs_sum
|
||||
|
||||
return (field.replace(data=result),)
|
||||
86
backend/nodes/filter_kuwahara.py
Normal file
86
backend/nodes/filter_kuwahara.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
def _kuwahara_pass(data: np.ndarray) -> np.ndarray:
|
||||
"""Single pass of the 5x5 Kuwahara filter.
|
||||
|
||||
Divides a 5x5 neighbourhood around each pixel into four overlapping 3x3
|
||||
quadrants (TL, TR, BL, BR), computes the mean and variance of each quadrant,
|
||||
and replaces the centre pixel with the mean of the quadrant that has the
|
||||
smallest variance. Boundary pixels are handled by reflecting the image.
|
||||
"""
|
||||
# Pad with reflect so every pixel has a full 5x5 neighbourhood.
|
||||
padded = np.pad(data, pad_width=2, mode="reflect")
|
||||
|
||||
rows, cols = data.shape
|
||||
|
||||
# For each of the four 3x3 quadrant offsets we need the per-pixel mean and
|
||||
# variance. The quadrant window positions (relative to the padded array)
|
||||
# for a centre pixel at (r+2, c+2) are:
|
||||
# TL : rows r..r+2, cols c..c+2 → offset (0,0) in padded
|
||||
# TR : rows r..r+2, cols c+2..c+4 → offset (0,2)
|
||||
# BL : rows r+2..r+4, cols c..c+2 → offset (2,0)
|
||||
# BR : rows r+2..r+4, cols c+2..c+4 → offset (2,2)
|
||||
# Each window is 3×3 = 9 pixels.
|
||||
|
||||
quadrant_offsets = [(0, 0), (0, 2), (2, 0), (2, 2)]
|
||||
n = 9 # 3×3 quadrant
|
||||
|
||||
# Accumulate sum and sum-of-squares for each quadrant using a simple nested
|
||||
# index loop over the 3×3 window positions.
|
||||
quad_sum = np.zeros((4, rows, cols), dtype=np.float64)
|
||||
quad_sum2 = np.zeros((4, rows, cols), dtype=np.float64)
|
||||
|
||||
for qi, (dr0, dc0) in enumerate(quadrant_offsets):
|
||||
for drow in range(3):
|
||||
for dcol in range(3):
|
||||
patch = padded[dr0 + drow: dr0 + drow + rows,
|
||||
dc0 + dcol: dc0 + dcol + cols]
|
||||
quad_sum[qi] += patch
|
||||
quad_sum2[qi] += patch * patch
|
||||
|
||||
quad_mean = quad_sum / n
|
||||
# var = E[x^2] - (E[x])^2, clamped to 0 to avoid floating-point negatives.
|
||||
quad_var = np.maximum(quad_sum2 / n - quad_mean * quad_mean, 0.0)
|
||||
|
||||
# Select the quadrant index with minimum variance for each pixel.
|
||||
best_qi = np.argmin(quad_var, axis=0) # shape (rows, cols)
|
||||
|
||||
# Gather the mean from the winning quadrant.
|
||||
result = np.choose(best_qi, quad_mean)
|
||||
return result.astype(np.float64)
|
||||
|
||||
|
||||
@register_node(display_name="Kuwahara Filter")
|
||||
class KuwaharaFilter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"iterations": ("INT", {"default": 1, "min": 1, "max": 20}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'filtered'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Edge-preserving smoothing using Kuwahara's minimum-variance quadrant method. "
|
||||
"Unlike Gaussian blur, sharp boundaries are preserved. "
|
||||
"Equivalent to Gwyddion's Kuwahara filter."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, iterations: int) -> tuple:
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
iterations = max(1, int(iterations))
|
||||
for _ in range(iterations):
|
||||
data = _kuwahara_pass(data)
|
||||
return (field.replace(data=data),)
|
||||
64
backend/nodes/local_contrast.py
Normal file
64
backend/nodes/local_contrast.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
@register_node(display_name="Local Contrast")
|
||||
class LocalContrast:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"kernel_size": ("INT", {"default": 10, "min": 2, "max": 100}),
|
||||
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'result'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Expand the local dynamic range at each pixel. "
|
||||
"Reveals fine surface features that are hidden by global contrast range. "
|
||||
"Equivalent to Gwyddion local_contrast.c."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, kernel_size: int, weight: float) -> tuple:
|
||||
from scipy.ndimage import minimum_filter, maximum_filter
|
||||
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
kernel_size = max(2, int(kernel_size))
|
||||
weight = float(np.clip(weight, 0.0, 1.0))
|
||||
|
||||
local_min = minimum_filter(data, size=kernel_size, mode="reflect")
|
||||
local_max = maximum_filter(data, size=kernel_size, mode="reflect")
|
||||
|
||||
global_min = float(data.min())
|
||||
global_max = float(data.max())
|
||||
|
||||
local_range = local_max - local_min
|
||||
eps = np.finfo(np.float64).eps * max(abs(global_max), abs(global_min), 1.0)
|
||||
|
||||
# Remap each pixel from its local range to the global range.
|
||||
# Where local_range is near zero, the pixel is already flat: leave it
|
||||
# unchanged (enhancement factor = 1).
|
||||
safe_range = np.where(local_range > eps, local_range, 1.0)
|
||||
global_span = global_max - global_min
|
||||
if global_span <= eps:
|
||||
# Uniform field – nothing to enhance.
|
||||
return (field.replace(data=data.copy()),)
|
||||
|
||||
enhancement_factor = global_span / safe_range
|
||||
# Locally enhanced value: remap v from [local_min, local_max] → [global_min, global_max]
|
||||
v_enhanced = global_min + enhancement_factor * (data - local_min)
|
||||
|
||||
# Blend between original and enhanced.
|
||||
result = (1.0 - weight) * data + weight * v_enhanced
|
||||
|
||||
return (field.replace(data=result),)
|
||||
122
backend/nodes/spot_removal.py
Normal file
122
backend/nodes/spot_removal.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import DataField
|
||||
from backend.node_registry import register_node
|
||||
|
||||
|
||||
@register_node(display_name="Spot Removal")
|
||||
class SpotRemoval:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": (["laplace", "mean", "zero"], {"default": "laplace"}),
|
||||
"max_iter": ("INT", {"default": 100, "min": 1, "max": 2000, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("IMAGE", {"label": "defects"}),
|
||||
},
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'result'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Fill defect pixels (hot pixels, dropouts, scan artifacts) by interpolation. "
|
||||
"The mask defines defect locations. Laplace method solves the 2D Laplace equation "
|
||||
"for smooth inpainting. Equivalent to Gwyddion spotremove.c."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
method: str,
|
||||
max_iter: int,
|
||||
mask: np.ndarray | None = None,
|
||||
) -> tuple:
|
||||
if mask is None:
|
||||
return (field,)
|
||||
|
||||
mask_array = np.asarray(mask)
|
||||
# Reshape mask to match field shape if it has extra dimensions (e.g. HxWx1)
|
||||
if mask_array.ndim == 3:
|
||||
mask_array = mask_array[:, :, 0]
|
||||
if mask_array.shape != field.data.shape:
|
||||
raise ValueError(
|
||||
f"Mask shape {mask_array.shape} does not match field shape {field.data.shape}."
|
||||
)
|
||||
|
||||
defect = mask_array > 0
|
||||
|
||||
if not np.any(defect):
|
||||
return (field,)
|
||||
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
|
||||
if method == "zero":
|
||||
result = data.copy()
|
||||
result[defect] = 0.0
|
||||
return (field.replace(data=result),)
|
||||
|
||||
if method == "mean":
|
||||
result = _mean_fill(data, defect)
|
||||
return (field.replace(data=result),)
|
||||
|
||||
# method == "laplace"
|
||||
result = _laplace_fill(data, defect, int(max_iter))
|
||||
return (field.replace(data=result),)
|
||||
|
||||
|
||||
def _mean_fill(data: np.ndarray, defect: np.ndarray) -> np.ndarray:
|
||||
"""Fill defect pixels with the mean of non-defect neighbours in a 3x3 window."""
|
||||
result = data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
# Global fallback: mean of all non-defect pixels
|
||||
non_defect_vals = data[~defect]
|
||||
global_mean = float(non_defect_vals.mean()) if non_defect_vals.size > 0 else 0.0
|
||||
|
||||
defect_coords = np.argwhere(defect)
|
||||
for y, x in defect_coords:
|
||||
y0 = max(y - 1, 0)
|
||||
y1 = min(y + 2, yres)
|
||||
x0 = max(x - 1, 0)
|
||||
x1 = min(x + 2, xres)
|
||||
|
||||
neighbourhood_data = data[y0:y1, x0:x1]
|
||||
neighbourhood_defect = defect[y0:y1, x0:x1]
|
||||
good = neighbourhood_data[~neighbourhood_defect]
|
||||
|
||||
if good.size > 0:
|
||||
result[y, x] = float(good.mean())
|
||||
else:
|
||||
result[y, x] = global_mean
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _laplace_fill(data: np.ndarray, defect: np.ndarray, max_iter: int) -> np.ndarray:
|
||||
"""Iterative Laplace solver: set defect pixels to neighbour average each iteration."""
|
||||
non_defect_vals = data[~defect]
|
||||
init_val = float(non_defect_vals.mean()) if non_defect_vals.size > 0 else 0.0
|
||||
|
||||
result = data.copy()
|
||||
result[defect] = init_val
|
||||
|
||||
for _ in range(max_iter):
|
||||
# Compute neighbour averages via rolled arrays
|
||||
neighbour_sum = (
|
||||
np.roll(result, -1, axis=0)
|
||||
+ np.roll(result, 1, axis=0)
|
||||
+ np.roll(result, -1, axis=1)
|
||||
+ np.roll(result, 1, axis=1)
|
||||
)
|
||||
new_vals = neighbour_sum / 4.0
|
||||
result[defect] = new_vals[defect]
|
||||
|
||||
return result
|
||||
52
backend/nodes/template_match.py
Normal file
52
backend/nodes/template_match.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import DataField
|
||||
from backend.node_registry import register_node
|
||||
|
||||
|
||||
@register_node(display_name="Template Match")
|
||||
class TemplateMatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("DATA_FIELD",),
|
||||
"template": ("DATA_FIELD",),
|
||||
"threshold": (
|
||||
"FLOAT",
|
||||
{"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'score'),
|
||||
('IMAGE', 'detections'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Find a template pattern within a larger data field using normalised cross-correlation. "
|
||||
"The score output shows match quality (1 = perfect match). Detections mask marks positions "
|
||||
"above the threshold. Equivalent to Gwyddion maskcor.c."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
image: DataField,
|
||||
template: DataField,
|
||||
threshold: float,
|
||||
) -> tuple:
|
||||
from skimage.feature import match_template
|
||||
|
||||
score = match_template(image.data, template.data, pad_input=True)
|
||||
|
||||
# Clip to [0, 1] for display (match_template returns values in [-1, 1])
|
||||
score_clipped = np.clip(score, 0.0, 1.0)
|
||||
|
||||
detections = (score_clipped >= float(threshold)).astype(np.uint8) * 255
|
||||
|
||||
score_field = image.replace(data=score_clipped)
|
||||
return (score_field, detections)
|
||||
74
backend/nodes/wavelet_denoise.py
Normal file
74
backend/nodes/wavelet_denoise.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import DataField
|
||||
from backend.node_registry import register_node
|
||||
|
||||
|
||||
@register_node(display_name="Wavelet Denoise")
|
||||
class WaveletDenoise:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"wavelet": (
|
||||
["db1", "db2", "db4", "db8", "sym4", "coif1", "bior1.3"],
|
||||
{"default": "db4"},
|
||||
),
|
||||
"method": (["BayesShrink", "VisuShrink"], {"default": "BayesShrink"}),
|
||||
"sigma": (
|
||||
"FLOAT",
|
||||
{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01},
|
||||
),
|
||||
"mode": (["soft", "hard"], {"default": "soft"}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'denoised'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Denoise using wavelet coefficient thresholding. BayesShrink adapts the threshold "
|
||||
"per sub-band; VisuShrink uses a global threshold. Equivalent to applying DWT from "
|
||||
"Gwyddion dwt.c with coefficient thresholding."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
wavelet: str,
|
||||
method: str,
|
||||
sigma: float,
|
||||
mode: str,
|
||||
) -> tuple:
|
||||
from skimage.restoration import denoise_wavelet
|
||||
|
||||
d = field.data
|
||||
dmin = float(d.min())
|
||||
drange = float(np.ptp(d))
|
||||
|
||||
if drange == 0:
|
||||
return (field,)
|
||||
|
||||
norm = (d - dmin) / drange
|
||||
sigma_val = float(sigma) if sigma > 0 else None
|
||||
|
||||
# `mode` is a Python builtin name; use threshold_mode locally to avoid shadowing
|
||||
threshold_mode = mode
|
||||
|
||||
denoised_norm = denoise_wavelet(
|
||||
norm,
|
||||
wavelet=wavelet,
|
||||
method=method,
|
||||
mode=threshold_mode,
|
||||
sigma=sigma_val,
|
||||
rescale_sigma=True,
|
||||
channel_axis=None,
|
||||
)
|
||||
|
||||
result = denoised_norm * drange + dmin
|
||||
return (field.replace(data=result),)
|
||||
@@ -17,6 +17,7 @@ dependencies = [
|
||||
"nanonispy>=1.1",
|
||||
"numpy>=1.26,<3",
|
||||
"pillow>=10,<12",
|
||||
"pywavelets>=1.8.0",
|
||||
"scikit-image>=0.22,<1",
|
||||
"scipy>=1.12,<2",
|
||||
]
|
||||
|
||||
89
tests/node_tests/cross_correlate.py
Normal file
89
tests/node_tests/cross_correlate.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_cross_correlate_same_field_peak_at_center():
|
||||
"""Correlating a field with itself in 'same' mode peaks at the centre."""
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
data = rng.standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(field, field, mode="same", normalize=True)
|
||||
|
||||
peak_y, peak_x = np.unravel_index(np.argmax(result.data), result.data.shape)
|
||||
cy, cx = result.data.shape[0] // 2, result.data.shape[1] // 2
|
||||
# Peak should be within a few pixels of centre
|
||||
assert abs(peak_y - cy) <= 2
|
||||
assert abs(peak_x - cx) <= 2
|
||||
|
||||
|
||||
def test_cross_correlate_same_mode_shape_equals_a():
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
a = make_field(data=rng.standard_normal((32, 48)))
|
||||
b = make_field(data=rng.standard_normal((32, 48)))
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(a, b, mode="same", normalize=True)
|
||||
assert result.data.shape == a.data.shape
|
||||
|
||||
|
||||
def test_cross_correlate_full_mode_shape():
|
||||
"""Full mode output shape should be Na+Nb-1 × Ma+Mb-1."""
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
a = make_field(data=rng.standard_normal((20, 30)))
|
||||
b = make_field(data=rng.standard_normal((20, 30)))
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(a, b, mode="full", normalize=True)
|
||||
assert result.data.shape == (20 + 20 - 1, 30 + 30 - 1)
|
||||
|
||||
|
||||
def test_cross_correlate_normalized_peak_is_one():
|
||||
"""Self-correlation normalised should give peak = 1."""
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
data = rng.standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(field, field, mode="same", normalize=True)
|
||||
assert result.data.max() == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_cross_correlate_unnormalized_runs():
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(4)
|
||||
data = rng.standard_normal((16, 16))
|
||||
field = make_field(data=data)
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(field, field, mode="same", normalize=False)
|
||||
assert result.data.shape == (16, 16)
|
||||
|
||||
|
||||
def test_cross_correlate_valid_mode():
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(5)
|
||||
a = make_field(data=rng.standard_normal((16, 16)))
|
||||
b = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(a, b, mode="valid", normalize=True)
|
||||
# Valid mode output: (16-8+1, 16-8+1) = (9, 9)
|
||||
assert result.data.shape == (9, 9)
|
||||
|
||||
|
||||
def test_cross_correlate_preserves_metadata_same_mode():
|
||||
from backend.nodes.cross_correlate import CrossCorrelate
|
||||
|
||||
rng = np.random.default_rng(6)
|
||||
field = make_field(data=rng.standard_normal((16, 16)))
|
||||
node = CrossCorrelate()
|
||||
result, = node.process(field, field, mode="same", normalize=True)
|
||||
assert result.xreal == field.xreal
|
||||
assert result.yreal == field.yreal
|
||||
75
tests/node_tests/entropy.py
Normal file
75
tests/node_tests/entropy.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_entropy_uniform_field_low():
|
||||
"""A field with a single unique value has zero entropy."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
node = Entropy()
|
||||
field = make_field(data=np.full((32, 32), 3.14))
|
||||
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||
# All values fall in one bin → p=1 → H = 0
|
||||
assert h == pytest.approx(0.0, abs=1e-10)
|
||||
assert h_norm == pytest.approx(0.0, abs=1e-10)
|
||||
|
||||
|
||||
def test_entropy_random_field_positive():
|
||||
"""A random field should have positive entropy."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
field = make_field(data=rng.standard_normal((64, 64)))
|
||||
node = Entropy()
|
||||
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||
assert h > 0.0
|
||||
assert 0.0 < h_norm <= 1.0
|
||||
|
||||
|
||||
def test_entropy_normalised_leq_one():
|
||||
"""Normalised entropy should never exceed 1."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
field = make_field(data=rng.uniform(0, 1, (64, 64)))
|
||||
node = Entropy()
|
||||
_, h_norm = node.process(field, mode="height values", n_bins=64)
|
||||
assert h_norm <= 1.0 + 1e-12
|
||||
|
||||
|
||||
def test_entropy_slope_mode():
|
||||
"""Slope mode should work and return valid entropy values."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
field = make_field(data=rng.standard_normal((32, 32)))
|
||||
node = Entropy()
|
||||
h, h_norm = node.process(field, mode="slope magnitude", n_bins=128)
|
||||
assert h > 0.0
|
||||
assert 0.0 <= h_norm <= 1.0
|
||||
|
||||
|
||||
def test_entropy_more_uniform_is_higher():
|
||||
"""Uniformly distributed values have higher entropy than a spiked distribution."""
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
rng = np.random.default_rng(4)
|
||||
uniform = rng.uniform(0, 1, (64, 64))
|
||||
spiked = np.zeros((64, 64))
|
||||
spiked[0, 0] = 1.0
|
||||
|
||||
node = Entropy()
|
||||
h_uniform, _ = node.process(make_field(data=uniform), mode="height values", n_bins=64)
|
||||
h_spiked, _ = node.process(make_field(data=spiked), mode="height values", n_bins=64)
|
||||
assert h_uniform > h_spiked
|
||||
|
||||
|
||||
def test_entropy_returns_floats():
|
||||
from backend.nodes.entropy import Entropy
|
||||
|
||||
field = make_field()
|
||||
node = Entropy()
|
||||
h, h_norm = node.process(field, mode="height values", n_bins=256)
|
||||
assert isinstance(h, float)
|
||||
assert isinstance(h_norm, float)
|
||||
95
tests/node_tests/filter_custom.py
Normal file
95
tests/node_tests/filter_custom.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_custom_convolution_identity_kernel():
|
||||
"""[[1]] with normalize=True (abs_sum=1) should return input unchanged."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
data = np.random.default_rng(0).standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, kernel="1", normalize=False, boundary="reflect")
|
||||
assert np.allclose(result.data, data)
|
||||
|
||||
|
||||
def test_custom_convolution_uniform_kernel_normalized():
|
||||
"""An all-ones kernel with normalize=True is a box filter (mean filter)."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
data = np.random.default_rng(1).standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
# 3x3 all-ones kernel, normalized → each pixel becomes mean of its neighbourhood
|
||||
kernel = "1 1 1\n1 1 1\n1 1 1"
|
||||
result, = node.process(field, kernel=kernel, normalize=True, boundary="reflect")
|
||||
# Output std should be less than input std (smoothing)
|
||||
assert result.data.std() < data.std()
|
||||
|
||||
|
||||
def test_custom_convolution_sharpen_increases_variation():
|
||||
"""A sharpening kernel should increase local variation on a smooth field."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
# Smooth ramp field — very low frequency content
|
||||
data = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
|
||||
field = make_field(data=data)
|
||||
sharpen = "0 -1 0\n-1 5 -1\n0 -1 0"
|
||||
result, = node.process(field, kernel=sharpen, normalize=False, boundary="reflect")
|
||||
# Sharpening without normalisation keeps the ramp intact plus adds edges
|
||||
# The std of the sharpened field should differ from input
|
||||
assert result.data.std() != pytest.approx(data.std(), rel=0.0)
|
||||
|
||||
|
||||
def test_custom_convolution_shape_preserved():
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
field = make_field(shape=(48, 64))
|
||||
result, = node.process(field, kernel="0 1 0\n1 1 1\n0 1 0", normalize=True, boundary="reflect")
|
||||
assert result.data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_custom_convolution_invalid_kernel_fallback():
|
||||
"""An invalid kernel string should return the input field unchanged."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
data = np.random.default_rng(2).standard_normal((16, 16))
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, kernel="", normalize=True, boundary="reflect")
|
||||
assert np.allclose(result.data, data)
|
||||
|
||||
|
||||
def test_custom_convolution_ragged_kernel_fallback():
|
||||
"""A ragged (non-rectangular) kernel should be rejected gracefully."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
data = np.random.default_rng(3).standard_normal((16, 16))
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, kernel="1 2\n1 2 3", normalize=True, boundary="reflect")
|
||||
assert np.allclose(result.data, data)
|
||||
|
||||
|
||||
def test_custom_convolution_boundary_modes():
|
||||
"""All boundary modes should produce valid output without error."""
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
field = make_field()
|
||||
for mode in ("reflect", "nearest", "wrap"):
|
||||
result, = node.process(field, kernel="1 1 1\n1 1 1\n1 1 1", normalize=True, boundary=mode)
|
||||
assert result.data.shape == field.data.shape
|
||||
|
||||
|
||||
def test_custom_convolution_preserves_metadata():
|
||||
from backend.nodes.filter_custom import CustomConvolution
|
||||
|
||||
node = CustomConvolution()
|
||||
field = make_field()
|
||||
result, = node.process(field, kernel="1", normalize=False, boundary="reflect")
|
||||
assert result.xreal == field.xreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
75
tests/node_tests/filter_kuwahara.py
Normal file
75
tests/node_tests/filter_kuwahara.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_kuwahara_shape_preserved():
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(shape=(48, 64))
|
||||
result, = node.process(field, iterations=1)
|
||||
assert result.data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_kuwahara_flat_field_unchanged():
|
||||
"""A constant field should pass through the Kuwahara filter unchanged."""
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(data=np.full((32, 32), 7.5))
|
||||
result, = node.process(field, iterations=1)
|
||||
assert np.allclose(result.data, 7.5)
|
||||
|
||||
|
||||
def test_kuwahara_reduces_noise():
|
||||
"""Applying the filter to a noisy field should reduce standard deviation."""
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
noisy = rng.standard_normal((64, 64))
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(data=noisy)
|
||||
result, = node.process(field, iterations=1)
|
||||
assert result.data.std() < noisy.std()
|
||||
|
||||
|
||||
def test_kuwahara_preserves_step_edge():
|
||||
"""The Kuwahara filter should preserve a sharp step edge better than a blur."""
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
# Left half = 0, right half = 1
|
||||
data = np.zeros((32, 64))
|
||||
data[:, 32:] = 1.0
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, iterations=1)
|
||||
|
||||
# The edge column should have a large jump (edge preserved)
|
||||
col_before = result.data[:, 30].mean()
|
||||
col_after = result.data[:, 34].mean()
|
||||
assert col_after - col_before > 0.5
|
||||
|
||||
|
||||
def test_kuwahara_multiple_iterations():
|
||||
"""Running multiple iterations should further reduce noise."""
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
noisy = rng.standard_normal((32, 32))
|
||||
node = KuwaharaFilter()
|
||||
field = make_field(data=noisy)
|
||||
result1, = node.process(field, iterations=1)
|
||||
result3, = node.process(field, iterations=3)
|
||||
assert result3.data.std() <= result1.data.std()
|
||||
|
||||
|
||||
def test_kuwahara_preserves_metadata():
|
||||
from backend.nodes.filter_kuwahara import KuwaharaFilter
|
||||
|
||||
node = KuwaharaFilter()
|
||||
field = make_field()
|
||||
result, = node.process(field, iterations=1)
|
||||
assert result.xreal == field.xreal
|
||||
assert result.yreal == field.yreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
67
tests/node_tests/local_contrast.py
Normal file
67
tests/node_tests/local_contrast.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_local_contrast_shape_preserved():
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
field = make_field(shape=(48, 64))
|
||||
result, = node.process(field, kernel_size=10, weight=0.5)
|
||||
assert result.data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_local_contrast_weight_zero_unchanged():
|
||||
"""weight=0 blends 100% original → result equals input."""
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
data = np.random.default_rng(0).standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
result, = node.process(field, kernel_size=5, weight=0.0)
|
||||
assert np.allclose(result.data, data)
|
||||
|
||||
|
||||
def test_local_contrast_uniform_field_unchanged():
|
||||
"""A flat field has nothing to enhance; it should be returned as-is."""
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
field = make_field(data=np.full((32, 32), 2.0))
|
||||
result, = node.process(field, kernel_size=5, weight=1.0)
|
||||
assert np.allclose(result.data, 2.0)
|
||||
|
||||
|
||||
def test_local_contrast_increases_dynamic_range():
|
||||
"""Weight=1 full enhancement should not compress global range beyond input."""
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
data = rng.standard_normal((64, 64))
|
||||
field = make_field(data=data)
|
||||
node = LocalContrast()
|
||||
result, = node.process(field, kernel_size=8, weight=1.0)
|
||||
# Global min/max should be preserved (by construction of the algorithm)
|
||||
assert np.isclose(result.data.min(), data.min(), atol=1e-6)
|
||||
assert np.isclose(result.data.max(), data.max(), atol=1e-6)
|
||||
|
||||
|
||||
def test_local_contrast_preserves_metadata():
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
field = make_field()
|
||||
result, = node.process(field, kernel_size=10, weight=0.5)
|
||||
assert result.xreal == field.xreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
|
||||
|
||||
def test_local_contrast_weight_clipped():
|
||||
"""Values outside [0,1] should be clipped without error."""
|
||||
from backend.nodes.local_contrast import LocalContrast
|
||||
|
||||
node = LocalContrast()
|
||||
field = make_field()
|
||||
result, = node.process(field, kernel_size=5, weight=2.0)
|
||||
assert result.data.shape == field.data.shape
|
||||
@@ -156,10 +156,16 @@ def test_save_generic():
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# LINE as plot image (PNG / TIFF)
|
||||
node.save(filename="line_plot_png", directory_path=tmpdir, format="PNG", value=line)
|
||||
assert Path(tmpdir, "line_plot_png.png").exists()
|
||||
node.save(filename="line_plot_tiff", directory_path=tmpdir, format="TIFF", value=line)
|
||||
assert Path(tmpdir, "line_plot_tiff.tiff").exists()
|
||||
|
||||
# Unsupported LINE format
|
||||
try:
|
||||
node.save(filename="line_bad", directory_path=tmpdir, format="TIFF", value=line)
|
||||
assert False, "Expected ValueError for LINE + TIFF"
|
||||
node.save(filename="line_bad", directory_path=tmpdir, format="OBJ", value=line)
|
||||
assert False, "Expected ValueError for LINE + OBJ"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
110
tests/node_tests/spot_removal.py
Normal file
110
tests/node_tests/spot_removal.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def _make_mask(shape, defect_positions):
|
||||
"""Create a uint8 mask array with 255 at given (row, col) positions."""
|
||||
mask = np.zeros(shape, dtype=np.uint8)
|
||||
for r, c in defect_positions:
|
||||
mask[r, c] = 255
|
||||
return mask
|
||||
|
||||
|
||||
def test_spot_removal_no_mask_returns_field_unchanged():
|
||||
"""Without a mask input the field should be returned as-is."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
field = make_field()
|
||||
result, = node.process(field, method="laplace", max_iter=50)
|
||||
# Should be the identical object (short-circuit path)
|
||||
assert result is field
|
||||
|
||||
|
||||
def test_spot_removal_zero_fill():
|
||||
"""method='zero' sets defect pixels to exactly 0."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
data = np.ones((16, 16)) * 5.0
|
||||
field = make_field(data=data)
|
||||
mask = _make_mask((16, 16), [(4, 4), (8, 8)])
|
||||
result, = node.process(field, method="zero", max_iter=1, mask=mask)
|
||||
assert result.data[4, 4] == pytest.approx(0.0)
|
||||
assert result.data[8, 8] == pytest.approx(0.0)
|
||||
# Non-defect pixels should stay 5.0
|
||||
assert result.data[0, 0] == pytest.approx(5.0)
|
||||
|
||||
|
||||
def test_spot_removal_mean_fill_surrounded_by_constant():
|
||||
"""On a constant field, mean fill should give back the constant."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
data = np.full((16, 16), 3.0)
|
||||
field = make_field(data=data)
|
||||
mask = _make_mask((16, 16), [(7, 7)])
|
||||
result, = node.process(field, method="mean", max_iter=1, mask=mask)
|
||||
assert result.data[7, 7] == pytest.approx(3.0, abs=1e-10)
|
||||
|
||||
|
||||
def test_spot_removal_laplace_fill_surrounded_by_constant():
|
||||
"""Laplace fill on a constant field should recover the constant at the defect."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
data = np.full((16, 16), 2.5)
|
||||
field = make_field(data=data)
|
||||
mask = _make_mask((16, 16), [(8, 8)])
|
||||
result, = node.process(field, method="laplace", max_iter=200, mask=mask)
|
||||
assert result.data[8, 8] == pytest.approx(2.5, abs=1e-3)
|
||||
|
||||
|
||||
def test_spot_removal_laplace_smooth_interpolation():
|
||||
"""Laplace should interpolate between boundary values smoothly."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
# Left half = 0, right half = 10; single defect in the middle
|
||||
data = np.zeros((16, 16))
|
||||
data[:, 8:] = 10.0
|
||||
field = make_field(data=data)
|
||||
# Defect at the boundary column
|
||||
mask = _make_mask((16, 16), [(8, 7)])
|
||||
result, = node.process(field, method="laplace", max_iter=500, mask=mask)
|
||||
# The filled value should be between 0 and 10
|
||||
filled = result.data[8, 7]
|
||||
assert 0.0 <= filled <= 10.0
|
||||
|
||||
|
||||
def test_spot_removal_shape_preserved():
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
field = make_field(shape=(48, 64))
|
||||
mask = _make_mask((48, 64), [(10, 20)])
|
||||
result, = node.process(field, method="mean", max_iter=10, mask=mask)
|
||||
assert result.data.shape == (48, 64)
|
||||
|
||||
|
||||
def test_spot_removal_mask_shape_mismatch_raises():
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
field = make_field(shape=(16, 16))
|
||||
bad_mask = np.ones((32, 32), dtype=np.uint8)
|
||||
with pytest.raises(ValueError, match="Mask shape"):
|
||||
node.process(field, method="zero", max_iter=1, mask=bad_mask)
|
||||
|
||||
|
||||
def test_spot_removal_empty_mask_unchanged():
|
||||
"""An all-zero mask means no defects — field returned unchanged."""
|
||||
from backend.nodes.spot_removal import SpotRemoval
|
||||
|
||||
node = SpotRemoval()
|
||||
data = np.random.default_rng(0).standard_normal((16, 16))
|
||||
field = make_field(data=data)
|
||||
mask = np.zeros((16, 16), dtype=np.uint8)
|
||||
result, = node.process(field, method="laplace", max_iter=50, mask=mask)
|
||||
assert result is field
|
||||
93
tests/node_tests/template_match.py
Normal file
93
tests/node_tests/template_match.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_template_match_exact_match_score_one():
|
||||
"""When template equals the image, the peak score should be 1."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
data = rng.standard_normal((32, 32))
|
||||
image_field = make_field(data=data)
|
||||
# Template is the full image → perfect correlation everywhere → peak = 1
|
||||
template_field = make_field(data=data)
|
||||
node = TemplateMatch()
|
||||
score_field, detections = node.process(image_field, template_field, threshold=0.9)
|
||||
assert score_field.data.max() == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_template_match_output_shape_matches_image():
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = TemplateMatch()
|
||||
score_field, detections = node.process(image_field, template_field, threshold=0.5)
|
||||
assert score_field.data.shape == image_field.data.shape
|
||||
assert detections.shape == image_field.data.shape
|
||||
|
||||
|
||||
def test_template_match_score_in_range():
|
||||
"""Score values should be clipped to [0, 1]."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((6, 6)))
|
||||
node = TemplateMatch()
|
||||
score_field, _ = node.process(image_field, template_field, threshold=0.5)
|
||||
assert score_field.data.min() >= 0.0 - 1e-10
|
||||
assert score_field.data.max() <= 1.0 + 1e-10
|
||||
|
||||
|
||||
def test_template_match_detections_binary():
|
||||
"""Detection mask values should be 0 or 255 only."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = TemplateMatch()
|
||||
_, detections = node.process(image_field, template_field, threshold=0.5)
|
||||
unique_values = set(np.unique(detections))
|
||||
assert unique_values <= {0, 255}
|
||||
|
||||
|
||||
def test_template_match_threshold_zero_all_detected():
|
||||
"""threshold=0 should mark all pixels as detections (score always >= 0)."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(4)
|
||||
image_field = make_field(data=rng.standard_normal((16, 16)))
|
||||
template_field = make_field(data=rng.standard_normal((4, 4)))
|
||||
node = TemplateMatch()
|
||||
_, detections = node.process(image_field, template_field, threshold=0.0)
|
||||
assert np.all(detections == 255)
|
||||
|
||||
|
||||
def test_template_match_threshold_one_sparse_detections():
|
||||
"""threshold=1.0 should detect very few (or no) positions."""
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(5)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = TemplateMatch()
|
||||
_, detections = node.process(image_field, template_field, threshold=1.0)
|
||||
# At threshold=1.0, only perfect matches count (rare for random data)
|
||||
detected_count = int((detections == 255).sum())
|
||||
assert detected_count < 10 # very few or none
|
||||
|
||||
|
||||
def test_template_match_preserves_metadata():
|
||||
from backend.nodes.template_match import TemplateMatch
|
||||
|
||||
rng = np.random.default_rng(6)
|
||||
image_field = make_field(data=rng.standard_normal((32, 32)))
|
||||
template_field = make_field(data=rng.standard_normal((8, 8)))
|
||||
node = TemplateMatch()
|
||||
score_field, _ = node.process(image_field, template_field, threshold=0.5)
|
||||
assert score_field.xreal == image_field.xreal
|
||||
assert score_field.yreal == image_field.yreal
|
||||
85
tests/node_tests/wavelet_denoise.py
Normal file
85
tests/node_tests/wavelet_denoise.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tests.node_tests._shared import make_field
|
||||
|
||||
|
||||
def test_wavelet_denoise_shape_preserved():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field(shape=(64, 64))
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == (64, 64)
|
||||
|
||||
|
||||
def test_wavelet_denoise_reduces_noise():
|
||||
"""Denoising noisy data should reduce standard deviation."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(0)
|
||||
clean = np.outer(np.linspace(0, 1, 32), np.linspace(0, 1, 32))
|
||||
noisy = clean + rng.normal(0, 0.1, clean.shape)
|
||||
field = make_field(data=noisy)
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
# Denoised should be closer to clean than the noisy input
|
||||
denoised_err = np.std(result.data - clean)
|
||||
noisy_err = np.std(noisy - clean)
|
||||
assert denoised_err < noisy_err
|
||||
|
||||
|
||||
def test_wavelet_denoise_uniform_field_unchanged():
|
||||
"""A flat field (no variation) is returned as-is."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field(data=np.full((32, 32), 5.0))
|
||||
result, = node.process(field, wavelet="db1", method="VisuShrink", sigma=0.0, mode="hard")
|
||||
# The short-circuit path returns the original field object
|
||||
assert result is field
|
||||
|
||||
|
||||
def test_wavelet_denoise_preserves_range():
|
||||
"""Output values should stay within the input data range (approx)."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(1)
|
||||
data = rng.standard_normal((32, 32))
|
||||
field = make_field(data=data)
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
# The normalisation ensures output is within [data.min(), data.max()]
|
||||
assert result.data.min() >= data.min() - 1e-10
|
||||
assert result.data.max() <= data.max() + 1e-10
|
||||
|
||||
|
||||
def test_wavelet_denoise_all_wavelets():
|
||||
"""All supported wavelets should run without error."""
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(2)
|
||||
field = make_field(data=rng.standard_normal((32, 32)))
|
||||
node = WaveletDenoise()
|
||||
for wavelet in ("db1", "db2", "db4", "db8", "sym4", "coif1", "bior1.3"):
|
||||
result, = node.process(field, wavelet=wavelet, method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == field.data.shape
|
||||
|
||||
|
||||
def test_wavelet_denoise_visu_shrink():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
rng = np.random.default_rng(3)
|
||||
field = make_field(data=rng.standard_normal((32, 32)))
|
||||
node = WaveletDenoise()
|
||||
result, = node.process(field, wavelet="db4", method="VisuShrink", sigma=0.0, mode="soft")
|
||||
assert result.data.shape == field.data.shape
|
||||
|
||||
|
||||
def test_wavelet_denoise_preserves_metadata():
|
||||
from backend.nodes.wavelet_denoise import WaveletDenoise
|
||||
|
||||
node = WaveletDenoise()
|
||||
field = make_field()
|
||||
result, = node.process(field, wavelet="db4", method="BayesShrink", sigma=0.0, mode="soft")
|
||||
assert result.xreal == field.xreal
|
||||
assert result.si_unit_z == field.si_unit_z
|
||||
Reference in New Issue
Block a user