combine fft filter into a single node, fix tests
This commit is contained in:
@@ -31,7 +31,7 @@ from threading import RLock
|
||||
from time import perf_counter
|
||||
from typing import Any, Callable
|
||||
|
||||
from backend.node_registry import NODE_CLASS_MAPPINGS, get_node_output_types
|
||||
from backend.node_registry import NODE_CLASS_MAPPINGS, get_node_output_types, get_node_output_accepted_types
|
||||
from backend.execution_context import active_node, execution_callbacks
|
||||
|
||||
|
||||
@@ -462,16 +462,19 @@ class ExecutionEngine:
|
||||
return
|
||||
|
||||
return_types = get_node_output_types(cls)
|
||||
output_accepted = get_node_output_accepted_types(cls)
|
||||
|
||||
for slot, type_name in enumerate(return_types):
|
||||
if slot >= len(result):
|
||||
break
|
||||
value = result[slot]
|
||||
all_types = {type_name} | set(output_accepted[slot] if slot < len(output_accepted) else [])
|
||||
|
||||
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
|
||||
# For polymorphic outputs, check the actual runtime type first.
|
||||
if isinstance(value, DataField) and ("DATA_FIELD" in all_types) and on_preview:
|
||||
arr = render_datafield_preview(value, value.colormap)
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return # one preview per node is enough
|
||||
return
|
||||
|
||||
if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview:
|
||||
arr = image_to_uint8(value)
|
||||
@@ -488,7 +491,7 @@ class ExecutionEngine:
|
||||
on_preview(node_id, encode_preview(arr))
|
||||
return
|
||||
|
||||
if type_name == "LINE" and isinstance(value, (np.ndarray, LineData)) and on_preview:
|
||||
if "LINE" in all_types and isinstance(value, (np.ndarray, LineData)) and on_preview:
|
||||
preview = self._render_line_preview(cls, slot, result)
|
||||
if preview:
|
||||
on_preview(node_id, preview)
|
||||
|
||||
@@ -15,28 +15,38 @@ NODE_CLASS_MAPPINGS: dict[str, type] = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {}
|
||||
|
||||
|
||||
def get_node_output_specs(cls: type) -> tuple[tuple[str, str], ...]:
|
||||
def get_node_output_specs(cls: type) -> tuple[tuple[str, str, dict], ...]:
|
||||
raw_outputs = getattr(cls, "OUTPUTS", None)
|
||||
if raw_outputs is None:
|
||||
raise AttributeError(f"{cls.__name__} must define OUTPUTS.")
|
||||
|
||||
specs: list[tuple[str, str]] = []
|
||||
specs: list[tuple[str, str, dict]] = []
|
||||
for index, output in enumerate(raw_outputs):
|
||||
if not isinstance(output, (list, tuple)) or len(output) != 2:
|
||||
if not isinstance(output, (list, tuple)) or len(output) not in (2, 3):
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.OUTPUTS[{index}] must be a 2-item tuple of (type, name)."
|
||||
f"{cls.__name__}.OUTPUTS[{index}] must be a 2- or 3-item tuple of (type, name[, meta])."
|
||||
)
|
||||
type_name, name = output
|
||||
specs.append((str(type_name), str(name)))
|
||||
type_name = output[0]
|
||||
name = output[1]
|
||||
meta: dict = output[2] if len(output) == 3 else {}
|
||||
specs.append((str(type_name), str(name), meta))
|
||||
return tuple(specs)
|
||||
|
||||
|
||||
def get_node_output_types(cls: type) -> tuple[str, ...]:
|
||||
return tuple(type_name for type_name, _ in get_node_output_specs(cls))
|
||||
return tuple(type_name for type_name, _, _meta in get_node_output_specs(cls))
|
||||
|
||||
|
||||
def get_node_output_names(cls: type) -> tuple[str, ...]:
|
||||
return tuple(name for _, name in get_node_output_specs(cls))
|
||||
return tuple(name for _, name, _meta in get_node_output_specs(cls))
|
||||
|
||||
|
||||
def get_node_output_accepted_types(cls: type) -> tuple[list[str], ...]:
|
||||
"""Return per-slot accepted_types lists (empty list means only the declared type)."""
|
||||
return tuple(
|
||||
list(meta.get("accepted_types", []))
|
||||
for _, _, meta in get_node_output_specs(cls)
|
||||
)
|
||||
|
||||
|
||||
def register_node(display_name: str | None = None):
|
||||
@@ -77,6 +87,7 @@ def get_node_info(class_name: str) -> dict[str, Any]:
|
||||
"input_order": {k: list(v.keys()) for k, v in input_types.items()},
|
||||
"output": list(get_node_output_types(cls)),
|
||||
"output_name": list(get_node_output_names(cls)),
|
||||
"output_accepted_types": list(get_node_output_accepted_types(cls)),
|
||||
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
|
||||
"manual_trigger": bool(getattr(cls, "MANUAL_TRIGGER", False)),
|
||||
"description": getattr(cls, "DESCRIPTION", ""),
|
||||
|
||||
@@ -4,8 +4,7 @@ from backend.nodes import (
|
||||
colormap,
|
||||
crop_resize,
|
||||
fft_2d_inverse,
|
||||
filter_fft_1d,
|
||||
filter_fft_2d,
|
||||
filter_fft,
|
||||
filter_gaussian,
|
||||
filter_median,
|
||||
flip,
|
||||
|
||||
89
backend/nodes/filter_fft.py
Normal file
89
backend/nodes/filter_fft.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, LineData
|
||||
from backend.nodes.helpers import _cached_1d_transfer, _cached_2d_transfer
|
||||
|
||||
|
||||
@register_node(display_name="FFT Filter")
|
||||
class FFTFilter:
|
||||
"""Frequency-domain filtering of a line profile or 2-D data field.
|
||||
|
||||
Accepts either a LINE or DATA_FIELD and returns a filtered output of the
|
||||
same type. Uses a Butterworth transfer function with configurable order
|
||||
for a smooth roll-off. Equivalent to Gwyddion fft_filter_1d / fft_filter_2d.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input": ("LINE", {
|
||||
"label": "input",
|
||||
"accepted_types": ["DATA_FIELD"],
|
||||
}),
|
||||
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
|
||||
"cutoff": ("FLOAT", {
|
||||
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"cutoff_high": ("FLOAT", {
|
||||
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('LINE', 'filtered', {"accepted_types": ["DATA_FIELD"]}),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Frequency-domain filtering of a line profile or 2-D data field. "
|
||||
"Connect a LINE for 1-D filtering or a DATA_FIELD for 2-D filtering — "
|
||||
"the output mirrors the input type. "
|
||||
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
|
||||
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency."
|
||||
)
|
||||
|
||||
def process(self, input, filter_type: str, cutoff: float,
|
||||
cutoff_high: float, order: int) -> tuple:
|
||||
if isinstance(input, DataField):
|
||||
return self._process_field(input, filter_type, float(cutoff), float(cutoff_high), int(order))
|
||||
return self._process_line(input, filter_type, float(cutoff), float(cutoff_high), int(order))
|
||||
|
||||
def _process_line(self, line, filter_type: str, cutoff: float,
|
||||
cutoff_high: float, order: int) -> tuple:
|
||||
z = np.asarray(line, dtype=np.float64).ravel()
|
||||
n = len(z)
|
||||
|
||||
Z = np.fft.rfft(z)
|
||||
H = _cached_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
|
||||
Z *= H
|
||||
filtered = np.fft.irfft(Z, n=n)
|
||||
|
||||
if isinstance(line, LineData):
|
||||
return (
|
||||
LineData(
|
||||
data=filtered,
|
||||
x_axis=line.x_axis.copy() if line.x_axis is not None else None,
|
||||
x_unit=line.x_unit,
|
||||
y_unit=line.y_unit,
|
||||
),
|
||||
)
|
||||
return (filtered,)
|
||||
|
||||
def _process_field(self, field: DataField, filter_type: str, cutoff: float,
|
||||
cutoff_high: float, order: int) -> tuple:
|
||||
data = field.data
|
||||
yres, xres = data.shape
|
||||
|
||||
mean_val = float(data.mean())
|
||||
centered = data - mean_val
|
||||
|
||||
spectrum = np.fft.rfft2(centered)
|
||||
transfer = _cached_2d_transfer(yres, xres, filter_type, cutoff, cutoff_high, order)
|
||||
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
|
||||
result += mean_val
|
||||
|
||||
return (field.replace(data=result),)
|
||||
@@ -1,63 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import LineData
|
||||
from backend.nodes.helpers import _cached_1d_transfer
|
||||
|
||||
|
||||
@register_node(display_name="FFT Filter 1D")
|
||||
class FFTFilter1D:
|
||||
"""Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
|
||||
|
||||
Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
|
||||
transfer function with configurable order for a smooth roll-off.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"line": ("LINE",),
|
||||
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
|
||||
"cutoff": ("FLOAT", {
|
||||
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"cutoff_high": ("FLOAT", {
|
||||
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('LINE', 'filtered'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Frequency-domain filtering of a 1-D line profile. "
|
||||
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
|
||||
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
|
||||
"Equivalent to Gwyddion fft_filter_1d."
|
||||
)
|
||||
|
||||
def process(self, line, filter_type: str, cutoff: float,
|
||||
cutoff_high: float, order: int) -> tuple:
|
||||
z = np.asarray(line, dtype=np.float64).ravel()
|
||||
n = len(z)
|
||||
|
||||
Z = np.fft.rfft(z)
|
||||
H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order))
|
||||
Z *= H
|
||||
filtered = np.fft.irfft(Z, n=n)
|
||||
|
||||
if isinstance(line, LineData):
|
||||
return (
|
||||
LineData(
|
||||
data=filtered,
|
||||
x_axis=line.x_axis.copy() if line.x_axis is not None else None,
|
||||
x_unit=line.x_unit,
|
||||
y_unit=line.y_unit,
|
||||
),
|
||||
)
|
||||
return (filtered,)
|
||||
@@ -1,63 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
from backend.nodes.helpers import _cached_2d_transfer
|
||||
|
||||
|
||||
@register_node(display_name="FFT Filter 2D")
|
||||
class FFTFilter2D:
|
||||
"""Frequency-domain filtering of 2-D data fields (images).
|
||||
|
||||
Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
|
||||
Butterworth transfer function in the frequency domain to remove or
|
||||
isolate periodic features.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
|
||||
"cutoff": ("FLOAT", {
|
||||
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"cutoff_high": ("FLOAT", {
|
||||
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||
}),
|
||||
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
('DATA_FIELD', 'filtered'),
|
||||
)
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Frequency-domain filtering of a 2-D data field. "
|
||||
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
|
||||
"with a radial Butterworth roll-off. Cutoffs are fractions of the "
|
||||
"Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
|
||||
"bandpass/notch to isolate or remove periodic noise. "
|
||||
"Equivalent to Gwyddion fft_filter_2d."
|
||||
)
|
||||
|
||||
def process(self, field: DataField, filter_type: str, cutoff: float,
|
||||
cutoff_high: float, order: int) -> tuple:
|
||||
data = field.data
|
||||
yres, xres = data.shape
|
||||
|
||||
mean_val = float(data.mean())
|
||||
centered = data - mean_val
|
||||
|
||||
spectrum = np.fft.rfft2(centered)
|
||||
transfer = _cached_2d_transfer(
|
||||
yres, xres, filter_type,
|
||||
float(cutoff), float(cutoff_high), int(order),
|
||||
)
|
||||
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
|
||||
result += mean_val
|
||||
|
||||
return (field.replace(data=result),)
|
||||
Reference in New Issue
Block a user