from collections import defaultdict, deque
from datetime import datetime, timedelta
from typing import Any, Iterator, Optional, Tuple, Union
from ..aggs_registry import resolve_aggregator
from .utils import get_field
[docs]
def get_time_window_details(
window: Union[int, float, str], time_field: Optional[str]
) -> Tuple[bool, Optional[int], Optional[float], Optional[datetime], bool]:
"""
Determine the mode (count or time) and parse the window size.
Parameters
----------
window : int, float, or str
Window size as integer/float (row count) or string (e.g. '5m', '1h').
time_field : str or None
Name of the time field, required for time-based windows.
Returns
-------
tuple
(count_mode: bool, count_window: Optional[int], time_window_seconds: Optional[float], origin: Optional[datetime], is_datetime: bool)
"""
if isinstance(window, (int, float)):
return True, int(window), float(window), None, False
elif isinstance(window, str) and window[-1].lower() in {"s", "m", "h", "d"}:
time_window_seconds = parse_window_size(window)
if time_field is None:
raise ValueError("String window requires a time_field")
return False, None, time_window_seconds, None, False
else:
raise ValueError(
f"Invalid window {window!r}: use int for row-count or str ending in s/m/h/d for time."
)
[docs]
def apply_group_rolling_summary(
records: Iterator[dict[str, Any]], step: dict[str, Any]
) -> Iterator[dict[str, Any]]:
"""
Lazily apply a rolling summary within each group.
Two modes:
- Count mode: `window` is int → last N rows
- Time mode: `window` is str ending in s/m/h/d → last T seconds
In time mode, `time_field` must be datetime or timedelta.
Parameters
----------
records : Iterator[dict]
Iterator of records, each a dict.
step : dict
Step configuration dict, must include 'window', 'aggregators', and optionally 'time_field', 'min_periods', 'step', '__group_keys'.
Returns
-------
Iterator[dict]
Iterator of records with rolling summary fields attached.
"""
window = step["window"]
aggregators = step["aggregators"]
time_field = step.get("time_field")
min_periods = step.get("min_periods", 1)
raw_step = step.get("step")
step_size = raw_step if (isinstance(raw_step, int) and raw_step > 0) else 1
group_keys = step.get("__group_keys") or []
count_mode, count_window, time_window_seconds, _, _ = get_time_window_details(
window, time_field
)
def process_one_group(
group_key_tuple: tuple[Any, ...], group_records: list[dict[str, Any]]
) -> list[dict[str, Any]]:
# sort by time_field if time mode, else leave original order
if not count_mode and time_field is not None:
# validate time_field type
sample = get_field(group_records[0], time_field)
if isinstance(sample, datetime):
local_origin = sample
elif isinstance(sample, timedelta):
local_origin = None # For timedelta, we don't need a datetime origin
else:
raise ValueError(
f"Rolling-summary: time_field '{time_field}' is {type(sample).__name__}; "
"for a string window you must supply datetime or timedelta values."
)
group_records = sorted(
group_records, key=lambda r: get_field(r, time_field)
)
# Initialize window_deque regardless of mode
window_deque: deque[dict[str, Any]] = deque()
results = []
for idx, row in enumerate(group_records):
window_deque.append(row)
# drop old items
if count_mode and count_window is not None:
while len(window_deque) > count_window:
window_deque.popleft()
current_window = list(window_deque)
else:
# time‐based eviction
if time_field is not None:
t = get_field(row, time_field)
now_s: float = 0.0
if isinstance(t, datetime) and local_origin is not None:
now_s = (t - local_origin).total_seconds()
elif isinstance(t, timedelta):
now_s = t.total_seconds()
else:
now_s = float(t) if t is not None else 0.0
while window_deque:
oldest = window_deque[0]
old_t = get_field(oldest, time_field)
old_s: float = 0.0
if isinstance(old_t, datetime) and local_origin is not None:
old_s = (old_t - local_origin).total_seconds()
elif isinstance(old_t, timedelta):
old_s = old_t.total_seconds()
else:
old_s = float(old_t) if old_t is not None else 0.0
if (
time_window_seconds is not None
and now_s - old_s > time_window_seconds
):
window_deque.popleft()
else:
break
current_window = list(window_deque)
# emit if enough and on step
if len(current_window) >= min_periods and (idx % step_size == 0):
out = dict(row)
# reattach group keys
for key_name, key_val in zip(group_keys, group_key_tuple):
out[key_name] = key_val
# compute aggregations
for out_field, (fn, in_f) in aggregators.items():
agg_fn = resolve_aggregator((fn, in_f), out_field)
out[out_field] = agg_fn(current_window)
results.append(out)
return results
def runner(records_iter):
for group_dict in records_iter:
key = group_dict["__group_key__"]
recs = group_dict.get("__group_records__", [])
yield from process_one_group(key, recs)
return runner(records)
[docs]
def parse_window_size(window_str: str) -> float:
"""
Parse a window size string like '5m', '10m', '1h', '30s', '1d' to seconds (float).
Parameters
----------
window_str : str
Window size string, must end with 's', 'm', 'h', or 'd'.
Returns
-------
float
Window size in seconds.
Raises
------
ValueError
If the string cannot be parsed or has an unrecognized unit.
"""
if not isinstance(window_str, str):
raise ValueError(f"Expected string for freq, got {type(window_str).__name__}")
unit = window_str[-1].lower()
try:
val = float(window_str[:-1])
except:
raise ValueError(f"Could not parse window size from '{window_str}'")
if unit == "s":
return val
if unit == "m":
return val * 60
if unit == "h":
return val * 3600
if unit == "d":
return val * 86400
raise ValueError(f"Unrecognized unit '{unit}' in window '{window_str}'")
[docs]
def apply_group_time_bucket(
records: Iterator[dict[str, Any]], step: dict[str, Any]
) -> Iterator[dict[str, Any]]:
"""
Assign each record in a group to a fixed, non-overlapping time bin.
Two modes:
- String freq with suffix (e.g. '5m', '1h'): requires datetime or timedelta time_field.
- Numeric freq (int/float): buckets numeric values directly.
Parameters
----------
records : Iterator[dict]
Iterator of group dicts, each with '__group_key__' and '__group_records__'.
step : dict
Step configuration dict, must include 'freq', 'aggregators', 'time_field', and optionally 'label', 'bucket_name', '__group_keys'.
Returns
-------
Iterator[dict]
Iterator of records with bucket assignments and aggregated fields.
"""
freq = step["freq"]
aggregators = step["aggregators"]
time_field = step["time_field"]
label_side = step.get("label", "left")
bucket_name = step.get("bucket_name", "bucket")
group_keys = step.get("__group_keys", [])
numeric_mode, _, bucket_size, _, _ = get_time_window_details(freq, time_field)
def process_one_group(
group_key_tuple: tuple[Any, ...], group_records: list[dict[str, Any]]
) -> list[dict[str, Any]]:
# Extract non-null values
def _get_time(r: dict[str, Any]) -> Any:
if time_field is not None:
return get_field(r, time_field)
return None
rows = [r for r in group_records if _get_time(r) is not None]
if not rows:
return []
# Sample one to inspect type
sample = _get_time(rows[0])
# Initialize variables with proper types
# Time-based mode: must be datetime or timedelta
if not numeric_mode:
local_origin: Optional[Union[datetime, timedelta]] = None
if isinstance(sample, datetime):
local_origin = sample
elif isinstance(sample, timedelta):
local_origin = timedelta(0)
else:
raise ValueError(
f"time_bucket: field '{time_field}' has type {type(sample).__name__}; "
"when freq has a time suffix you must provide datetime or timedelta values."
)
# sort by timestamp/timedelta
rows.sort(key=_get_time)
else:
# numeric mode: we treat values as floats, no origin needed
# Sort by numeric time field
rows.sort(key=_get_time)
# Partition into buckets
buckets: dict[int, list[dict]] = {}
labels: dict[int, Union[datetime, timedelta, float]] = {}
for r in rows:
t = _get_time(r)
total: float = 0.0
if (
not numeric_mode
and isinstance(t, datetime)
and isinstance(local_origin, datetime)
):
total = (t - local_origin).total_seconds()
elif (
not numeric_mode
and isinstance(t, timedelta)
and local_origin is not None
):
total = t.total_seconds()
else:
total = float(t) if t is not None else 0.0
if bucket_size is not None:
idx = int(total // bucket_size)
buckets.setdefault(idx, []).append(r)
if idx not in labels:
edge = (idx + (1 if label_side == "right" else 0)) * bucket_size
if not numeric_mode:
# datetime label
if isinstance(local_origin, datetime):
labels[idx] = local_origin + timedelta(seconds=edge)
else:
labels[idx] = timedelta(seconds=edge)
else:
labels[idx] = float(edge) # Ensure numeric labels are float
# Build output
out = []
for idx, group in buckets.items():
row_out = {k: v for k, v in zip(group_keys, group_key_tuple)}
row_out[bucket_name] = labels[idx]
for out_field, (fn, in_f) in aggregators.items():
agg = resolve_aggregator((fn, in_f), out_field)
row_out[out_field] = agg(group)
out.append(row_out)
return out
def runner(all_groups: Iterator[dict[str, Any]]) -> Iterator[dict[str, Any]]:
for g in all_groups:
key = g["__group_key__"]
recs = g.get("__group_records__", [])
result = process_one_group(key, recs)
yield from result
return runner(records)
[docs]
def apply_group_by(
records: Iterator[dict[str, Any]], step: dict[str, Any]
) -> Iterator[dict[str, Any]]:
"""
Group records by one or more fields.
Parameters
----------
records : Iterator[dict]
Iterator of records to group.
step : dict
Step configuration dict, must include 'keys'.
Returns
-------
Iterator[dict]
Iterator of group dicts, each with '__group_key__' and '__group_records__'.
"""
keys = step["keys"]
compiled = step.get("_compiled_keys")
if not compiled:
compiled = [k.split(".") for k in keys]
step["_compiled_keys"] = compiled
grouped = defaultdict(list)
for record in records:
key = tuple(get_field(record, k) for k in compiled)
grouped[key].append(record)
for key, group_records in grouped.items():
yield {"__group_key__": key, "__group_records__": group_records}
[docs]
def apply_group_summary(
records: Iterator[dict[str, Any]], step: dict[str, Any]
) -> Iterator[dict[str, Any]]:
"""
Apply a summary function to each group of records.
Parameters
----------
records : Iterator[dict]
Iterator of group dicts, each with '__group_key__' and '__group_records__'.
step : dict
Step configuration dict, must include 'agg' and optionally 'group_keys'.
Returns
-------
Iterator[dict]
Iterator of summary dicts for each group.
"""
agg_func = step["agg"]
group_keys = step.get("group_keys") # get actual group key names if available
for group in records:
key = group["__group_key__"]
rows = group["__group_records__"]
result = agg_func(rows)
if not isinstance(result, dict):
raise ValueError("group_summary function must return a dict")
if group_keys:
output = {k: v for k, v in zip(group_keys, key)}
else:
output = {f"group_{i}": v for i, v in enumerate(key)}
output.update(result)
yield output
[docs]
def apply_group_cumulative(
records: Iterator[dict[str, Any]], step: dict[str, Any]
) -> Iterator[dict[str, Any]]:
"""
Apply a cumulative sum to a field for each group of records.
Parameters
----------
records : Iterator[dict]
Iterator of group dicts, each with '__group_key__' and '__group_records__'.
step : dict
Step configuration dict, must include 'field' and 'alias'.
Returns
-------
Iterator[dict]
Iterator of records with cumulative field attached.
"""
field = step["field"]
alias = step["alias"]
for group in records:
key = group["__group_key__"]
rows = group["__group_records__"]
total = 0
for r in rows:
total += r.get(field, 0)
new_r = dict(r)
new_r[alias] = total
yield new_r