Pytorch Attention
Attention Mechanism is one of the most important concepts in deep learning.\n\nAttention mechanism allows models to learn to "focus" on the most relevant parts of the input, achieving great success in natural language processing, computer vision, and other fields.\n\nThis section provides a detailed introduction to the core principles of attention mechanism, PyTorch implementation, and various attention variants.\n\n> **Applicable Version:** The code in this article is written based on PyTorch 2.0+. The `batch_first` parameter of `nn.MultiheadAttention` was introduced in PyTorch 1.9; earlier versions require manual dimension handling.\n\n* * *\n\n## 1. Attention Mechanism Basics\n\n### 1.1 Why Attention is Needed\n\nTraditional Sequence-to-Sequence (Seq2Seq) models have a fundamental problem: the Encoder needs to compress all information into a fixed-length vector. For long sequences, this vector becomes an information bottleneckβthe longer the sentence, the more severe the information loss.\n\nThe core idea of attention mechanism is: allow the Decoder to "see" all hidden states of the Encoder when generating each output, and dynamically allocate different attention weights based on the current context. This is like human translation, where translating each word involves looking back at the corresponding part of the original text.\n\n### 1.2 The Essence of Attention Mechanism\n\nAttention mechanism can be viewed as a weighted sum operation. Given Query, Key, and Value, it calculates the similarity between Query and each Key to allocate weights, then performs weighted sum on Value:\n\n$$\n\\text{Attention} \\left(\\right. Q , K , V \\left.\\right) = \\text{softmax} \\left(\\right. \\frac{Q K^{T}}{\\sqrt{d_{k}}} \\left.\\right) V\n$$\n\n$$\n\nWhere:\n\n* **Q (Query)**: Query vector, representing "what information am I looking for"\n* **K (Key)**: Key vector, representing "what information do I have here" (used for matching)\n* **V (Value)**: Value vector, representing "what content can I actually provide"\n* **$d_{k}$**: Dimension of Key, $\\sqrt{d_{k}}$ is used for scaling to prevent dot product values from becoming too large and causing softmax gradient vanishing\n\nWhy divide by $\\sqrt{d_{k}}$? When $d_{k}$ is large, the variance of the dot product of Q and K grows linearly with dimension, leading to excessively large softmax inputs and gradients approaching zero. After scaling, the variance returns to 1, ensuring normal gradient flow.\n\n## Example\n\nimport torch\n\nimport torch.nn as nn\n\nimport torch.nn.functional as F\n\nimport math\n\ndef scaled_dot_product_attention(Q, K, V, mask=None):\n\n"""\n\n Scaled Dot-Product Attention\nParameters:\n\n Q: Query tensor [batch, n_heads, seq_len_q, d_k]\n\n K: Key tensor [batch, n_heads, seq_len_k, d_k]\n\n V: Value tensor [batch, n_heads, seq_len_v, d_v]\n\n mask: Mask tensor, masked positions set to False/0\n\nReturns:\n\n output: Attention output [batch, n_heads, seq_len_q, d_v]\n\n attention_weights: Attention weights [batch, n_heads, seq_len_q, seq_len_k]\n\n """\n\n d_k = Q.size(-1)\n\n# 1. Calculate QK^T / sqrt(d_k)\n\n scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)\n\n# 2. Apply mask (optional)\n\n# Set masked positions to extremely small values, approaching 0 after softmax\n\nif mask is not None:\n\n scores = scores.masked_fill(mask ==0, -1e9)\n\n# 3. Softmax normalization β attention weights\n\n attention_weights = F.softmax(scores, dim=-1)\n\n# 4. Weighted sum\n\n output = torch.matmul(attention_weights, V)\n\nreturn output, attention_weights\n\n# Test\n\n batch_size, n_heads =2,4\n\n seq_len_q, seq_len_k =5,6\n\n d_k, d_v =8,16\n\nQ = torch.randn(batch_size, n_heads, seq_len_q, d_k)\n\n K = torch.randn(batch_size, n_heads, seq_len_k, d_k)\n\n V = torch.randn(batch_size, n_heads, seq_len_k, d_v)\n\noutput, attn_weights = scaled_dot_product_attention(Q, K, V)\n\nprint(f"Output shape: {output.shape}")# [2, 4, 5, 16]\n\nprint(f"Attention weights shape: {attn_weights.shape}")# [2, 4, 5, 6]\n\nprint(f"Row sum of weights (should be 1.0): {attn_weights[0, 0, 0].sum().item():.4f}")\n\n> The essence of attention weights is a probability distributionβafter softmax, the sum of each row is 1. Larger weights indicate that the model "focuses" more on that position.\n\n* * *\n\n## 2. PyTorch Attention Modules\n\n### 2.1 Multi-Head Attention\n\nMulti-Head Attention allows the model to simultaneously attend to information from different representation subspaces at different positions. It projects Q, K, V through different linear projections into multiple subspaces, computes attention independently in each subspace, and finally concatenates the results for another linear transformation.\n\nThis is like having multiple people examine the same text from different angles simultaneouslyβsome focus on grammatical structure, some on semantic relationships, and some on long-distance dependencies. The multi-head mechanism enables the model to capture richer patterns.\n\n$\\text{MultiHead} \\left(\\right. Q , K , V \\left.\\right) = \\text{Concat} \\left(\\right. \\text{head}_{1} , \\ldots , \\text{head}_{h} \\left.\\right) W^{O}$\n\n$\\text{where}\\textrm{ }\\text{head}_{i} = \\text{Attention} \\left(\\right. Q W_{i}^{Q} , K W_{i}^{K} , V W_{i}^{V} \\left.\\right)$\n\n## Example\n\nimport torch\n\nimport torch.nn as nn\n\nimport torch.nn.functional as F\n\nimport math\n\nclass MultiHeadAttention(nn.Module):\n\ndef __init__ (self, d_model, n_heads, dropout=0.1):\n\nsuper(). __init__ ()\n\nassert d_model % n_heads ==0,"d_model must be divisible by n_heads"\n\nself.d_model= d_model\n\nself.n_heads= n_heads\n\nself.d_k= d_model // n_heads # Dimension per head\n\n# Q, K, V linear projections (compute all heads at once, more efficient)\n\nself.w_q= nn.Linear(d_model, d_model)\n\nself.w_k= nn.Linear(d_model, d_model)\n\nself.w_v= nn.Linear(d_model, d_model)\n\n# Output projection\n\nself.w_o= nn.Linear(d_model, d_model)\n\nself.dropout= nn.Dropout(dropout)\n\ndef forward(self, query, key, value, mask=None):\n\n"""\n\n Parameters:\n\n query: [batch, seq_len_q, d_model]\n\n key: [batch, seq_len_k, d_model]\n\n value: [batch, seq_len_k, d_model]\n\n mask: [batch, 1, 1, seq_len_k] or [batch, 1, seq_len_q, seq_len_k]\n\n """\n\n batch_size = query.size(0)\n\n# 1. Linear projection, then split heads\n\n# [batch, seq_len, d_model] β [batch, seq_len, n_heads, d_k]\n\n# β [batch, n_heads, seq_len, d_k]\n\n Q =self.w_q(query).view(batch_size, -1,self.n_heads,self.d_k).transpose(1,2)\n\n K =self.w_k(key).view(batch_size, -1,self.n_heads,self.d_k).transpose(1,2)\n\n V =self.w_v(value).view(batch_size, -1,self.n_heads,self.d_k).transpose(1,2)\n\n# 2. Scaled dot-product attention\n\n scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)\n\nif mask is not None:\n\n scores = scores.masked_fill(mask ==0, -1e9)\n\nattn_weights = F.softmax(scores, dim=-1)\n\n attn_weights =self.dropout(attn_weights)\n\n# 3. Weighted sum\n\n context = torch.matmul(attn_weights, V)\n\n# 4. Merge multi-heads: [batch, n_heads, seq_len, d_k] β [batch, seq_len, d_model]\n\n context = context.transpose(1,2).contiguous().view(batch_size, -1,self.d_model)\n\n# 5. Output projection\n\n output =self.w_o(context)\n\nreturn output, attn_weights\n\n# Test\n\n d_model, n_heads =128,8\n\n seq_len, batch =10,4\n\nlayer = MultiHeadAttention(d_model, n_heads)\n\nquery = torch.randn(batch, seq_len, d_model)\n\n key = torch.randn(batch, seq_len, d_model)\n\n value = torch.randn(batch, seq_len, d_model)\n\noutput, attn_weights = layer(query, key, value)\n\nprint(f"Output shape: {output.shape}")# [4, 10, 128]\n\nprint(f"Attention weights shape: {attn_weights.shape}")# [4, 8, 10, 10]\n\nprint(f"Dimension per head d_k = {d_model // n_heads}")\n\n### 2.2 PyTorch Built-in MultiheadAttention\n\nPyTorch provides a highly optimized `nn.MultiheadAttention` with fused kernels at theUnderlying, which is faster than manual implementation in most scenarios. Recommended for production environments.\n\n## Example\n\nimport torch\n\nimport torch.nn as nn\n\nclass TransformerAttention(nn.Module):\n\n"""Self-attention layer using PyTorch built-in MultiheadAttention"""\n\ndef __init__ (self, d_model, n_heads, dropout=0.1):\n\nsuper(). __init__ ()\n\nself.attention= nn.MultiheadAttention(\n\n embed_dim=d_model,\n\n num_heads=n_heads,\n\n dropout=dropout,\n\n batch_first=True,# Input format is [batch, seq, features]\n\n# PyTorch 2.0+ can enable Flash Attention backend:\n\n# attn_implementation="flash_attention_2" # Requires flash-attn installation\n\n)\n\nself.layernorm= nn.LayerNorm(d_model)\n\ndef forward(self, x, key_padding_mask=None):\n\n"""\n\n Self-attention: Q, K, V are all x\n\n Parameters:\n\n x: [batch, seq_len, d_model]\n\n key_padding_mask: [batch, seq_len], True indicates padding position\n\n """\n\n attn_output, attn_weights =self.attention(\n\n x, x, x,# self-attention\n\n key_padding_mask=key_padding_mask\n\n)\n\n# Pre-Norm residual connection (more stable training than Post-Norm)\n\n output =self.layernorm(x + attn_output)\n\nreturn output, attn_weights\n\n# Test\n\n d_model, n_heads =128,8\n\n seq_len, batch =10,4\n\nmodel = TransformerAttention(d_model, n_heads)\n\n x = torch.randn(batch, seq_len, d_model)\n\n# Create padding mask: True means this position should be ignored\n\n key_padding_mask = torch.zeros(batch, seq_len, dtype=torch.bool)\n\n key_padding_mask[0,7:]=True# Last 3 positions of first sample are padding\n\n key_padding_mask[1,5:]=True# Last 5 positions of second sample are padding\n\noutput, attn_weights = model(x, key_padding_mask=key_padding_mask)\n\nprint(f"Output shape: {output.shape}")# [4, 10, 128]\n\nprint(f"Attention weights shape: {attn_weights.shape}")# [4, 10, 10]\n\n> **Note on mask type:** `nn.MultiheadAttention`'s `key_padding_mask` uses the **True = ignore** convention, which is opposite to the 0 = ignore convention in some custom implementations. Be sure to verify when using.\n\n* * *\n\n## 3. Variants of Attention Mechanism\n\n### 3.1 Self-Attention\n\nSelf-Attention is a special form of attention mechanismβQ, K, V all come from the same input sequence. It allows each position in the sequence to directly attend to all other positions, capturing global dependencies. This is the core component of Transformer and the key to its superiority over RNN: RNN needs to pass information step by step, while self-attention achieves this in one step.\n\n## Example\n\nimport torch\n\nimport torch.nn as nn\n\nimport torch.nn.functional as F\n\nimport math\n\nclass SelfAttention(nn.Module):\n\n"""Multi-head self-attention layer, supports causal mask (for autoregressive generation)"""\n\ndef __init__ (self, d_model, n_heads, dropout=0.1):\n\nsuper(). __init__ ()\n\nassert d_model % n_heads ==0\n\nself.d_model= d_model\n\nself.n_heads= n_heads\n\nself.d_k= d_model // n_heads\n\nself.w_q= nn.Linear(d_model, d_model)\n\nself.w_k= nn.Linear(d_model, d_model)\n\nself.w_v= nn.Linear(d_model, d_model)\n\nself.out_proj= nn.Linear(d_model, d_model)\n\nself.dropout= nn.Dropout(dropout)\n\ndef _split_heads(self, x, batch_size):\n\n"""[batch, seq, d_model] β [batch, n_heads, seq, d_k]"""\n\nreturn x.view(batch_size, -1,self.n_heads,self.d_k).transpose(1,2)\n\ndef forward(self, x, causal=False):\n\n"""\n\n Self-attention: Q, K, V all come from the same input x\n\n Parameters:\n\n x: [batch, seq_len, d_model]\n\n causal: Whether to use causal mask (prevent seeing future positions)\n\n """\n\n batch_size, seq_len, _ = x.size()\n\n# Projection + split heads\n\n Q =self._split_heads(self.w_q(x), batch_size)\n\n K =self._split_heads(self.w_k(x), batch_size)\n\n V =self._split_he
YouTip