Pytorch Torch Nn Embeddingbag
# PyTorch torch.nn.EmbeddingBag Module
`torch.nn.EmbeddingBag` is a powerful module in PyTorch designed to compute sums, means, or max values of "bags" of embeddings without instantiating the intermediate embeddings.
It is highly optimized and offers significant performance and memory advantages over a sequential combination of `torch.nn.Embedding` followed by an aggregation operation (like `torch.sum()` or `torch.mean()`), making it a staple for text classification, recommendation systems, and bag-of-words representations.
---
## Function Definition
```python
torch.nn.EmbeddingBag(
num_embeddings,
embedding_dim,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
mode='mean',
sparse=False,
include_last_offset=False
)
```
### Key Parameters
| Parameter | Type | Description |
| :--- | :--- | :--- |
| `num_embeddings` | `int` | Size of the dictionary of embeddings (vocabulary size). |
| `embedding_dim` | `int` | The size of each embedding vector. |
| `max_norm` | `float`, optional | If given, each embedding vector with a norm larger than `max_norm` is renormalized to have norm `max_norm`. |
| `norm_type` | `float`, optional | The $p$-norm to compute for the `max_norm` option. Default is `2.0`. |
| `scale_grad_by_freq` | `bool`, optional | If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default is `False`. |
| `mode` | `str`, optional | The aggregation method. Options are `'sum'`, `'mean'`, or `'max'`. Default is `'mean'`. |
| `sparse` | `bool`, optional | If `True`, the gradient $W$ matrix will be a sparse tensor. |
| `include_last_offset` | `bool`, optional | If `True`, the `offsets` tensor includes an extra element at the end which marks the end of the last sequence. Default is `False`. |
---
## How It Works: Inputs and Offsets
Unlike a standard `nn.Embedding` which expects a 2D tensor of padded sequences, `nn.EmbeddingBag` can accept a **1D concatenated tensor** of indices along with an **`offsets`** tensor.
* **`input`**: A 1D or 2D tensor containing the indices of the embeddings to extract.
* **`offsets`**: A 1D tensor containing the starting index positions of each bag (sequence) inside the 1D input tensor.
For example, if your input sequences are `[1, 2, 3]` and `[4, 5]`, you can flatten them into a single 1D tensor: `indices = [1, 2, 3, 4, 5]`. The corresponding offsets would be `offsets = [0, 3]`, indicating that the first bag starts at index `0` and the second bag starts at index `3`.
---
## Code Examples
### Example 1: Basic Usage with Offsets
This example demonstrates how to pass a 1D tensor of indices along with an `offsets` tensor to aggregate embeddings.
```python
import torch
import torch.nn as nn
# Define an EmbeddingBag with a vocabulary size of 1000 and embedding dimension of 128
ebag = nn.EmbeddingBag(1000, 128, mode='mean')
# Flattened word indices for two sentences: [1, 2, 3] and [4, 5]
indices = torch.tensor([1, 2, 3, 4, 5])
# Offsets indicating the starting index of each sentence in the 1D tensor
offsets = torch.tensor([0, 3])
# Forward pass
output = ebag(indices, offsets)
print("Input indices shape:", indices.shape)
print("Offsets shape:", offsets.shape)
print("Output shape (Batch Size, Embedding Dim):", output.shape)
# Output shape will be torch.Size([2, 128])
```
---
### Example 2: Fast Text Classification (FastText Architecture)
`EmbeddingBag` is the core component of the FastText architecture, allowing rapid sentence-level embedding generation.
```python
import torch
import torch.nn as nn
class FastText(nn.Module):
def __init__(self, vocab_size, embed_dim, num_classes):
super(FastText, self).__init__()
# EmbeddingBag aggregates word embeddings into a single sentence representation
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
# Instantiate model: Vocab size 10000, Embedding size 128, 2 Output classes
model = FastText(10000, 128, 2)
# Simulate a batch of text indices (flattened)
text = torch.randint(0, 10000, (100,))
# Define offsets for 4 variable-length sequences in the batch
offsets = torch.tensor([0, 30, 60, 100]) # Last offset 100 marks the end if include_last_offset=True
# Forward pass (Note: default include_last_offset is False, so we pass the starts of the bags)
offsets_input = torch.tensor([0, 30, 60])
output = model(text, offsets_input)
print("Output shape:", output.shape)
# Output shape: torch.Size([3, 2])
```
---
### Example 3: Comparing Aggregation Modes
You can aggregate embeddings using `'mean'`, `'sum'`, or `'max'`.
```python
import torch
import torch.nn as nn
modes = ['mean', 'sum', 'max']
for mode in modes:
ebag = nn.EmbeddingBag(100, 32, mode=mode)
indices = torch.tensor([1, 2, 3, 4])
offsets = torch.tensor([0, 2]) # Bag 1: [1, 2], Bag 2: [3, 4]
out = ebag(indices, offsets)
print(f"Mode '{mode}' output shape:", out.shape)
```
---
## Common Use Cases
* **Text Classification**: Powering models like FastText where sentence embeddings are computed as the average of their constituent word embeddings.
* **Recommendation Systems**: Aggregating sparse categorical features (e.g., user search history, liked items) into a single dense vector representation.
* **Bag-of-Words (BoW) Models**: Efficiently scaling up bag-of-words representations for large vocabularies.
---
## Key Considerations & Tips
1. **Performance Advantage**: `nn.EmbeddingBag` is much faster and more memory-efficient than running `nn.Embedding` followed by a reduction operation because it avoids allocating memory for the intermediate embedding tensors.
2. **Handling 2D Inputs**: If your input is a 2D tensor of shape `(batch_size, sequence_length)` containing padded sequences of equal length, you do **not** need to provide the `offsets` parameter. PyTorch will automatically treat each row as a bag.
3. **The `include_last_offset` Parameter**: If you set `include_last_offset=True`, the `offsets` tensor must have a size equal to the number of bags plus one. The last element represents the total length of the 1D index tensor.
YouTip