deduplication pass
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from backend.node_registry import register_node
|
||||
from backend.execution_context import emit_preview, emit_overlay
|
||||
from backend.data_types import DataField, encode_preview, RecordTable
|
||||
from backend.nodes.helpers import _mask_overlay
|
||||
from backend.execution_context import emit_overlay
|
||||
from backend.data_types import DataField, RecordTable
|
||||
from backend.nodes.helpers import bool_to_mask, histogram_with_centers, emit_mask_preview
|
||||
|
||||
|
||||
@register_node(display_name="Threshold Mask")
|
||||
@@ -36,9 +36,7 @@ class ThresholdMask:
|
||||
def process(self, field: DataField, method: str, threshold: float, direction: str) -> tuple:
|
||||
data = field.data
|
||||
|
||||
raw_counts, bin_edges = np.histogram(data.ravel(), bins=256)
|
||||
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
|
||||
counts = raw_counts.astype(np.float64)
|
||||
counts, bin_centers = histogram_with_centers(data)
|
||||
xmin = float(bin_centers[0]) if len(bin_centers) else 0.0
|
||||
xmax = float(bin_centers[-1]) if len(bin_centers) else 1.0
|
||||
|
||||
@@ -70,11 +68,11 @@ class ThresholdMask:
|
||||
})
|
||||
|
||||
if direction == "above":
|
||||
mask = (data >= t).astype(np.uint8) * 255
|
||||
mask = bool_to_mask(data >= t)
|
||||
else:
|
||||
mask = (data < t).astype(np.uint8) * 255
|
||||
mask = bool_to_mask(data < t)
|
||||
|
||||
emit_preview(encode_preview(_mask_overlay(field, mask)))
|
||||
emit_mask_preview(field, mask)
|
||||
|
||||
table = RecordTable([
|
||||
{"quantity": "threshold", "value": threshold, "unit": field.si_unit_xy},
|
||||
|
||||
Reference in New Issue
Block a user