PyTorch torch.nn.LSTM function | Rookie Tutorial
PyTorch torch.nn Reference Manual
torch.nn.LSTM is the module in PyTorch used for Long Short-Term Memory networks.
LSTM is a special type of recurrent neural network capable of learning long-term dependencies, widely used in sequence modeling tasks.
Function Definition
torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, bidirectional=False)
Parameter Explanation:
- input_size (int): Dimension of input features.
- hidden_size (int): Dimension of hidden state.
- num_layers (int): Number of LSTM layers. Default is 1.
- bias (bool): Whether to use bias. Default is True.
- batch_first (bool): If True, input and output shapes are (batch, seq, feature). Default is True.
- dropout (float): Dropout rate applied to non-last layers. Default is 0.
- bidirectional (bool): Whether to use bidirectional LSTM. Default is False.
Input and Output
Input:
- input: Tensor with shape (batch, seq_len, input_size)
- h_0: Initial hidden state, shape (num_layers * num_directions, batch, hidden_size)
- c_0: Initial cell state, shape (num_layers * num_directions, batch, hidden_size)
Output:
- output: Output from the last hidden layer, shape (batch, seq_len, num_directions * hidden_size)
- h_n: Final hidden states for all layers
- c_n: Final cell states for all layers
Usage Examples
Example 1: Basic Usage
Create and use an LSTM:
import torch
import torch.nn as nn
# Create LSTM: input dimension 256, hidden dimension 512, 2 layers
lstm = nn.LSTM(input_size=256, hidden_size=512, num_layers=2, batch_first=True)
# Create input: batch=4, sequence length=10, input dimension=256
input_tensor = torch.randn(4, 10, 256)
# Forward pass
output, (h_n, c_n) = lstm(input_tensor)
print("Input shape:", input_tensor.shape)
print("Output shape:", output.shape) # (4, 10, 512)
print("Hidden state shape:", h_n.shape) # (2, 4, 512)
print("Cell state shape:", c_n.shape) # (2, 4, 512)
Example 2: Bidirectional LSTM
Use bidirectional LSTM to capture bidirectional context:
import torch
import torch.nn as nn
# Bidirectional LSTM
bilstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = bilstm(input_tensor)
print("Bidirectional LSTM output shape:", output.shape) # (4, 10, 512) = 256*2
print("Hidden state shape:", h_n.shape) # (4, 4, 256) = 2 layers * 2 directions
print("Last layer hidden state:", h_n[-2:, :, :].shape) # Forward and backward
Example 3: Initialize Hidden State Manually
Manually initialize hidden states:
import torch
import torch.nn as nn
lstm = nn.LSTM(input_size=256, hidden_size=512, batch_first=True)
# Manually create initial hidden states
batch_size = 4
num_layers = 2
hidden_size = 512
h_0 = torch.zeros(num_layers, batch_size, hidden_size)
c_0 = torch.zeros(num_layers, batch_size, hidden_size)
# Pass in initial states
input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = lstm(input_tensor, (h_0, c_0))
print("Custom initial state used successfully")
print("Output shape:", output.shape)
Example 4: Complete Sentiment Classification Model
Text classification based on LSTM:
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
super(LSTMClassifier, self).__init__()
# Embedding layer
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# LSTM layer
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.3
)
# Fully connected classification layer
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
# x: (batch, seq_len)
embedded = self.embedding(x) # (batch, seq_len, embed_dim)
# LSTM output
output, (hidden, cell) = self.lstm(embedded)
# Concatenate the final hidden states from both directions
# hidden: (4, batch, hidden_dim) - 2 layers * 2 directions
hidden = torch.cat([hidden, hidden], dim=1) # (batch, hidden_dim*2)
# Classification
logits = self.fc(hidden)
return logits
# Instantiate model
vocab_size = 10000
model = LSTMClassifier(vocab_size=vocab_size, embed_dim=128, hidden_dim=128, num_classes=2)
# Test input: batch=8, sequence length=50
input_ids = torch.randint(1, vocab_size, (8, 50))
output = model(input_ids)
print("Model structure:")
print(model)
print("Input shape:", input_ids.shape)
print("Output shape:", output.shape) # (8, 2)
Example 5: Stacked Multi-layer LSTM
Deep LSTM network:
import torch
import torch.nn as nn
# 4-layer stacked LSTM with dropout
deep_lstm = nn.LSTM(
input_size=256,
hidden_size=512,
num_layers=4,
batch_first=True,
dropout=0.4 # Dropout between layers
)
input_tensor = torch.randn(2, 20, 256)
output, (h_n, c_n) = deep_lstm(input_tensor)
print("4-layer LSTM output shape:", output.shape)
print("Hidden state shape (4 layers):", h_n.shape)
print("Cell state shape (4 layers):", c_n.shape)
Concept of LSTM Gates
LSTM controls information flow through three gates:
- Forget Gate: Determines how much information from the previous time step to retain
- Input Gate: Determines how much new information to add
- Output Gate: Determines how much information to output
Frequently Asked Questions
Q1: What does batch_first=True mean?
The first dimension of input and output tensors is batch_size. If False, the first dimension is sequence length.
Q2: When to use bidirectional LSTM?
For tasks requiring bidirectional context such as sequence labeling, sentiment analysis. Machine translation commonly uses encoder-decoder architecture.
Q3: How to choose hidden layer size?
Typically between 128β512; adjust based on task complexity and data volume. Too small leads to underfitting; too large may cause overfitting.
Application Scenarios
nn.LSTM is mainly used in the following scenarios:
- Natural Language Processing: Text classification, named entity recognition
- Time Series Prediction: Stock prediction, speech recognition
- Sequence-to-Sequence Tasks: Machine translation, text generation
Tip: When using
bidirectional=True, the output dimension becomeshidden_size * 2.
YouTip