Source code for bigframes.ml.impute
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers for missing value imputation. This module is styled after
scikit-learn's preprocessing module: https://scikit-learn.org/stable/modules/impute.html."""
from __future__ import annotations
import typing
from typing import Iterable, List, Literal, Optional
import bigframes_vendored.sklearn.impute._base
from bigframes.core import log_adapter
import bigframes.core.utils as core_utils
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd
[docs]
@log_adapter.class_logger
class SimpleImputer(
base.Transformer,
bigframes_vendored.sklearn.impute._base.SimpleImputer,
):
__doc__ = bigframes_vendored.sklearn.impute._base.SimpleImputer.__doc__
[docs]
def __init__(
self,
strategy: Literal["mean", "median", "most_frequent"] = "mean",
):
self.strategy = strategy
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()
self._base_sql_generator = globals.base_sql_generator()
def _keys(self):
return (self._bqml_model, self.strategy)
def _compile_to_sql(
self,
X: bpd.DataFrame,
columns: Optional[Iterable[str]] = None,
) -> List[str]:
"""Compile this transformer to a list of SQL expressions that can be included in
a BQML TRANSFORM clause
Args:
X: DataFrame to transform.
columns: transform columns. If None, transform all columns in X.
Returns: a list of tuples sql_expr."""
if columns is None:
columns = X.columns
columns, _ = core_utils.get_standardized_ids(columns)
return [
self._base_sql_generator.ml_imputer(
column, self.strategy, f"imputer_{column}"
)
for column in columns
]
@classmethod
def _parse_from_sql(cls, sql: str) -> tuple[SimpleImputer, str]:
"""Parse SQL to tuple(SimpleImputer, column_label).
Args:
sql: SQL string of format "ML.IMPUTER({col_label}, {strategy}) OVER()"
Returns:
tuple(SimpleImputer, column_label)"""
s = sql[sql.find("(") + 1 : sql.find(")")]
col_label, strategy = s.split(", ")
return cls(strategy[1:-1]), _unescape_id(col_label) # type: ignore[arg-type]
[docs]
def fit(
self,
X: utils.ArrayType,
y=None, # ignored
) -> SimpleImputer:
(X,) = utils.batch_convert_to_dataframe(X)
transform_sqls = self._compile_to_sql(X)
self._bqml_model = self._bqml_model_factory.create_model(
X,
options={"model_type": "transform_only"},
transforms=transform_sqls,
)
self._extract_output_names()
return self
def _unescape_id(id: str) -> str:
"""Very simple conversion to removed ` characters from ids.
A proper sql parser should be used instead.
"""
return id.removeprefix("`").removesuffix("`")