YouTip LogoYouTip

Pytorch Torch Nn Multiheadattention

Image 1: PyTorch torch.nn Reference Manual PyTorch torch.nn Reference Manual


torch.nn.MultiheadAttention is the multi-head attention mechanism module in PyTorch.

It is a core component of the Transformer architecture, allowing the model to simultaneously attend to information from different representation subspaces at different positions.

Function Definition

torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, kdim=None, vdim=None, batch_first=True)
Parameters:

  • embed_dim (int): Dimension of the input embeddings.
  • num_heads (int): Number of attention heads.
  • dropout (float): Dropout probability. Default: 0.
  • kdim (int): Dimension of key vectors. Default: None (same as embed_dim).
  • vdim (int): Dimension of value vectors. Default: None (same as embed_dim).
  • batch_first (bool): If True, the first dimension of input/output is batch. Default: True.

Usage Examples

Example 1: Basic Usage

Example

import torch

import torch.nn as nn

# Create multi-head attention: 512-dim, 8 heads

mha = nn.MultiheadAttention(embed_dim=512, num_heads=8)

# Input: batch=4, sequence length=100, dim=512

query = torch.randn(4,100,512)

key = torch.randn(4,100,512)

value = torch.randn(4,100,512)

# Forward pass

output, attn_weight = mha(query, key, value)

print("Query shape:", query.shape)

print("Output shape:", output.shape)

print("Attention weights shape:", attn_weight.shape)

Example 2: Self-Attention

Example

import torch

import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=256, num_heads=4)

# Same input used as Q, K, V (self-attention)

x = torch.randn(2,50,256)

# self-attention: q=k=v=x

output, weights = mha(x, x, x)

print("Input shape:", x.shape)

print("Output shape:", output.shape)

print("Attention weights shape:", weights.shape)

Example 3: Attention with Mask

Example

import torch

import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=128, num_heads=4)

# Input

x = torch.randn(1,20,128)

# Create upper triangular mask (for decoder)

mask = torch.triu(torch.ones(20,20), diagonal=1).bool()

output, _ = mha(x, x, x, attn_mask=mask)

print("Input shape:", x.shape)

print("Output shape:", output.shape)

print("Mask shape:", mask.shape)

Example 4: Complete Transformer Encoder Layer

Example

import torch

import torch.nn as nn

class TransformerLayer(nn.Module):

    def __init__ (self, d_model, nhead):

        super(TransformerLayer,self).__init__()

        self.self_attn= nn.MultiheadAttention(d_model, nhead, batch_first=True)

        self.norm1= nn.LayerNorm(d_model)

        self.norm2= nn.LayerNorm(d_model)

        self.ffn= nn.Sequential(

            nn.Linear(d_model, d_model * 4),

            nn.ReLU(),

            nn.Linear(d_model * 4, d_model)

        )

    def forward(self, x):

        # Self-attention + residual

        attn_out, _ =self.self_attn(x, x, x)

        x =self.norm1(x + attn_out)

        # FFN + residual

        ffn_out =self.ffn(x)

        x =self.norm2(x + ffn_out)

        return x

# Test

layer = TransformerLayer(d_model=512, nhead=8)

x = torch.randn(4,100,512)

output = layer(x)

print("Input shape:", x.shape)

print("Output shape:", output.shape)

Example 5: Inspecting Attention Weights

Example

import torch

import torch.nn as nn

import numpy as np

mha = nn.MultiheadAttention(embed_dim=64, num_heads=2, batch_first=True)

# Simple video sequence

x = torch.randn(1,5,64)

_, attn = mha(x, x, x)

attn = attn.squeeze(0)  # Remove batch dimension

print("Attention weights of the first head (first 3 positions):")

print(attn[0, :3, :3].tolist())

print("Visualization - attention from position 0 to all positions:")

print(np.array2string(attn[0,0].numpy(), precision=2))

Frequently Asked Questions

Q1: How to choose num_heads?

embed_dim must be divisible by num_heads. Common values: 8, 12, 16.

Q2: Why are Q, K, and V matrices needed?

Allows the model to learn different projections, enhancing representational capacity.

Q3: What is key_padding_mask?

Used to mask padding positions to avoid attention computation on them.


Usage Scenarios

  • Transformer: Encoder and decoder
  • Self-attention models: BERT, GPT
  • Sequence modeling: Replacement for RNNs

Tip: When batch_first=True, input shape is (batch, seq, embed_dim).


Image 2: PyTorch torch.nn Reference Manual PyTorch torch.nn Reference Manual

← Pytorch Torch Nn ParameterPytorch Torch Nn Module β†’