{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# BigFrames ML Cross-Vaidation\n", "\n", "This demo shows how to do cross validation in bigframes.ml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Prepare Data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import bigframes.pandas as bpd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Query job aa2b9845-0e66-4f42-a360-ffe03215caf6 is DONE. 0 Bytes processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job fe2bc354-672e-4d08-b969-bb2ede299fca is DONE. 28.9 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 8d16fa20-391f-4917-86fc-1a595dba3fc6 is DONE. 33.6 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
speciesislandculmen_length_mmculmen_depth_mmflipper_length_mmbody_mass_gsex
0Gentoo penguin (Pygoscelis papua)Biscoe45.216.4223.05950.0MALE
1Gentoo penguin (Pygoscelis papua)Biscoe46.514.5213.04400.0FEMALE
2Adelie Penguin (Pygoscelis adeliae)Biscoe37.716.0183.03075.0FEMALE
3Gentoo penguin (Pygoscelis papua)Biscoe46.415.6221.05000.0MALE
4Gentoo penguin (Pygoscelis papua)Biscoe46.113.2211.04500.0FEMALE
5Adelie Penguin (Pygoscelis adeliae)Torgersen43.119.2197.03500.0MALE
6Gentoo penguin (Pygoscelis papua)Biscoe45.215.8215.05300.0MALE
7Adelie Penguin (Pygoscelis adeliae)Dream36.217.3187.03300.0FEMALE
8Chinstrap penguin (Pygoscelis antarctica)Dream46.018.9195.04150.0FEMALE
9Gentoo penguin (Pygoscelis papua)Biscoe54.315.7231.05650.0MALE
11Adelie Penguin (Pygoscelis adeliae)Torgersen39.517.4186.03800.0FEMALE
12Gentoo penguin (Pygoscelis papua)Biscoe42.713.7208.03950.0FEMALE
13Adelie Penguin (Pygoscelis adeliae)Biscoe41.020.0203.04725.0MALE
14Gentoo penguin (Pygoscelis papua)Biscoe48.515.0219.04850.0FEMALE
15Chinstrap penguin (Pygoscelis antarctica)Dream49.618.2193.03775.0MALE
16Gentoo penguin (Pygoscelis papua)Biscoe50.817.3228.05600.0MALE
17Gentoo penguin (Pygoscelis papua)Biscoe46.214.1217.04375.0FEMALE
18Adelie Penguin (Pygoscelis adeliae)Biscoe38.817.2180.03800.0MALE
19Chinstrap penguin (Pygoscelis antarctica)Dream51.018.8203.04100.0MALE
20Gentoo penguin (Pygoscelis papua)Biscoe42.913.1215.05000.0FEMALE
21Gentoo penguin (Pygoscelis papua)Biscoe50.415.3224.05550.0MALE
22Gentoo penguin (Pygoscelis papua)Biscoe49.016.1216.05550.0MALE
23Gentoo penguin (Pygoscelis papua)Biscoe43.414.4218.04600.0FEMALE
24Gentoo penguin (Pygoscelis papua)Biscoe45.015.4220.05050.0MALE
25Gentoo penguin (Pygoscelis papua)Biscoe47.514.0212.04875.0FEMALE
\n", "

25 rows × 7 columns

\n", "
[334 rows x 7 columns in total]" ], "text/plain": [ " species island culmen_length_mm \\\n", "0 Gentoo penguin (Pygoscelis papua) Biscoe 45.2 \n", "1 Gentoo penguin (Pygoscelis papua) Biscoe 46.5 \n", "2 Adelie Penguin (Pygoscelis adeliae) Biscoe 37.7 \n", "3 Gentoo penguin (Pygoscelis papua) Biscoe 46.4 \n", "4 Gentoo penguin (Pygoscelis papua) Biscoe 46.1 \n", "5 Adelie Penguin (Pygoscelis adeliae) Torgersen 43.1 \n", "6 Gentoo penguin (Pygoscelis papua) Biscoe 45.2 \n", "7 Adelie Penguin (Pygoscelis adeliae) Dream 36.2 \n", "8 Chinstrap penguin (Pygoscelis antarctica) Dream 46.0 \n", "9 Gentoo penguin (Pygoscelis papua) Biscoe 54.3 \n", "11 Adelie Penguin (Pygoscelis adeliae) Torgersen 39.5 \n", "12 Gentoo penguin (Pygoscelis papua) Biscoe 42.7 \n", "13 Adelie Penguin (Pygoscelis adeliae) Biscoe 41.0 \n", "14 Gentoo penguin (Pygoscelis papua) Biscoe 48.5 \n", "15 Chinstrap penguin (Pygoscelis antarctica) Dream 49.6 \n", "16 Gentoo penguin (Pygoscelis papua) Biscoe 50.8 \n", "17 Gentoo penguin (Pygoscelis papua) Biscoe 46.2 \n", "18 Adelie Penguin (Pygoscelis adeliae) Biscoe 38.8 \n", "19 Chinstrap penguin (Pygoscelis antarctica) Dream 51.0 \n", "20 Gentoo penguin (Pygoscelis papua) Biscoe 42.9 \n", "21 Gentoo penguin (Pygoscelis papua) Biscoe 50.4 \n", "22 Gentoo penguin (Pygoscelis papua) Biscoe 49.0 \n", "23 Gentoo penguin (Pygoscelis papua) Biscoe 43.4 \n", "24 Gentoo penguin (Pygoscelis papua) Biscoe 45.0 \n", "25 Gentoo penguin (Pygoscelis papua) Biscoe 47.5 \n", "\n", " culmen_depth_mm flipper_length_mm body_mass_g sex \n", "0 16.4 223.0 5950.0 MALE \n", "1 14.5 213.0 4400.0 FEMALE \n", "2 16.0 183.0 3075.0 FEMALE \n", "3 15.6 221.0 5000.0 MALE \n", "4 13.2 211.0 4500.0 FEMALE \n", "5 19.2 197.0 3500.0 MALE \n", "6 15.8 215.0 5300.0 MALE \n", "7 17.3 187.0 3300.0 FEMALE \n", "8 18.9 195.0 4150.0 FEMALE \n", "9 15.7 231.0 5650.0 MALE \n", "11 17.4 186.0 3800.0 FEMALE \n", "12 13.7 208.0 3950.0 FEMALE \n", "13 20.0 203.0 4725.0 MALE \n", "14 15.0 219.0 4850.0 FEMALE \n", "15 18.2 193.0 3775.0 MALE \n", "16 17.3 228.0 5600.0 MALE \n", "17 14.1 217.0 4375.0 FEMALE \n", "18 17.2 180.0 3800.0 MALE \n", "19 18.8 203.0 4100.0 MALE \n", "20 13.1 215.0 5000.0 FEMALE \n", "21 15.3 224.0 5550.0 MALE \n", "22 16.1 216.0 5550.0 MALE \n", "23 14.4 218.0 4600.0 FEMALE \n", "24 15.4 220.0 5050.0 MALE \n", "25 14.0 212.0 4875.0 FEMALE \n", "...\n", "\n", "[334 rows x 7 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# read and filter out unavailable data\n", "df = bpd.read_gbq(\"bigframes-dev.bqml_tutorial.penguins\")\n", "df = df.dropna()\n", "df" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Select X and y from the dataset\n", "X = df[\n", " [\n", " \"species\",\n", " \"island\",\n", " \"culmen_length_mm\",\n", " ]\n", " ]\n", "y = df[\"body_mass_g\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.1 Define KFold class and Train/Test for Each Fold (Manual Approach)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from bigframes.ml import model_selection, linear_model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Create KFold instance, n_splits defines how many folds the data will split. For example, n_split=5 will split the entire dataset into 5 pieces. \n", "# In each fold, 4 pieces will be used for training, and the other piece will be used for evaluation. \n", "kf = model_selection.KFold(n_splits=5)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Query job 9ce9fb43-306d-46e9-bbe5-d98ee55143bd is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 8c86156d-ee97-4f66-9dc1-db15ff3d8e8e is DONE. 16.4 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job b8f2b382-b938-4dff-8bdb-129703ade285 is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", "0 297.36838 148892.914876 0.009057 \n", "\n", " median_absolute_error r2_score explained_variance \n", "0 238.424052 0.814613 0.816053 \n", "\n", "[1 rows x 6 columns]\n" ] }, { "data": { "text/html": [ "Query job ec2968f3-1713-4617-8a26-6fe4267f8061 is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job c7a1b80f-26f5-41b1-bcdc-b276af141671 is DONE. 16.4 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 82054991-c22f-41b3-9802-f16919949e26 is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", "0 307.6149 139013.303482 0.007907 \n", "\n", " median_absolute_error r2_score explained_variance \n", "0 266.589811 0.782835 0.794297 \n", "\n", "[1 rows x 6 columns]\n" ] }, { "data": { "text/html": [ "Query job 3e5ae019-7c5b-44ea-8392-85145fdb6802 is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job c35dfd28-504d-4d12-b039-da890b9cb51d is DONE. 16.5 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 29ac1bb3-f864-400e-8cac-0b4c7f78ebcd is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", "0 348.412701 180661.063512 0.01125 \n", "\n", " median_absolute_error r2_score explained_variance \n", "0 313.29406 0.744053 0.74537 \n", "\n", "[1 rows x 6 columns]\n" ] }, { "data": { "text/html": [ "Query job d90f5938-2894-4c93-8691-21162a2fca4c is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 4c6328b3-2d3f-42bb-9f83-4f8c84773c95 is DONE. 16.4 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 8a885a6a-d3ad-4569-80ce-4f57d9b86105 is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", "0 309.991882 151820.705254 0.008898 \n", "\n", " median_absolute_error r2_score explained_variance \n", "0 212.758708 0.694001 0.694287 \n", "\n", "[1 rows x 6 columns]\n" ] }, { "data": { "text/html": [ "Query job d1e60370-11c8-4f49-a8d5-85417662aa51 is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job d8e8712a-6347-4725-a27d-49810d4acc1c is DONE. 16.5 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 6a0ebaa6-5572-404f-a41d-b90e2c65d948 is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", "0 256.569216 103495.042886 0.006605 \n", "\n", " median_absolute_error r2_score explained_variance \n", "0 222.940815 0.818589 0.832344 \n", "\n", "[1 rows x 6 columns]\n" ] } ], "source": [ "for X_train, X_test, y_train, y_test in kf.split(X, y):\n", " model = linear_model.LinearRegression()\n", " model.fit(X_train, y_train)\n", " score = model.score(X_test, y_test)\n", "\n", " print(score)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2.2 Use cross_validate Function to Do Cross Validation (Automatic Approach)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Query job 5bdcd65d-7d72-4094-be3a-cf67a1787cf4 is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job bb0504b2-b656-4a08-9bf8-dcab0d188022 is DONE. 16.4 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 8c5c4b66-9a14-455a-a3f5-99f0f522713f is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 9c9b81de-35b6-4561-8881-57da8b73cc7f is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job b781f1aa-6572-49e5-ab8d-f1908b497a1c is DONE. 16.4 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 41a2a58e-0289-4d58-8e39-de286f2a91fb is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 7ee839a9-f77c-49b0-844e-8eecc1647b97 is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job a317d488-8589-4faa-940b-e59af91caf4d is DONE. 16.5 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 2de96ea8-519a-4976-a641-eb26a4bd38fb is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 41a7d5a0-c76b-4ef3-a3da-d4d5a2ebbb0e is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 9e82ddc9-8461-4644-ba34-957a7426ff8e is DONE. 16.4 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job 0fa84d07-fdfa-41c9-b601-9326a94f3a09 is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job d4495568-f1b5-431b-b892-4fc7dcbccfd5 is DONE. 37.0 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job af1e6460-3078-4a8b-8992-9e7df9dcfbb3 is DONE. 16.5 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Query job f14401bf-fd80-401a-a61d-52614fba1ca7 is DONE. 37.3 kB processed. Open Job" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'test_score': [ mean_absolute_error mean_squared_error mean_squared_log_error \\\n", " 0 322.341485 157616.627179 0.009137 \n", " \n", " median_absolute_error r2_score explained_variance \n", " 0 269.412639 0.705594 0.724882 \n", " \n", " [1 rows x 6 columns],\n", " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", " 0 289.682121 136550.318797 0.00878 \n", " \n", " median_absolute_error r2_score explained_variance \n", " 0 212.874686 0.799363 0.81416 \n", " \n", " [1 rows x 6 columns],\n", " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", " 0 325.358522 155218.752974 0.009606 \n", " \n", " median_absolute_error r2_score explained_variance \n", " 0 267.301671 0.777174 0.7782 \n", " \n", " [1 rows x 6 columns],\n", " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", " 0 286.874056 120586.575364 0.007484 \n", " \n", " median_absolute_error r2_score explained_variance \n", " 0 247.656578 0.79281 0.796001 \n", " \n", " [1 rows x 6 columns],\n", " mean_absolute_error mean_squared_error mean_squared_log_error \\\n", " 0 287.989397 145947.465344 0.008447 \n", " \n", " median_absolute_error r2_score explained_variance \n", " 0 186.777549 0.791452 0.798825 \n", " \n", " [1 rows x 6 columns]],\n", " 'fit_time': [18.79181448201416,\n", " 19.092008439009078,\n", " 75.7446747609647,\n", " 17.520530884969048,\n", " 21.157033596013207],\n", " 'score_time': [4.247669544012751,\n", " 6.792615927988663,\n", " 4.502274781989399,\n", " 4.484583999030292,\n", " 4.224339194013737]}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# By using model_selection.cross_validate, the above 2.1 process is automated. The returned scores contains the evaluation results for each fold.\n", "model = linear_model.LinearRegression()\n", "scores = model_selection.cross_validate(model, X, y, cv=5)\n", "scores" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "venv (3.10.14)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }