combine fft filter into a single node, fix tests
This commit is contained in:
@@ -31,7 +31,7 @@ from threading import RLock
|
|||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
from backend.node_registry import NODE_CLASS_MAPPINGS, get_node_output_types
|
from backend.node_registry import NODE_CLASS_MAPPINGS, get_node_output_types, get_node_output_accepted_types
|
||||||
from backend.execution_context import active_node, execution_callbacks
|
from backend.execution_context import active_node, execution_callbacks
|
||||||
|
|
||||||
|
|
||||||
@@ -462,16 +462,19 @@ class ExecutionEngine:
|
|||||||
return
|
return
|
||||||
|
|
||||||
return_types = get_node_output_types(cls)
|
return_types = get_node_output_types(cls)
|
||||||
|
output_accepted = get_node_output_accepted_types(cls)
|
||||||
|
|
||||||
for slot, type_name in enumerate(return_types):
|
for slot, type_name in enumerate(return_types):
|
||||||
if slot >= len(result):
|
if slot >= len(result):
|
||||||
break
|
break
|
||||||
value = result[slot]
|
value = result[slot]
|
||||||
|
all_types = {type_name} | set(output_accepted[slot] if slot < len(output_accepted) else [])
|
||||||
|
|
||||||
if type_name == "DATA_FIELD" and isinstance(value, DataField) and on_preview:
|
# For polymorphic outputs, check the actual runtime type first.
|
||||||
|
if isinstance(value, DataField) and ("DATA_FIELD" in all_types) and on_preview:
|
||||||
arr = render_datafield_preview(value, value.colormap)
|
arr = render_datafield_preview(value, value.colormap)
|
||||||
on_preview(node_id, encode_preview(arr))
|
on_preview(node_id, encode_preview(arr))
|
||||||
return # one preview per node is enough
|
return
|
||||||
|
|
||||||
if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview:
|
if type_name == "IMAGE" and isinstance(value, np.ndarray) and on_preview:
|
||||||
arr = image_to_uint8(value)
|
arr = image_to_uint8(value)
|
||||||
@@ -488,7 +491,7 @@ class ExecutionEngine:
|
|||||||
on_preview(node_id, encode_preview(arr))
|
on_preview(node_id, encode_preview(arr))
|
||||||
return
|
return
|
||||||
|
|
||||||
if type_name == "LINE" and isinstance(value, (np.ndarray, LineData)) and on_preview:
|
if "LINE" in all_types and isinstance(value, (np.ndarray, LineData)) and on_preview:
|
||||||
preview = self._render_line_preview(cls, slot, result)
|
preview = self._render_line_preview(cls, slot, result)
|
||||||
if preview:
|
if preview:
|
||||||
on_preview(node_id, preview)
|
on_preview(node_id, preview)
|
||||||
|
|||||||
@@ -15,28 +15,38 @@ NODE_CLASS_MAPPINGS: dict[str, type] = {}
|
|||||||
NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {}
|
NODE_DISPLAY_NAME_MAPPINGS: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_node_output_specs(cls: type) -> tuple[tuple[str, str], ...]:
|
def get_node_output_specs(cls: type) -> tuple[tuple[str, str, dict], ...]:
|
||||||
raw_outputs = getattr(cls, "OUTPUTS", None)
|
raw_outputs = getattr(cls, "OUTPUTS", None)
|
||||||
if raw_outputs is None:
|
if raw_outputs is None:
|
||||||
raise AttributeError(f"{cls.__name__} must define OUTPUTS.")
|
raise AttributeError(f"{cls.__name__} must define OUTPUTS.")
|
||||||
|
|
||||||
specs: list[tuple[str, str]] = []
|
specs: list[tuple[str, str, dict]] = []
|
||||||
for index, output in enumerate(raw_outputs):
|
for index, output in enumerate(raw_outputs):
|
||||||
if not isinstance(output, (list, tuple)) or len(output) != 2:
|
if not isinstance(output, (list, tuple)) or len(output) not in (2, 3):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"{cls.__name__}.OUTPUTS[{index}] must be a 2-item tuple of (type, name)."
|
f"{cls.__name__}.OUTPUTS[{index}] must be a 2- or 3-item tuple of (type, name[, meta])."
|
||||||
)
|
)
|
||||||
type_name, name = output
|
type_name = output[0]
|
||||||
specs.append((str(type_name), str(name)))
|
name = output[1]
|
||||||
|
meta: dict = output[2] if len(output) == 3 else {}
|
||||||
|
specs.append((str(type_name), str(name), meta))
|
||||||
return tuple(specs)
|
return tuple(specs)
|
||||||
|
|
||||||
|
|
||||||
def get_node_output_types(cls: type) -> tuple[str, ...]:
|
def get_node_output_types(cls: type) -> tuple[str, ...]:
|
||||||
return tuple(type_name for type_name, _ in get_node_output_specs(cls))
|
return tuple(type_name for type_name, _, _meta in get_node_output_specs(cls))
|
||||||
|
|
||||||
|
|
||||||
def get_node_output_names(cls: type) -> tuple[str, ...]:
|
def get_node_output_names(cls: type) -> tuple[str, ...]:
|
||||||
return tuple(name for _, name in get_node_output_specs(cls))
|
return tuple(name for _, name, _meta in get_node_output_specs(cls))
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_output_accepted_types(cls: type) -> tuple[list[str], ...]:
|
||||||
|
"""Return per-slot accepted_types lists (empty list means only the declared type)."""
|
||||||
|
return tuple(
|
||||||
|
list(meta.get("accepted_types", []))
|
||||||
|
for _, _, meta in get_node_output_specs(cls)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_node(display_name: str | None = None):
|
def register_node(display_name: str | None = None):
|
||||||
@@ -77,6 +87,7 @@ def get_node_info(class_name: str) -> dict[str, Any]:
|
|||||||
"input_order": {k: list(v.keys()) for k, v in input_types.items()},
|
"input_order": {k: list(v.keys()) for k, v in input_types.items()},
|
||||||
"output": list(get_node_output_types(cls)),
|
"output": list(get_node_output_types(cls)),
|
||||||
"output_name": list(get_node_output_names(cls)),
|
"output_name": list(get_node_output_names(cls)),
|
||||||
|
"output_accepted_types": list(get_node_output_accepted_types(cls)),
|
||||||
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
|
"output_node": bool(getattr(cls, "OUTPUT_NODE", False)),
|
||||||
"manual_trigger": bool(getattr(cls, "MANUAL_TRIGGER", False)),
|
"manual_trigger": bool(getattr(cls, "MANUAL_TRIGGER", False)),
|
||||||
"description": getattr(cls, "DESCRIPTION", ""),
|
"description": getattr(cls, "DESCRIPTION", ""),
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ from backend.nodes import (
|
|||||||
colormap,
|
colormap,
|
||||||
crop_resize,
|
crop_resize,
|
||||||
fft_2d_inverse,
|
fft_2d_inverse,
|
||||||
filter_fft_1d,
|
filter_fft,
|
||||||
filter_fft_2d,
|
|
||||||
filter_gaussian,
|
filter_gaussian,
|
||||||
filter_median,
|
filter_median,
|
||||||
flip,
|
flip,
|
||||||
|
|||||||
89
backend/nodes/filter_fft.py
Normal file
89
backend/nodes/filter_fft.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
from backend.node_registry import register_node
|
||||||
|
from backend.data_types import DataField, LineData
|
||||||
|
from backend.nodes.helpers import _cached_1d_transfer, _cached_2d_transfer
|
||||||
|
|
||||||
|
|
||||||
|
@register_node(display_name="FFT Filter")
|
||||||
|
class FFTFilter:
|
||||||
|
"""Frequency-domain filtering of a line profile or 2-D data field.
|
||||||
|
|
||||||
|
Accepts either a LINE or DATA_FIELD and returns a filtered output of the
|
||||||
|
same type. Uses a Butterworth transfer function with configurable order
|
||||||
|
for a smooth roll-off. Equivalent to Gwyddion fft_filter_1d / fft_filter_2d.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"input": ("LINE", {
|
||||||
|
"label": "input",
|
||||||
|
"accepted_types": ["DATA_FIELD"],
|
||||||
|
}),
|
||||||
|
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
|
||||||
|
"cutoff": ("FLOAT", {
|
||||||
|
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||||
|
}),
|
||||||
|
"cutoff_high": ("FLOAT", {
|
||||||
|
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
|
||||||
|
}),
|
||||||
|
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OUTPUTS = (
|
||||||
|
('LINE', 'filtered', {"accepted_types": ["DATA_FIELD"]}),
|
||||||
|
)
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Frequency-domain filtering of a line profile or 2-D data field. "
|
||||||
|
"Connect a LINE for 1-D filtering or a DATA_FIELD for 2-D filtering — "
|
||||||
|
"the output mirrors the input type. "
|
||||||
|
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
|
||||||
|
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, input, filter_type: str, cutoff: float,
|
||||||
|
cutoff_high: float, order: int) -> tuple:
|
||||||
|
if isinstance(input, DataField):
|
||||||
|
return self._process_field(input, filter_type, float(cutoff), float(cutoff_high), int(order))
|
||||||
|
return self._process_line(input, filter_type, float(cutoff), float(cutoff_high), int(order))
|
||||||
|
|
||||||
|
def _process_line(self, line, filter_type: str, cutoff: float,
|
||||||
|
cutoff_high: float, order: int) -> tuple:
|
||||||
|
z = np.asarray(line, dtype=np.float64).ravel()
|
||||||
|
n = len(z)
|
||||||
|
|
||||||
|
Z = np.fft.rfft(z)
|
||||||
|
H = _cached_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
|
||||||
|
Z *= H
|
||||||
|
filtered = np.fft.irfft(Z, n=n)
|
||||||
|
|
||||||
|
if isinstance(line, LineData):
|
||||||
|
return (
|
||||||
|
LineData(
|
||||||
|
data=filtered,
|
||||||
|
x_axis=line.x_axis.copy() if line.x_axis is not None else None,
|
||||||
|
x_unit=line.x_unit,
|
||||||
|
y_unit=line.y_unit,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return (filtered,)
|
||||||
|
|
||||||
|
def _process_field(self, field: DataField, filter_type: str, cutoff: float,
|
||||||
|
cutoff_high: float, order: int) -> tuple:
|
||||||
|
data = field.data
|
||||||
|
yres, xres = data.shape
|
||||||
|
|
||||||
|
mean_val = float(data.mean())
|
||||||
|
centered = data - mean_val
|
||||||
|
|
||||||
|
spectrum = np.fft.rfft2(centered)
|
||||||
|
transfer = _cached_2d_transfer(yres, xres, filter_type, cutoff, cutoff_high, order)
|
||||||
|
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
|
||||||
|
result += mean_val
|
||||||
|
|
||||||
|
return (field.replace(data=result),)
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import numpy as np
|
|
||||||
from backend.node_registry import register_node
|
|
||||||
from backend.data_types import LineData
|
|
||||||
from backend.nodes.helpers import _cached_1d_transfer
|
|
||||||
|
|
||||||
|
|
||||||
@register_node(display_name="FFT Filter 1D")
|
|
||||||
class FFTFilter1D:
|
|
||||||
"""Bandpass / lowpass / highpass / notch filtering of 1-D line profiles.
|
|
||||||
|
|
||||||
Equivalent to Gwyddion's fft_filter_1d module. Uses a Butterworth
|
|
||||||
transfer function with configurable order for a smooth roll-off.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"line": ("LINE",),
|
|
||||||
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
|
|
||||||
"cutoff": ("FLOAT", {
|
|
||||||
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
|
|
||||||
}),
|
|
||||||
"cutoff_high": ("FLOAT", {
|
|
||||||
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
|
|
||||||
}),
|
|
||||||
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OUTPUTS = (
|
|
||||||
('LINE', 'filtered'),
|
|
||||||
)
|
|
||||||
FUNCTION = "process"
|
|
||||||
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Frequency-domain filtering of a 1-D line profile. "
|
|
||||||
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
|
|
||||||
"with a Butterworth roll-off. Cutoffs are fractions of the Nyquist frequency. "
|
|
||||||
"Equivalent to Gwyddion fft_filter_1d."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, line, filter_type: str, cutoff: float,
|
|
||||||
cutoff_high: float, order: int) -> tuple:
|
|
||||||
z = np.asarray(line, dtype=np.float64).ravel()
|
|
||||||
n = len(z)
|
|
||||||
|
|
||||||
Z = np.fft.rfft(z)
|
|
||||||
H = _cached_1d_transfer(n, filter_type, float(cutoff), float(cutoff_high), int(order))
|
|
||||||
Z *= H
|
|
||||||
filtered = np.fft.irfft(Z, n=n)
|
|
||||||
|
|
||||||
if isinstance(line, LineData):
|
|
||||||
return (
|
|
||||||
LineData(
|
|
||||||
data=filtered,
|
|
||||||
x_axis=line.x_axis.copy() if line.x_axis is not None else None,
|
|
||||||
x_unit=line.x_unit,
|
|
||||||
y_unit=line.y_unit,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return (filtered,)
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import numpy as np
|
|
||||||
from backend.node_registry import register_node
|
|
||||||
from backend.data_types import DataField
|
|
||||||
from backend.nodes.helpers import _cached_2d_transfer
|
|
||||||
|
|
||||||
|
|
||||||
@register_node(display_name="FFT Filter 2D")
|
|
||||||
class FFTFilter2D:
|
|
||||||
"""Frequency-domain filtering of 2-D data fields (images).
|
|
||||||
|
|
||||||
Equivalent to Gwyddion's fft_filter_2d module. Applies a radial
|
|
||||||
Butterworth transfer function in the frequency domain to remove or
|
|
||||||
isolate periodic features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"field": ("DATA_FIELD",),
|
|
||||||
"filter_type": (["lowpass", "highpass", "bandpass", "notch"],),
|
|
||||||
"cutoff": ("FLOAT", {
|
|
||||||
"default": 0.1, "min": 0.001, "max": 1.0, "step": 0.001,
|
|
||||||
}),
|
|
||||||
"cutoff_high": ("FLOAT", {
|
|
||||||
"default": 0.4, "min": 0.001, "max": 1.0, "step": 0.001,
|
|
||||||
}),
|
|
||||||
"order": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OUTPUTS = (
|
|
||||||
('DATA_FIELD', 'filtered'),
|
|
||||||
)
|
|
||||||
FUNCTION = "process"
|
|
||||||
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Frequency-domain filtering of a 2-D data field. "
|
|
||||||
"Supports lowpass, highpass, bandpass, and notch (band-reject) modes "
|
|
||||||
"with a radial Butterworth roll-off. Cutoffs are fractions of the "
|
|
||||||
"Nyquist frequency. Use lowpass to smooth, highpass to sharpen, or "
|
|
||||||
"bandpass/notch to isolate or remove periodic noise. "
|
|
||||||
"Equivalent to Gwyddion fft_filter_2d."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, field: DataField, filter_type: str, cutoff: float,
|
|
||||||
cutoff_high: float, order: int) -> tuple:
|
|
||||||
data = field.data
|
|
||||||
yres, xres = data.shape
|
|
||||||
|
|
||||||
mean_val = float(data.mean())
|
|
||||||
centered = data - mean_val
|
|
||||||
|
|
||||||
spectrum = np.fft.rfft2(centered)
|
|
||||||
transfer = _cached_2d_transfer(
|
|
||||||
yres, xres, filter_type,
|
|
||||||
float(cutoff), float(cutoff_high), int(order),
|
|
||||||
)
|
|
||||||
result = np.fft.irfft2(spectrum * transfer, s=(yres, xres))
|
|
||||||
result += mean_val
|
|
||||||
|
|
||||||
return (field.replace(data=result),)
|
|
||||||
@@ -454,10 +454,15 @@ function socketTypesCompatible(sourceType, targetSpecOrType) {
|
|||||||
return socketSpecAcceptsType(sourceType, targetSpecOrType);
|
return socketSpecAcceptsType(sourceType, targetSpecOrType);
|
||||||
}
|
}
|
||||||
|
|
||||||
function outputTypeCanConnectToTarget(outputType, targetSpecOrType) {
|
function outputTypeCanConnectToTarget(outputType, targetSpecOrType, outputAcceptedTypes = []) {
|
||||||
if (socketTypesCompatible(outputType, targetSpecOrType)) {
|
if (socketTypesCompatible(outputType, targetSpecOrType)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
// Polymorphic output: the output socket declares it can also produce the target type
|
||||||
|
if (outputAcceptedTypes.length > 0) {
|
||||||
|
const targetType = Array.isArray(targetSpecOrType) ? targetSpecOrType[0] : targetSpecOrType;
|
||||||
|
if (outputAcceptedTypes.includes(targetType)) return true;
|
||||||
|
}
|
||||||
return outputType === 'ANNOTATION_SOURCE'
|
return outputType === 'ANNOTATION_SOURCE'
|
||||||
&& !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)
|
&& !socketTypesCompatible('ANNOTATION_SOURCE', targetSpecOrType)
|
||||||
&& (
|
&& (
|
||||||
@@ -674,8 +679,8 @@ function ContextMenu({
|
|||||||
});
|
});
|
||||||
if (!hasMatch) continue;
|
if (!hasMatch) continue;
|
||||||
} else {
|
} else {
|
||||||
const hasMatch = def.output.some((type) =>
|
const hasMatch = def.output.some((type, idx) =>
|
||||||
outputTypeCanConnectToTarget(type, filterSpec || filterType)
|
outputTypeCanConnectToTarget(type, filterSpec || filterType, def.output_accepted_types?.[idx] || [])
|
||||||
);
|
);
|
||||||
if (!hasMatch) continue;
|
if (!hasMatch) continue;
|
||||||
}
|
}
|
||||||
@@ -1392,7 +1397,16 @@ function Flow() {
|
|||||||
const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle);
|
const resolvedTarget = getResolvedHandleRef(connection.target, connection.targetHandle);
|
||||||
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
|
const targetNode = reactFlow.getNode(resolvedTarget.nodeId);
|
||||||
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
|
const targetSpec = getNodeInputSpecForHandle(targetNode, resolvedTarget.handleId) || resolvedTarget.type;
|
||||||
return socketTypesCompatible(srcType, targetSpec);
|
if (socketTypesCompatible(srcType, targetSpec)) return true;
|
||||||
|
// Polymorphic output: check if the source output declares it can produce the target type
|
||||||
|
const srcProxy = parseGroupProxyHandle(connection.sourceHandle);
|
||||||
|
const srcNodeId = srcProxy ? srcProxy.nodeId : connection.source;
|
||||||
|
const srcHandleId = srcProxy ? srcProxy.realHandle : connection.sourceHandle;
|
||||||
|
const srcNode = reactFlow.getNode(srcNodeId);
|
||||||
|
const srcSlot = getOutputSlot(srcHandleId);
|
||||||
|
const srcAcceptedTypes = srcNode?.data?.definition?.output_accepted_types?.[srcSlot] || [];
|
||||||
|
const targetType = Array.isArray(targetSpec) ? targetSpec[0] : targetSpec;
|
||||||
|
return Array.isArray(srcAcceptedTypes) && srcAcceptedTypes.includes(targetType);
|
||||||
}, [reactFlow]);
|
}, [reactFlow]);
|
||||||
|
|
||||||
const onConnect = useCallback((params) => {
|
const onConnect = useCallback((params) => {
|
||||||
@@ -1765,8 +1779,8 @@ function Flow() {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Dragged from an input → connect from the first matching output on the new node
|
// Dragged from an input → connect from the first matching output on the new node
|
||||||
const outputIdx = def.output.findIndex((type) =>
|
const outputIdx = def.output.findIndex((type, idx) =>
|
||||||
outputTypeCanConnectToTarget(type, filterSpec)
|
outputTypeCanConnectToTarget(type, filterSpec, def.output_accepted_types?.[idx] || [])
|
||||||
);
|
);
|
||||||
if (outputIdx !== -1) {
|
if (outputIdx !== -1) {
|
||||||
const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec);
|
const outputType = resolveOutputTypeForTarget(def.output[outputIdx], filterSpec);
|
||||||
|
|||||||
63
tests/node_tests/filter_fft.py
Normal file
63
tests/node_tests/filter_fft.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import numpy as np
|
||||||
|
from tests.node_tests._shared import make_field
|
||||||
|
|
||||||
|
|
||||||
|
def test_fft_filter_line():
|
||||||
|
from backend.nodes.filter_fft import FFTFilter
|
||||||
|
node = FFTFilter()
|
||||||
|
|
||||||
|
n = 256
|
||||||
|
t = np.arange(n, dtype=np.float64) / n
|
||||||
|
low = np.sin(2 * np.pi * 3 * t)
|
||||||
|
high = np.sin(2 * np.pi * 80 * t)
|
||||||
|
line = low + high
|
||||||
|
|
||||||
|
filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
||||||
|
assert len(filtered_lp) == n
|
||||||
|
corr_low = np.corrcoef(filtered_lp, low)[0, 1]
|
||||||
|
corr_high = np.corrcoef(filtered_lp, high)[0, 1]
|
||||||
|
assert corr_low > 0.95
|
||||||
|
assert abs(corr_high) < 0.3
|
||||||
|
|
||||||
|
filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
||||||
|
assert abs(np.corrcoef(filtered_hp, low)[0, 1]) < 0.3
|
||||||
|
assert np.corrcoef(filtered_hp, high)[0, 1] > 0.95
|
||||||
|
|
||||||
|
filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4)
|
||||||
|
assert abs(np.corrcoef(filtered_bp, low)[0, 1]) < 0.3
|
||||||
|
assert np.corrcoef(filtered_bp, high)[0, 1] > 0.9
|
||||||
|
|
||||||
|
filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4)
|
||||||
|
assert np.corrcoef(filtered_notch, low)[0, 1] > 0.95
|
||||||
|
assert abs(np.corrcoef(filtered_notch, high)[0, 1]) < 0.3
|
||||||
|
|
||||||
|
|
||||||
|
def test_fft_filter_field():
|
||||||
|
from backend.nodes.filter_fft import FFTFilter
|
||||||
|
from backend.data_types import DataField
|
||||||
|
node = FFTFilter()
|
||||||
|
|
||||||
|
N = 128
|
||||||
|
y, x = np.mgrid[0:N, 0:N] / N
|
||||||
|
low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y)
|
||||||
|
high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y)
|
||||||
|
data = low_2d + high_2d
|
||||||
|
field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6)
|
||||||
|
|
||||||
|
result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
||||||
|
assert isinstance(result_lp, DataField)
|
||||||
|
assert result_lp.data.shape == (N, N)
|
||||||
|
assert result_lp.xreal == field.xreal
|
||||||
|
assert result_lp.si_unit_z == field.si_unit_z
|
||||||
|
corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1]
|
||||||
|
corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1]
|
||||||
|
assert corr_low > 0.9
|
||||||
|
assert abs(corr_high) < 0.3
|
||||||
|
|
||||||
|
result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
||||||
|
assert abs(np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]) < 0.3
|
||||||
|
assert np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] > 0.9
|
||||||
|
|
||||||
|
const = make_field(data=np.ones((32, 32)) * 7.0)
|
||||||
|
result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2)
|
||||||
|
assert np.allclose(result_const.data, 7.0, atol=1e-10)
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def test_fft_filter_1d():
|
|
||||||
from backend.nodes.filter_fft_1d import FFTFilter1D
|
|
||||||
node = FFTFilter1D()
|
|
||||||
|
|
||||||
n = 256
|
|
||||||
t = np.arange(n, dtype=np.float64) / n
|
|
||||||
low = np.sin(2 * np.pi * 3 * t)
|
|
||||||
high = np.sin(2 * np.pi * 80 * t)
|
|
||||||
line = low + high
|
|
||||||
|
|
||||||
filtered_lp, = node.process(line, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
|
||||||
assert len(filtered_lp) == n
|
|
||||||
corr_low = np.corrcoef(filtered_lp, low)[0, 1]
|
|
||||||
corr_high = np.corrcoef(filtered_lp, high)[0, 1]
|
|
||||||
assert corr_low > 0.95
|
|
||||||
assert abs(corr_high) < 0.3
|
|
||||||
|
|
||||||
filtered_hp, = node.process(line, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
|
||||||
corr_low_hp = np.corrcoef(filtered_hp, low)[0, 1]
|
|
||||||
corr_high_hp = np.corrcoef(filtered_hp, high)[0, 1]
|
|
||||||
assert abs(corr_low_hp) < 0.3
|
|
||||||
assert corr_high_hp > 0.95
|
|
||||||
|
|
||||||
filtered_bp, = node.process(line, filter_type="bandpass", cutoff=0.4, cutoff_high=0.8, order=4)
|
|
||||||
assert abs(np.corrcoef(filtered_bp, low)[0, 1]) < 0.3
|
|
||||||
assert np.corrcoef(filtered_bp, high)[0, 1] > 0.9
|
|
||||||
|
|
||||||
filtered_notch, = node.process(line, filter_type="notch", cutoff=0.4, cutoff_high=0.8, order=4)
|
|
||||||
assert np.corrcoef(filtered_notch, low)[0, 1] > 0.95
|
|
||||||
assert abs(np.corrcoef(filtered_notch, high)[0, 1]) < 0.3
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
from tests.node_tests._shared import make_field
|
|
||||||
|
|
||||||
|
|
||||||
def test_fft_filter_2d():
|
|
||||||
from backend.nodes.filter_fft_2d import FFTFilter2D
|
|
||||||
node = FFTFilter2D()
|
|
||||||
|
|
||||||
N = 128
|
|
||||||
y, x = np.mgrid[0:N, 0:N] / N
|
|
||||||
low_2d = np.sin(2 * np.pi * 3 * x) + np.sin(2 * np.pi * 3 * y)
|
|
||||||
high_2d = np.sin(2 * np.pi * 40 * x) + np.sin(2 * np.pi * 40 * y)
|
|
||||||
data = low_2d + high_2d
|
|
||||||
field = make_field(data=data, shape=None, xreal=1e-6, yreal=1e-6)
|
|
||||||
|
|
||||||
result_lp, = node.process(field, filter_type="lowpass", cutoff=0.15, cutoff_high=0.4, order=4)
|
|
||||||
assert result_lp.data.shape == (N, N)
|
|
||||||
assert result_lp.xreal == field.xreal
|
|
||||||
assert result_lp.si_unit_z == field.si_unit_z
|
|
||||||
corr_low = np.corrcoef(result_lp.data.ravel(), low_2d.ravel())[0, 1]
|
|
||||||
corr_high = np.corrcoef(result_lp.data.ravel(), high_2d.ravel())[0, 1]
|
|
||||||
assert corr_low > 0.9
|
|
||||||
assert abs(corr_high) < 0.3
|
|
||||||
|
|
||||||
result_hp, = node.process(field, filter_type="highpass", cutoff=0.4, cutoff_high=0.4, order=4)
|
|
||||||
assert abs(np.corrcoef(result_hp.data.ravel(), low_2d.ravel())[0, 1]) < 0.3
|
|
||||||
assert np.corrcoef(result_hp.data.ravel(), high_2d.ravel())[0, 1] > 0.9
|
|
||||||
|
|
||||||
const = make_field(data=np.ones((32, 32)) * 7.0)
|
|
||||||
result_const, = node.process(const, filter_type="lowpass", cutoff=0.5, cutoff_high=0.5, order=2)
|
|
||||||
assert np.allclose(result_const.data, 7.0, atol=1e-10)
|
|
||||||
@@ -36,7 +36,7 @@ def test_threshold_otsu_bimodal():
|
|||||||
data[70:100, 80:110] = 10.0 # another bright region
|
data[70:100, 80:110] = 10.0 # another bright region
|
||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
|
|
||||||
mask, = node.process(field, method="otsu", threshold=0.0, direction="above")
|
mask, table = node.process(field, method="otsu", threshold=0.0, direction="above")
|
||||||
bright_pixels = (mask == 255)
|
bright_pixels = (mask == 255)
|
||||||
# Should capture both bright regions
|
# Should capture both bright regions
|
||||||
assert bright_pixels[40, 40], "Otsu missed bright region 1"
|
assert bright_pixels[40, 40], "Otsu missed bright region 1"
|
||||||
@@ -57,7 +57,7 @@ def test_threshold_relative_range():
|
|||||||
data[10:20, 10:20] = 8.0 # bright patch, range = [2, 8], midpoint = 5
|
data[10:20, 10:20] = 8.0 # bright patch, range = [2, 8], midpoint = 5
|
||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
|
|
||||||
mask, = node.process(field, method="relative", threshold=0.5, direction="above")
|
mask, table = node.process(field, method="relative", threshold=0.5, direction="above")
|
||||||
# Only the bright patch (value 8 >= 5) should be masked
|
# Only the bright patch (value 8 >= 5) should be masked
|
||||||
assert np.all(mask[10:20, 10:20] == 255)
|
assert np.all(mask[10:20, 10:20] == 255)
|
||||||
assert np.all(mask[0:10, :] == 0)
|
assert np.all(mask[0:10, :] == 0)
|
||||||
@@ -74,7 +74,7 @@ def test_threshold_empty_mask():
|
|||||||
data = np.ones((64, 64))
|
data = np.ones((64, 64))
|
||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
|
|
||||||
mask, = node.process(field, method="absolute", threshold=999.0, direction="above")
|
mask, table = node.process(field, method="absolute", threshold=999.0, direction="above")
|
||||||
assert mask.sum() == 0, "Mask should be completely empty"
|
assert mask.sum() == 0, "Mask should be completely empty"
|
||||||
print(" PASS\n")
|
print(" PASS\n")
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ def test_threshold_full_mask():
|
|||||||
data = np.ones((64, 64)) * 5.0
|
data = np.ones((64, 64)) * 5.0
|
||||||
field = make_field(data)
|
field = make_field(data)
|
||||||
|
|
||||||
mask, = node.process(field, method="absolute", threshold=-1.0, direction="above")
|
mask, table = node.process(field, method="absolute", threshold=-1.0, direction="above")
|
||||||
assert np.all(mask == 255), "Mask should be all white"
|
assert np.all(mask == 255), "Mask should be all white"
|
||||||
print(" PASS\n")
|
print(" PASS\n")
|
||||||
|
|
||||||
@@ -345,7 +345,7 @@ def test_pipeline_synthetic():
|
|||||||
|
|
||||||
# Step 1: threshold
|
# Step 1: threshold
|
||||||
thresh = ThresholdMask()
|
thresh = ThresholdMask()
|
||||||
mask, = thresh.process(field, method="absolute", threshold=1.0, direction="above")
|
mask, table = thresh.process(field, method="absolute", threshold=1.0, direction="above")
|
||||||
|
|
||||||
# Grains are well above noise, so mask should capture all 5
|
# Grains are well above noise, so mask should capture all 5
|
||||||
assert mask.max() == 255, "No grains detected"
|
assert mask.max() == 255, "No grains detected"
|
||||||
@@ -387,7 +387,7 @@ def test_pipeline_demo_image():
|
|||||||
|
|
||||||
# Threshold to find grains (they are raised above background)
|
# Threshold to find grains (they are raised above background)
|
||||||
thresh = ThresholdMask()
|
thresh = ThresholdMask()
|
||||||
mask, = thresh.process(field, method="otsu", threshold=0.0, direction="above")
|
mask, table = thresh.process(field, method="otsu", threshold=0.0, direction="above")
|
||||||
|
|
||||||
# Should detect grains
|
# Should detect grains
|
||||||
assert mask.max() == 255, "No grains found in demo image"
|
assert mask.max() == 255, "No grains found in demo image"
|
||||||
|
|||||||
3377
tests/test_nodes.py
3377
tests/test_nodes.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user