How to save and load machine learning models using Pickle

Machine learning or ML models can take days to train. Pickle save and Pickle load allows you to save them to share or re-run them later, without the need for re-training.

How to save and load machine learning models using Pickle
Pickled cabbage, by The Matter of Food, Unsplash.
6 minutes to read

Machine learning models often take hours or days to run, especially on large datasets with many features. If your machine goes off, you’ll lose your model and you’ll need to re-train it from scratch.

Pickle is a useful Python tool that allows you to save your ML models, to minimise lengthy re-training and allow you to share, commit, and re-load pre-trained machine learning models. Most data scientists working in ML will use Pickle or Joblib to save their ML model for future use.

Pickle is a generic object serialization module that can be used for serializing and deserializing objects. While it’s most commonly associated with saving and reloading trained machine learning models, it can actually be used on any kind of object. Here’s how you can use Pickle to save a trained model to a file and reload it to obtain predictions.

Load packages

First load up the packages. We’re using Pandas, Pickle, the train_test_split package from Scikit-Learn’s model selection module and the XGBClassifier model from XGBoost. You can install anything you don’t have by entering pip3 install package-name into your terminal.

import pandas as pd
import pickle
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier

Load the data

Next we’ll load up the data set. The built-in scikit-learn dataset have a useful return_X_y argument which allows you to assign the data to an X dataframe and a y series.

X, y = load_diabetes(return_X_y=True, as_frame=True)
X.head()
age sex bmi bp s1 s2 s3 s4 s5 s6
0 0.038076 0.050680 0.061696 0.021872 -0.044223 -0.034821 -0.043401 -0.002592 0.019908 -0.017646
1 -0.001882 -0.044642 -0.051474 -0.026328 -0.008449 -0.019163 0.074412 -0.039493 -0.068330 -0.092204
2 0.085299 0.050680 0.044451 -0.005671 -0.045599 -0.034194 -0.032356 -0.002592 0.002864 -0.025930
3 -0.089063 -0.044642 -0.011595 -0.036656 0.012191 0.024991 -0.036038 0.034309 0.022692 -0.009362
4 0.005383 -0.044642 -0.036385 0.021872 0.003935 0.015596 0.008142 -0.002592 -0.031991 -0.046641

Train the model

Now we have our data, we will use train_test_split to create our train and test groups. 30% of the data will be assigned to the test group and will be held out from the training data. We will then configure a basic XGBClassifier model and train it on the X_train and y_train data. While this process is very quick for this tiny model, it can take hours or days on larger datasets, so you’ll not want to repeat it too often.

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
model = XGBClassifier(random_state=0)
model.fit(X_train, y_train)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
              importance_type='gain', interaction_constraints='',
              learning_rate=0.300000012, max_delta_step=0, max_depth=6,
              min_child_weight=1, missing=nan, monotone_constraints='()',
              n_estimators=100, n_jobs=0, num_parallel_tree=1,
              objective='multi:softprob', random_state=0, reg_alpha=0,
              reg_lambda=1, scale_pos_weight=None, subsample=1,
              tree_method='exact', validate_parameters=1, verbosity=None)

Save the model with Pickle

To save the ML model using Pickle all we need to do is pass the model object into the dump() function of Pickle. This will serialize the object and convert it into a “byte stream” that we can save as a file called model.pkl. You can then store, or commit to Git, this model and run it on unseen test data without the need to re-train the model again from scratch.

pickle.dump(model, open('model.pkl', 'wb'))

Load the model from Pickle

To load a saved model from a Pickle file, all you need to do is pass the “pickled” model into the Pickle load() function and it will be deserialized. By assigning this back to a model object, you can then run your original model’s predict() function, pass in some test data and get back an array of predictions.

pickled_model = pickle.load(open('model.pkl', 'rb'))
pickled_model.predict(X_test)
array([288., 308., 144., 158.,  48., 258.,  65., 225.,  70., 185.,  88.,
       178., 116.,  42., 144., 200., 185.,  90.,  70., 268., 178., 202.,
        66.,  52., 220., 242.,  87.,  65.,  48., 178., 104.,  83., 101.,
       139., 164., 131.,  91., 202., 104., 242.,  83., 214., 115., 225.,
       178.,  78.,  90., 200., 183., 242., 124.,  72., 113., 164., 233.,
        91., 139., 143., 150., 296., 221., 200., 140., 111., 225.,  96.,
        83., 268., 220.,  72.,  52., 164., 185., 141., 118., 258.,  97.,
       190., 245., 233., 138., 144.,  52., 277.,  77.,  64.,  59., 144.,
        94., 109.,  71.,  66.,  90., 173.,  75., 129., 232., 237.,  69.,
        97.,  91., 143., 273.,  64., 295., 139., 262., 242., 113.,  97.,
       144.,  53., 115., 185.,  96., 168., 164., 200., 151.,  83., 252.,
        53.,  39., 151., 200., 111., 123.,  85., 259., 248., 242.,  52.,
       180.])

Matt Clarke, Saturday, March 06, 2021

Matt Clarke Matt is an Ecommerce and Marketing Director who uses data science to help in his work. Matt has a Master's degree in Internet Retailing (plus two other Master's degrees in different fields) and specialises in the technical side of ecommerce and marketing.