Machine learning models often take hours or days to run, especially on large datasets with many features. If your machine goes off, you’ll lose your model and you’ll need to re-train it from scratch.
Pickle is a useful Python tool that allows you to save your ML models, to minimise lengthy re-training and allow you to share, commit, and re-load pre-trained machine learning models. Most data scientists working in ML will use Pickle or Joblib to save their ML model for future use.
Pickle is a generic object serialization module that can be used for serializing and deserializing objects. While it’s most commonly associated with saving and reloading trained machine learning models, it can actually be used on any kind of object. Here’s how you can use Pickle to save a trained model to a file and reload it to obtain predictions.
First load up the packages. We’re using Pandas, Pickle, the train_test_split
package from Scikit-Learn’s model
selection module and the XGBClassifier model from XGBoost. You can install anything you don’t have by entering
pip3 install package-name
into your terminal.
import pandas as pd
import pickle
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
Next we’ll load up the data set. The built-in scikit-learn dataset have a useful return_X_y
argument which allows you to
assign the data to an X
dataframe and a y
series.
X, y = load_diabetes(return_X_y=True, as_frame=True)
X.head()
age | sex | bmi | bp | s1 | s2 | s3 | s4 | s5 | s6 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0.038076 | 0.050680 | 0.061696 | 0.021872 | -0.044223 | -0.034821 | -0.043401 | -0.002592 | 0.019908 | -0.017646 |
1 | -0.001882 | -0.044642 | -0.051474 | -0.026328 | -0.008449 | -0.019163 | 0.074412 | -0.039493 | -0.068330 | -0.092204 |
2 | 0.085299 | 0.050680 | 0.044451 | -0.005671 | -0.045599 | -0.034194 | -0.032356 | -0.002592 | 0.002864 | -0.025930 |
3 | -0.089063 | -0.044642 | -0.011595 | -0.036656 | 0.012191 | 0.024991 | -0.036038 | 0.034309 | 0.022692 | -0.009362 |
4 | 0.005383 | -0.044642 | -0.036385 | 0.021872 | 0.003935 | 0.015596 | 0.008142 | -0.002592 | -0.031991 | -0.046641 |
Now we have our data, we will use train_test_split
to create our train and test groups. 30% of the data will be assigned to the test group and will be held out from the training data. We will then configure a basic XGBClassifier model and train it on the X_train
and y_train
data. While this process is very quick for this tiny model, it can take hours or days on larger datasets, so you’ll not want to repeat it too often.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
model = XGBClassifier(random_state=0)
model.fit(X_train, y_train)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
importance_type='gain', interaction_constraints='',
learning_rate=0.300000012, max_delta_step=0, max_depth=6,
min_child_weight=1, missing=nan, monotone_constraints='()',
n_estimators=100, n_jobs=0, num_parallel_tree=1,
objective='multi:softprob', random_state=0, reg_alpha=0,
reg_lambda=1, scale_pos_weight=None, subsample=1,
tree_method='exact', validate_parameters=1, verbosity=None)
To save the ML model using Pickle all we need to do is pass the model
object into the dump()
function of Pickle. This will
serialize the object and convert it into a “byte stream” that we can save as a file called model.pkl
. You can then
store, or commit to Git, this model and run it on unseen test data without the need to re-train the model again
from scratch.
pickle.dump(model, open('model.pkl', 'wb'))
To load a saved model from a Pickle file, all you need to do is pass the “pickled” model into the Pickle load()
function and it will be deserialized. By assigning this back to a model object, you can then run your original model’s predict()
function, pass in some test data and get back an array of predictions.
pickled_model = pickle.load(open('model.pkl', 'rb'))
pickled_model.predict(X_test)
array([288., 308., 144., 158., 48., 258., 65., 225., 70., 185., 88.,
178., 116., 42., 144., 200., 185., 90., 70., 268., 178., 202.,
66., 52., 220., 242., 87., 65., 48., 178., 104., 83., 101.,
139., 164., 131., 91., 202., 104., 242., 83., 214., 115., 225.,
178., 78., 90., 200., 183., 242., 124., 72., 113., 164., 233.,
91., 139., 143., 150., 296., 221., 200., 140., 111., 225., 96.,
83., 268., 220., 72., 52., 164., 185., 141., 118., 258., 97.,
190., 245., 233., 138., 144., 52., 277., 77., 64., 59., 144.,
94., 109., 71., 66., 90., 173., 75., 129., 232., 237., 69.,
97., 91., 143., 273., 64., 295., 139., 262., 242., 113., 97.,
144., 53., 115., 185., 96., 168., 164., 200., 151., 83., 252.,
53., 39., 151., 200., 111., 123., 85., 259., 248., 242., 52.,
180.])
Matt Clarke, Saturday, March 06, 2021