From 9fbd305854b6658ad7032c9db84b2662214a756c Mon Sep 17 00:00:00 2001 From: matei jordache Date: Thu, 16 Apr 2026 00:06:15 -0700 Subject: [PATCH] add masking to stats --- backend/nodes/statistics.py | 21 ++++++++++++++--- docs/nodes/Statistics.md | 3 ++- tests/node_tests/statistics.py | 43 ++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/backend/nodes/statistics.py b/backend/nodes/statistics.py index a8ed8b6..2d19eaf 100644 --- a/backend/nodes/statistics.py +++ b/backend/nodes/statistics.py @@ -2,6 +2,7 @@ from __future__ import annotations import numpy as np from backend.node_registry import register_node from backend.data_types import DataField, RecordTable +from backend.nodes.helpers import mask_to_bool @register_node(display_name="Statistics") @@ -11,7 +12,10 @@ class Statistics: return { "required": { "field": ("DATA_FIELD",), - } + }, + "optional": { + "mask": ("IMAGE",), + }, } OUTPUTS = ( @@ -21,13 +25,24 @@ class Statistics: DESCRIPTION = ( "Compute basic surface statistics: min, max, mean, RMS roughness, median, " - "and skewness." + "and skewness. When a mask is provided, only pixels inside the mask are " + "included." ) KEYWORDS = ("mean", "rms", "min", "max", "skewness", "kurtosis", "median", "roughness") - def process(self, field: DataField) -> tuple: + def process(self, field: DataField, mask: np.ndarray | None = None) -> tuple: d = field.data + if mask is not None: + selector = mask_to_bool(mask) + if selector.shape != d.shape: + raise ValueError( + f"Mask shape {selector.shape} does not match field shape {d.shape}" + ) + d = d[selector] + if d.size == 0: + raise ValueError("Mask selects no pixels") + mean = float(d.mean()) rms = float(np.sqrt(np.mean((d - mean) ** 2))) skewness = float(np.mean(((d - mean) / rms) ** 3)) if rms > 0 else 0.0 diff --git a/docs/nodes/Statistics.md b/docs/nodes/Statistics.md index 30b0041..a82cd75 100644 --- a/docs/nodes/Statistics.md +++ b/docs/nodes/Statistics.md @@ -7,6 +7,7 @@ Compute basic surface statistics: min, max, mean, RMS roughness, median, and ske | Name | Type | Required | Description | |------|------|----------|-------------| | field | DATA_FIELD | Yes | Input field to analyze | +| mask | IMAGE | No | Optional binary mask — only pixels inside the mask contribute to the statistics | ## Outputs @@ -20,4 +21,4 @@ None. ## Notes -- None. +- When a mask is provided, it must match the field's pixel resolution. Only pixels where the mask is non-zero are included in the statistics. diff --git a/tests/node_tests/statistics.py b/tests/node_tests/statistics.py index f0ed4d5..1447d3e 100644 --- a/tests/node_tests/statistics.py +++ b/tests/node_tests/statistics.py @@ -26,3 +26,46 @@ def test_statistics(): assert const_stats["RMS"] == 0.0 assert const_stats["skewness"] == 0.0 assert const_stats["kurtosis"] == 0.0 + + +def test_statistics_with_mask(): + """A mask restricts the stats to pixels where mask != 0.""" + from backend.nodes.statistics import Statistics + node = Statistics() + + data = np.array([[1, 2], [3, 4]], dtype=np.float64) + field = make_field(data=data) + # Mask selects only pixels >= 3 (bottom row). + mask = np.array([[0, 0], [255, 255]], dtype=np.uint8) + + table, = node.process(field, mask=mask) + stats = {row["quantity"]: row["value"] for row in table} + assert stats["min"] == 3.0 + assert stats["max"] == 4.0 + assert stats["mean"] == 3.5 + + +def test_statistics_mask_shape_mismatch(): + from backend.nodes.statistics import Statistics + node = Statistics() + + field = make_field(data=np.zeros((4, 4))) + bad_mask = np.zeros((3, 3), dtype=np.uint8) + try: + node.process(field, mask=bad_mask) + raise AssertionError("expected shape mismatch to raise") + except ValueError: + pass + + +def test_statistics_empty_mask(): + from backend.nodes.statistics import Statistics + node = Statistics() + + field = make_field(data=np.ones((4, 4))) + empty_mask = np.zeros((4, 4), dtype=np.uint8) + try: + node.process(field, mask=empty_mask) + raise AssertionError("expected empty mask to raise") + except ValueError: + pass