Source code for bigframes.ml.pipeline

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""For composing estimators together. This module is styled after scikit-learn's
pipeline module: https://scikit-learn.org/stable/modules/pipeline.html."""


from __future__ import annotations

from typing import List, Optional, Tuple

import bigframes_vendored.constants as constants
import bigframes_vendored.sklearn.pipeline
from google.cloud import bigquery

from bigframes.core import log_adapter
import bigframes.dataframe
from bigframes.ml import (
    base,
    compose,
    forecasting,
    impute,
    loader,
    preprocessing,
    utils,
)
import bigframes.session


[docs] @log_adapter.class_logger class Pipeline( base.BaseEstimator, bigframes_vendored.sklearn.pipeline.Pipeline, ): __doc__ = bigframes_vendored.sklearn.pipeline.Pipeline.__doc__
[docs] def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]): self.steps = steps if len(steps) != 2: raise NotImplementedError( f"Currently only two step (transform, estimator) pipelines are supported. {constants.FEEDBACK_LINK}" ) transform, estimator = steps[0][1], steps[1][1] if isinstance( transform, ( compose.ColumnTransformer, preprocessing.StandardScaler, preprocessing.OneHotEncoder, preprocessing.MaxAbsScaler, preprocessing.MinMaxScaler, preprocessing.KBinsDiscretizer, preprocessing.LabelEncoder, preprocessing.PolynomialFeatures, impute.SimpleImputer, ), ): self._transform = transform else: raise NotImplementedError( f"Transformer type {type(transform)} is not yet supported by Pipeline. {constants.FEEDBACK_LINK}" ) if not isinstance( estimator, base.TrainablePredictor, ): raise NotImplementedError( f"Estimator type {type(estimator)} is not supported by Pipeline. {constants.FEEDBACK_LINK}" ) # BQML doesn't support ARIMA_PLUS with transformers. b/298676367 if isinstance(estimator, forecasting.ARIMAPlus): raise NotImplementedError( f"Estimator type {type(estimator)} is not supported by Pipeline. {constants.FEEDBACK_LINK}" ) self._transform = transform self._estimator = estimator
@classmethod def _from_bq( cls, session: bigframes.session.Session, bq_model: bigquery.Model ) -> Pipeline: col_transformer = compose.ColumnTransformer._extract_from_bq_model(bq_model) transform = col_transformer._merge(bq_model) estimator = loader._model_from_bq(session, bq_model) return cls([("transform", transform), ("estimator", estimator)])
[docs] def fit( self, X: utils.BigFramesArrayType, y: Optional[utils.BigFramesArrayType] = None, ) -> Pipeline: (X,) = utils.batch_convert_to_dataframe(X) transform_sqls = self._transform._compile_to_sql(X) if y is not None: # If labels columns are present, they should pass through un-transformed (y,) = utils.batch_convert_to_dataframe(y) transform_sqls.extend(y.columns.tolist()) self._estimator._fit(X=X, y=y, transforms=transform_sqls) return self
[docs] def predict(self, X: utils.ArrayType) -> bigframes.dataframe.DataFrame: return self._estimator.predict(X)
[docs] def score( self, X: utils.BigFramesArrayType, y: Optional[utils.BigFramesArrayType] = None, ) -> bigframes.dataframe.DataFrame: (X,) = utils.batch_convert_to_dataframe(X) if y is not None: (y,) = utils.batch_convert_to_dataframe(y) return self._estimator.score(X=X, y=y)
[docs] def to_gbq(self, model_name: str, replace: bool = False) -> Pipeline: """Save the pipeline to BigQuery. Args: model_name (str): The name of the model(pipeline). replace (bool, default False): Whether to replace if the model(pipeline) already exists. Default to False. Returns: Pipeline: Saved model(pipeline).""" if not self._estimator._bqml_model: raise RuntimeError("A model must be fitted before it can be saved") new_model = self._estimator._bqml_model.copy(model_name, replace) return new_model.session.read_gbq_model(model_name)