Open In App

Save and Load Machine Learning Models in Python with scikit-learn

Last Updated : 02 Aug, 2022
Improve
Improve
Like Article
Like
Save
Share
Report

In this article, let’s learn how to save and load your machine learning model in Python with scikit-learn in this tutorial.

Once we create a machine learning model, our job doesn’t end there. We can save the model to use in the future. We can either use the pickle or the joblib library for this purpose. The dump method is used to create the model and the load method is used to load and use the dumped model. Now let’s demonstrate how to do it. The save and load methods of both pickle and joblib have the same parameters.

syntax of dump() method:

pickle.dump(obj, file, protocol=None, *, fix_imports=True, buffer_callback=None)

parameters:

  • obj: The pickled Python object.
  • file: The pickled object will be written to a file or buffer.
  • fix_imports: When supplied, the method dump() will determine if the pickling procedure should be compatible with Python version 2 or not based on the value for the pickle protocol option. True is the default value. Only a name-value pair should be used with this default parameter.

syntax of load() method:

pickle.load(file, *, fix_imports=True, encoding=’ASCII’, errors=’strict’, buffers=None)

The load() method Returns the rebuilt object hierarchy indicated therein after reading the pickled representation of an object from the open file object file.

Example 1: Saving and loading models using pickle

Python’s default method for serializing objects is a pickle. Your machine learning algorithms can be serialized/encoded using the pickling process, and the serialized format can then be saved to a file. When you want to deserialize/decode your model and utilize it to produce new predictions, you can load this file later. The training of a linear regression model is shown in the example that follows. In the below example we fit the data with train data and the dump() method is used to create a  model. The dump method takes in the machine learning model and a file is given. The test data is used to find predictions after loading the model using the load() method. root mean square error metric is used to evaluate the predictions of the model.

Python3




# import packages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn import metrics
import pickle
  
# import the dataset
dataset = pd.read_csv('headbrain1.csv')
  
X = dataset.iloc[:, : -1].values
Y = dataset.iloc[:, -1].values
  
# train test split
X_train, X_test, y_train, y_test = train_test_split(
    X, Y, test_size=0.2, random_state=0)
  
# create a linear regression model
regressor = LinearRegression()
regressor.fit(X_train, y_train)
  
  
# save the model
filename = 'linear_model.sav'
pickle.dump(regressor, open(filename, 'wb'))
  
# load the model
load_model = pickle.load(open(filename, 'rb'))
  
y_pred = load_model.predict(X_test)
print('root mean squared error : ', np.sqrt(
    metrics.mean_squared_error(y_test, y_pred)))


Output:

root mean squared error :  72.11529287182815

Example 2: Saving and loading models using joblib

The SciPy ecosystem includes Joblib, which offers tools for pipelining Python jobs. It offers tools for effectively saving and loading Python objects that employ NumPy data structures. This can be helpful for machine learning algorithms that need to store the complete dataset or have a lot of parameters. let’s look at a simple example where we save and load a linear regression model. The same steps are repeated while using the joblib library.

Python3




# import packages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn import metrics
import joblib
  
# import the dataset
dataset = pd.read_csv('headbrain1.csv')
  
X = dataset.iloc[:, : -1].values
Y = dataset.iloc[:, -1].values
  
# train test split
X_train, X_test, y_train, y_test = train_test_split(
    X, Y, test_size=0.2, random_state=0)
  
# create a linear regression model
regressor = LinearRegression()
regressor.fit(X_train, y_train)
  
  
# save the model
filename = 'linear_model_2.sav'
joblib.dump(regressor, open(filename, 'wb'))
  
# load the model
load_model = joblib.load(open(filename, 'rb'))
  
y_pred = load_model.predict(X_test)
print('root mean squared error : ', np.sqrt(
    metrics.mean_squared_error(y_test, y_pred)))


Output:

root mean squared error :  72.11529287182815


Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads