Source code for bigframes.core.window.rolling

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import datetime
from typing import Literal, Mapping, Sequence, TYPE_CHECKING, Union

import bigframes_vendored.pandas.core.window.rolling as vendored_pandas_rolling
import numpy
import pandas

from bigframes import dtypes
from bigframes.core import agg_expressions
from bigframes.core import expression as ex
from bigframes.core import log_adapter, ordering, utils, window_spec
import bigframes.core.blocks as blocks
from bigframes.core.window import ordering as window_ordering
import bigframes.operations.aggregations as agg_ops

if TYPE_CHECKING:
    import bigframes.dataframe as df
    import bigframes.series as series


[docs] @log_adapter.class_logger class Window(vendored_pandas_rolling.Window): __doc__ = vendored_pandas_rolling.Window.__doc__
[docs] def __init__( self, block: blocks.Block, window_spec: window_spec.WindowSpec, value_column_ids: Sequence[str], drop_null_groups: bool = True, is_series: bool = False, skip_agg_column_id: str | None = None, ): self._block = block self._window_spec = window_spec self._value_column_ids = value_column_ids self._drop_null_groups = drop_null_groups self._is_series = is_series # The column ID that won't be aggregated on. # This is equivalent to pandas `on` parameter in rolling() self._skip_agg_column_id = skip_agg_column_id
[docs] def count(self): return self._apply_aggregate_op(agg_ops.count_op)
[docs] def sum(self): return self._apply_aggregate_op(agg_ops.sum_op)
[docs] def mean(self): return self._apply_aggregate_op(agg_ops.mean_op)
[docs] def var(self): return self._apply_aggregate_op(agg_ops.var_op)
[docs] def std(self): return self._apply_aggregate_op(agg_ops.std_op)
[docs] def max(self): return self._apply_aggregate_op(agg_ops.max_op)
[docs] def min(self): return self._apply_aggregate_op(agg_ops.min_op)
[docs] def agg(self, func) -> Union[df.DataFrame, series.Series]: if utils.is_dict_like(func): return self._agg_dict(func) elif utils.is_list_like(func): return self._agg_list(func) else: return self._agg_func(func)
aggregate = agg def _agg_func(self, func) -> df.DataFrame: ids, labels = self._aggregated_columns() aggregations = [agg(col_id, agg_ops.lookup_agg_func(func)[0]) for col_id in ids] return self._apply_aggs(aggregations, labels) def _agg_dict(self, func: Mapping) -> df.DataFrame: aggregations: list[agg_expressions.Aggregation] = [] column_labels = [] function_labels = [] want_aggfunc_level = any(utils.is_list_like(aggs) for aggs in func.values()) for label, funcs_for_id in func.items(): col_id = self._block.label_to_col_id[label][-1] # get last matching column func_list = ( funcs_for_id if utils.is_list_like(funcs_for_id) else [funcs_for_id] ) for f in func_list: f_op, f_label = agg_ops.lookup_agg_func(f) aggregations.append(agg(col_id, f_op)) column_labels.append(label) function_labels.append(f_label) if want_aggfunc_level: result_labels: pandas.Index = utils.combine_indices( pandas.Index(column_labels), pandas.Index(function_labels), ) else: result_labels = pandas.Index(column_labels) return self._apply_aggs(aggregations, result_labels) def _agg_list(self, func: Sequence) -> df.DataFrame: ids, labels = self._aggregated_columns() aggregations = [ agg(col_id, agg_ops.lookup_agg_func(f)[0]) for col_id in ids for f in func ] if self._is_series: # if series, no need to rebuild result_cols_idx = pandas.Index( [agg_ops.lookup_agg_func(f)[1] for f in func] ) else: if self._block.column_labels.nlevels > 1: # Restructure MultiIndex for proper format: (idx1, idx2, func) # rather than ((idx1, idx2), func). column_labels = [ tuple(label) + (agg_ops.lookup_agg_func(f)[1],) for label in labels.to_frame(index=False).to_numpy() for f in func ] else: # Single-level index column_labels = [ (label, agg_ops.lookup_agg_func(f)[1]) for label in labels for f in func ] result_cols_idx = pandas.MultiIndex.from_tuples( column_labels, names=[*self._block.column_labels.names, None] ) return self._apply_aggs(aggregations, result_cols_idx) def _apply_aggs( self, exprs: Sequence[agg_expressions.Aggregation], labels: pandas.Index ): block, ids = self._block.apply_analytic( agg_exprs=exprs, window=self._window_spec, result_labels=labels, skip_null_groups=self._drop_null_groups, ) if self._window_spec.grouping_keys: original_index_ids = block.index_columns block = block.reset_index(drop=False) # grouping keys will always be direct column references, but we should probably # refactor this class to enforce this statically index_ids = ( *[col.id.name for col in self._window_spec.grouping_keys], # type: ignore *original_index_ids, ) block = block.set_index(col_ids=index_ids) if self._skip_agg_column_id is not None: block = block.select_columns([self._skip_agg_column_id, *ids]) else: block = block.select_columns(ids).with_column_labels(labels) if self._is_series and (len(block.value_columns) == 1): import bigframes.series as series return series.Series(block) else: import bigframes.dataframe as df return df.DataFrame(block) def _apply_aggregate_op( self, op: agg_ops.UnaryAggregateOp, ): ids, labels = self._aggregated_columns() aggregations = [agg(col_id, op) for col_id in ids] return self._apply_aggs(aggregations, labels) def _aggregated_columns(self) -> tuple[Sequence[str], pandas.Index]: agg_col_ids = [ col_id for col_id in self._value_column_ids if col_id != self._skip_agg_column_id ] labels: pandas.Index = pandas.Index( [self._block.col_id_to_label[col] for col in agg_col_ids] ) return agg_col_ids, labels
def create_range_window( block: blocks.Block, window: pandas.Timedelta | numpy.timedelta64 | datetime.timedelta | str, *, value_column_ids: Sequence[str] = tuple(), min_periods: int | None, on: str | None = None, closed: Literal["right", "left", "both", "neither"], is_series: bool, grouping_keys: Sequence[str] = tuple(), drop_null_groups: bool = True, ) -> Window: if on is None: # Rolling on index index_dtypes = block.index.dtypes if len(index_dtypes) > 1: raise ValueError("Range rolling on MultiIndex is not supported") if index_dtypes[0] != dtypes.TIMESTAMP_DTYPE: raise ValueError("Index type should be timestamps with timezones") rolling_key_col_id = block.index_columns[0] else: # Rolling on a specific column rolling_key_col_id = block.resolve_label_exact_or_error(on) if block.expr.get_column_type(rolling_key_col_id) != dtypes.TIMESTAMP_DTYPE: raise ValueError(f"Column {on} type should be timestamps with timezones") order_direction = window_ordering.find_order_direction( block.expr.node, rolling_key_col_id ) if order_direction is None: target_str = "index" if on is None else f"column {on}" raise ValueError( f"The {target_str} might not be in a monotonic order. Please sort by {target_str} before rolling." ) if isinstance(window, str): window = pandas.Timedelta(window) spec = window_spec.WindowSpec( bounds=window_spec.RangeWindowBounds.from_timedelta_window(window, closed), min_periods=1 if min_periods is None else min_periods, ordering=( ordering.OrderingExpression(ex.deref(rolling_key_col_id), order_direction), ), grouping_keys=tuple(ex.deref(col) for col in grouping_keys), ) selected_value_col_ids = ( value_column_ids if value_column_ids else block.value_columns ) # This step must be done after finding the order direction of the window key. if grouping_keys: block = block.order_by([ordering.ascending_over(col) for col in grouping_keys]) return Window( block, spec, value_column_ids=selected_value_col_ids, is_series=is_series, skip_agg_column_id=None if on is None else rolling_key_col_id, drop_null_groups=drop_null_groups, ) def agg(input: str, op: agg_ops.AggregateOp) -> agg_expressions.Aggregation: if isinstance(op, agg_ops.UnaryAggregateOp): return agg_expressions.UnaryAggregation(op, ex.deref(input)) else: assert isinstance(op, agg_ops.NullaryAggregateOp) return agg_expressions.NullaryAggregation(op)