How to save and load machine learning model in PyTorch
The main question running through your mind is this: Why should I save and load my machine learning model? The reason you’re asking this is that you’ve probably worked with small datasets and simple models. But the real world runs on big data and deep neural networks. Saving your model is a smart and useful PyTorch feature and one you should incorporate in your future projects.
When it comes to saving a model, you have two options: save the best or save the latest. Your end goal in any project you take up is quite simple: Get the best model. So, when you decide to save the best, PyTorch updates the state on each model you test and ensures that the output model is the best one. Now you’re probably thinking “Great! Why do I even need the latest option?”. Well, because life’s unpredictable. Your model may just randomly stop iterating. You could lose connection to the working environment. Your laptop could burst up in flames( Okay, maybe not, but you get the drift). Saving the latest state of your model allows you to load it later and continue from where you left off, which is very useful when you’re dealing with large datasets or complex models.
Saving your model in PyTorch
So now the question is how do you go about actually saving your model. Believe it or not, it’s just a one-line code:
#Saving a checkpoint torch.save(state, path)
Quite simple, right? Let’s break into the parameters. I’ll start with the easy one: The path is simply where you want to save your model. I would recommend you do this on the Cloud when it comes to large scale industrial projects. When it comes to the state, this parameter depends on what you really want to save.
- If you just want to save the model parameters, use model.state_dict().
- If you want to save the entire model, where the serialized data is bound to the specific classes and the exact directory structure used is remembered, just use model.
- If you want to save the model to resume training later, you need to save more than just the model: You also need to save the state of the optimizer, epochs, score, etc. This is where you declare the state as a dictionary as shown in the code snippet below:
state = {
    'state_dict': model.state_dict(),
    'epoch': epoch,
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)Loading your model
Great! Now that you’ve saved the model, how do you go about loading it?
#Loading a checkpoint state = torch.load(path)
Simple again. After writing the above code, you need to do a bit more work.
- If you want to load just the model parameters, use model.load_state_dict( state ).
- If you want to load the entire model, no extra work, you’re done!
- If you want to load the model that you had saved to resume training later: It’s basically the same as above, but the state is a dictionary this time around. Here’s the code to work it to help you out with that:
model.load_state_dict(state['state_dict']) optimizer.load_state_dict(state['optimizer']) epoch.load_state_dict(state['epoch'])
And there you have it, folks. This tutorial has given you the tools you need for saving and loading a model in PyTorch, a useful feature that you should definitely use in your projects.
Check out the official PyTorch tutorial on saving and loading models: https://pytorch.org/tutorials/beginner/saving_loading_models.html
Learn: Graphs, Automatic Differentiation and Autograd in PyTorch Python
Leave a Reply