YouTip LogoYouTip

Pytorch Torch Nn Transformer

## PyTorch torch.nn.Transformer `torch.nn.Transformer` is a complete, ready-to-use Transformer model implemented in PyTorch. Based on the seminal paper *"Attention Is All You Need"* by Vaswani et al., it contains both encoder and decoder stacks and is designed for sequence-to-sequence (Seq2Seq) tasks. --- ### Function Definition ```python torch.nn.Transformer( d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True ) ``` #### Key Parameters: * **`d_model`**: The number of expected features in the encoder/decoder inputs (default: `512`). * **`nhead`**: The number of heads in the multi-head attention models (default: `8`). * **`num_encoder_layers`**: The number of sub-encoder-layers in the encoder (default: `6`). * **`num_decoder_layers`**: The number of sub-decoder-layers in the decoder (default: `6`). * **`dim_feedforward`**: The dimension of the feedforward network model (default: `2048`). * **`dropout`**: The dropout value (default: `0.1`). * **`activation`**: The activation function of the encoder/decoder intermediate layer. Can be a string (e.g., `'relu'`, `'gelu'`) or a unary callable (default: `'gelu'`). * **`batch_first`**: If `True`, then the input and output tensors are provided as `(batch, seq, feature)`. If `False`, they are `(seq, batch, feature)` (default: `True` in newer PyTorch versions, but always recommended to set explicitly). --- ## Code Examples ### Example 1: Basic Usage This example demonstrates how to initialize a basic Transformer model and pass random tensors through it. ```python import torch import torch.nn as nn # Initialize the Transformer model # We set batch_first=True so that the batch dimension comes first transformer = nn.Transformer(d_model=512, nhead=8, batch_first=True) # Encoder input: (batch_size, sequence_length, d_model) src = torch.randn(32, 10, 512) # Batch size: 32, Sequence length: 10 # Decoder input: (batch_size, sequence_length, d_model) tgt = torch.randn(32, 20, 512) # Batch size: 32, Sequence length: 20 # Forward pass output = transformer(src, tgt) print("Output shape:", output.shape) # Expected output: torch.Size([32, 20, 512]) ``` --- ### Example 2: Simple Machine Translation Model In real-world scenarios, you need to embed token IDs into continuous vectors and scale them before passing them to the Transformer. This example shows how to wrap `torch.nn.Transformer` inside a custom PyTorch module. ```python import torch import torch.nn as nn class TransformerMT(nn.Module): def __init__(self, vocab_size, d_model=256, nhead=4): super(TransformerMT, self).__init__() self.d_model = d_model # Embedding layer to convert token IDs to continuous vectors self.embedding = nn.Embedding(vocab_size, d_model) # Complete Transformer module self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, batch_first=True) # Linear layer to project output back to vocabulary space self.fc = nn.Linear(d_model, vocab_size) def forward(self, src, tgt): # Embed and scale inputs src = self.embedding(src) * (self.d_model ** 0.5) tgt = self.embedding(tgt) * (self.d_model ** 0.5) # Pass through the Transformer out = self.transformer(src, tgt) # Project to vocabulary size return self.fc(out) # Instantiate the model with a vocabulary size of 10,000 model = TransformerMT(vocab_size=10000) # Generate dummy source and target token IDs (Batch size: 32) src = torch.randint(0, 10000, (32, 50)) # Source sequence length: 50 tgt = torch.randint(0, 10000, (32, 40)) # Target sequence length: 40 # Forward pass output = model(src, tgt) print("Output shape:", output.shape) # Expected output: torch.Size([32, 40, 10000]) ``` --- ## Common Use Cases * **Machine Translation**: Translating text from a source language to a target language. * **Text Generation & Summarization**: Generating coherent text sequences based on input prompts. * **Sequence-to-Sequence (Seq2Seq) Tasks**: Any task requiring mapping an input sequence to an output sequence of potentially different lengths (e.g., speech-to-text, conversational AI). --- ## Important Considerations > πŸ’‘ **Pro-Tip: Tensor Dimensions** > * When **`batch_first=True`**, the expected input shape is `(batch, seq, d_model)`. > * When **`batch_first=False`**, the expected input shape is `(seq, batch, d_model)`. > * For autoregressive decoding (like text generation), you must pass a **`tgt_mask`** (causal mask) to the decoder during training to prevent it from "peeking" at future tokens.
← Pytorch Torch Nn TransformerenPytorch Torch Nn Softmax β†’