From 61b68c142b8de64acba43a659764e28f93bef1ca Mon Sep 17 00:00:00 2001 From: matei jordache Date: Wed, 25 Mar 2026 19:26:20 -0700 Subject: [PATCH] performance buff --- backend/data_types.py | 20 ++++++-- backend/nodes/__init__.py | 2 +- backend/nodes/filters.py | 99 +++++++++++++++++++++++---------------- backend/nodes/mask.py | 58 ++++++++++++++--------- tests/test_nodes.py | 2 +- 5 files changed, 113 insertions(+), 68 deletions(-) diff --git a/backend/data_types.py b/backend/data_types.py index ec29b3f..e65d77a 100644 --- a/backend/data_types.py +++ b/backend/data_types.py @@ -13,6 +13,7 @@ DataField mirrors Gwyddion's GwyDataField structure: from __future__ import annotations from dataclasses import dataclass, field +from functools import lru_cache import numpy as np @@ -141,14 +142,21 @@ def datafield_to_uint8(df: DataField, colormap: str = "gray") -> np.ndarray: Normalize a DataField to a uint8 (H, W, 3) RGB array using matplotlib colormap. Returns shape (H, W, 3) uint8. """ - import matplotlib.cm as cm normalized = normalize_for_colormap( df.data, offset=df.display_offset, scale=df.display_scale, ) - cmap = cm.get_cmap(colormap) + if colormap == "gray": + grey = np.rint(normalized * 255.0).astype(np.uint8) + rgb = np.empty(grey.shape + (3,), dtype=np.uint8) + rgb[..., 0] = grey + rgb[..., 1] = grey + rgb[..., 2] = grey + return rgb + + cmap = _get_colormap(colormap) rgba = cmap(normalized) # (H, W, 4) float [0,1] rgb = (rgba[:, :, :3] * 255).astype(np.uint8) return rgb @@ -180,6 +188,12 @@ def encode_preview(arr: np.ndarray) -> str: img = Image.fromarray(arr) buf = io.BytesIO() - img.save(buf, format="PNG") + img.save(buf, format="PNG", compress_level=1, optimize=False) b64 = base64.b64encode(buf.getvalue()).decode() return f"data:image/png;base64,{b64}" + + +@lru_cache(maxsize=len(COLORMAPS)) +def _get_colormap(colormap: str): + import matplotlib.cm as cm + return cm.get_cmap(colormap) diff --git a/backend/nodes/__init__.py b/backend/nodes/__init__.py index b38760b..d2fc2ff 100644 --- a/backend/nodes/__init__.py +++ b/backend/nodes/__init__.py @@ -4,4 +4,4 @@ from . import io, filters, modify, level, analysis, mask, display try: from . import particle except ImportError: - from . import grains + from . import particless diff --git a/backend/nodes/filters.py b/backend/nodes/filters.py index 8a5f872..916e963 100644 --- a/backend/nodes/filters.py +++ b/backend/nodes/filters.py @@ -10,6 +10,7 @@ Gwyddion equivalents: """ from __future__ import annotations +from functools import lru_cache import numpy as np from backend.node_registry import register_node from backend.data_types import DataField @@ -38,7 +39,7 @@ class GaussianFilter: def process(self, field: DataField, sigma: float) -> tuple: from scipy.ndimage import gaussian_filter - data = gaussian_filter(field.data.copy(), sigma=float(sigma)) + data = gaussian_filter(field.data, sigma=float(sigma)) return (field.replace(data=data),) @@ -66,7 +67,7 @@ class MedianFilter: def process(self, field: DataField, size: int) -> tuple: from scipy.ndimage import median_filter size = max(1, int(size)) - data = median_filter(field.data.copy(), size=size) + data = median_filter(field.data, size=size) return (field.replace(data=data),) @@ -97,7 +98,7 @@ class EdgeDetect: def process(self, field: DataField, method: str, sigma: float) -> tuple: from scipy.ndimage import sobel, prewitt, gaussian_laplace, laplace - data = field.data.copy() + data = field.data if method == "sobel": sx = sobel(data, axis=1) @@ -158,6 +159,45 @@ def _build_1d_transfer(n: int, filter_type: str, cutoff: float, return H +@lru_cache(maxsize=64) +def _cached_1d_transfer(n: int, filter_type: str, cutoff: float, + cutoff_high: float, order: int) -> np.ndarray: + transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order) + transfer.setflags(write=False) + return transfer + + +@lru_cache(maxsize=32) +def _fft_radius_grid(yres: int, xres: int) -> np.ndarray: + fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0 + fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0 + radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0) + np.clip(radius, 0.0, 1.0, out=radius) + radius.setflags(write=False) + return radius + + +@lru_cache(maxsize=128) +def _cached_2d_transfer(yres: int, xres: int, filter_type: str, + cutoff: float, cutoff_high: float, order: int) -> np.ndarray: + radius = _fft_radius_grid(yres, xres) + + if filter_type == "lowpass": + transfer = _butterworth_lp(radius, cutoff, order) + elif filter_type == "highpass": + transfer = _butterworth_hp(radius, cutoff, order) + elif filter_type == "bandpass": + transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order) + elif filter_type == "notch": + band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order) + transfer = 1.0 - band + else: + transfer = np.ones_like(radius) + + transfer.setflags(write=False) + return transfer + + # --------------------------------------------------------------------------- # FFTFilter1D — frequency-domain filtering of LINE profiles # --------------------------------------------------------------------------- @@ -206,7 +246,7 @@ class FFTFilter1D: Z = np.fft.rfft(z) # Build and apply transfer function - H = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order) + H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order)) Z *= H # Inverse FFT @@ -258,45 +298,24 @@ class FFTFilter2D: def process(self, field: DataField, filter_type: str, cutoff: float, cutoff_high: float, order: int) -> tuple: - data = field.data.copy() + data = field.data yres, xres = data.shape - # Subtract mean to avoid DC leakage artefacts - mean_val = data.mean() - data -= mean_val + # Subtract mean to avoid DC leakage artefacts. + mean_val = float(data.mean()) + centered = data - mean_val - # Forward 2D FFT - F = np.fft.fft2(data) - F = np.fft.fftshift(F) - - # Build radial frequency grid normalised to [0, 1] (1 = Nyquist) - fy = np.fft.fftshift(np.fft.fftfreq(yres)) # range [-0.5, 0.5) - fx = np.fft.fftshift(np.fft.fftfreq(xres)) - FX, FY = np.meshgrid(fx, fy) - # Normalise so that corner = 1 in each axis independently, - # then take Euclidean norm; max radial value = 1.0 at Nyquist. - R = np.sqrt((FX / 0.5) ** 2 + (FY / 0.5) ** 2) - R = np.clip(R / R.max(), 0, 1) if R.max() > 0 else R - - # Build transfer function - if filter_type == "lowpass": - H = _butterworth_lp(R, cutoff, order) - elif filter_type == "highpass": - H = _butterworth_hp(R, cutoff, order) - elif filter_type == "bandpass": - H = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order) - elif filter_type == "notch": - bp = _butterworth_hp(R, cutoff, order) * _butterworth_lp(R, cutoff_high, order) - H = 1.0 - bp - else: - H = np.ones_like(R) - - # Apply filter - F *= H - - # Inverse FFT - F = np.fft.ifftshift(F) - result = np.fft.ifft2(F).real + # Real-valued FFT keeps only the unique half-plane and avoids shift copies. + spectrum = np.fft.rfft2(centered) + transfer = _cached_2d_transfer( + yres, + xres, + filter_type, + float(cutoff), + float(cutoff_high), + int(order), + ) + result = np.fft.irfft2(spectrum * transfer, s=(yres, xres)) # Restore DC result += mean_val diff --git a/backend/nodes/mask.py b/backend/nodes/mask.py index 4222af3..a370e1e 100644 --- a/backend/nodes/mask.py +++ b/backend/nodes/mask.py @@ -9,6 +9,7 @@ Gwyddion equivalents: """ from __future__ import annotations +from functools import lru_cache import json import numpy as np from backend.node_registry import register_node @@ -21,13 +22,36 @@ def _mask_overlay(field: DataField, mask: np.ndarray) -> np.ndarray: Returns (H, W, 3) uint8 array. """ grey = datafield_to_uint8(field, "gray") # (H, W, 3) uint8 - overlay = grey.astype(np.float64) - mask_bool = mask == 255 - alpha = 0.45 - overlay[mask_bool, 0] = overlay[mask_bool, 0] * (1 - alpha) + 255 * alpha - overlay[mask_bool, 1] = overlay[mask_bool, 1] * (1 - alpha) - overlay[mask_bool, 2] = overlay[mask_bool, 2] * (1 - alpha) - return np.clip(overlay, 0, 255).astype(np.uint8) + mask_bool = mask > 127 + if not np.any(mask_bool): + return grey + + overlay = grey.copy() + red = overlay[..., 0] + green = overlay[..., 1] + blue = overlay[..., 2] + + # Integer alpha blend equivalent to a 45% red overlay, without float64 work. + red_vals = red[mask_bool].astype(np.uint16) + green_vals = green[mask_bool].astype(np.uint16) + blue_vals = blue[mask_bool].astype(np.uint16) + red[mask_bool] = ((red_vals * 55) + (255 * 45) + 50) // 100 + green[mask_bool] = ((green_vals * 55) + 50) // 100 + blue[mask_bool] = ((blue_vals * 55) + 50) // 100 + return overlay + + +@lru_cache(maxsize=128) +def _mask_structure(radius: int, shape: str) -> np.ndarray: + radius = max(1, int(radius)) + if shape == "disk": + y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1] + struct = (x * x + y * y) <= radius * radius + else: + size = 2 * radius + 1 + struct = np.ones((size, size), dtype=bool) + struct.setflags(write=False) + return struct def _clamp_fraction(value) -> float: @@ -285,31 +309,19 @@ class MaskMorphology: def process(self, mask: np.ndarray, operation: str, radius: int, shape: str, field: DataField | None = None) -> tuple: - from scipy.ndimage import binary_dilation, binary_erosion + from scipy.ndimage import binary_closing, binary_dilation, binary_erosion, binary_opening binary = mask > 127 - - if shape == "disk": - y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1] - struct = (x * x + y * y) <= radius * radius - else: - size = 2 * radius + 1 - struct = np.ones((size, size), dtype=bool) + struct = _mask_structure(radius, shape) if operation == "dilate": result = binary_dilation(binary, structure=struct) elif operation == "erode": result = binary_erosion(binary, structure=struct) elif operation == "open": - result = binary_dilation( - binary_erosion(binary, structure=struct), - structure=struct, - ) + result = binary_opening(binary, structure=struct) elif operation == "close": - result = binary_erosion( - binary_dilation(binary, structure=struct), - structure=struct, - ) + result = binary_closing(binary, structure=struct) else: raise ValueError(f"Unknown morphological operation: {operation}") diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 909eb78..71a13b4 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -758,7 +758,7 @@ def test_draw_mask(): def test_particle_analysis(): print("=== Test: ParticleAnalysis ===") - from backend.nodes.grains import ParticleAnalysis + from backend.nodes.particless import ParticleAnalysis node = ParticleAnalysis() # Create a field with two distinct particles