performance buff

This commit is contained in:
2026-03-25 19:26:20 -07:00
parent ca59bac478
commit 61b68c142b
5 changed files with 113 additions and 68 deletions

View File

@@ -10,6 +10,7 @@ Gwyddion equivalents:
"""
from __future__ import annotations
from functools import lru_cache
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@@ -38,7 +39,7 @@ class GaussianFilter:
def process(self, field: DataField, sigma: float) -> tuple:
from scipy.ndimage import gaussian_filter
data = gaussian_filter(field.data.copy(), sigma=float(sigma))
data = gaussian_filter(field.data, sigma=float(sigma))
return (field.replace(data=data),)
@@ -66,7 +67,7 @@ class MedianFilter:
def process(self, field: DataField, size: int) -> tuple:
from scipy.ndimage import median_filter
size = max(1, int(size))
data = median_filter(field.data.copy(), size=size)
data = median_filter(field.data, size=size)
return (field.replace(data=data),)
@@ -97,7 +98,7 @@ class EdgeDetect:
def process(self, field: DataField, method: str, sigma: float) -> tuple:
from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace
data = field.data.copy()
data = field.data
if method == "sobel":
sx = sobel(data, axis=1)
@@ -158,6 +159,45 @@ def _build_1d_transfer(n: int, filter_type: str, cutoff: float,
return H
@lru_cache(maxsize=64)
def _cached_1d_transfer(n: int, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> np.ndarray:
transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
transfer.setflags(write=False)
return transfer
@lru_cache(maxsize=32)
def _fft_radius_grid(yres: int, xres: int) -> np.ndarray:
fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0
fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0
radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0)
np.clip(radius, 0.0, 1.0, out=radius)
radius.setflags(write=False)
return radius
@lru_cache(maxsize=128)
def _cached_2d_transfer(yres: int, xres: int, filter_type: str,
cutoff: float, cutoff_high: float, order: int) -> np.ndarray:
radius = _fft_radius_grid(yres, xres)
if filter_type == "lowpass":
transfer = _butterworth_lp(radius, cutoff, order)
elif filter_type == "highpass":
transfer = _butterworth_hp(radius, cutoff, order)
elif filter_type == "bandpass":
transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
elif filter_type == "notch":
band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
transfer = 1.0 - band
else:
transfer = np.ones_like(radius)
transfer.setflags(write=False)
return transfer
# ---------------------------------------------------------------------------
# FFTFilter1D — frequency-domain filtering of LINE profiles
# ---------------------------------------------------------------------------
@@ -206,7 +246,7 @@ class FFTFilter1D:
Z = np.fft.rfft(z)
# Build and apply transfer function
H = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order))
Z *= H
# Inverse FFT
@@ -258,45 +298,24 @@ class FFTFilter2D:
def process(self, field: DataField, filter_type: str, cutoff: float,
cutoff_high: float, order: int) -> tuple:
data = field.data.copy()
data = field.data
yres, xres = data.shape
# Subtract mean to avoid DC leakage artefacts
mean_val = data.mean()
data -= mean_val
# Subtract mean to avoid DC leakage artefacts.
mean_val = float(data.mean())
centered = data - mean_val
# Forward 2D FFT
F = np.fft.fft2(data)
F = np.fft.fftshift(F)
# Build radial frequency grid normalised to [0, 1] (1 = Nyquist)
fy = np.fft.fftshift(np.fft.fftfreq(yres)) # range [-0.5, 0.5)
fx = np.fft.fftshift(np.fft.fftfreq(xres))
FX, FY = np.meshgrid(fx, fy)
# Normalise so that corner = 1 in each axis independently,
# then take Euclidean norm; max radial value = 1.0 at Nyquist.
R = np.sqrt((FX / 0.5) ** 2 + (FY / 0.5) ** 2)
R = np.clip(R / R.max(), 0, 1) if R.max() > 0 else R
# Build transfer function
if filter_type == "lowpass":
H = _butterworth_lp(R, cutoff, order)
elif filter_type == "highpass":
H = _butterworth_hp(R, cutoff, order)
elif filter_type == "bandpass":
H = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order)
elif filter_type == "notch":
bp = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order)
H = 1.0 - bp
else:
H = np.ones_like(R)
# Apply filter
F *= H
# Inverse FFT
F = np.fft.ifftshift(F)
result = np.fft.ifft2(F).real
# Real-valued FFT keeps only the unique half-plane and avoids shift copies.
spectrum = np.fft.rfft2(centered)
transfer = _cached_2d_transfer(
yres,
xres,
filter_type,
float(cutoff),
float(cutoff_high),
int(order),
)
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
# Restore DC
result += mean_val