Files
tono/backend/nodes/fft_2d_invert.py
2026-03-28 21:06:22 -07:00

105 lines
3.8 KiB
Python

from __future__ import annotations
import numpy as np
from backend.node_registry import register_node
from backend.data_types import DataField
@register_node(display_name="Inverse 2D FFT")
class InverseFFT2D:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"spectrum": ("DATA_FIELD",),
"representation": (["magnitude", "log_magnitude", "psdf"],),
},
"optional": {
"phase": ("DATA_FIELD",),
},
}
OUTPUTS = (
('DATA_FIELD', 'image'),
)
FUNCTION = "process"
DESCRIPTION = (
"Reconstruct a spatial-domain image from a 2D frequency spectrum. "
"For exact reconstruction, connect magnitude/phase (or log magnitude/phase, "
"or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed."
)
def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple:
if spectrum.domain != "frequency":
raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.")
if phase is not None:
if phase.data.shape != spectrum.data.shape:
raise ValueError("Phase input must have the same shape as the spectrum.")
if phase.domain != "frequency":
raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.")
amplitude = self._resolve_amplitude(spectrum, representation)
phase_data = phase.data if phase is not None else np.zeros_like(amplitude)
F = amplitude * np.exp(1j * phase_data)
spatial = np.fft.ifft2(np.fft.ifftshift(F)).real
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
z_unit = self._recover_z_unit(spectrum, representation, phase)
out_field = DataField(
data=spatial,
xreal=xreal,
yreal=yreal,
si_unit_xy="m",
si_unit_z=z_unit,
domain="spatial",
colormap=spectrum.colormap,
)
return (out_field,)
def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray:
data = np.asarray(spectrum.data, dtype=np.float64)
if representation == "magnitude":
return np.clip(data, 0.0, None)
if representation == "log_magnitude":
return np.expm1(data)
if representation == "psdf":
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
n = spectrum.xres * spectrum.yres
dx = xreal / spectrum.xres
dy = yreal / spectrum.yres
scale = n * 4.0 * np.pi ** 2 / (dx * dy)
return np.sqrt(np.clip(data, 0.0, None) * scale)
raise ValueError(f"Unsupported spectrum representation: {representation}")
def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]:
if representation == "psdf":
xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal
yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal
else:
xreal = spectrum.xres / spectrum.xreal
yreal = spectrum.yres / spectrum.yreal
return float(xreal), float(yreal)
def _recover_z_unit(
self,
spectrum: DataField,
representation: str,
phase: DataField | None,
) -> str:
if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip():
return phase.si_unit_z
if representation != "psdf":
return spectrum.si_unit_z
unit = str(spectrum.si_unit_z or "").strip()
if unit.startswith("(") and ")^2 m^2" in unit:
return unit.split(")^2 m^2", 1)[0][1:]
if unit.endswith("^2 m^2"):
return unit[:-6].removesuffix("^2").strip()
return ""