Pytorch Torch Nn Transformer
## PyTorch torch.nn.Transformer
`torch.nn.Transformer` is a complete, ready-to-use Transformer model implemented in PyTorch. Based on the seminal paper *"Attention Is All You Need"* by Vaswani et al., it contains both encoder and decoder stacks and is designed for sequence-to-sequence (Seq2Seq) tasks.
---
### Function Definition
```python
torch.nn.Transformer(
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation='gelu',
batch_first=True
)
```
#### Key Parameters:
* **`d_model`**: The number of expected features in the encoder/decoder inputs (default: `512`).
* **`nhead`**: The number of heads in the multi-head attention models (default: `8`).
* **`num_encoder_layers`**: The number of sub-encoder-layers in the encoder (default: `6`).
* **`num_decoder_layers`**: The number of sub-decoder-layers in the decoder (default: `6`).
* **`dim_feedforward`**: The dimension of the feedforward network model (default: `2048`).
* **`dropout`**: The dropout value (default: `0.1`).
* **`activation`**: The activation function of the encoder/decoder intermediate layer. Can be a string (e.g., `'relu'`, `'gelu'`) or a unary callable (default: `'gelu'`).
* **`batch_first`**: If `True`, then the input and output tensors are provided as `(batch, seq, feature)`. If `False`, they are `(seq, batch, feature)` (default: `True` in newer PyTorch versions, but always recommended to set explicitly).
---
## Code Examples
### Example 1: Basic Usage
This example demonstrates how to initialize a basic Transformer model and pass random tensors through it.
```python
import torch
import torch.nn as nn
# Initialize the Transformer model
# We set batch_first=True so that the batch dimension comes first
transformer = nn.Transformer(d_model=512, nhead=8, batch_first=True)
# Encoder input: (batch_size, sequence_length, d_model)
src = torch.randn(32, 10, 512) # Batch size: 32, Sequence length: 10
# Decoder input: (batch_size, sequence_length, d_model)
tgt = torch.randn(32, 20, 512) # Batch size: 32, Sequence length: 20
# Forward pass
output = transformer(src, tgt)
print("Output shape:", output.shape)
# Expected output: torch.Size([32, 20, 512])
```
---
### Example 2: Simple Machine Translation Model
In real-world scenarios, you need to embed token IDs into continuous vectors and scale them before passing them to the Transformer. This example shows how to wrap `torch.nn.Transformer` inside a custom PyTorch module.
```python
import torch
import torch.nn as nn
class TransformerMT(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4):
super(TransformerMT, self).__init__()
self.d_model = d_model
# Embedding layer to convert token IDs to continuous vectors
self.embedding = nn.Embedding(vocab_size, d_model)
# Complete Transformer module
self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, batch_first=True)
# Linear layer to project output back to vocabulary space
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt):
# Embed and scale inputs
src = self.embedding(src) * (self.d_model ** 0.5)
tgt = self.embedding(tgt) * (self.d_model ** 0.5)
# Pass through the Transformer
out = self.transformer(src, tgt)
# Project to vocabulary size
return self.fc(out)
# Instantiate the model with a vocabulary size of 10,000
model = TransformerMT(vocab_size=10000)
# Generate dummy source and target token IDs (Batch size: 32)
src = torch.randint(0, 10000, (32, 50)) # Source sequence length: 50
tgt = torch.randint(0, 10000, (32, 40)) # Target sequence length: 40
# Forward pass
output = model(src, tgt)
print("Output shape:", output.shape)
# Expected output: torch.Size([32, 40, 10000])
```
---
## Common Use Cases
* **Machine Translation**: Translating text from a source language to a target language.
* **Text Generation & Summarization**: Generating coherent text sequences based on input prompts.
* **Sequence-to-Sequence (Seq2Seq) Tasks**: Any task requiring mapping an input sequence to an output sequence of potentially different lengths (e.g., speech-to-text, conversational AI).
---
## Important Considerations
> π‘ **Pro-Tip: Tensor Dimensions**
> * When **`batch_first=True`**, the expected input shape is `(batch, seq, d_model)`.
> * When **`batch_first=False`**, the expected input shape is `(seq, batch, d_model)`.
> * For autoregressive decoding (like text generation), you must pass a **`tgt_mask`** (causal mask) to the decoder during training to prevent it from "peeking" at future tokens.
YouTip