PyTorch torch.save and torch.load Functions
Pytorch torch Reference Manual
torch.save and torch.load are functions in PyTorch used for serializing (saving) and deserializing (loading) tensors, models, and other Python objects.
These functions are essential in scenarios such as saving trained models and saving checkpoints to resume training.
Function Definitions
torch.save(obj, f, pickle_module, pickle_protocol)
torch.load(f, map_location, pickle_module, weights_only)
torch.save Parameters:
obj: The object to be saved, which can be a tensor, model, dictionary, list, etc.f: File path (string or file object).pickle_module(optional): Module used for serialization.pickle_protocol(optional): Serialization protocol version.
torch.load Parameters:
f: File path (string or file object).map_location(optional): Specifies how to map storage to different devices.pickle_module(optional): Module used for deserialization.weights_only(bool, optional): Whether to load only weights without loading Python objects.
Usage Examples
Example 1: Save and Load Tensors
Example
import torch
# Create some tensors
x = torch.tensor([1,2,3,4,5])
y = torch.randn(3,4)
# Save tensors to file
torch.save({'x': x,'y': y},'tensors.pth')
# Load tensors from file
loaded = torch.load('tensors.pth')
print("Loaded data:", loaded)
print("x:", loaded['x'])
print("y:", loaded['y'])
Output:
Loaded data: {'x': tensor([1, 2, 3, 4, 5]), 'y': tensor([[-0.3042, -0.9077, -1.0826, 0.9333],
[ 0.0551, 0.6728, 0.5942, -0.1522],
[-0.3744, 0.9239, -0.2104, -0.5239]])}
x: tensor([1, 2, 3, 4, 5])
y: tensor([[-0.3042, -0.9077, -1.0826, 0.9333],
[ 0.0551, 0.6728, 0.5942, -0.1522],
[-0.3744, 0.9239, -0.2104, -0.5239]])
Example 2: Save and Load Models
Example
import torch
import torch.nn as nn
# Define a simple neural network
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Create model instance
model = SimpleNet()
# Save model (save entire model)
torch.save(model, 'model.pth')
# Load model
loaded_model = torch.load('model.pth')
print("Model saved and loaded")
print(loaded_model)
Example 3: Save Only Model Parameters (Recommended)
Example
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNet()
# Save only model parameters (recommended approach)
torch.save(model.state_dict(), 'model_weights.pth')
# Create new model and load parameters
new_model = SimpleNet()
new_model.load_state_dict(torch.load('model_weights.pth'))
print("Model parameters saved and loaded")
print(new_model.state_dict().keys())
Output:
Model parameters saved and loaded
odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
It is recommended to save only model parameters (state_dict) instead of saving the entire model. This allows reusing parameters across models with different architectures.
Example 4: Migrating Between CPU and GPU
Example
import torch
# Assume model is saved on GPU
if torch.cuda.is_available():
x = torch.randn(2, 3, device='cuda')
torch.save(x, 'tensor_gpu.pth')
# Load on CPU
x_cpu = torch.load('tensor_gpu.pth', map_location='cpu')
print("Loaded on CPU:", x_cpu.device)
Using the map_location parameter allows loading tensors onto different devices.
Notes
- Saved files typically use
.pthor.ptextensions. - It is recommended to save only model parameters (
state_dict) instead of saving the entire model. - Be mindful of security and version compatibility when loading files.
YouTip