Source code for bigframes.ml.remote

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BigFrames general remote models."""

from __future__ import annotations

from typing import Mapping, Optional
import warnings

from bigframes.core import global_session, log_adapter
import bigframes.dataframe
import bigframes.exceptions as bfe
from bigframes.ml import base, core, globals, utils
import bigframes.session

_REMOTE_MODEL_STATUS = "remote_model_status"


[docs] @log_adapter.class_logger class VertexAIModel(base.BaseEstimator): """Remote model from a Vertex AI HTTPS endpoint. User must specify HTTPS endpoint, input schema and output schema. For more information, see Deploy model on Vertex AI: https://cloud.google.com/bigquery/docs/bigquery-ml-remote-model-tutorial#Deploy-Model-on-Vertex-AI. Args: endpoint (str): Vertex AI HTTPS endpoint. input (Mapping): Input schema: `{column_name: column_type}`. Supported types are "bool", "string", "int64", "float64", "array<bool>", "array<string>", "array<int64>", "array<float64>". output (Mapping): Output label schema: `{column_name: column_type}`. Supported the same types as the input. session (bigframes.Session or None): BQ session to create the model. If None, use the global default session. connection_name (str or None): Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>. If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach permission if the connection isn't fully set up. """
[docs] def __init__( self, endpoint: str, input: Mapping[str, str], output: Mapping[str, str], *, session: Optional[bigframes.session.Session] = None, connection_name: Optional[str] = None, ): self.endpoint = endpoint self.input = input self.output = output self.session = session or global_session.get_global_session() self._bq_connection_manager = self.session.bqconnectionmanager self.connection_name = connection_name self._bqml_model_factory = globals.bqml_model_factory() self._bqml_model: core.BqmlModel = self._create_bqml_model()
def _create_bqml_model(self): # Parse and create connection if needed. self.connection_name = self.session._create_bq_connection( connection=self.connection_name, iam_role="aiplatform.user" ) options = { "endpoint": self.endpoint, } self.input = { k: utils.standardize_type(v, globals._REMOTE_MODEL_SUPPORTED_DTYPES) for k, v in self.input.items() } self.output = { k: utils.standardize_type(v, globals._REMOTE_MODEL_SUPPORTED_DTYPES) for k, v in self.output.items() } return self._bqml_model_factory.create_remote_model( session=self.session, connection_name=self.connection_name, input=self.input, output=self.output, options=options, )
[docs] def predict( self, X: utils.ArrayType, ) -> bigframes.dataframe.DataFrame: """Predict the result from the input DataFrame. Args: X (bigframes.pandas.DataFrame or bigframes.pandas.Series or pandas.DataFrame or pandas.Series): Input DataFrame or Series, which needs to comply with the input parameter of the model. Returns: bigframes.pandas.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. """ (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) df = self._bqml_model.predict(X) # unlike LLM models, the general remote model status is null for successful runs. if (df[_REMOTE_MODEL_STATUS].notna()).any(): msg = bfe.format_message( f"Some predictions failed. Check column {_REMOTE_MODEL_STATUS} for " "detailed status. You may want to filter the failed rows and retry." ) warnings.warn(msg, category=RuntimeWarning) return df