clean tests
This commit is contained in:
@@ -4,6 +4,49 @@ from backend.node_registry import register_node
|
||||
from backend.data_types import DataField
|
||||
|
||||
|
||||
def _normalize_mask(mask: np.ndarray | None, shape: tuple[int, int]) -> np.ndarray | None:
|
||||
if mask is None:
|
||||
return None
|
||||
|
||||
mask_array = np.asarray(mask)
|
||||
if mask_array.shape[:2] != shape:
|
||||
raise ValueError(f"Mask shape {mask_array.shape} does not match field shape {shape}.")
|
||||
return mask_array > 127
|
||||
|
||||
|
||||
def _fit_plane(
|
||||
data: np.ndarray,
|
||||
mask: np.ndarray | None,
|
||||
masking: str,
|
||||
) -> tuple[float, float, float, np.ndarray, np.ndarray]:
|
||||
yres, xres = data.shape
|
||||
x = np.linspace(0.0, 1.0, xres)
|
||||
y = np.linspace(0.0, 1.0, yres)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
if mask is None or masking == "ignore":
|
||||
valid = np.ones(data.shape, dtype=bool)
|
||||
elif masking == "include":
|
||||
valid = mask
|
||||
elif masking == "exclude":
|
||||
valid = ~mask
|
||||
else:
|
||||
raise ValueError(f"Unknown masking mode: {masking}")
|
||||
|
||||
if np.count_nonzero(valid) < 3:
|
||||
raise ValueError("Plane Level requires at least three usable pixels for fitting.")
|
||||
|
||||
A = np.column_stack([
|
||||
np.ones(int(np.count_nonzero(valid)), dtype=np.float64),
|
||||
xx[valid].ravel(),
|
||||
yy[valid].ravel(),
|
||||
])
|
||||
z = data[valid].ravel()
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
|
||||
pa, pbx, pby = coeffs
|
||||
return float(pa), float(pbx), float(pby), xx, yy
|
||||
|
||||
|
||||
@register_node(display_name="Plane Level")
|
||||
class PlaneLevelField:
|
||||
@classmethod
|
||||
@@ -11,7 +54,11 @@ class PlaneLevelField:
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
"masking": (["ignore", "include", "exclude"], {"default": "ignore"}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("DATA_FIELD",)
|
||||
@@ -19,27 +66,19 @@ class PlaneLevelField:
|
||||
FUNCTION = "process"
|
||||
|
||||
DESCRIPTION = (
|
||||
"Fit and subtract a least-squares plane from the data. "
|
||||
"Equivalent to gwy_data_field_fit_plane + gwy_data_field_plane_level."
|
||||
"Fit and subtract a least-squares plane from the data. Supports include/exclude mask fitting "
|
||||
"for flattening around features, similar to masked plane fitting workflows in Gwyddion."
|
||||
)
|
||||
|
||||
def process(self, field: DataField) -> tuple:
|
||||
def process(
|
||||
self,
|
||||
field: DataField,
|
||||
masking: str = "ignore",
|
||||
mask: np.ndarray | None = None,
|
||||
) -> tuple:
|
||||
data = field.data.copy()
|
||||
yres, xres = data.shape
|
||||
|
||||
x = np.linspace(0.0, 1.0, xres)
|
||||
y = np.linspace(0.0, 1.0, yres)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
A = np.column_stack([
|
||||
np.ones(xres * yres),
|
||||
xx.ravel(),
|
||||
yy.ravel(),
|
||||
])
|
||||
z = data.ravel()
|
||||
|
||||
coeffs, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
|
||||
pa, pbx, pby = coeffs
|
||||
mask_array = _normalize_mask(mask, data.shape)
|
||||
pa, pbx, pby, xx, yy = _fit_plane(data, mask_array, masking)
|
||||
|
||||
plane = (pa + pbx * xx + pby * yy)
|
||||
return (field.replace(data=data - plane),)
|
||||
|
||||
Reference in New Issue
Block a user