105 lines
3.8 KiB
Python
105 lines
3.8 KiB
Python
from __future__ import annotations
|
|
import numpy as np
|
|
from backend.node_registry import register_node
|
|
from backend.data_types import DataField
|
|
|
|
|
|
@register_node(display_name="Inverse 2D FFT")
|
|
class InverseFFT2D:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"spectrum": ("DATA_FIELD",),
|
|
"representation": (["magnitude", "log_magnitude", "psdf"],),
|
|
},
|
|
"optional": {
|
|
"phase": ("DATA_FIELD",),
|
|
},
|
|
}
|
|
|
|
OUTPUTS = (
|
|
('DATA_FIELD', 'image'),
|
|
)
|
|
FUNCTION = "process"
|
|
|
|
DESCRIPTION = (
|
|
"Reconstruct a spatial-domain image from a 2D frequency spectrum. "
|
|
"For exact reconstruction, connect magnitude/phase (or log magnitude/phase, "
|
|
"or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed."
|
|
)
|
|
|
|
def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple:
|
|
if spectrum.domain != "frequency":
|
|
raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.")
|
|
|
|
if phase is not None:
|
|
if phase.data.shape != spectrum.data.shape:
|
|
raise ValueError("Phase input must have the same shape as the spectrum.")
|
|
if phase.domain != "frequency":
|
|
raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.")
|
|
|
|
amplitude = self._resolve_amplitude(spectrum, representation)
|
|
phase_data = phase.data if phase is not None else np.zeros_like(amplitude)
|
|
F = amplitude * np.exp(1j * phase_data)
|
|
|
|
spatial = np.fft.ifft2(np.fft.ifftshift(F)).real
|
|
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
|
|
z_unit = self._recover_z_unit(spectrum, representation, phase)
|
|
|
|
out_field = DataField(
|
|
data=spatial,
|
|
xreal=xreal,
|
|
yreal=yreal,
|
|
si_unit_xy="m",
|
|
si_unit_z=z_unit,
|
|
domain="spatial",
|
|
colormap=spectrum.colormap,
|
|
)
|
|
return (out_field,)
|
|
|
|
def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray:
|
|
data = np.asarray(spectrum.data, dtype=np.float64)
|
|
|
|
if representation == "magnitude":
|
|
return np.clip(data, 0.0, None)
|
|
if representation == "log_magnitude":
|
|
return np.expm1(data)
|
|
if representation == "psdf":
|
|
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
|
|
n = spectrum.xres * spectrum.yres
|
|
dx = xreal / spectrum.xres
|
|
dy = yreal / spectrum.yres
|
|
scale = n * 4.0 * np.pi ** 2 / (dx * dy)
|
|
return np.sqrt(np.clip(data, 0.0, None) * scale)
|
|
|
|
raise ValueError(f"Unsupported spectrum representation: {representation}")
|
|
|
|
def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]:
|
|
if representation == "psdf":
|
|
xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal
|
|
yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal
|
|
else:
|
|
xreal = spectrum.xres / spectrum.xreal
|
|
yreal = spectrum.yres / spectrum.yreal
|
|
return float(xreal), float(yreal)
|
|
|
|
def _recover_z_unit(
|
|
self,
|
|
spectrum: DataField,
|
|
representation: str,
|
|
phase: DataField | None,
|
|
) -> str:
|
|
if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip():
|
|
return phase.si_unit_z
|
|
|
|
if representation != "psdf":
|
|
return spectrum.si_unit_z
|
|
|
|
unit = str(spectrum.si_unit_z or "").strip()
|
|
if unit.startswith("(") and ")^2 m^2" in unit:
|
|
return unit.split(")^2 m^2", 1)[0][1:]
|
|
if unit.endswith("^2 m^2"):
|
|
return unit[:-6].removesuffix("^2").strip()
|
|
return ""
|