import copy
[docs]
class FlowOptimizer:
"""
Optimizer for a flow plan.
Performs conservative optimizations: it fuses simple operations,
pushes down filters, limits, and select/drop operations only when
provably safe, and eliminates redundant steps.
"""
MAX_PASSES = 5
def __init__(self, plan):
self.plan = plan
FIELD_USAGE_HANDLERS = {
"select": lambda step: set(step.get("fields", [])),
"drop": lambda step: set(step.get("keys", [])),
"dropna": lambda step: set(step.get("fields") or []),
"rename": lambda step: set(step.get("mapping", {}).keys())
| set(step.get("mapping", {}).values()),
"assign": lambda step: set(step.get("fields", {}).keys()),
"cast": lambda step: set(step.get("casts", {}).keys()),
"filter": lambda step: set(),
"join": lambda step: set(step.get("on", [])),
"sort": lambda step: set(step.get("keys", [])),
"group_by": lambda step: set(step.get("keys", [])),
"group_rolling_summary": lambda step: (
({step.get("time_field")} if step.get("time_field") else set())
| {
agg[1]
for agg in step.get("aggregators", {}).values()
if isinstance(agg, tuple)
}
),
}
def _is_order_sensitive(self, op: str) -> bool:
return op in {
"sort",
"group_summary",
"group_cumulative",
"group_rolling_summary",
"pivot",
}
def _blocks_filter_pushdown(self, op: str) -> bool:
return self._is_order_sensitive(op) or op in {
"select",
"drop",
"dropna",
"rename",
"flatten",
"map",
"assign",
"pipe",
"join",
"group_by",
"summary",
"limit",
"explode",
}
[docs]
def optimize(self):
plan = copy.deepcopy(self.plan)
for _ in range(self.MAX_PASSES):
new_plan = self._optimize_once(plan)
if new_plan == plan:
break
plan = new_plan
plan = self._validate_rolling_has_sort(plan)
return plan
def _optimize_once(self, plan):
plan = self._fuse_map_assign_filter(plan)
plan = self._pushdown_filters(plan)
plan = self._pushdown_limit(plan)
plan = self._pushdown_select_drop(plan)
plan = self._eliminate_redundant_steps(plan)
return plan
def _get_fields_used(self, step):
return self.FIELD_USAGE_HANDLERS.get(step.get("op"), lambda s: set())(step)
def _compute_required_fields(self, plan):
required = set()
required_by_step = []
for step in reversed(plan):
required_by_step.append(required.copy())
required |= self._get_fields_used(step)
return list(reversed(required_by_step))
def _is_already_early_enough(self, plan, index):
return index > 0 and plan[index - 1].get("op", "").startswith("from_")
def _pushdown_select_drop(self, plan):
required_fields_list = self._compute_required_fields(plan)
new_plan = []
pending_push = []
for i, step in enumerate(plan):
op = step.get("op")
if op in {"select", "drop"}:
if op == "select":
fields = set(step.get("fields", []))
cond = required_fields_list[i].issubset(fields)
else:
fields = set(step.get("keys", []))
cond = required_fields_list[i].isdisjoint(fields)
if cond:
if self._is_already_early_enough(plan, i):
new_plan.append(step)
else:
moved = dict(step)
moved["_original_index"] = i
pending_push.append(moved)
continue
if self._is_order_sensitive(op):
new_plan.extend(pending_push)
pending_push = []
new_plan.append(step)
if op and op.startswith("from_") and pending_push:
new_plan.extend(pending_push)
pending_push = []
return self._annotate_moves(new_plan)
def _annotate_moves(self, plan):
result = []
for idx, step in enumerate(plan):
if "_original_index" in step:
orig = step.pop("_original_index")
if idx < orig:
note = (
"moved earlier in plan"
if orig - idx > 1
else "reordered (same logical position)"
)
step.setdefault("_notes", []).append(note)
result.append(step)
return result
def _eliminate_redundant_steps(self, plan):
new_plan = []
prev_op = None
for step in plan:
op = step.get("op")
if op in {"drop", "dropna"} and op == prev_op:
continue
new_plan.append(dict(step))
prev_op = op
return new_plan
def _pushdown_filters(self, plan):
new_plan = []
pending = []
pending_orig_idx = []
for idx, step in enumerate(plan):
op = step.get("op")
if op == "filter":
pending.append(step.copy())
pending_orig_idx.append(idx)
continue
if self._blocks_filter_pushdown(op) and pending:
for filt, orig in zip(pending, pending_orig_idx):
tagged = filt.copy()
if len(new_plan) < orig:
tagged.setdefault("_notes", []).append(
"pushed down from later step"
)
new_plan.append(tagged)
pending.clear()
pending_orig_idx.clear()
if op and op.startswith("from_") and pending:
for filt, orig in zip(pending, pending_orig_idx):
tagged = filt.copy()
if len(new_plan) + 1 < orig:
tagged.setdefault("_notes", []).append(
"pushed down from later step"
)
new_plan.append(tagged)
pending.clear()
pending_orig_idx.clear()
new_plan.append(step.copy())
for filt in pending:
new_plan.append(filt)
return new_plan
def _pushdown_limit(self, plan):
limit_step = None
new_plan = []
moved = False
for step in reversed(plan):
if step.get("op") == "limit":
limit_step = dict(step)
elif limit_step and step.get("op") in {
"assign",
"select",
"drop",
"rename",
}:
moved = True
new_plan.insert(0, dict(step))
else:
if limit_step:
if moved:
limit_step.setdefault("_notes", []).append(
"pushed down from later step"
)
new_plan.insert(0, limit_step)
limit_step, moved = None, False
new_plan.insert(0, dict(step))
if limit_step:
if moved:
limit_step.setdefault("_notes", []).append(
"pushed down to earliest safe point"
)
new_plan.insert(0, limit_step)
return new_plan
def _fuse_map_assign_filter(self, plan):
new_plan = []
i = 0
fusables = {"map", "assign", "filter"}
while i < len(plan):
if plan[i].get("op") in fusables:
j = i
group = []
while j < len(plan) and plan[j].get("op") in fusables:
group.append(plan[j])
j += 1
if len(group) > 1:
fused = {
"op": "fused",
"ops": [s["op"] for s in group],
"steps": [dict(s) for s in group],
"_notes": [f"fused: {', '.join(s['op'] for s in group)}"],
}
new_plan.append(fused)
else:
new_plan.append(dict(group[0]))
i = j
else:
new_plan.append(dict(plan[i]))
i += 1
return new_plan
def _validate_rolling_has_sort(self, plan):
validated_plan = []
last_group_by_idx = -1
for idx, step in enumerate(plan):
op = step.get("op")
if op == "group_by":
last_group_by_idx = idx
if op == "group_rolling_summary":
sorted_before = any(
p.get("op") == "sort" for p in plan[last_group_by_idx + 1 : idx]
)
if not sorted_before:
step = dict(step)
step.setdefault("_notes", []).append(
"⚠️ group_rolling_summary used without prior sort — results may be unstable"
)
validated_plan.append(step)
return validated_plan