Source code for penaltyblog.matchflow.steps.transform

import copy
import itertools
import random
from collections import OrderedDict, defaultdict
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union

from .utils import (
    fast_get_field,
    flatten_dict,
    get_field,
    reservoir_sample,
    set_nested_field,
)

# Type aliases for records and streams
Record = Dict[str, Any]
RecordStream = Iterator[Record]

if TYPE_CHECKING:
    from ..flow import Flow


def _coerce_join_key(value, strategy="strict"):
    """
    Coerce a join key value for consistent comparison across data types.

    Args:
        value: The join key value to coerce
        strategy: Coercion strategy - 'strict', 'auto', or 'string'

    Returns:
        Coerced value for comparison
    """
    if value is None:
        return None

    if strategy == "strict":
        return value
    elif strategy == "string":
        return str(value)
    elif strategy == "auto":
        # Smart coercion - numeric values get consistent string representation
        if isinstance(value, (int, float)):
            # Convert to string for consistent comparison
            if isinstance(value, float) and value.is_integer():
                # 1.0 -> "1" to match with int 1
                return str(int(value))
            else:
                return str(value)
        elif isinstance(value, str):
            # Try to convert numeric strings for consistency
            try:
                # Check if it's a valid integer
                if value.lstrip("-").isdigit():
                    return str(int(value))
                # Check if it's a valid float
                float_val = float(value)
                if float_val.is_integer():
                    return str(int(float_val))
                else:
                    return str(float_val)
            except ValueError:
                # Not a number, return as-is
                return value
        else:
            # Other types converted to string
            return str(value)
    else:
        raise ValueError(f"Unknown coercion strategy: {strategy}")


[docs] def apply_filter(records: RecordStream, step: dict) -> RecordStream: """ Filter records based on a predicate. Args: records: A stream of records to filter. step: Configuration dict containing the predicate to apply. Returns: A stream of filtered records. """ pred = step["predicate"] for r in records: if pred(r): yield r
[docs] def apply_assign(records: RecordStream, step: dict) -> RecordStream: """ Assign new fields to each record. Args: records: A stream of records to assign fields to. step: Configuration dict containing the fields to assign. Returns: A stream of records with assigned fields. """ fields = step["fields"] for r in records: new = dict(r) for k, func in fields.items(): new[k] = func(r) yield new
[docs] def apply_select(records: RecordStream, step: dict) -> RecordStream: """ Select specific fields from each record. Args: records: A stream of records to select fields from. step: Configuration dict containing the fields to select. Returns: A stream of records with selected fields. """ field_names = step["fields"] if all("." not in f for f in field_names): for record in records: yield {k: record.get(k) for k in field_names} else: compiled_paths = step.get("_compiled_fields") if not compiled_paths: compiled_paths = [f.split(".") for f in field_names] step["_compiled_fields"] = compiled_paths for record in records: out: dict[str, Any] = {} for path in compiled_paths: value = fast_get_field(record, path) set_nested_field(out, ".".join(path), value) yield out
def _del_nested(d, parts) -> None: """Delete a nested key from a dict given a list of parts.""" for part in parts[:-1]: d = d.get(part, {}) d.pop(parts[-1], None) def _set_nested(d, parts, value) -> None: """Set a nested key in a dict given a list of parts.""" for part in parts[:-1]: if part not in d or not isinstance(d[part], dict): d[part] = {} d = d[part] d[parts[-1]] = value
[docs] def apply_rename(records: RecordStream, step: dict) -> RecordStream: """ Rename fields in each record. Args: records: A stream of records to rename fields in. step: Configuration dict containing the mapping of old to new field names. Returns: A stream of records with renamed fields. """ mapping = step["mapping"] for r in records: new = dict(r) for old, new_key in mapping.items(): # Try flat key match first if old in new: new[new_key] = new.pop(old) else: # Try nested (dot-path) access old_parts = old.split(".") new_parts = new_key.split(".") val = get_field(new, old_parts) if val is not None: _del_nested(new, old_parts) _set_nested(new, new_parts, val) yield new
[docs] def apply_sort(records: RecordStream, step: dict) -> RecordStream: """ Sort records based on a list of keys. Args: records: A stream of records to sort. step: Configuration dict containing the keys to sort by. Returns: A stream of sorted records. """ keys = step["keys"] ascending = step.get("ascending", [True] * len(keys)) records_list: List[Record] = list(records) def sort_key(r): key_parts = [] for k, asc in zip(keys, ascending): v = r.get(k) # Reverse sort by negating numeric or inverting sortable value key_parts.append(v if asc else _invert(v)) return tuple(key_parts) return iter(sorted(records_list, key=sort_key))
def _invert(value): if isinstance(value, (int, float)): return -value elif isinstance(value, str): return "".join(chr(255 - ord(c)) for c in value) # crude inverse for strings else: return value # fallback: not reversed
[docs] def apply_limit(records: RecordStream, step: dict) -> RecordStream: """ Limit the number of records. Args: records: A stream of records to limit. step: Configuration dict containing the count to limit by. Returns: A stream of limited records. """ count = step["count"] for i, record in enumerate(records): if i >= count: break yield record
[docs] def apply_drop(records: RecordStream, step: dict) -> RecordStream: """ Drop specified fields from each record. Args: records: A stream of records to drop fields from. step: Configuration dict containing the keys to drop. Returns: A stream of records with dropped fields. """ keys = step["keys"] for record in records: new = dict(record) for key in keys: parts = key.split(".") d = new try: for part in parts[:-1]: next_d = d.get(part) if not isinstance(next_d, dict): raise KeyError d = next_d d.pop(parts[-1], None) # silently ignore missing keys except Exception: continue # skip malformed paths yield new
[docs] def apply_flatten(records: RecordStream, step: dict) -> RecordStream: """ Flatten nested dictionaries into a single-level dictionary using dot notation. Args: records: A stream of records to flatten. step: Configuration dict containing the keys to flatten. Returns: A stream of flattened records. """ for r in records: yield flatten_dict(r)
[docs] def apply_distinct(records: RecordStream, step: dict) -> RecordStream: keys = step.get("keys") keep = step.get("keep", "first") if keep == "first": return _distinct_first(records, keys) elif keep == "last": return _distinct_last(records, keys) else: raise ValueError("distinct keep must be 'first' or 'last'")
def _record_identity(record, keys): if keys: return tuple(get_field(record, k.split(".")) for k in keys) return tuple(sorted(record.items())) def _distinct_first(records, keys): seen = set() for record in records: identity = _record_identity(record, keys) if identity in seen: continue seen.add(identity) yield record def _distinct_last(records, keys): seen = OrderedDict() for record in records: identity = _record_identity(record, keys) seen[identity] = record # later value overwrites yield from seen.values()
[docs] def apply_dropna(records: RecordStream, step: dict) -> RecordStream: """ Drop records with missing values. Args: records: A stream of records to drop missing values from. step: Configuration dict containing the fields to drop missing values from. Returns: A stream of records with dropped records. """ fields = step.get("fields") compiled = step.get("_compiled_fields") if fields: if not compiled: compiled = [f.split(".") for f in fields] step["_compiled_fields"] = compiled for record in records: if any(get_field(record, path) is None for path in compiled): continue yield record else: inferred_keys = None for record in records: if inferred_keys is None: inferred_keys = list(record.keys()) if any(record.get(k) is None for k in inferred_keys): continue yield record
[docs] def apply_explode(records: RecordStream, step: dict) -> RecordStream: """ Explode records based on a list of fields. Args: records: A stream of records to explode. step: Configuration dict containing the fields to explode. Returns: A stream of exploded records. """ fields = step["fields"] compiled = step.get("_compiled_fields") if not compiled: compiled = [f.split(".") for f in fields] step["_compiled_fields"] = compiled for record in records: values = [get_field(record, f) for f in compiled] if all(isinstance(v, list) for v in values): lengths = [len(v) for v in values] if len(set(lengths)) != 1: raise ValueError( f"Cannot explode fields with mismatched lengths: {lengths}" ) if lengths[0] == 0: yield copy.deepcopy(record) else: for i in range(lengths[0]): new_record = copy.deepcopy(record) for f, v in zip(compiled, values): set_nested_field(new_record, ".".join(f), v[i]) yield new_record else: yield copy.deepcopy(record)
[docs] def apply_join(records: RecordStream, step: dict) -> RecordStream: """ Join records based on a list of keys. Dispatcher function that selects the appropriate join strategy. Args: records: A stream of records to join. step: Configuration dict containing the keys to join by. Returns: A stream of joined records. """ # Future logic to select strategy will go here. # For now, always use hash join. return _apply_hash_join(records, step)
def _apply_sort_merge_join(records: RecordStream, step: dict) -> RecordStream: """ Sort-merge join implementation for memory-efficient joins on pre-sorted data. Args: records: A stream of records to join (assumed to be sorted on join keys). step: Configuration dict containing the keys to join by. Returns: A stream of joined records. """ from ..executor import FlowExecutor # Extract parameters on = step.get("on") left_on = step.get("left_on") right_on = step.get("right_on") lsuffix = step.get("lsuffix", "") rsuffix = step.get("rsuffix", "_right") how = step.get("how", "left") type_coercion = step.get("type_coercion", "strict") # Determine join keys left_keys: Union[List[str], None] = None right_keys: Union[List[str], None] = None if on is not None: left_keys = right_keys = on else: left_keys = left_on right_keys = right_on # Compile join keys compiled_left = step.get("_compiled_left") compiled_right = step.get("_compiled_right") if not compiled_left and left_keys is not None: compiled_left = [k.split(".") for k in left_keys] step["_compiled_left"] = compiled_left if not compiled_right: compiled_right = [k.split(".") for k in right_keys] if right_keys else [] step["_compiled_right"] = compiled_right # Get left and right iterators left_iter = records right_iter = FlowExecutor(step["right_plan"]).execute() # Key extraction functions def left_key(record): return tuple( _coerce_join_key(get_field(record, k), type_coercion) for k in (compiled_left or []) ) def right_key(record): return tuple( _coerce_join_key(get_field(record, k), type_coercion) for k in (compiled_right or []) ) # Helper function to merge records with suffix handling def merge_records(left_rec, right_rec, is_left_primary=True): if is_left_primary: merged = dict(left_rec) for rk, rv in right_rec.items(): if right_keys is not None and rk in right_keys: continue if rk in merged: if rsuffix: merged[rk + rsuffix] = rv else: merged[rk] = rv # Apply lsuffix if needed if lsuffix: for lk in list(merged.keys()): if ( lk in left_rec and left_keys is not None and lk not in left_keys and lk in right_rec ): if lk + lsuffix not in merged: merged[lk + lsuffix] = merged.pop(lk) else: # Right is primary (for right joins) merged = dict(right_rec) for lk, lv in left_rec.items(): if left_keys is not None and lk in left_keys: continue if lk in merged: if lsuffix: merged[lk + lsuffix] = lv else: merged[lk] = lv return merged # Create null record for unmatched sides def create_null_left(right_rec, sample_left=None): result = dict(right_rec) if sample_left: for lk in sample_left.keys(): if right_keys is not None and lk not in right_keys: field_name = lk + lsuffix if (lsuffix and lk in result) else lk if field_name not in result: result[field_name] = None return result def create_null_right(left_rec, sample_right=None): result = dict(left_rec) if sample_right: for rk in sample_right.keys(): if left_keys is not None and rk not in left_keys: field_name = rk + rsuffix if (rsuffix and rk in result) else rk if field_name not in result: result[field_name] = None return result # Group by key using itertools.groupby left_grouped = itertools.groupby(left_iter, key=left_key) right_grouped = itertools.groupby(right_iter, key=right_key) # Convert to iterators we can peek at try: left_key_val, left_group_iter = next(left_grouped) left_group = list(left_group_iter) # Materialize group left_has_data = True except StopIteration: left_has_data = False left_key_val = None left_group = [] try: right_key_val, right_group_iter = next(right_grouped) right_group = list(right_group_iter) # Materialize group right_has_data = True except StopIteration: right_has_data = False right_key_val = None right_group = [] sample_left = left_group[0] if left_group else None sample_right = right_group[0] if right_group else None # Main sort-merge loop while left_has_data or right_has_data: if not right_has_data or ( left_has_data and left_key_val is not None and right_key_val is not None and left_key_val < right_key_val ): # Left key is smaller or no more right data if how in ["left", "outer"]: for left_rec in left_group: yield create_null_right(left_rec, sample_right) elif how == "anti": for left_rec in left_group: yield dict(left_rec) # Advance left try: left_key_val, left_group_iter = next(left_grouped) left_group = list(left_group_iter) except StopIteration: left_has_data = False elif not left_has_data or ( right_has_data and left_key_val is not None and right_key_val is not None and right_key_val < left_key_val ): # Right key is smaller or no more left data if how in ["right", "outer"]: for right_rec in right_group: yield create_null_left(right_rec, sample_left) # Advance right try: right_key_val, right_group_iter = next(right_grouped) right_group = list(right_group_iter) except StopIteration: right_has_data = False else: # Keys match if how != "anti": # Create cartesian product of matching groups for left_rec in left_group: for right_rec in right_group: if how == "right": yield merge_records( left_rec, right_rec, is_left_primary=False ) else: yield merge_records( left_rec, right_rec, is_left_primary=True ) # Advance both try: left_key_val, left_group_iter = next(left_grouped) left_group = list(left_group_iter) except StopIteration: left_has_data = False try: right_key_val, right_group_iter = next(right_grouped) right_group = list(right_group_iter) except StopIteration: right_has_data = False def _apply_hash_join(records: RecordStream, step: dict) -> RecordStream: """ Hash join implementation for joining records. Args: records: A stream of records to join. step: Configuration dict containing the keys to join by. Returns: A stream of joined records. """ from ..executor import FlowExecutor # Extract parameters on = step.get("on") left_on = step.get("left_on") right_on = step.get("right_on") lsuffix = step.get("lsuffix", "") rsuffix = step.get("rsuffix", "_right") how = step.get("how", "left") type_coercion = step.get("type_coercion", "strict") # Determine join keys left_keys: Union[List[str], None] = None right_keys: Union[List[str], None] = None if on is not None: left_keys = right_keys = on else: left_keys = left_on right_keys = right_on # Compile join keys compiled_left = step.get("_compiled_left") compiled_right = step.get("_compiled_right") if not compiled_left and left_keys is not None: compiled_left = [k.split(".") for k in left_keys] step["_compiled_left"] = compiled_left if not compiled_right: compiled_right = [k.split(".") for k in right_keys] if right_keys else [] step["_compiled_right"] = compiled_right # Execute right plan and build index right_records: List[Record] = list(FlowExecutor(step["right_plan"]).execute()) right_index: dict[tuple[Any, ...], list[dict]] = {} for r in right_records: key = tuple( _coerce_join_key(get_field(r, k), type_coercion) for k in (compiled_right or []) ) right_index.setdefault(key, []).append(r) # Handle right join by swapping if how == "right": # Swap left and right, then do a left join def right_join_generator(): # Build left index left_records: List[Record] = list(records) left_index: dict[tuple[Any, ...], list[dict]] = {} for l in left_records: key = tuple( _coerce_join_key(get_field(l, k), type_coercion) for k in (compiled_left or []) ) left_index.setdefault(key, []).append(l) # Process right records as primary for right in right_records: key = tuple( _coerce_join_key(get_field(right, k), type_coercion) for k in (compiled_right or []) ) matches = left_index.get(key) if not matches: # No left match - yield right with nulls for left fields joined = dict(right) # Add null values for left-only fields for left_rec in ( left_records[:1] if left_records else [{}] ): # Use first left record as template for lk in left_rec.keys(): if right_keys is not None and lk not in right_keys: left_name = ( lk + lsuffix if (lsuffix and lk in joined) else lk ) if left_name not in joined: joined[left_name] = None break yield joined else: for left in matches: joined = dict(right) for lk, lv in left.items(): if right_keys is not None and lk in right_keys: continue if lk in joined: joined[lk + lsuffix] = lv else: joined[lk] = lv yield joined yield from right_join_generator() return # Track matched right keys for outer join matched_right_keys: Optional[set[tuple[Any, ...]]] = ( set() if how == "outer" else None ) # Process left records for left in records: key = tuple( _coerce_join_key(get_field(left, k), type_coercion) for k in (compiled_left or []) ) matches = right_index.get(key) if not matches: # No right match if how in ["left", "outer"]: yield dict(left) elif how == "anti": yield dict(left) # For inner join, skip unmatched left records continue # Has matches if how == "anti": # Anti join - skip records that have matches continue # Mark this key as matched for outer join if matched_right_keys is not None: matched_right_keys.add(key) # Join matched records for right in matches: joined = dict(left) for rk, rv in right.items(): if right_keys is not None and rk in right_keys: continue if rk in joined: # Handle suffix collision if rsuffix: joined[rk + rsuffix] = rv else: # If no rsuffix, left value takes precedence pass else: joined[rk] = rv # Apply lsuffix to overlapping left fields if needed if lsuffix: for lk in list(joined.keys()): if lk in left and left_keys is not None and lk not in left_keys: # Check if this left field conflicts with a right field if any(lk in r for r in matches): # Rename left field with lsuffix if lk + lsuffix not in joined: joined[lk + lsuffix] = joined.pop(lk) yield joined # Handle outer join - emit unmatched right records if how == "outer": for key, right_group in right_index.items(): if matched_right_keys is not None and key not in matched_right_keys: for right in right_group: # Create record with right data and null left fields joined = dict(right) # Add null values for left-only fields left_sample = ( next(iter(records), None) if hasattr(records, "__iter__") else None ) if left_sample: for lk in left_sample.keys(): if left_keys is not None and lk not in left_keys: left_name = ( lk + lsuffix if (lsuffix and lk in joined) else lk ) if left_name not in joined: joined[left_name] = None yield joined
[docs] def apply_split_array(records: RecordStream, step: dict) -> RecordStream: """ Split an array into multiple records. Args: records: A stream of records to split arrays from. step: Configuration dict containing the field to split. Returns: A stream of split records. """ field = step["field"] into = step["into"] for record in records: if field not in record and "." not in field: # Simple field missing entirely yield record continue value = get_field(record, field) # Skip if field is missing or explicitly None if value is None: yield record continue new_record = dict(record) if isinstance(value, (list, tuple)): for i, key in enumerate(into): new_record[key] = value[i] if i < len(value) else None yield new_record else: # Field exists but isn't a list → treat as error or pass through unchanged yield record
[docs] def apply_pivot(records: RecordStream, step: dict) -> RecordStream: """ Pivot records based on a list of index fields. Args: records: A stream of records to pivot. step: Configuration dict containing the index fields to pivot by. Returns: A stream of pivoted records. """ index_fields = step["index"] col_field = step["columns"] val_field = step["values"] compiled_index = step.get("_compiled_index") if not compiled_index: compiled_index = [f.split(".") for f in index_fields] step["_compiled_index"] = compiled_index compiled_col = step.get("_compiled_col") or col_field.split(".") compiled_val = step.get("_compiled_val") or val_field.split(".") step["_compiled_col"] = compiled_col step["_compiled_val"] = compiled_val grouped = defaultdict(list) for r in records: key = tuple(get_field(r, f) for f in compiled_index) grouped[key].append(r) for key, rows in grouped.items(): result = {f: k for f, k in zip(index_fields, key)} for row in rows: col = get_field(row, compiled_col) val = get_field(row, compiled_val) if col is not None: result[col] = val yield result
[docs] def apply_summary(records: RecordStream, step: dict) -> RecordStream: """ Apply a summary function to the records. Args: records: A stream of records to apply the summary function to. step: Configuration dict containing the summary function to apply. Returns: A stream of summary results. """ agg_func = step["agg"] rows: List[Record] = list(records) result = agg_func(rows) if not isinstance(result, dict): raise ValueError("summary function must return a dict") yield result
[docs] def apply_sample_fraction(records: RecordStream, step: dict) -> RecordStream: """ Sample a fraction of the records. Args: records: A stream of records to sample a fraction from. step: Configuration dict containing the fraction to sample. Returns: A stream of sampled records. """ p = step["p"] seed = step.get("seed") rng = random.Random(seed) for r in records: if rng.random() < p: yield r
[docs] def apply_sample_n(records: RecordStream, step: dict) -> RecordStream: """ Sample a fixed number of records. Args: records: A stream of records to sample from. step: Configuration dict with 'n' and optional 'seed'. Returns: A stream of sampled records. """ n = step["n"] seed = step.get("seed") for r in reservoir_sample(records, n, seed): yield r
[docs] def apply_map(records: RecordStream, step: dict) -> RecordStream: """ Apply a function to each record. Args: records: A stream of records to apply the function to. step: Configuration dict containing the function to apply. Returns: A stream of mapped records. """ func = step["func"] for r in records: result = func(r) if result is None: continue if not isinstance(result, dict): raise TypeError("map function must return a dict") yield result
[docs] def apply_fused(records: RecordStream, step: dict) -> RecordStream: """ Apply a fused sequence of map/assign/filter operations. Args: records: A stream of records to apply fused operations to. step: Configuration dict with an 'ops' list and potentially embedded steps. Returns: A stream of records with the fused operations applied. """ # Extract embedded steps embedded_steps = step.get("steps", []) # Sanity fallback: reconstruct from original plan if needed if not embedded_steps: raise ValueError("Fused step missing original embedded steps") # Apply them sequentially for sub_step in embedded_steps: op = sub_step["op"] if op == "map": records = apply_map(records, sub_step) elif op == "assign": records = apply_assign(records, sub_step) elif op == "filter": records = apply_filter(records, sub_step) else: raise ValueError(f"Unsupported op in fused step: {op}") return records