fft multi channel output
This commit is contained in:
@@ -266,21 +266,20 @@ class FFT2D:
|
||||
"field": ("DATA_FIELD",),
|
||||
"windowing": (["hann", "hamming", "blackman", "none"],),
|
||||
"level": (["mean", "plane", "none"],),
|
||||
"output": (["log_magnitude", "magnitude", "phase", "psdf"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("spectrum",)
|
||||
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "DATA_FIELD", "DATA_FIELD")
|
||||
RETURN_NAMES = ("log_magnitude", "magnitude", "phase", "psdf")
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Compute the 2D FFT with optional windowing and mean/plane subtraction. "
|
||||
"Output can be log magnitude, magnitude, phase, or PSDF. "
|
||||
"Outputs log magnitude, magnitude, phase, and PSDF as separate channels. "
|
||||
"Equivalent to gwy_data_field_2dfft / gwy_data_field_2dpsdf."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, windowing: str, level: str, output: str) -> tuple:
|
||||
def process(self, field: DataField, windowing: str, level: str) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
@@ -320,8 +319,59 @@ class FFT2D:
|
||||
F = np.fft.fftshift(np.fft.fft2(data))
|
||||
n = xres * yres
|
||||
|
||||
if output == "log_magnitude":
|
||||
mag = np.abs(F)
|
||||
magnitude = np.abs(F)
|
||||
log_magnitude = np.log1p(magnitude)
|
||||
phase = np.angle(F)
|
||||
|
||||
dx = field.xreal / xres
|
||||
dy = field.yreal / yres
|
||||
psdf = (magnitude ** 2) * dx * dy / (n * 4.0 * np.pi ** 2)
|
||||
|
||||
spatial_freq_xreal = xres / field.xreal
|
||||
spatial_freq_yreal = yres / field.yreal
|
||||
angular_freq_xreal = 2.0 * np.pi * xres / field.xreal
|
||||
angular_freq_yreal = 2.0 * np.pi * yres / field.yreal
|
||||
|
||||
return (
|
||||
DataField(
|
||||
data=log_magnitude,
|
||||
xreal=spatial_freq_xreal,
|
||||
yreal=spatial_freq_yreal,
|
||||
si_unit_xy="1/m",
|
||||
si_unit_z=field.si_unit_z,
|
||||
domain="frequency",
|
||||
colormap=field.colormap,
|
||||
),
|
||||
DataField(
|
||||
data=magnitude,
|
||||
xreal=spatial_freq_xreal,
|
||||
yreal=spatial_freq_yreal,
|
||||
si_unit_xy="1/m",
|
||||
si_unit_z=field.si_unit_z,
|
||||
domain="frequency",
|
||||
colormap=field.colormap,
|
||||
),
|
||||
DataField(
|
||||
data=phase,
|
||||
xreal=spatial_freq_xreal,
|
||||
yreal=spatial_freq_yreal,
|
||||
si_unit_xy="1/m",
|
||||
si_unit_z=field.si_unit_z,
|
||||
domain="frequency",
|
||||
colormap=field.colormap,
|
||||
),
|
||||
DataField(
|
||||
data=psdf,
|
||||
xreal=angular_freq_xreal,
|
||||
yreal=angular_freq_yreal,
|
||||
si_unit_xy="1/m",
|
||||
si_unit_z=f"({field.si_unit_z})^2 m^2",
|
||||
domain="frequency",
|
||||
colormap=field.colormap,
|
||||
),
|
||||
)
|
||||
|
||||
if False: # Unreachable legacy block retained below.
|
||||
# Log scale with floor to avoid log(0)
|
||||
result = np.log1p(mag)
|
||||
elif output == "magnitude":
|
||||
@@ -359,6 +409,109 @@ class FFT2D:
|
||||
return (out_field,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InverseFFT2D
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@register_node(display_name="Inverse 2D FFT")
|
||||
class InverseFFT2D:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"spectrum": ("DATA_FIELD",),
|
||||
"representation": (["magnitude", "log_magnitude", "psdf"],),
|
||||
},
|
||||
"optional": {
|
||||
"phase": ("DATA_FIELD",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
RETURN_NAMES = ("image",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "analysis"
|
||||
DESCRIPTION = (
|
||||
"Reconstruct a spatial-domain image from a 2D frequency spectrum. "
|
||||
"For exact reconstruction, connect magnitude/phase (or log magnitude/phase, "
|
||||
"or PSDF/phase) from the 2D FFT node. If phase is omitted, zero phase is assumed."
|
||||
)
|
||||
|
||||
def process(self, spectrum: DataField, representation: str, phase: DataField | None = None) -> tuple:
|
||||
if spectrum.domain != "frequency":
|
||||
raise ValueError("Inverse 2D FFT requires a frequency-domain DATA_FIELD input.")
|
||||
|
||||
if phase is not None:
|
||||
if phase.data.shape != spectrum.data.shape:
|
||||
raise ValueError("Phase input must have the same shape as the spectrum.")
|
||||
if phase.domain != "frequency":
|
||||
raise ValueError("Phase input must also be a frequency-domain DATA_FIELD.")
|
||||
|
||||
amplitude = self._resolve_amplitude(spectrum, representation)
|
||||
phase_data = phase.data if phase is not None else np.zeros_like(amplitude)
|
||||
F = amplitude * np.exp(1j * phase_data)
|
||||
|
||||
spatial = np.fft.ifft2(np.fft.ifftshift(F)).real
|
||||
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
|
||||
z_unit = self._recover_z_unit(spectrum, representation, phase)
|
||||
|
||||
out_field = DataField(
|
||||
data=spatial,
|
||||
xreal=xreal,
|
||||
yreal=yreal,
|
||||
si_unit_xy="m",
|
||||
si_unit_z=z_unit,
|
||||
domain="spatial",
|
||||
colormap=spectrum.colormap,
|
||||
)
|
||||
return (out_field,)
|
||||
|
||||
def _resolve_amplitude(self, spectrum: DataField, representation: str) -> np.ndarray:
|
||||
data = np.asarray(spectrum.data, dtype=np.float64)
|
||||
|
||||
if representation == "magnitude":
|
||||
return np.clip(data, 0.0, None)
|
||||
if representation == "log_magnitude":
|
||||
return np.expm1(data)
|
||||
if representation == "psdf":
|
||||
xreal, yreal = self._recover_spatial_extent(spectrum, representation)
|
||||
n = spectrum.xres * spectrum.yres
|
||||
dx = xreal / spectrum.xres
|
||||
dy = yreal / spectrum.yres
|
||||
scale = n * 4.0 * np.pi ** 2 / (dx * dy)
|
||||
return np.sqrt(np.clip(data, 0.0, None) * scale)
|
||||
|
||||
raise ValueError(f"Unsupported spectrum representation: {representation}")
|
||||
|
||||
def _recover_spatial_extent(self, spectrum: DataField, representation: str) -> tuple[float, float]:
|
||||
if representation == "psdf":
|
||||
xreal = 2.0 * np.pi * spectrum.xres / spectrum.xreal
|
||||
yreal = 2.0 * np.pi * spectrum.yres / spectrum.yreal
|
||||
else:
|
||||
xreal = spectrum.xres / spectrum.xreal
|
||||
yreal = spectrum.yres / spectrum.yreal
|
||||
return float(xreal), float(yreal)
|
||||
|
||||
def _recover_z_unit(
|
||||
self,
|
||||
spectrum: DataField,
|
||||
representation: str,
|
||||
phase: DataField | None,
|
||||
) -> str:
|
||||
if phase is not None and isinstance(phase.si_unit_z, str) and phase.si_unit_z.strip():
|
||||
return phase.si_unit_z
|
||||
|
||||
if representation != "psdf":
|
||||
return spectrum.si_unit_z
|
||||
|
||||
unit = str(spectrum.si_unit_z or "").strip()
|
||||
if unit.startswith("(") and ")^2 m^2" in unit:
|
||||
return unit.split(")^2 m^2", 1)[0][1:]
|
||||
if unit.endswith("^2 m^2"):
|
||||
return unit[:-6].removesuffix("^2").strip()
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CrossSection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user