Files
tono/backend/nodes/wavelet_denoise.py

75 lines
2.1 KiB
Python

from __future__ import annotations
import numpy as np
from backend.data_types import DataField
from backend.node_registry import register_node
@register_node(display_name="Wavelet Denoise")
class WaveletDenoise:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"field": ("DATA_FIELD",),
"wavelet": (
["db1", "db2", "db4", "db8", "sym4", "coif1", "bior1.3"],
{"default": "db4"},
),
"method": (["BayesShrink", "VisuShrink"], {"default": "BayesShrink"}),
"sigma": (
"FLOAT",
{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01},
),
"mode": (["soft", "hard"], {"default": "soft"}),
}
}
OUTPUTS = (
('DATA_FIELD', 'denoised'),
)
FUNCTION = "process"
DESCRIPTION = (
"Denoise using wavelet coefficient thresholding. BayesShrink adapts the threshold "
"per sub-band; VisuShrink uses a global threshold. Equivalent to applying DWT from "
"Gwyddion dwt.c with coefficient thresholding."
)
def process(
self,
field: DataField,
wavelet: str,
method: str,
sigma: float,
mode: str,
) -> tuple:
from skimage.restoration import denoise_wavelet
d = field.data
dmin = float(d.min())
drange = float(np.ptp(d))
if drange == 0:
return (field,)
norm = (d - dmin) / drange
sigma_val = float(sigma) if sigma > 0 else None
# `mode` is a Python builtin name; use threshold_mode locally to avoid shadowing
threshold_mode = mode
denoised_norm = denoise_wavelet(
norm,
wavelet=wavelet,
method=method,
mode=threshold_mode,
sigma=sigma_val,
rescale_sigma=True,
channel_axis=None,
)
result = denoised_norm * drange + dmin
return (field.replace(data=result),)