add masking to stats
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.data_types import DataField, RecordTable
|
||||
from backend.nodes.helpers import mask_to_bool
|
||||
|
||||
|
||||
@register_node(display_name="Statistics")
|
||||
@@ -11,7 +12,10 @@ class Statistics:
|
||||
return {
|
||||
"required": {
|
||||
"field": ("DATA_FIELD",),
|
||||
}
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
OUTPUTS = (
|
||||
@@ -21,13 +25,24 @@ class Statistics:
|
||||
|
||||
DESCRIPTION = (
|
||||
"Compute basic surface statistics: min, max, mean, RMS roughness, median, "
|
||||
"and skewness."
|
||||
"and skewness. When a mask is provided, only pixels inside the mask are "
|
||||
"included."
|
||||
)
|
||||
|
||||
KEYWORDS = ("mean", "rms", "min", "max", "skewness", "kurtosis", "median", "roughness")
|
||||
|
||||
def process(self, field: DataField) -> tuple:
|
||||
def process(self, field: DataField, mask: np.ndarray | None = None) -> tuple:
|
||||
d = field.data
|
||||
if mask is not None:
|
||||
selector = mask_to_bool(mask)
|
||||
if selector.shape != d.shape:
|
||||
raise ValueError(
|
||||
f"Mask shape {selector.shape} does not match field shape {d.shape}"
|
||||
)
|
||||
d = d[selector]
|
||||
if d.size == 0:
|
||||
raise ValueError("Mask selects no pixels")
|
||||
|
||||
mean = float(d.mean())
|
||||
rms = float(np.sqrt(np.mean((d - mean) ** 2)))
|
||||
skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0
|
||||
|
||||
@@ -7,6 +7,7 @@ Compute basic surface statistics: min, max, mean, RMS roughness, median, and ske
|
||||
| Name | Type | Required | Description |
|
||||
|------|------|----------|-------------|
|
||||
| field | DATA_FIELD | Yes | Input field to analyze |
|
||||
| mask | IMAGE | No | Optional binary mask — only pixels inside the mask contribute to the statistics |
|
||||
|
||||
## Outputs
|
||||
|
||||
@@ -20,4 +21,4 @@ None.
|
||||
|
||||
## Notes
|
||||
|
||||
- None.
|
||||
- When a mask is provided, it must match the field's pixel resolution. Only pixels where the mask is non-zero are included in the statistics.
|
||||
|
||||
@@ -26,3 +26,46 @@ def test_statistics():
|
||||
assert const_stats["RMS"] == 0.0
|
||||
assert const_stats["skewness"] == 0.0
|
||||
assert const_stats["kurtosis"] == 0.0
|
||||
|
||||
|
||||
def test_statistics_with_mask():
|
||||
"""A mask restricts the stats to pixels where mask != 0."""
|
||||
from backend.nodes.statistics import Statistics
|
||||
node = Statistics()
|
||||
|
||||
data = np.array([[1, 2], [3, 4]], dtype=np.float64)
|
||||
field = make_field(data=data)
|
||||
# Mask selects only pixels >= 3 (bottom row).
|
||||
mask = np.array([[0, 0], [255, 255]], dtype=np.uint8)
|
||||
|
||||
table, = node.process(field, mask=mask)
|
||||
stats = {row["quantity"]: row["value"] for row in table}
|
||||
assert stats["min"] == 3.0
|
||||
assert stats["max"] == 4.0
|
||||
assert stats["mean"] == 3.5
|
||||
|
||||
|
||||
def test_statistics_mask_shape_mismatch():
|
||||
from backend.nodes.statistics import Statistics
|
||||
node = Statistics()
|
||||
|
||||
field = make_field(data=np.zeros((4, 4)))
|
||||
bad_mask = np.zeros((3, 3), dtype=np.uint8)
|
||||
try:
|
||||
node.process(field, mask=bad_mask)
|
||||
raise AssertionError("expected shape mismatch to raise")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_statistics_empty_mask():
|
||||
from backend.nodes.statistics import Statistics
|
||||
node = Statistics()
|
||||
|
||||
field = make_field(data=np.ones((4, 4)))
|
||||
empty_mask = np.zeros((4, 4), dtype=np.uint8)
|
||||
try:
|
||||
node.process(field, mask=empty_mask)
|
||||
raise AssertionError("expected empty mask to raise")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user