deduplication pass
This commit is contained in:
@@ -4,16 +4,7 @@ import numpy as np
|
||||
|
||||
from backend.data_types import DataField, LineData
|
||||
from backend.node_registry import register_node
|
||||
|
||||
|
||||
def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None:
|
||||
if mask is None:
|
||||
return None
|
||||
|
||||
mask_array = np.asarray(mask)
|
||||
if mask_array.shape[:2] != shape:
|
||||
raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.")
|
||||
return mask_array > 127
|
||||
from backend.nodes.helpers import normalize_mask, apply_masking, masked_values
|
||||
|
||||
|
||||
def _trimmed_mean_or_median(values: np.ndarray, trim_fraction: float) -> float:
|
||||
@@ -33,18 +24,8 @@ def _trimmed_mean_or_median(values: np.ndarray, trim_fraction: float) -> float:
|
||||
return float(trimmed.mean()) if trimmed.size else float(np.median(sorted_values))
|
||||
|
||||
|
||||
def _masked_values(data: np.ndarray, mask: np.ndarray | None, masking: str) -> np.ndarray:
|
||||
if mask is None or masking == "ignore":
|
||||
return data
|
||||
if masking == "include":
|
||||
return data[mask]
|
||||
if masking == "exclude":
|
||||
return data[~mask]
|
||||
raise ValueError(f"Unknown masking mode: {masking}")
|
||||
|
||||
|
||||
def _global_masked_median(data: np.ndarray, mask: np.ndarray | None, masking: str) -> float:
|
||||
selected = _masked_values(data, mask, masking)
|
||||
selected = masked_values(data, mask, masking)
|
||||
if selected.size == 0:
|
||||
selected = np.asarray(data, dtype=np.float64).ravel()
|
||||
return float(np.median(selected))
|
||||
@@ -75,7 +56,7 @@ def _find_row_shifts_trimmed_mean(
|
||||
shifts[i] = _trimmed_mean_or_median(row, trim_fraction)
|
||||
continue
|
||||
|
||||
values = _masked_values(row, row_mask, masking)
|
||||
values = masked_values(row, row_mask, masking)
|
||||
if values.size >= mincount:
|
||||
shifts[i] = _trimmed_mean_or_median(values, trim_fraction)
|
||||
else:
|
||||
@@ -162,12 +143,7 @@ def _row_level_poly(
|
||||
row = data[i]
|
||||
row_mask = None if mask is None else mask[i]
|
||||
|
||||
if row_mask is None or masking == "ignore":
|
||||
valid = np.ones(xres, dtype=bool)
|
||||
elif masking == "include":
|
||||
valid = row_mask
|
||||
else:
|
||||
valid = ~row_mask
|
||||
valid = apply_masking(row, row_mask, masking)
|
||||
|
||||
coeffs = np.zeros(degree + 1, dtype=np.float64)
|
||||
if np.count_nonzero(valid) > degree:
|
||||
@@ -331,7 +307,7 @@ class LineCorrection:
|
||||
mask: np.ndarray | None = None,
|
||||
) -> tuple:
|
||||
data = np.asarray(field.data, dtype=np.float64)
|
||||
mask_array = _normalize_mask(mask, data.shape)
|
||||
mask_array = normalize_mask(mask, data.shape)
|
||||
|
||||
if direction not in {"horizontal", "vertical"}:
|
||||
raise ValueError(f"Unknown direction: {direction}")
|
||||
|
||||
Reference in New Issue
Block a user