PyTorch Model Saving and Loading |
\\n\\nIn deep learning projects, model saving and loading is a crucial step, primarily for the following reasons:
\\n\\n- \\n
- Resuming interrupted training: When training is unexpectedly interrupted, you can resume from the saved checkpoint. \\n
- Model deployment: Deploying the trained model to production environments. \\n
- Model sharing: Facilitating sharing of model results among team members. \\n
- Transfer learning: Saving pre-trained models for use in other tasks. \\n
- Performance evaluation: Saving models at different training stages for comparison. \\n
\\n\\n
Basic Saving and Loading Methods
\\n\\nSaving the Entire Model
\\n\\nThis is the simplest method, saving both the model architecture and parameters:
\\n\\nimport torch\\nimport torchvision.models as models\\n\\n# Create and train a model\\nmodel = models.resnet18(pretrained=True)\\n\\n# ... Training code ...\\n\\n# Save the entire model\\ntorch.save(model, 'model.pth')\\n\\n# Load the entire model\\nloaded_model = torch.load('model.pth')\\n\\nAdvantages:
\\n\\n- \\n
- Code is simple and intuitive. \\n
- The complete model structure is preserved. \\n
Disadvantages:
\\n\\n- \\n
- Larger file size. \\n
- Dependent on the model class definition. \\n
Saving Only Model Parameters (Recommended)
\\n\\nThe more recommended approach is to save only the modelβs state dictionary (state_dict):
# Save model parameters\\ntorch.save(model.state_dict(), 'model_weights.pth')\\n\\n# Load model parameters\\nmodel = models.resnet18() # Must first create a model with the same architecture\\nmodel.load_state_dict(torch.load('model_weights.pth'))\\nmodel.eval() # Set to evaluation mode\\n\\nAdvantages:
\\n\\n- \\n
- Smaller file size. \\n
- More flexible; can be loaded into different architectures. \\n
- Better compatibility. \\n
\\n\\n
Saving and Loading Training State
\\n\\nIn real-world projects, we often need to save additional information such as optimizer state and epoch numbers:
\\n\\n# Save checkpoint\\ncheckpoint = {\\n 'epoch': epoch,\\n 'model_state_dict': model.state_dict(),\\n 'optimizer_state_dict': optimizer.state_dict(),\\n 'loss': loss,\\n # You can add other information that needs to be saved\\n}\\ntorch.save(checkpoint, 'checkpoint.pth')\\n\\n# Load checkpoint\\ncheckpoint = torch.load('checkpoint.pth')\\nmodel.load_state_dict(checkpoint['model_state_dict'])\\noptimizer.load_state_dict(checkpoint['optimizer_state_dict'])\\nepoch = checkpoint['epoch']\\nloss = checkpoint['loss']\\nmodel.eval() # Or model.train() Depends on your needs\\n\\n\\n\\n
Loading Models Across Devices
\\n\\nCPU/GPU Compatibility Handling
\\n\\n# Specify map_location when saving\\ntorch.save(model.state_dict(), 'model_weights.pth')\\n\\n# Load to CPU (when the model was trained on GPU)\\ndevice = torch.device('cpu')\\nmodel.load_state_dict(torch.load('model_weights.pth', map_location=device))\\n\\n# Load to GPU\\ndevice = torch.device('cuda')\\nmodel.load_state_dict(torch.load('model_weights.pth', map_location=device))\\nmodel.to(device)\\n\\nLoading Models Trained with Multiple GPUs
\\n\\n# Save multi-GPU model\\ntorch.save(model.module.state_dict(), 'multigpu_model.pth')\\n\\n# Load to a single GPU\\nmodel = ModelClass()\\nmodel.load_state_dict(torch.load('multigpu_model.pth'))\\n\\n\\n\\n
Model Conversion and Compatibility
\\n\\nPyTorch Version Compatibility
\\n\\n# Specify _use_new_zipfile_serialization when saving=TrueFor better compatibility\\ntorch.save(model.state_dict(), 'model.pth', _use_new_zipfile_serialization=True)\\n\\nConverting to TorchScript
\\n\\n# Convert the model to TorchScript format\\nscripted_model = torch.jit.script(model)\\ntorch.jit.save(scripted_model, 'model_scripted.pt')\\n\\n# Load TorchScript model\\nloaded_script = torch.jit.load('model_scripted.pt')\\n\\n\\n\\n
Best Practices and Common Issues
\\n\\nBest Practices
\\n\\n- \\n
- Naming conventions: Use meaningful filenames, e.g.,
resnet18_epoch50.pth. \\n - Regular checkpointing: Save checkpoints every few epochs. \\n
- Verify loading: Immediately test loading after saving. \\n
- Document models: Record model architecture and training parameters. \\n
- Version control: Include model files in version control systems. \\n
Solutions to Common Issues
\\n\\nIssue 1: Missing key(s) in state_dict
\\n\\nSolution: Ensure the model architecture matches exactly, or use strict=False:
model.load_state_dict(torch.load('model.pth'), strict=False)\\n\\nIssue 2: CUDA out of memory
\\n\\nSolution: Load to CPU first:
\\n\\nmodel.load_state_dict(torch.load('model.pth', map_location='cpu'))\\n\\nIssue 3: Unable to load an older model version
\\n\\nSolution: Try loading in different PyTorch versions, or convert the model format.
\\n\\n\\n\\n
Practical Application Examples
\\n\\nImage Classification Model Saving and Loading Workflow
\\n\\nComplete Code Example
\\n\\nimport torch\\nimport torch.nn as nn\\nimport torch.optim as optim\\n\\n# Define a simple model\\nclass SimpleModel(nn.Module):\\n def __init__(self):\\n super(SimpleModel, self).__init__()\\n self.fc = nn.Linear(10, 2)\\n\\n def forward(self, x):\\n return self.fc(x)\\n\\n# Initialization\\nmodel = SimpleModel()\\noptimizer = optim.SGD(model.parameters(), lr=0.01)\\ncriterion = nn.CrossEntropyLoss()\\n\\n# Simulate training process\\nfor epoch in range(5):\\n # Simulate training steps\\n inputs = torch.randn(32, 10)\\n labels = torch.randint(0, 2, (32,))\\n\\n optimizer.zero_grad()\\n outputs = model(inputs)\\n loss = criterion(outputs, labels)\\n loss.backward()\\n optimizer.step()\\n\\n # per2itemsepochSave a checkpoint\\n if epoch % 2 == 0:\\n checkpoint = {\\n 'epoch': epoch,\\n 'model_state_dict': model.state_dict(),\\n 'optimizer_state_dict': optimizer.state_dict(),\\n 'loss': loss.item(),\\n }\\n torch.save(checkpoint, f'checkpoint_epoch{epoch}.pth')\\n print(f'Checkpoint saved at epoch {epoch}')\\n\\n# Final save\\ntorch.save(model.state_dict(), 'final_model.pth')\\n\\n# Loading example\\nloaded_model = SimpleModel()\\nloaded_model.load_state_dict(torch.load('final_model.pth'))\\nloaded_model.eval()\\n\\n# Test the loaded model\\ntest_input = torch.randn(1, 10)\\nwith torch.no_grad():\\n output = loaded_model(test_input)\\nprint(f'Test output: {output}')
YouTip