BigFrames ML Cross-Vaidation#

This demo shows how to do cross validation in bigframes.ml

1. Prepare Data#

import bigframes.pandas as bpd
# read and filter out unavailable data
df = bpd.read_gbq("bigframes-dev.bqml_tutorial.penguins")
df = df.dropna()
df
Query job aa2b9845-0e66-4f42-a360-ffe03215caf6 is DONE. 0 Bytes processed. Open Job
Query job fe2bc354-672e-4d08-b969-bb2ede299fca is DONE. 28.9 kB processed. Open Job
Query job 8d16fa20-391f-4917-86fc-1a595dba3fc6 is DONE. 33.6 kB processed. Open Job
species island culmen_length_mm culmen_depth_mm flipper_length_mm body_mass_g sex
0 Gentoo penguin (Pygoscelis papua) Biscoe 45.2 16.4 223.0 5950.0 MALE
1 Gentoo penguin (Pygoscelis papua) Biscoe 46.5 14.5 213.0 4400.0 FEMALE
2 Adelie Penguin (Pygoscelis adeliae) Biscoe 37.7 16.0 183.0 3075.0 FEMALE
3 Gentoo penguin (Pygoscelis papua) Biscoe 46.4 15.6 221.0 5000.0 MALE
4 Gentoo penguin (Pygoscelis papua) Biscoe 46.1 13.2 211.0 4500.0 FEMALE
5 Adelie Penguin (Pygoscelis adeliae) Torgersen 43.1 19.2 197.0 3500.0 MALE
6 Gentoo penguin (Pygoscelis papua) Biscoe 45.2 15.8 215.0 5300.0 MALE
7 Adelie Penguin (Pygoscelis adeliae) Dream 36.2 17.3 187.0 3300.0 FEMALE
8 Chinstrap penguin (Pygoscelis antarctica) Dream 46.0 18.9 195.0 4150.0 FEMALE
9 Gentoo penguin (Pygoscelis papua) Biscoe 54.3 15.7 231.0 5650.0 MALE
11 Adelie Penguin (Pygoscelis adeliae) Torgersen 39.5 17.4 186.0 3800.0 FEMALE
12 Gentoo penguin (Pygoscelis papua) Biscoe 42.7 13.7 208.0 3950.0 FEMALE
13 Adelie Penguin (Pygoscelis adeliae) Biscoe 41.0 20.0 203.0 4725.0 MALE
14 Gentoo penguin (Pygoscelis papua) Biscoe 48.5 15.0 219.0 4850.0 FEMALE
15 Chinstrap penguin (Pygoscelis antarctica) Dream 49.6 18.2 193.0 3775.0 MALE
16 Gentoo penguin (Pygoscelis papua) Biscoe 50.8 17.3 228.0 5600.0 MALE
17 Gentoo penguin (Pygoscelis papua) Biscoe 46.2 14.1 217.0 4375.0 FEMALE
18 Adelie Penguin (Pygoscelis adeliae) Biscoe 38.8 17.2 180.0 3800.0 MALE
19 Chinstrap penguin (Pygoscelis antarctica) Dream 51.0 18.8 203.0 4100.0 MALE
20 Gentoo penguin (Pygoscelis papua) Biscoe 42.9 13.1 215.0 5000.0 FEMALE
21 Gentoo penguin (Pygoscelis papua) Biscoe 50.4 15.3 224.0 5550.0 MALE
22 Gentoo penguin (Pygoscelis papua) Biscoe 49.0 16.1 216.0 5550.0 MALE
23 Gentoo penguin (Pygoscelis papua) Biscoe 43.4 14.4 218.0 4600.0 FEMALE
24 Gentoo penguin (Pygoscelis papua) Biscoe 45.0 15.4 220.0 5050.0 MALE
25 Gentoo penguin (Pygoscelis papua) Biscoe 47.5 14.0 212.0 4875.0 FEMALE

25 rows × 7 columns

[334 rows x 7 columns in total]
# Select X and y from the dataset
X = df[
        [
            "species",
            "island",
            "culmen_length_mm",
        ]
    ]
y = df["body_mass_g"]

2.1 Define KFold class and Train/Test for Each Fold (Manual Approach)#

from bigframes.ml import model_selection, linear_model
# Create KFold instance, n_splits defines how many folds the data will split. For example, n_split=5 will split the entire dataset into 5 pieces. 
# In each fold, 4 pieces will be used for training, and the other piece will be used for evaluation.  
kf = model_selection.KFold(n_splits=5)
for X_train, X_test, y_train, y_test in kf.split(X, y):
    model = linear_model.LinearRegression()
    model.fit(X_train, y_train)
    score = model.score(X_test, y_test)

    print(score)
Query job 9ce9fb43-306d-46e9-bbe5-d98ee55143bd is DONE. 37.0 kB processed. Open Job
Query job 8c86156d-ee97-4f66-9dc1-db15ff3d8e8e is DONE. 16.4 kB processed. Open Job
Query job b8f2b382-b938-4dff-8bdb-129703ade285 is DONE. 37.3 kB processed. Open Job
   mean_absolute_error  mean_squared_error  mean_squared_log_error  \
0            297.36838       148892.914876                0.009057   

   median_absolute_error  r2_score  explained_variance  
0             238.424052  0.814613            0.816053  

[1 rows x 6 columns]
Query job ec2968f3-1713-4617-8a26-6fe4267f8061 is DONE. 37.0 kB processed. Open Job
Query job c7a1b80f-26f5-41b1-bcdc-b276af141671 is DONE. 16.4 kB processed. Open Job
Query job 82054991-c22f-41b3-9802-f16919949e26 is DONE. 37.3 kB processed. Open Job
   mean_absolute_error  mean_squared_error  mean_squared_log_error  \
0             307.6149       139013.303482                0.007907   

   median_absolute_error  r2_score  explained_variance  
0             266.589811  0.782835            0.794297  

[1 rows x 6 columns]
Query job 3e5ae019-7c5b-44ea-8392-85145fdb6802 is DONE. 37.0 kB processed. Open Job
Query job c35dfd28-504d-4d12-b039-da890b9cb51d is DONE. 16.5 kB processed. Open Job
Query job 29ac1bb3-f864-400e-8cac-0b4c7f78ebcd is DONE. 37.3 kB processed. Open Job
   mean_absolute_error  mean_squared_error  mean_squared_log_error  \
0           348.412701       180661.063512                 0.01125   

   median_absolute_error  r2_score  explained_variance  
0              313.29406  0.744053             0.74537  

[1 rows x 6 columns]
Query job d90f5938-2894-4c93-8691-21162a2fca4c is DONE. 37.0 kB processed. Open Job
Query job 4c6328b3-2d3f-42bb-9f83-4f8c84773c95 is DONE. 16.4 kB processed. Open Job
Query job 8a885a6a-d3ad-4569-80ce-4f57d9b86105 is DONE. 37.3 kB processed. Open Job
   mean_absolute_error  mean_squared_error  mean_squared_log_error  \
0           309.991882       151820.705254                0.008898   

   median_absolute_error  r2_score  explained_variance  
0             212.758708  0.694001            0.694287  

[1 rows x 6 columns]
Query job d1e60370-11c8-4f49-a8d5-85417662aa51 is DONE. 37.0 kB processed. Open Job
Query job d8e8712a-6347-4725-a27d-49810d4acc1c is DONE. 16.5 kB processed. Open Job
Query job 6a0ebaa6-5572-404f-a41d-b90e2c65d948 is DONE. 37.3 kB processed. Open Job
   mean_absolute_error  mean_squared_error  mean_squared_log_error  \
0           256.569216       103495.042886                0.006605   

   median_absolute_error  r2_score  explained_variance  
0             222.940815  0.818589            0.832344  

[1 rows x 6 columns]

2.2 Use cross_validate Function to Do Cross Validation (Automatic Approach)#

# By using model_selection.cross_validate, the above 2.1 process is automated. The returned scores contains the evaluation results for each fold.
model = linear_model.LinearRegression()
scores = model_selection.cross_validate(model, X, y, cv=5)
scores
Query job 5bdcd65d-7d72-4094-be3a-cf67a1787cf4 is DONE. 37.0 kB processed. Open Job
Query job bb0504b2-b656-4a08-9bf8-dcab0d188022 is DONE. 16.4 kB processed. Open Job
Query job 8c5c4b66-9a14-455a-a3f5-99f0f522713f is DONE. 37.3 kB processed. Open Job
Query job 9c9b81de-35b6-4561-8881-57da8b73cc7f is DONE. 37.0 kB processed. Open Job
Query job b781f1aa-6572-49e5-ab8d-f1908b497a1c is DONE. 16.4 kB processed. Open Job
Query job 41a2a58e-0289-4d58-8e39-de286f2a91fb is DONE. 37.3 kB processed. Open Job
Query job 7ee839a9-f77c-49b0-844e-8eecc1647b97 is DONE. 37.0 kB processed. Open Job
Query job a317d488-8589-4faa-940b-e59af91caf4d is DONE. 16.5 kB processed. Open Job
Query job 2de96ea8-519a-4976-a641-eb26a4bd38fb is DONE. 37.3 kB processed. Open Job
Query job 41a7d5a0-c76b-4ef3-a3da-d4d5a2ebbb0e is DONE. 37.0 kB processed. Open Job
Query job 9e82ddc9-8461-4644-ba34-957a7426ff8e is DONE. 16.4 kB processed. Open Job
Query job 0fa84d07-fdfa-41c9-b601-9326a94f3a09 is DONE. 37.3 kB processed. Open Job
Query job d4495568-f1b5-431b-b892-4fc7dcbccfd5 is DONE. 37.0 kB processed. Open Job
Query job af1e6460-3078-4a8b-8992-9e7df9dcfbb3 is DONE. 16.5 kB processed. Open Job
Query job f14401bf-fd80-401a-a61d-52614fba1ca7 is DONE. 37.3 kB processed. Open Job
{'test_score': [   mean_absolute_error  mean_squared_error  mean_squared_log_error  \
  0           322.341485       157616.627179                0.009137   
  
     median_absolute_error  r2_score  explained_variance  
  0             269.412639  0.705594            0.724882  
  
  [1 rows x 6 columns],
     mean_absolute_error  mean_squared_error  mean_squared_log_error  \
  0           289.682121       136550.318797                 0.00878   
  
     median_absolute_error  r2_score  explained_variance  
  0             212.874686  0.799363             0.81416  
  
  [1 rows x 6 columns],
     mean_absolute_error  mean_squared_error  mean_squared_log_error  \
  0           325.358522       155218.752974                0.009606   
  
     median_absolute_error  r2_score  explained_variance  
  0             267.301671  0.777174              0.7782  
  
  [1 rows x 6 columns],
     mean_absolute_error  mean_squared_error  mean_squared_log_error  \
  0           286.874056       120586.575364                0.007484   
  
     median_absolute_error  r2_score  explained_variance  
  0             247.656578   0.79281            0.796001  
  
  [1 rows x 6 columns],
     mean_absolute_error  mean_squared_error  mean_squared_log_error  \
  0           287.989397       145947.465344                0.008447   
  
     median_absolute_error  r2_score  explained_variance  
  0             186.777549  0.791452            0.798825  
  
  [1 rows x 6 columns]],
 'fit_time': [18.79181448201416,
  19.092008439009078,
  75.7446747609647,
  17.520530884969048,
  21.157033596013207],
 'score_time': [4.247669544012751,
  6.792615927988663,
  4.502274781989399,
  4.484583999030292,
  4.224339194013737]}