YouTip LogoYouTip

Pytorch Torch Save

PyTorch torch.save and torch.load Functions

Image 1: Pytorch torch Reference Manual Pytorch torch Reference Manual

torch.save and torch.load are functions in PyTorch used for serializing (saving) and deserializing (loading) tensors, models, and other Python objects.

These functions are essential in scenarios such as saving trained models and saving checkpoints to resume training.

Function Definitions

torch.save(obj, f, pickle_module, pickle_protocol)
torch.load(f, map_location, pickle_module, weights_only)

torch.save Parameters:

  • obj: The object to be saved, which can be a tensor, model, dictionary, list, etc.
  • f: File path (string or file object).
  • pickle_module (optional): Module used for serialization.
  • pickle_protocol (optional): Serialization protocol version.

torch.load Parameters:

  • f: File path (string or file object).
  • map_location (optional): Specifies how to map storage to different devices.
  • pickle_module (optional): Module used for deserialization.
  • weights_only (bool, optional): Whether to load only weights without loading Python objects.

Usage Examples

Example 1: Save and Load Tensors

Example

import torch

# Create some tensors

x = torch.tensor([1,2,3,4,5])

y = torch.randn(3,4)

# Save tensors to file

torch.save({'x': x,'y': y},'tensors.pth')

# Load tensors from file

loaded = torch.load('tensors.pth')

print("Loaded data:", loaded)

print("x:", loaded['x'])

print("y:", loaded['y'])

Output:

Loaded data: {'x': tensor([1, 2, 3, 4, 5]), 'y': tensor([[-0.3042, -0.9077, -1.0826,  0.9333],
        [ 0.0551,  0.6728,  0.5942, -0.1522],
        [-0.3744,  0.9239, -0.2104, -0.5239]])}
x: tensor([1, 2, 3, 4, 5])
y: tensor([[-0.3042, -0.9077, -1.0826,  0.9333],
        [ 0.0551,  0.6728,  0.5942, -0.1522],
        [-0.3744,  0.9239, -0.2104, -0.5239]])

Example 2: Save and Load Models

Example

import torch
import torch.nn as nn

# Define a simple neural network

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create model instance

model = SimpleNet()

# Save model (save entire model)

torch.save(model, 'model.pth')

# Load model

loaded_model = torch.load('model.pth')

print("Model saved and loaded")
print(loaded_model)

Example 3: Save Only Model Parameters (Recommended)

Example

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()

# Save only model parameters (recommended approach)

torch.save(model.state_dict(), 'model_weights.pth')

# Create new model and load parameters

new_model = SimpleNet()
new_model.load_state_dict(torch.load('model_weights.pth'))

print("Model parameters saved and loaded")
print(new_model.state_dict().keys())

Output:

Model parameters saved and loaded
odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

It is recommended to save only model parameters (state_dict) instead of saving the entire model. This allows reusing parameters across models with different architectures.

Example 4: Migrating Between CPU and GPU

Example

import torch

# Assume model is saved on GPU

if torch.cuda.is_available():
    x = torch.randn(2, 3, device='cuda')
    torch.save(x, 'tensor_gpu.pth')

# Load on CPU

x_cpu = torch.load('tensor_gpu.pth', map_location='cpu')

print("Loaded on CPU:", x_cpu.device)

Using the map_location parameter allows loading tensors onto different devices.

Notes

  • Saved files typically use .pth or .pt extensions.
  • It is recommended to save only model parameters (state_dict) instead of saving the entire model.
  • Be mindful of security and version compatibility when loading files.

Image 2: Pytorch torch Reference Manual Pytorch torch Reference Manual

← Pytorch Torch Scatter_AddPytorch Torch Row_Stack β†’