Files
tono/backend/nodes/helpers.py
2026-03-27 13:27:39 -07:00

888 lines
27 KiB
Python

"""
Shared helper functions for argonode nodes.
"""
from __future__ import annotations
import json
from functools import lru_cache
from pathlib import Path
from typing import Callable
import numpy as np
from backend.runtime_paths import demo_dir, input_dir, output_dir
# ---------------------------------------------------------------------------
# Scalar payload helpers (from display.py)
# ---------------------------------------------------------------------------
def _scalar_payload(value: float, unit: str = "") -> dict:
payload = {"value": float(value)}
if isinstance(unit, str) and unit.strip():
payload["unit"] = unit.strip()
return payload
# ---------------------------------------------------------------------------
# Measurement helpers (from display.py — used by ValueDisplay)
# ---------------------------------------------------------------------------
def _measurement_names(table: list) -> list[str]:
names = []
for row in table:
if not isinstance(row, dict):
continue
quantity = row.get("quantity")
if isinstance(quantity, str) and quantity and quantity not in names:
names.append(quantity)
return names
def _measurement_entry(table: list, selection: str) -> dict:
names = _measurement_names(table)
if not names:
raise ValueError("Measurement table has no selectable rows.")
target = selection if selection in names else names[0]
for row in table:
if isinstance(row, dict) and row.get("quantity") == target:
return row
raise ValueError(f"Measurement '{target}' was not found.")
def _measurement_value(table: list, selection: str) -> float:
row = _measurement_entry(table, selection)
value = row.get("value")
if isinstance(value, bool):
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
try:
numeric = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.") from exc
if np.isfinite(numeric):
return numeric
raise ValueError(f"Measurement '{row.get('quantity', selection)}' does not have a numeric value.")
# ---------------------------------------------------------------------------
# SI formatting helpers (from display.py — used by Annotations)
# ---------------------------------------------------------------------------
_SI_PREFIXES = [
(1e24, "Y"), (1e21, "Z"), (1e18, "E"), (1e15, "P"), (1e12, "T"),
(1e9, "G"), (1e6, "M"), (1e3, "k"), (1.0, ""), (1e-3, "m"),
(1e-6, "u"), (1e-9, "n"), (1e-12, "p"), (1e-15, "f"),
(1e-18, "a"), (1e-21, "z"), (1e-24, "y"),
]
_PREFIXABLE_UNITS = {"m", "s", "A", "V", "W", "Hz", "F", "C", "J", "N", "Pa", "T", "H", "S", "g", "K", "Ohm", "ohm", "\u03a9"}
def _format_numeric(value: float) -> str:
if not np.isfinite(value):
return str(value)
abs_value = abs(value)
if abs_value == 0:
return "0"
if abs_value >= 1e4 or abs_value < 1e-3:
return f"{value:.3e}"
return f"{value:.4g}"
def _format_with_unit(value: float, unit: str) -> str:
unit = (unit or "").strip()
if not unit:
return _format_numeric(value)
if unit in _PREFIXABLE_UNITS and np.isfinite(value) and value != 0:
abs_value = abs(value)
for scale, prefix in _SI_PREFIXES:
scaled = abs_value / scale
if 1 <= scaled < 1000:
signed = value / scale
return f"{_format_numeric(signed)} {prefix}{unit}"
return f"{_format_numeric(value)} {unit}"
def _nice_length(target: float) -> float:
if not np.isfinite(target) or target <= 0:
return 0.0
exponent = np.floor(np.log10(target))
base = 10.0 ** exponent
for step in (5.0, 2.0, 1.0):
candidate = step * base
if candidate <= target:
return candidate
return base
def _display_value_range(field) -> tuple[float, float, float]:
data = np.asarray(field.data, dtype=np.float64)
dmin = float(data.min())
dmax = float(data.max())
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmax <= dmin:
return dmin, dmin, dmin
offset = float(field.display_offset)
scale = float(field.display_scale)
if not np.isfinite(offset):
offset = 0.0
if not np.isfinite(scale) or scale <= 0.0:
scale = 1.0
low_norm = float(np.clip(offset, 0.0, 1.0))
high_norm = float(np.clip(offset + scale, 0.0, 1.0))
if high_norm < low_norm:
low_norm, high_norm = high_norm, low_norm
mid_norm = 0.5 * (low_norm + high_norm)
span = dmax - dmin
return (
dmin + low_norm * span,
dmin + mid_norm * span,
dmin + high_norm * span,
)
def _render_annotation_text(text: str, size_px: int, color: tuple[int, int, int]):
from PIL import Image, ImageDraw, ImageFont
size_px = max(8, int(round(size_px)))
try:
font = ImageFont.truetype("DejaVuSans.ttf", size_px)
probe = Image.new("RGBA", (1, 1), (0, 0, 0, 0))
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
text_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
text_draw = ImageDraw.Draw(text_image)
text_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=(*color, 255))
return text_image
except Exception:
font = ImageFont.load_default()
probe = Image.new("L", (1, 1), 0)
probe_draw = ImageDraw.Draw(probe)
bbox = probe_draw.textbbox((0, 0), text, font=font)
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
mask = Image.new("L", (width, height), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.text((-bbox[0], -bbox[1]), text, font=font, fill=255)
scale = max(1.0, size_px / max(1, height))
scaled_width = max(1, int(round(width * scale)))
scaled_height = max(1, int(round(height * scale)))
resampling = getattr(Image, "Resampling", Image)
scaled_mask = mask.resize((scaled_width, scaled_height), resample=resampling.BILINEAR)
text_image = Image.new("RGBA", (scaled_width, scaled_height), (*color, 0))
text_image.putalpha(scaled_mask)
return text_image
def _import_ibw_loader():
"""Import igor's binary wave loader with NumPy 2 compatibility."""
if not hasattr(np, "complex"):
# igor 0.3 still references np.complex at import time.
setattr(np, "complex", complex)
try:
from igor.binarywave import load as load_ibw
except ImportError:
raise ImportError("Install 'igor' package to load .ibw files: pip install igor")
return load_ibw
# ---------------------------------------------------------------------------
# Markup helpers (from display.py — used by Markup)
# ---------------------------------------------------------------------------
def _normalize_markup_color(color: object, default: str = "#ffd54f") -> str:
if isinstance(color, str):
text = color.strip()
if len(text) == 4 and text.startswith("#"):
text = "#" + "".join(ch * 2 for ch in text[1:])
if len(text) == 7 and text.startswith("#"):
try:
int(text[1:], 16)
return text.lower()
except ValueError:
pass
return default
def _parse_markup_shapes(raw_shapes) -> list[dict]:
if isinstance(raw_shapes, str):
try:
raw_shapes = json.loads(raw_shapes or "[]")
except json.JSONDecodeError:
raw_shapes = []
if not isinstance(raw_shapes, list):
return []
parsed = []
for shape in raw_shapes:
if not isinstance(shape, dict):
continue
kind = str(shape.get("kind", "")).strip().lower()
if kind not in {"line", "rectangle", "circle", "arrow"}:
continue
try:
x1 = float(shape.get("x1"))
y1 = float(shape.get("y1"))
x2 = float(shape.get("x2"))
y2 = float(shape.get("y2"))
width = int(round(float(shape.get("width", 3))))
except (TypeError, ValueError):
continue
coords = [x1, y1, x2, y2]
if not all(np.isfinite(value) for value in coords):
continue
parsed.append({
"kind": kind,
"x1": float(np.clip(x1, 0.0, 1.0)),
"y1": float(np.clip(y1, 0.0, 1.0)),
"x2": float(np.clip(x2, 0.0, 1.0)),
"y2": float(np.clip(y2, 0.0, 1.0)),
"width": max(1, min(128, width)),
"color": _normalize_markup_color(shape.get("color")),
})
return parsed
def _draw_arrow(draw, start, end, color, width):
dx = end[0] - start[0]
dy = end[1] - start[1]
length = float(np.hypot(dx, dy))
if length <= 1e-6:
radius = max(1.0, width / 2.0)
draw.ellipse(
(start[0] - radius, start[1] - radius, start[0] + radius, start[1] + radius),
fill=color,
)
return
ux = dx / length
uy = dy / length
head_length = max(10.0, width * 4.0)
head_width = max(8.0, width * 3.0)
shaft_end = (
end[0] - ux * head_length,
end[1] - uy * head_length,
)
draw.line((start, shaft_end), fill=color, width=width)
px = -uy
py = ux
left = (
shaft_end[0] + px * head_width / 2.0,
shaft_end[1] + py * head_width / 2.0,
)
right = (
shaft_end[0] - px * head_width / 2.0,
shaft_end[1] - py * head_width / 2.0,
)
draw.polygon([end, left, right], fill=color)
def _render_markup_image(image, shapes):
from PIL import Image as PILImage, ImageDraw
from backend.data_types import image_to_uint8
base = image_to_uint8(image)
if base.ndim == 2:
base = np.repeat(base[:, :, np.newaxis], 3, axis=2)
canvas = PILImage.fromarray(base.copy())
draw = ImageDraw.Draw(canvas)
height, width = base.shape[:2]
for shape in shapes:
x1 = float(shape["x1"]) * width
y1 = float(shape["y1"]) * height
x2 = float(shape["x2"]) * width
y2 = float(shape["y2"]) * height
color = str(shape["color"])
stroke_width = int(shape["width"])
kind = str(shape["kind"])
if kind == "line":
draw.line(((x1, y1), (x2, y2)), fill=color, width=stroke_width)
elif kind == "rectangle":
draw.rectangle((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "circle":
draw.ellipse((x1, y1, x2, y2), outline=color, width=stroke_width)
elif kind == "arrow":
_draw_arrow(draw, (x1, y1), (x2, y2), color, stroke_width)
return np.asarray(canvas, dtype=np.uint8)
# ---------------------------------------------------------------------------
# Mask helpers (from mask.py — used by multiple mask nodes)
# ---------------------------------------------------------------------------
def _mask_overlay(field, mask):
from backend.data_types import datafield_to_uint8
grey = datafield_to_uint8(field, "gray")
mask_bool = mask > 127
if not np.any(mask_bool):
return grey
overlay = grey.copy()
red = overlay[..., 0]
green = overlay[..., 1]
blue = overlay[..., 2]
red_vals = red[mask_bool].astype(np.uint16)
green_vals = green[mask_bool].astype(np.uint16)
blue_vals = blue[mask_bool].astype(np.uint16)
red[mask_bool] = ((red_vals * 55) + (255 * 45) + 50) // 100
green[mask_bool] = ((green_vals * 55) + 50) // 100
blue[mask_bool] = ((blue_vals * 55) + 50) // 100
return overlay
@lru_cache(maxsize=128)
def _mask_structure(radius: int, shape: str):
radius = max(1, int(radius))
if shape == "disk":
y, x = np.ogrid[-radius:radius + 1, -radius:radius + 1]
struct = (x * x + y * y) <= radius * radius
else:
size = 2 * radius + 1
struct = np.ones((size, size), dtype=bool)
struct.setflags(write=False)
return struct
def _clamp_fraction(value) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, numeric))
def _parse_mask_strokes(mask_paths) -> list[dict]:
if isinstance(mask_paths, list):
raw_strokes = mask_paths
elif isinstance(mask_paths, str) and mask_paths.strip():
try:
parsed = json.loads(mask_paths)
except json.JSONDecodeError:
return []
raw_strokes = parsed if isinstance(parsed, list) else []
else:
return []
strokes = []
for stroke in raw_strokes:
if not isinstance(stroke, dict):
continue
raw_points = stroke.get("points")
if not isinstance(raw_points, list):
continue
points = []
for point in raw_points:
if not isinstance(point, dict):
continue
if "x" not in point or "y" not in point:
continue
points.append({
"x": _clamp_fraction(point.get("x")),
"y": _clamp_fraction(point.get("y")),
})
if not points:
continue
try:
size = max(1, int(round(float(stroke.get("size", 1)))))
except (TypeError, ValueError):
size = 1
strokes.append({
"size": size,
"points": points,
})
return strokes
def _rasterize_mask(width, height, strokes, default_pen_size):
from PIL import Image as PILImage, ImageDraw
width = max(1, int(width))
height = max(1, int(height))
default_pen_size = max(1, int(default_pen_size))
mask_image = PILImage.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
for stroke in strokes:
points = stroke.get("points") or []
if not points:
continue
size = stroke.get("size", default_pen_size)
try:
size = max(1, int(round(float(size))))
except (TypeError, ValueError):
size = default_pen_size
pixel_points = []
for point in points:
px = int(round(_clamp_fraction(point.get("x")) * (width - 1)))
py = int(round(_clamp_fraction(point.get("y")) * (height - 1)))
pixel_points.append((px, py))
radius = max(0.5, size / 2.0)
if len(pixel_points) == 1:
x, y = pixel_points[0]
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
continue
draw.line(pixel_points, fill=255, width=size)
for x, y in pixel_points:
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=255)
return np.asarray(mask_image, dtype=np.uint8)
# ---------------------------------------------------------------------------
# Path / directory helpers (from io.py)
# ---------------------------------------------------------------------------
DEMO_DIR = demo_dir()
INPUT_DIR = input_dir()
OUTPUT_DIR = output_dir()
_MAX_SAVE_FIELDS = 8
_DEMO_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".npy", ".npz",
".gwy", ".sxm", ".ibw"}
_SPM_EXTENSIONS = {".gwy", ".sxm", ".ibw"}
_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".tiff", ".tif", ".bmp"}
_ARRAY_EXTENSIONS = {".npy", ".npz"}
_PATH_COMPATIBLE_EXTENSIONS = _IMAGE_EXTENSIONS | _ARRAY_EXTENSIONS | _SPM_EXTENSIONS
def _resolve_path(filepath: str):
path = Path(filepath)
if path.is_absolute():
return path
candidate = INPUT_DIR / filepath
if candidate.exists():
return candidate
candidate = DEMO_DIR / filepath
if candidate.exists():
return candidate
return INPUT_DIR / filepath
def list_channels(filepath: str) -> list[dict]:
path = _resolve_path(filepath)
if not path.exists():
return [{"name": "field", "type": "DATA_FIELD"}]
ext = path.suffix.lower()
if ext == ".gwy":
try:
import gwyfile
obj = gwyfile.load(str(path))
channels = gwyfile.util.get_datafields(obj)
if channels:
return [{"name": k, "type": "DATA_FIELD"} for k in channels]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".sxm":
try:
import nanonispy as nap
sxm = nap.read.Scan(str(path))
if sxm.signals:
return [{"name": k, "type": "DATA_FIELD"} for k in sxm.signals]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
if ext == ".ibw":
try:
load_ibw = _import_ibw_loader()
wave = load_ibw(str(path))
raw = wave["wave"]["wData"]
labels = wave["wave"].get("labels", None)
if raw.ndim >= 3 and labels:
dim_idx = min(2, len(labels) - 1)
if dim_idx >= 0 and labels[dim_idx]:
decoded = []
for lbl in labels[dim_idx]:
if lbl:
name = lbl.split(b"\x00")[0].decode("ascii", errors="replace").strip()
if name:
decoded.append(name)
if decoded:
return [{"name": n, "type": "DATA_FIELD"} for n in decoded]
if raw.ndim >= 3 and raw.shape[2] > 1:
return [{"name": f"ch{i}", "type": "DATA_FIELD"} for i in range(raw.shape[2])]
except Exception:
pass
return [{"name": "field", "type": "DATA_FIELD"}]
return [{"name": "field", "type": "DATA_FIELD"}]
def list_folder_paths(folderpath: str) -> list[dict]:
path = _resolve_path(folderpath)
if not path.exists() or not path.is_dir():
return []
resolved_dir = str(path.resolve())
results = [{"name": "directory", "type": "DIRECTORY", "path": resolved_dir}]
for entry in sorted(path.iterdir(), key=lambda p: p.name.lower()):
if not entry.is_file() or entry.name.startswith("."):
continue
if entry.suffix.lower() not in _PATH_COMPATIBLE_EXTENSIONS:
continue
results.append({"name": entry.name, "type": "FILE_PATH", "path": str(entry.resolve())})
return results
def _list_demo_files() -> list[str]:
if not DEMO_DIR.exists():
return []
return sorted(
f.name for f in DEMO_DIR.iterdir()
if f.is_file() and not f.name.startswith(".") and f.suffix.lower() in _DEMO_EXTENSIONS
)
# ---------------------------------------------------------------------------
# Butterworth / FFT helpers (from filters.py — used by FFTFilter1D, FFTFilter2D)
# ---------------------------------------------------------------------------
def _butterworth_lp(freq, cutoff, order):
with np.errstate(divide="ignore", over="ignore"):
return 1.0 / (1.0 + (freq / cutoff) ** (2 * order))
def _butterworth_hp(freq, cutoff, order):
with np.errstate(divide="ignore", invalid="ignore"):
h = 1.0 / (1.0 + (cutoff / freq) ** (2 * order))
h = np.where(np.isfinite(h), h, 0.0)
return h
def _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order):
freq = np.linspace(0, 1, n // 2 + 1)
if filter_type == "lowpass":
H = _butterworth_lp(freq, cutoff, order)
elif filter_type == "highpass":
H = _butterworth_hp(freq, cutoff, order)
elif filter_type == "bandpass":
H = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
elif filter_type == "notch":
bp = _butterworth_hp(freq, cutoff, order) * _butterworth_lp(freq, cutoff_high, order)
H = 1.0 - bp
else:
H = np.ones_like(freq)
return H
@lru_cache(maxsize=64)
def _cached_1d_transfer(n, filter_type, cutoff, cutoff_high, order):
transfer = _build_1d_transfer(n, filter_type, cutoff, cutoff_high, order)
transfer.setflags(write=False)
return transfer
@lru_cache(maxsize=32)
def _fft_radius_grid(yres, xres):
fy = np.fft.fftfreq(yres)[:, np.newaxis] * 2.0
fx = np.fft.rfftfreq(xres)[np.newaxis, :] * 2.0
radius = np.sqrt(fx * fx + fy * fy) / np.sqrt(2.0)
np.clip(radius, 0.0, 1.0, out=radius)
radius.setflags(write=False)
return radius
@lru_cache(maxsize=128)
def _cached_2d_transfer(yres, xres, filter_type, cutoff, cutoff_high, order):
radius = _fft_radius_grid(yres, xres)
if filter_type == "lowpass":
transfer = _butterworth_lp(radius, cutoff, order)
elif filter_type == "highpass":
transfer = _butterworth_hp(radius, cutoff, order)
elif filter_type == "bandpass":
transfer = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
elif filter_type == "notch":
band = _butterworth_hp(radius, cutoff, order) * _butterworth_lp(radius, cutoff_high, order)
transfer = 1.0 - band
else:
transfer = np.ones_like(radius)
transfer.setflags(write=False)
return transfer
# ---------------------------------------------------------------------------
# Cross-section and stats helpers (from analysis.py)
# ---------------------------------------------------------------------------
def _extend_to_edges(x1, y1, x2, y2):
dx = x2 - x1
dy = y2 - y1
t_candidates = []
if abs(dx) > 1e-12:
for bx in (0.0, 1.0):
t = (bx - x1) / dx
y_at_t = y1 + t * dy
if -1e-9 <= y_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if abs(dy) > 1e-12:
for by in (0.0, 1.0):
t = (by - y1) / dy
x_at_t = x1 + t * dx
if -1e-9 <= x_at_t <= 1.0 + 1e-9:
t_candidates.append(t)
if len(t_candidates) < 2:
return x1, y1, x2, y2
t_min = min(t_candidates)
t_max = max(t_candidates)
return (
np.clip(x1 + t_min * dx, 0, 1),
np.clip(y1 + t_min * dy, 0, 1),
np.clip(x1 + t_max * dx, 0, 1),
np.clip(y1 + t_max * dy, 0, 1),
)
def _safe_rq(d):
return float(np.sqrt(np.mean(d * d)))
LINE_OPS: dict[str, tuple] = {}
def _line_op(name, unit=""):
def decorator(fn):
LINE_OPS[name] = (fn, unit)
return fn
return decorator
@_line_op("min")
def _op_min(z):
return float(z.min())
@_line_op("max")
def _op_max(z):
return float(z.max())
@_line_op("mean")
def _op_mean(z):
return float(z.mean())
@_line_op("median")
def _op_median(z):
return float(np.median(z))
@_line_op("sum")
def _op_sum(z):
return float(z.sum())
@_line_op("range")
def _op_range(z):
return float(z.max() - z.min())
@_line_op("length", unit="pts")
def _op_length(z):
return float(len(z))
@_line_op("rms")
def _op_rms(z):
return float(np.sqrt(np.mean(z * z)))
@_line_op("Ra")
def _op_ra(z):
return float(np.mean(np.abs(z - z.mean())))
@_line_op("Rq")
def _op_rq(z):
d = z - z.mean()
return _safe_rq(d)
@_line_op("Rsk")
def _op_rsk(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**3) / rq**3) if rq > 0 else 0.0
@_line_op("Rku")
def _op_rku(z):
d = z - z.mean()
rq = _safe_rq(d)
return float(np.mean(d**4) / rq**4) if rq > 0 else 0.0
@_line_op("Rp")
def _op_rp(z):
return float((z - z.mean()).max())
@_line_op("Rv")
def _op_rv(z):
return float(-(z - z.mean()).min())
@_line_op("Rt")
def _op_rt(z):
d = z - z.mean()
return float(d.max() - d.min())
@_line_op("Dq")
def _op_dq(z):
dz = np.diff(z)
return float(np.sqrt(np.mean(dz * dz)))
@_line_op("Da")
def _op_da(z):
return float(np.mean(np.abs(np.diff(z))))
TABLE_OPS: dict[str, Callable[[np.ndarray], float]] = {
"min": lambda values: float(np.min(values)),
"max": lambda values: float(np.max(values)),
"avg": lambda values: float(np.mean(values)),
"mean": lambda values: float(np.mean(values)),
"median": lambda values: float(np.median(values)),
"sum": lambda values: float(np.sum(values)),
"range": lambda values: float(np.max(values) - np.min(values)),
"std": lambda values: float(np.std(values)),
"variance": lambda values: float(np.var(values)),
"count": lambda values: float(len(values)),
}
ARRAY_OPS: dict[str, Callable[[np.ndarray], float]] = {
"min": lambda values: float(np.min(values)),
"max": lambda values: float(np.max(values)),
"avg": lambda values: float(np.mean(values)),
"mean": lambda values: float(np.mean(values)),
"median": lambda values: float(np.median(values)),
"sum": lambda values: float(np.sum(values)),
"range": lambda values: float(np.max(values) - np.min(values)),
"std": lambda values: float(np.std(values)),
"variance": lambda values: float(np.var(values)),
"rms": lambda values: float(np.sqrt(np.mean(values * values))),
"count": lambda values: float(values.size),
}
def _square_unit(unit: str) -> str:
unit = str(unit or "").strip()
if not unit:
return ""
if any(token in unit for token in ("^", "(", ")", "/", "*", " ")):
return f"({unit})^2"
return f"{unit}^2"
def _apply_scalar_unit(base_unit: str, operation: str) -> str:
unit = str(base_unit or "").strip()
if operation == "count":
return "count"
if not unit:
return ""
if operation == "variance":
return _square_unit(unit)
return unit
def _common_table_unit(table: list, column: str) -> str:
candidates = []
seen = set()
unit_key = f"{column}_unit"
for row in table:
if not isinstance(row, dict):
continue
unit = None
if unit_key in row and isinstance(row.get(unit_key), str):
unit = row.get(unit_key)
elif column == "value" and isinstance(row.get("unit"), str):
unit = row.get("unit")
if unit is None:
continue
unit = unit.strip()
if not unit or unit in seen:
continue
seen.add(unit)
candidates.append(unit)
if len(candidates) == 1:
return candidates[0]
return ""
def extract_numeric_table_values(table: list, column: str) -> list[float]:
values = []
for row in table:
if not isinstance(row, dict) or column not in row:
continue
value = row[column]
if isinstance(value, bool):
continue
try:
numeric = float(value)
except (TypeError, ValueError):
continue
if np.isfinite(numeric):
values.append(numeric)
return values
def resolve_table_column_name(table: list, column: str) -> str:
requested = str(column or "").strip()
if requested:
return requested
if extract_numeric_table_values(table, "value"):
return "value"
numeric_columns = []
seen = set()
for row in table:
if not isinstance(row, dict):
continue
for key in row.keys():
if key in seen:
continue
seen.add(key)
if extract_numeric_table_values(table, key):
numeric_columns.append(key)
if len(numeric_columns) == 1:
return numeric_columns[0]
if not numeric_columns:
raise ValueError("Stats could not find any numeric columns in the input table.")
raise ValueError(
"Stats found multiple numeric columns; set the column name explicitly."
)