133 lines
5.1 KiB
Python
133 lines
5.1 KiB
Python
from __future__ import annotations
|
|
import numpy as np
|
|
from backend.node_registry import register_node
|
|
from backend.execution_context import emit_value
|
|
from backend.data_types import DataField, LineData, MeasureTable
|
|
from backend.nodes.helpers import (
|
|
LINE_OPS,
|
|
TABLE_OPS,
|
|
ARRAY_OPS,
|
|
_scalar_payload,
|
|
_apply_scalar_unit,
|
|
_common_table_unit,
|
|
extract_numeric_table_values,
|
|
resolve_table_column_name,
|
|
)
|
|
|
|
|
|
@register_node(display_name="Stats")
|
|
class Stats:
|
|
"""Polymorphic scalar stats node for LINE, RECORD_TABLE, DATA_FIELD, or IMAGE inputs."""
|
|
|
|
_broadcast_value_fn = None
|
|
_current_node_id: str = ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"input": ("DATA_FIELD", {
|
|
"accepted_types": ["IMAGE", "LINE", "RECORD_TABLE"],
|
|
}),
|
|
"column": ("STRING", {
|
|
"default": "value",
|
|
"choices_from_table_input": "input",
|
|
"show_when_source_type": {
|
|
"input": ["RECORD_TABLE"],
|
|
},
|
|
}),
|
|
"operation": ("STRING", {
|
|
"default": "mean",
|
|
"choices_by_source_type": {
|
|
"LINE": list(LINE_OPS.keys()),
|
|
"RECORD_TABLE": list(TABLE_OPS.keys()),
|
|
"DATA_FIELD": list(ARRAY_OPS.keys()),
|
|
"IMAGE": list(ARRAY_OPS.keys()),
|
|
},
|
|
"source_type_input": "input",
|
|
}),
|
|
}
|
|
}
|
|
|
|
OUTPUTS = (
|
|
('FLOAT', 'value'),
|
|
)
|
|
FUNCTION = "process"
|
|
|
|
DESCRIPTION = (
|
|
"Compute a contextual scalar statistic from a LINE, record table, DATA_FIELD, or IMAGE. "
|
|
"The available operations adapt to the connected input type."
|
|
)
|
|
|
|
def process(self, input, operation: str, column: str = "value") -> tuple:
|
|
source_type, values, resolved_column = self._resolve_input_values(input, column)
|
|
|
|
if source_type == "RECORD_TABLE":
|
|
ops = TABLE_OPS
|
|
elif source_type == "LINE":
|
|
ops = LINE_OPS
|
|
else:
|
|
ops = ARRAY_OPS
|
|
|
|
if operation not in ops:
|
|
raise ValueError(f"Operation '{operation}' is not valid for {source_type} input.")
|
|
|
|
op_entry = ops[operation]
|
|
fn = op_entry[0] if isinstance(op_entry, tuple) else op_entry
|
|
result = fn(values)
|
|
emit_value(
|
|
_scalar_payload(result, self._resolve_output_unit(input, source_type, resolved_column, operation)),
|
|
)
|
|
return (result,)
|
|
|
|
def _resolve_output_unit(self, input_value, source_type: str, column: str | None, operation: str) -> str:
|
|
if source_type == "DATA_FIELD" and isinstance(input_value, DataField):
|
|
return _apply_scalar_unit(input_value.si_unit_z, operation)
|
|
|
|
if source_type == "LINE":
|
|
line_entry = LINE_OPS.get(operation)
|
|
explicit_unit = line_entry[1] if isinstance(line_entry, tuple) and len(line_entry) > 1 else ""
|
|
if explicit_unit:
|
|
return _apply_scalar_unit(explicit_unit, operation)
|
|
if isinstance(input_value, LineData):
|
|
return _apply_scalar_unit(input_value.y_unit, operation)
|
|
return ""
|
|
|
|
if source_type == "RECORD_TABLE" and isinstance(input_value, list) and column:
|
|
return _apply_scalar_unit(_common_table_unit(input_value, column), operation)
|
|
|
|
return ""
|
|
|
|
def _resolve_input_values(self, input_value, column: str) -> tuple[str, np.ndarray, str | None]:
|
|
if isinstance(input_value, DataField):
|
|
values = np.asarray(input_value.data, dtype=np.float64)
|
|
return ("DATA_FIELD", values.ravel(), None)
|
|
|
|
if isinstance(input_value, MeasureTable):
|
|
raise ValueError("Stats only accepts record tables, not measurement tables.")
|
|
|
|
if isinstance(input_value, list):
|
|
if not input_value:
|
|
raise ValueError("Stats requires a non-empty record table input.")
|
|
column_name = resolve_table_column_name(input_value, column)
|
|
values = extract_numeric_table_values(input_value, column_name)
|
|
if not values:
|
|
raise ValueError(f"Column '{column_name}' has no numeric values.")
|
|
return ("RECORD_TABLE", np.asarray(values, dtype=np.float64), column_name)
|
|
|
|
if isinstance(input_value, LineData):
|
|
values = np.asarray(input_value.data, dtype=np.float64)
|
|
if values.size == 0:
|
|
raise ValueError("Stats requires a non-empty input.")
|
|
return ("LINE", values.ravel(), None)
|
|
|
|
if isinstance(input_value, np.ndarray):
|
|
values = np.asarray(input_value, dtype=np.float64)
|
|
if values.size == 0:
|
|
raise ValueError("Stats requires a non-empty input.")
|
|
if values.ndim == 1:
|
|
return ("LINE", values.ravel(), None)
|
|
return ("IMAGE", values.ravel(), None)
|
|
|
|
raise ValueError(f"Unsupported Stats input type: {type(input_value).__name__}")
|