PyTorch torch.nn Reference Manual
torch.nn.MultiheadAttention is the multi-head attention mechanism module in PyTorch.
It is a core component of the Transformer architecture, allowing the model to simultaneously attend to information from different representation subspaces at different positions.
Function Definition
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, kdim=None, vdim=None, batch_first=True)
Parameters:
embed_dim(int): Dimension of the input embeddings.num_heads(int): Number of attention heads.dropout(float): Dropout probability. Default: 0.kdim(int): Dimension of key vectors. Default: None (same asembed_dim).vdim(int): Dimension of value vectors. Default: None (same asembed_dim).batch_first(bool): If True, the first dimension of input/output is batch. Default: True.
Usage Examples
Example 1: Basic Usage
Example
import torch
import torch.nn as nn
# Create multi-head attention: 512-dim, 8 heads
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8)
# Input: batch=4, sequence length=100, dim=512
query = torch.randn(4,100,512)
key = torch.randn(4,100,512)
value = torch.randn(4,100,512)
# Forward pass
output, attn_weight = mha(query, key, value)
print("Query shape:", query.shape)
print("Output shape:", output.shape)
print("Attention weights shape:", attn_weight.shape)
Example 2: Self-Attention
Example
import torch
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=256, num_heads=4)
# Same input used as Q, K, V (self-attention)
x = torch.randn(2,50,256)
# self-attention: q=k=v=x
output, weights = mha(x, x, x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("Attention weights shape:", weights.shape)
Example 3: Attention with Mask
Example
import torch
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=128, num_heads=4)
# Input
x = torch.randn(1,20,128)
# Create upper triangular mask (for decoder)
mask = torch.triu(torch.ones(20,20), diagonal=1).bool()
output, _ = mha(x, x, x, attn_mask=mask)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("Mask shape:", mask.shape)
Example 4: Complete Transformer Encoder Layer
Example
import torch
import torch.nn as nn
class TransformerLayer(nn.Module):
def __init__ (self, d_model, nhead):
super(TransformerLayer,self).__init__()
self.self_attn= nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm1= nn.LayerNorm(d_model)
self.norm2= nn.LayerNorm(d_model)
self.ffn= nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model)
)
def forward(self, x):
# Self-attention + residual
attn_out, _ =self.self_attn(x, x, x)
x =self.norm1(x + attn_out)
# FFN + residual
ffn_out =self.ffn(x)
x =self.norm2(x + ffn_out)
return x
# Test
layer = TransformerLayer(d_model=512, nhead=8)
x = torch.randn(4,100,512)
output = layer(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
Example 5: Inspecting Attention Weights
Example
import torch
import torch.nn as nn
import numpy as np
mha = nn.MultiheadAttention(embed_dim=64, num_heads=2, batch_first=True)
# Simple video sequence
x = torch.randn(1,5,64)
_, attn = mha(x, x, x)
attn = attn.squeeze(0) # Remove batch dimension
print("Attention weights of the first head (first 3 positions):")
print(attn[0, :3, :3].tolist())
print("Visualization - attention from position 0 to all positions:")
print(np.array2string(attn[0,0].numpy(), precision=2))
Frequently Asked Questions
Q1: How to choose num_heads?
embed_dim must be divisible by num_heads. Common values: 8, 12, 16.
Q2: Why are Q, K, and V matrices needed?
Allows the model to learn different projections, enhancing representational capacity.
Q3: What is key_padding_mask?
Used to mask padding positions to avoid attention computation on them.
Usage Scenarios
- Transformer: Encoder and decoder
- Self-attention models: BERT, GPT
- Sequence modeling: Replacement for RNNs
Tip: When batch_first=True, input shape is (batch, seq, embed_dim).
YouTip