add line correction and scar removal nodes
This commit is contained in:
@@ -25,12 +25,15 @@ from backend.nodes import (
|
||||
plane_level_field,
|
||||
poly_level_field,
|
||||
fix_zero,
|
||||
line_correction,
|
||||
# Mask
|
||||
draw_mask,
|
||||
threshold_mask,
|
||||
mask_morphology,
|
||||
mask_invert,
|
||||
mask_combine,
|
||||
# Correction
|
||||
scar_removal,
|
||||
# Display
|
||||
color_map,
|
||||
font_node,
|
||||
|
||||
388
backend/nodes/line_correction.py
Normal file
388
backend/nodes/line_correction.py
Normal file
@@ -0,0 +1,388 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import DataField, LineData
|
||||
from backend.node_registry import register_node
|
||||
|
||||
|
||||
def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None:
|
||||
if mask is None:
|
||||
return None
|
||||
|
||||
mask_array = np.asarray(mask)
|
||||
if mask_array.shape[:2] != shape:
|
||||
raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.")
|
||||
return mask_array > 127
|
||||
|
||||
|
||||
def _trimmed_mean_or_median(values: np.ndarray, trim_fraction: float) -> float:
|
||||
values = np.asarray(values, dtype=np.float64)
|
||||
if values.size == 0:
|
||||
return 0.0
|
||||
|
||||
sorted_values = np.sort(values, kind="mergesort")
|
||||
count = sorted_values.size
|
||||
nlowest = int(np.rint(trim_fraction * count))
|
||||
nhighest = int(np.rint(trim_fraction * count))
|
||||
|
||||
if nlowest + nhighest + 1 >= count:
|
||||
return float(np.median(sorted_values))
|
||||
|
||||
trimmed = sorted_values[nlowest:count - nhighest]
|
||||
return float(trimmed.mean()) if trimmed.size else float(np.median(sorted_values))
|
||||
|
||||
|
||||
def _masked_values(data: np.ndarray, mask: np.ndarray | None, masking: str) -> np.ndarray:
|
||||
if mask is None or masking == "ignore":
|
||||
return data
|
||||
if masking == "include":
|
||||
return data[mask]
|
||||
if masking == "exclude":
|
||||
return data[~mask]
|
||||
raise ValueError(f"Unknown masking mode: {masking}")
|
||||
|
||||
|
||||
def _global_masked_median(data: np.ndarray, mask: np.ndarray | None, masking: str) -> float:
|
||||
selected = _masked_values(data, mask, masking)
|
||||
if selected.size == 0:
|
||||
selected = np.asarray(data, dtype=np.float64).ravel()
|
||||
return float(np.median(selected))
|
||||
|
||||
|
||||
def _find_row_shifts_trimmed_mean(
|
||||
data: np.ndarray,
|
||||
mask: np.ndarray | None,
|
||||
masking: str,
|
||||
trim_fraction: float,
|
||||
mincount: int = 0,
|
||||
) -> np.ndarray:
|
||||
yres, xres = data.shape
|
||||
if yres == 0:
|
||||
return np.zeros(0, dtype=np.float64)
|
||||
|
||||
if mincount <= 0:
|
||||
mincount = int(np.rint(np.log(max(xres, 1)) + 1.0))
|
||||
|
||||
total_median = _global_masked_median(data, mask, masking)
|
||||
shifts = np.empty(yres, dtype=np.float64)
|
||||
|
||||
for i in range(yres):
|
||||
row = data[i]
|
||||
row_mask = None if mask is None else mask[i]
|
||||
|
||||
if row_mask is None or masking == "ignore":
|
||||
shifts[i] = _trimmed_mean_or_median(row, trim_fraction)
|
||||
continue
|
||||
|
||||
values = _masked_values(row, row_mask, masking)
|
||||
if values.size >= mincount:
|
||||
shifts[i] = _trimmed_mean_or_median(values, trim_fraction)
|
||||
else:
|
||||
shifts[i] = total_median
|
||||
|
||||
shifts -= shifts.mean()
|
||||
return shifts
|
||||
|
||||
|
||||
def _slope_level_row_shifts(shifts: np.ndarray) -> np.ndarray:
|
||||
shifts = np.asarray(shifts, dtype=np.float64).copy()
|
||||
if shifts.size <= 1:
|
||||
shifts -= shifts.mean() if shifts.size else 0.0
|
||||
return shifts
|
||||
|
||||
x = np.arange(shifts.size, dtype=np.float64)
|
||||
A = np.column_stack((np.ones_like(x), x))
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, shifts, rcond=None)
|
||||
intercept, slope = coeffs
|
||||
shifts -= intercept + slope * x
|
||||
return shifts
|
||||
|
||||
|
||||
def _find_row_shifts_trimmed_diff(
|
||||
data: np.ndarray,
|
||||
mask: np.ndarray | None,
|
||||
masking: str,
|
||||
trim_fraction: float,
|
||||
mincount: int = 0,
|
||||
) -> np.ndarray:
|
||||
yres, xres = data.shape
|
||||
shifts = np.zeros(yres, dtype=np.float64)
|
||||
if yres <= 1:
|
||||
return shifts
|
||||
|
||||
if mincount <= 0:
|
||||
mincount = int(np.rint(np.log(max(xres, 1)) + 1.0))
|
||||
|
||||
for i in range(yres - 1):
|
||||
upper = data[i]
|
||||
lower = data[i + 1]
|
||||
|
||||
if mask is None or masking == "ignore":
|
||||
diffs = lower - upper
|
||||
else:
|
||||
upper_mask = mask[i]
|
||||
lower_mask = mask[i + 1]
|
||||
valid = upper_mask & lower_mask if masking == "include" else (~upper_mask & ~lower_mask)
|
||||
diffs = (lower - upper)[valid]
|
||||
|
||||
if diffs.size >= mincount:
|
||||
shifts[i + 1] = _trimmed_mean_or_median(diffs, trim_fraction)
|
||||
else:
|
||||
shifts[i + 1] = 0.0
|
||||
|
||||
shifts = np.cumsum(shifts)
|
||||
return _slope_level_row_shifts(shifts)
|
||||
|
||||
|
||||
def _vandermonde(x: np.ndarray, degree: int) -> np.ndarray:
|
||||
return np.vander(np.asarray(x, dtype=np.float64), N=degree + 1, increasing=True)
|
||||
|
||||
|
||||
def _row_level_poly(
|
||||
data: np.ndarray,
|
||||
mask: np.ndarray | None,
|
||||
masking: str,
|
||||
degree: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
yres, xres = data.shape
|
||||
corrected = data.copy()
|
||||
background = np.zeros_like(corrected)
|
||||
shifts = np.zeros(yres, dtype=np.float64)
|
||||
|
||||
if yres == 0 or xres == 0:
|
||||
return corrected, background, shifts
|
||||
|
||||
xc = 0.5 * (xres - 1)
|
||||
avg = float(data.mean())
|
||||
x_all = np.arange(xres, dtype=np.float64) - xc
|
||||
design_all = _vandermonde(x_all, degree)
|
||||
|
||||
for i in range(yres):
|
||||
row = data[i]
|
||||
row_mask = None if mask is None else mask[i]
|
||||
|
||||
if row_mask is None or masking == "ignore":
|
||||
valid = np.ones(xres, dtype=bool)
|
||||
elif masking == "include":
|
||||
valid = row_mask
|
||||
else:
|
||||
valid = ~row_mask
|
||||
|
||||
coeffs = np.zeros(degree + 1, dtype=np.float64)
|
||||
if np.count_nonzero(valid) > degree:
|
||||
design = design_all[valid]
|
||||
coeffs, _, _, _ = np.linalg.lstsq(design, row[valid], rcond=None)
|
||||
|
||||
coeffs[0] -= avg
|
||||
row_background = design_all @ coeffs
|
||||
corrected[i] = row - row_background
|
||||
background[i] = row_background
|
||||
shifts[i] = coeffs[0]
|
||||
|
||||
return corrected, background, shifts
|
||||
|
||||
|
||||
def _calculate_segment_correction(upper: np.ndarray, middle: np.ndarray, lower: np.ndarray) -> np.ndarray:
|
||||
length = upper.size
|
||||
if length < 4:
|
||||
return np.zeros(length, dtype=np.float64)
|
||||
|
||||
corr = float(np.mean((upper + lower) / 2.0 - middle))
|
||||
return (3.0 * corr + (upper + lower) / 2.0 - middle) / 4.0
|
||||
|
||||
|
||||
def _line_correct_step_iter(data: np.ndarray) -> np.ndarray:
|
||||
yres, xres = data.shape
|
||||
if yres < 3 or xres == 0:
|
||||
return data.copy()
|
||||
|
||||
corrections = np.zeros_like(data)
|
||||
vertical_diff_energy = float(np.mean((data[1:] - data[:-1]) ** 2))
|
||||
if vertical_diff_energy <= 0.0:
|
||||
return data.copy()
|
||||
|
||||
threshold = 3.0
|
||||
for i in range(yres - 2):
|
||||
upper = data[i]
|
||||
middle = data[i + 1]
|
||||
lower = data[i + 2]
|
||||
marker_row = corrections[i + 1]
|
||||
|
||||
candidates = (middle - upper) * (middle - lower) > threshold * vertical_diff_energy
|
||||
if np.any(candidates):
|
||||
signs = np.where(2.0 * middle[candidates] - upper[candidates] - lower[candidates] > 0.0, 1.0, -1.0)
|
||||
marker_row[candidates] = signs
|
||||
|
||||
segment_start = 0
|
||||
while segment_start < xres:
|
||||
sign = marker_row[segment_start]
|
||||
if sign == 0.0:
|
||||
segment_start += 1
|
||||
continue
|
||||
|
||||
segment_end = segment_start + 1
|
||||
while segment_end < xres and marker_row[segment_end] == sign:
|
||||
segment_end += 1
|
||||
|
||||
marker_row[segment_start:segment_end] = _calculate_segment_correction(
|
||||
upper[segment_start:segment_end],
|
||||
middle[segment_start:segment_end],
|
||||
lower[segment_start:segment_end],
|
||||
)
|
||||
segment_start = segment_end
|
||||
|
||||
return data + corrections
|
||||
|
||||
|
||||
def _conservative_filter(data: np.ndarray, size: int) -> np.ndarray:
|
||||
if size <= 1:
|
||||
return data.copy()
|
||||
|
||||
from scipy.ndimage import maximum_filter, minimum_filter
|
||||
|
||||
footprint = np.ones((size, size), dtype=bool)
|
||||
footprint[size // 2, size // 2] = False
|
||||
min_neighbours = minimum_filter(data, footprint=footprint, mode="nearest")
|
||||
max_neighbours = maximum_filter(data, footprint=footprint, mode="nearest")
|
||||
return np.clip(data, min_neighbours, max_neighbours)
|
||||
|
||||
|
||||
def _line_correct_step(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
corrected = data.copy()
|
||||
avg = float(corrected.mean())
|
||||
|
||||
shifts = _find_row_shifts_trimmed_mean(corrected, None, "ignore", 0.5, 0)
|
||||
corrected -= shifts[:, np.newaxis]
|
||||
corrected = _line_correct_step_iter(corrected)
|
||||
corrected = _line_correct_step_iter(corrected)
|
||||
corrected = _conservative_filter(corrected, 5)
|
||||
corrected += avg - float(corrected.mean())
|
||||
|
||||
background = data - corrected
|
||||
step_shifts = background.mean(axis=1) if background.size else np.zeros(data.shape[0], dtype=np.float64)
|
||||
return corrected, step_shifts
|
||||
|
||||
|
||||
def _line_axis(length: int, real_extent: float) -> np.ndarray | None:
|
||||
if length <= 0:
|
||||
return None
|
||||
return np.linspace(0.0, float(real_extent), int(length), dtype=np.float64)
|
||||
|
||||
|
||||
@register_node(display_name="Line Correction")
|
||||
class LineCorrection:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"method": ([
|
||||
"median",
|
||||
"median_diff",
|
||||
"trimmed_mean",
|
||||
"trimmed_diff",
|
||||
"polynomial",
|
||||
"step",
|
||||
], {"default": "median"}),
|
||||
"direction": (["horizontal", "vertical"], {"default": "horizontal"}),
|
||||
"masking": (["ignore", "include", "exclude"], {"default": "ignore"}),
|
||||
"trim_fraction": ("FLOAT", {
|
||||
"default": 0.05,
|
||||
"min": 0.0,
|
||||
"max": 0.5,
|
||||
"step": 0.01,
|
||||
"show_when_widget_value": {"method": ["trimmed_mean", "trimmed_diff"]},
|
||||
}),
|
||||
"polynomial_degree": ("INT", {
|
||||
"default": 1,
|
||||
"min": 0,
|
||||
"max": 5,
|
||||
"step": 1,
|
||||
"show_when_widget_value": {"method": ["polynomial"]},
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD", "DATA_FIELD", "LINE")
|
||||
RETURN_NAMES = ("corrected", "background", "row shifts")
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Correct scan-line mismatches using Gwyddion-derived row alignment methods. "
|
||||
"Supports median and trimmed row alignment, difference-based alignment, polynomial row leveling, "
|
||||
"and the step-line correction path from Gwyddion's linecorrect/linematch modules."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
method: str,
|
||||
direction: str,
|
||||
masking: str,
|
||||
trim_fraction: float,
|
||||
polynomial_degree: int,
|
||||
mask: np.ndarray | None = None,
|
||||
) -> tuple:
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
mask_array = _normalize_mask(mask, data.shape)
|
||||
|
||||
if direction not in {"horizontal", "vertical"}:
|
||||
raise ValueError(f"Unknown direction: {direction}")
|
||||
|
||||
working = data.copy()
|
||||
working_mask = None if mask_array is None else mask_array.copy()
|
||||
if direction == "vertical":
|
||||
working = working.T
|
||||
if working_mask is not None:
|
||||
working_mask = working_mask.T
|
||||
|
||||
if method == "median":
|
||||
shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, 0.5, 0)
|
||||
corrected = working - shifts[:, np.newaxis]
|
||||
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
|
||||
elif method == "median_diff":
|
||||
shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, 0.5, 0)
|
||||
corrected = working - shifts[:, np.newaxis]
|
||||
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
|
||||
elif method == "trimmed_mean":
|
||||
shifts = _find_row_shifts_trimmed_mean(working, working_mask, masking, float(trim_fraction), 0)
|
||||
corrected = working - shifts[:, np.newaxis]
|
||||
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
|
||||
elif method == "trimmed_diff":
|
||||
shifts = _find_row_shifts_trimmed_diff(working, working_mask, masking, float(trim_fraction), 0)
|
||||
corrected = working - shifts[:, np.newaxis]
|
||||
background = np.broadcast_to(shifts[:, np.newaxis], working.shape).copy()
|
||||
elif method == "polynomial":
|
||||
corrected, background, shifts = _row_level_poly(
|
||||
working,
|
||||
working_mask,
|
||||
masking,
|
||||
int(polynomial_degree),
|
||||
)
|
||||
elif method == "step":
|
||||
corrected, shifts = _line_correct_step(working)
|
||||
background = working - corrected
|
||||
else:
|
||||
raise ValueError(f"Unknown line correction method: {method}")
|
||||
|
||||
if direction == "vertical":
|
||||
corrected = corrected.T
|
||||
background = background.T
|
||||
line_axis = _line_axis(field.xres, field.xreal)
|
||||
else:
|
||||
line_axis = _line_axis(field.yres, field.yreal)
|
||||
|
||||
corrected_field = field.replace(data=corrected)
|
||||
background_field = field.replace(data=background)
|
||||
shift_line = LineData(
|
||||
data=np.asarray(shifts, dtype=np.float64),
|
||||
x_axis=line_axis,
|
||||
x_unit=field.si_unit_xy,
|
||||
y_unit=field.si_unit_z,
|
||||
)
|
||||
|
||||
return (corrected_field, background_field, shift_line)
|
||||
219
backend/nodes/scar_removal.py
Normal file
219
backend/nodes/scar_removal.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from backend.data_types import DataField
|
||||
from backend.node_registry import register_node
|
||||
|
||||
|
||||
def _mark_scars_one_sign(
|
||||
data: np.ndarray,
|
||||
threshold_high: float,
|
||||
threshold_low: float,
|
||||
min_length: int,
|
||||
max_width: int,
|
||||
negative: bool,
|
||||
) -> np.ndarray:
|
||||
yres, xres = data.shape
|
||||
marks = np.zeros_like(data, dtype=np.float64)
|
||||
|
||||
min_length = max(int(min_length), 1)
|
||||
max_width = min(int(max_width), yres - 2)
|
||||
threshold_high = max(float(threshold_high), float(threshold_low))
|
||||
threshold_low = float(threshold_low)
|
||||
|
||||
if min_length > xres or max_width < 1 or threshold_low <= 0.0:
|
||||
return marks
|
||||
|
||||
vertical_rms = float(np.sqrt(np.sum((data[:-1] - data[1:]) ** 2) / max(xres * yres, 1)))
|
||||
if vertical_rms == 0.0:
|
||||
return marks
|
||||
|
||||
for i in range(yres - (max_width + 1)):
|
||||
for j in range(xres):
|
||||
if negative:
|
||||
top = data[i, j]
|
||||
bottom = data[i + 1, j]
|
||||
width = 0
|
||||
for candidate in range(1, max_width + 1):
|
||||
top = min(data[i, j], data[i + candidate + 1, j])
|
||||
bottom = max(bottom, data[i + candidate, j])
|
||||
if top - bottom >= threshold_low * vertical_rms:
|
||||
width = candidate
|
||||
break
|
||||
|
||||
if width:
|
||||
for candidate in range(width, 0, -1):
|
||||
marks[i + candidate, j] = max(
|
||||
marks[i + candidate, j],
|
||||
(top - data[i + candidate, j]) / vertical_rms,
|
||||
)
|
||||
else:
|
||||
bottom = data[i, j]
|
||||
top = data[i + 1, j]
|
||||
width = 0
|
||||
for candidate in range(1, max_width + 1):
|
||||
bottom = max(data[i, j], data[i + candidate + 1, j])
|
||||
top = min(top, data[i + candidate, j])
|
||||
if top - bottom >= threshold_low * vertical_rms:
|
||||
width = candidate
|
||||
break
|
||||
|
||||
if width:
|
||||
for candidate in range(width, 0, -1):
|
||||
marks[i + candidate, j] = max(
|
||||
marks[i + candidate, j],
|
||||
(data[i + candidate, j] - bottom) / vertical_rms,
|
||||
)
|
||||
|
||||
for i in range(yres):
|
||||
row = marks[i]
|
||||
for j in range(1, xres):
|
||||
if row[j] >= threshold_low and row[j - 1] >= threshold_high:
|
||||
row[j] = threshold_high
|
||||
for j in range(xres - 1, 0, -1):
|
||||
if row[j - 1] >= threshold_low and row[j] >= threshold_high:
|
||||
row[j - 1] = threshold_high
|
||||
|
||||
for i in range(yres):
|
||||
row = marks[i]
|
||||
run_length = 0
|
||||
for j in range(xres):
|
||||
if row[j] >= threshold_high:
|
||||
row[j] = 1.0
|
||||
run_length += 1
|
||||
continue
|
||||
|
||||
if 0 < run_length < min_length:
|
||||
row[j - run_length:j] = 0.0
|
||||
row[j] = 0.0
|
||||
run_length = 0
|
||||
|
||||
if 0 < run_length < min_length:
|
||||
row[xres - run_length:xres] = 0.0
|
||||
|
||||
return marks
|
||||
|
||||
|
||||
def _mark_scars(
|
||||
data: np.ndarray,
|
||||
scar_type: str,
|
||||
threshold_high: float,
|
||||
threshold_low: float,
|
||||
min_length: int,
|
||||
max_width: int,
|
||||
) -> np.ndarray:
|
||||
if scar_type == "positive":
|
||||
return _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=False)
|
||||
if scar_type == "negative":
|
||||
return _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=True)
|
||||
if scar_type == "both":
|
||||
positive = _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=False)
|
||||
negative = _mark_scars_one_sign(data, threshold_high, threshold_low, min_length, max_width, negative=True)
|
||||
return np.maximum(positive, negative)
|
||||
raise ValueError(f"Unknown scar type: {scar_type}")
|
||||
|
||||
|
||||
def _laplace_inpaint(data: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||
mask = np.asarray(mask, dtype=bool)
|
||||
if not np.any(mask):
|
||||
return data.copy()
|
||||
|
||||
if np.all(mask):
|
||||
return np.zeros_like(data, dtype=np.float64)
|
||||
|
||||
from scipy.sparse import csr_matrix
|
||||
from scipy.sparse.linalg import MatrixRankWarning, spsolve
|
||||
from skimage.restoration import inpaint_biharmonic
|
||||
|
||||
yres, xres = data.shape
|
||||
unknown_indices = -np.ones((yres, xres), dtype=np.int64)
|
||||
unknown_coords = np.argwhere(mask)
|
||||
unknown_indices[mask] = np.arange(unknown_coords.shape[0], dtype=np.int64)
|
||||
|
||||
rows: list[int] = []
|
||||
cols: list[int] = []
|
||||
values: list[float] = []
|
||||
rhs = np.zeros(unknown_coords.shape[0], dtype=np.float64)
|
||||
|
||||
for row_index, (y, x) in enumerate(unknown_coords):
|
||||
degree = 0
|
||||
for ny, nx in ((y - 1, x), (y + 1, x), (y, x - 1), (y, x + 1)):
|
||||
if ny < 0 or ny >= yres or nx < 0 or nx >= xres:
|
||||
continue
|
||||
|
||||
degree += 1
|
||||
if mask[ny, nx]:
|
||||
rows.append(row_index)
|
||||
cols.append(int(unknown_indices[ny, nx]))
|
||||
values.append(-1.0)
|
||||
else:
|
||||
rhs[row_index] += float(data[ny, nx])
|
||||
|
||||
rows.append(row_index)
|
||||
cols.append(row_index)
|
||||
values.append(float(degree))
|
||||
|
||||
matrix = csr_matrix((values, (rows, cols)), shape=(unknown_coords.shape[0], unknown_coords.shape[0]))
|
||||
restored = data.copy()
|
||||
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("error", category=MatrixRankWarning)
|
||||
solved = spsolve(matrix, rhs)
|
||||
except (MatrixRankWarning, RuntimeError, ValueError):
|
||||
return np.asarray(inpaint_biharmonic(data, mask, channel_axis=None), dtype=np.float64)
|
||||
|
||||
restored[mask] = np.asarray(solved, dtype=np.float64)
|
||||
return restored
|
||||
|
||||
|
||||
@register_node(display_name="Scar Removal")
|
||||
class ScarRemoval:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
"scar_type": (["both", "positive", "negative"], {"default": "both"}),
|
||||
"threshold_high": ("FLOAT", {"default": 0.666, "min": 0.0, "max": 2.0, "step": 0.01}),
|
||||
"threshold_low": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 2.0, "step": 0.01}),
|
||||
"min_length": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
|
||||
"max_width": ("INT", {"default": 4, "min": 1, "max": 32, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD", "IMAGE")
|
||||
RETURN_NAMES = ("corrected", "scar mask")
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Detect and remove horizontal scan scars using Gwyddion-derived scar marking thresholds, "
|
||||
"then interpolate over the detected mask with a Laplace-style inpaint."
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
scar_type: str,
|
||||
threshold_high: float,
|
||||
threshold_low: float,
|
||||
min_length: int,
|
||||
max_width: int,
|
||||
) -> tuple:
|
||||
threshold_high = float(max(threshold_high, threshold_low))
|
||||
threshold_low = float(min(threshold_high, threshold_low))
|
||||
|
||||
marks = _mark_scars(
|
||||
np.asarray(field.data, dtype=np.float64),
|
||||
scar_type,
|
||||
threshold_high,
|
||||
threshold_low,
|
||||
int(min_length),
|
||||
int(max_width),
|
||||
)
|
||||
scar_mask = marks > 0.0
|
||||
corrected = _laplace_inpaint(np.asarray(field.data, dtype=np.float64), scar_mask)
|
||||
return (field.replace(data=corrected), scar_mask.astype(np.uint8) * 255)
|
||||
Reference in New Issue
Block a user