Files
tono/backend/nodes/psf_estimation.py
2026-04-04 00:25:53 -07:00

178 lines
6.0 KiB
Python

"""PSF estimation — estimate and fit point spread functions for deconvolution."""
from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField, RecordTable
@register_node(display_name="PSF Estimation")
class PSFEstimation:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"measured": ("DATA_FIELD",),
"ideal": ("DATA_FIELD",),
"method": (["wiener", "least_squares", "gaussian_fit"], {"default": "wiener"}),
"regularization": ("FLOAT", {"default": 0.01, "min": 1e-6, "max": 1.0, "step": 0.001}),
"psf_size": ("INT", {"default": 32, "min": 4, "max": 128}),
}
}
OUTPUTS = (
('DATA_FIELD', 'psf'),
('RECORD_TABLE', 'parameters'),
)
FUNCTION = "process"
DESCRIPTION = (
"Estimate a point spread function (PSF) from a measured (blurred) image "
"and an ideal (sharp) reference. The PSF can then be used with the "
"Deconvolution node to restore other images. Three methods are available: "
"pseudo-Wiener deconvolution, regularised least-squares, and Gaussian fit. "
"Equivalent to Gwyddion's psf.c / psf-fit.c modules."
)
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
@staticmethod
def _crop_centre(arr: np.ndarray, size: int) -> np.ndarray:
"""Crop the central *size x size* region from *arr*."""
yc, xc = arr.shape[0] // 2, arr.shape[1] // 2
half = size // 2
return arr[yc - half : yc - half + size, xc - half : xc - half + size]
@staticmethod
def _normalise(psf: np.ndarray) -> np.ndarray:
"""Normalise so that the PSF sums to 1."""
s = psf.sum()
if abs(s) > 1e-30:
psf = psf / s
return psf
@staticmethod
def _fit_gaussian_2d(psf: np.ndarray):
"""Fit a 2-D Gaussian to *psf* using moment analysis.
Returns (gaussian_array, sigma_x, sigma_y, amplitude).
"""
h, w = psf.shape
psf_pos = np.maximum(psf, 0.0)
total = psf_pos.sum()
if total < 1e-30:
return np.zeros_like(psf), 0.0, 0.0, 0.0
y_idx, x_idx = np.mgrid[0:h, 0:w]
# centroid
cx = float(np.sum(x_idx * psf_pos) / total)
cy = float(np.sum(y_idx * psf_pos) / total)
# second moments → sigma
sx = float(np.sqrt(np.sum(psf_pos * (x_idx - cx) ** 2) / total))
sy = float(np.sqrt(np.sum(psf_pos * (y_idx - cy) ** 2) / total))
sx = max(sx, 1e-6)
sy = max(sy, 1e-6)
amplitude = float(psf_pos.max())
gauss = amplitude * np.exp(
-((x_idx - cx) ** 2 / (2 * sx ** 2) + (y_idx - cy) ** 2 / (2 * sy ** 2))
)
gauss = PSFEstimation._normalise(gauss)
return gauss, sx, sy, amplitude
# ------------------------------------------------------------------
# methods
# ------------------------------------------------------------------
def _wiener(
self,
F_measured: np.ndarray,
F_ideal: np.ndarray,
regularization: float,
psf_size: int,
) -> np.ndarray:
"""Pseudo-Wiener PSF estimation."""
F_psf = np.conj(F_ideal) * F_measured / (np.abs(F_ideal) ** 2 + regularization)
psf = np.real(np.fft.ifft2(F_psf))
psf = np.fft.fftshift(psf)
psf = self._crop_centre(psf, psf_size)
return self._normalise(psf)
def _least_squares(
self,
F_measured: np.ndarray,
F_ideal: np.ndarray,
regularization: float,
psf_size: int,
) -> np.ndarray:
"""Regularised least-squares PSF estimation."""
abs_ideal = np.abs(F_ideal)
F_psf = np.where(
abs_ideal < regularization,
0.0,
F_measured / (F_ideal + regularization * np.sign(F_ideal)),
)
psf = np.real(np.fft.ifft2(F_psf))
psf = np.fft.fftshift(psf)
psf = self._crop_centre(psf, psf_size)
return self._normalise(psf)
# ------------------------------------------------------------------
# main entry
# ------------------------------------------------------------------
def process(
self,
measured: DataField,
ideal: DataField,
method: str,
regularization: float,
psf_size: int,
) -> tuple:
measured_data = np.asarray(measured.data, dtype=np.float64)
ideal_data = np.asarray(ideal.data, dtype=np.float64)
F_measured = np.fft.fft2(measured_data)
F_ideal = np.fft.fft2(ideal_data)
parameters = RecordTable()
if method == "wiener":
psf = self._wiener(F_measured, F_ideal, regularization, psf_size)
elif method == "least_squares":
psf = self._least_squares(F_measured, F_ideal, regularization, psf_size)
elif method == "gaussian_fit":
raw_psf = self._wiener(F_measured, F_ideal, regularization, psf_size)
psf, sigma_x, sigma_y, amplitude = self._fit_gaussian_2d(raw_psf)
parameters = RecordTable([
{"quantity": "sigma_x", "value": sigma_x, "unit": "px"},
{"quantity": "sigma_y", "value": sigma_y, "unit": "px"},
{"quantity": "amplitude", "value": amplitude, "unit": ""},
])
else:
raise ValueError(f"Unknown PSF estimation method: {method!r}")
# Build output DataField — inherit spatial metadata, adjust for psf_size
yres, xres = measured_data.shape
psf_xreal = measured.xreal * psf_size / xres
psf_yreal = measured.yreal * psf_size / yres
psf_field = measured.replace(
data=psf,
xreal=psf_xreal,
yreal=psf_yreal,
xoff=0.0,
yoff=0.0,
)
return (psf_field, parameters)