YouTip LogoYouTip

Pytorch Torch Nn Embedding

PyTorch torch.nn.Embedding Function |

Image 1: PyTorch torch.nn Reference Manual PyTorch torch.nn Reference Manual


torch.nn.Embedding is a module in PyTorch used for word embeddings.

It maps discrete vocabulary indices to continuous vector space, and is one of the most fundamental operations in natural language processing.

Function Definition

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
Parameter Description:

  • num_embeddings (int): Vocabulary size, i.e., the number of rows in the embedding matrix.
  • embedding_dim (int): Dimension of each embedding vector.
  • padding_idx (int): Specifies the padding index; its embedding vector is set to zero. Default is None.
  • max_norm (float): If not None, embedding vectors are normalized to this norm. Default is None.
  • norm_type (float): Order of the norm used for computation. Default is 2.0.
  • scale_grad_by_freq (bool): Whether to scale gradients by word frequency. Default is False.
  • sparse (bool): Whether the weight matrix is sparse. Default is False.

Attributes:

  • weight (Tensor): Learnable weights of shape (num_embeddings, embedding_dim).

Usage Examples

Example 1: Basic Usage

Create and use an embedding layer:

Example

import torch
import torch.nn as nn

# Create embedding layer: vocabulary size 10000, embedding dimension 256
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=256)

# Word indices (starting from 0)
# Shape: (batch, seq_len)
input_indices = torch.tensor([[12,45,678],[901,23,56]])

# Look up embedding vectors
output = embedding(input_indices)

print("Input indices shape:", input_indices.shape)
print("Output embeddings shape:", output.shape)  # (2, 3, 256)

# View embedding matrix shape
print("Embedding matrix shape:", embedding.weight.shape)

Example 2: Using padding_idx

Specify padding index:

Example

import torch
import torch.nn as nn

# Create embedding layer with padding
embedding = nn.Embedding(num_embeddings=1000, embedding_dim=64, padding_idx=0)

# 0 is used as padding
input_indices = torch.tensor([[1,2,3],[4,0,0]])  # Second sentence has padding
output = embedding(input_indices)

print("Input shape:", input_indices.shape)
print("Output shape:", output.shape)
print("Padding embedding vector:", output[1,1].tolist())  # All zeros
print("Non-padding embedding vector:", output[0,0].tolist()[:5])  # Non-zero

Example 3: Pretrained Word Vectors

Load pretrained word vectors:

Example

import torch
import torch.nn as nn
import numpy as np

# Simulate pretrained word vectors (GloVe, Word2Vec, etc. can be used in practice)
vocab_size = 1000
embedding_dim = 300

# Random initialization (pretrained vectors should be loaded in practice)
pretrained_weights = np.random.randn(vocab_size, embedding_dim).astype('float32')

# Create embedding layer
embedding = nn.Embedding(vocab_size, embedding_dim)

# Load pretrained weights
embedding.weight.data = torch.from_numpy(pretrained_weights)

# Freeze embedding layer (not trained)
embedding.weight.requires_grad = False

print("Embedding layer trainable:", embedding.weight.requires_grad)
print("Embedding matrix shape:", embedding.weight.shape)

Example 4: Limiting Embedding Vector Norm

Use max_norm to constrain vector norms:

Example

import torch
import torch.nn as nn

# Limit maximum embedding vector norm to 1.0
embedding = nn.Embedding(1000, 64, max_norm=1.0)

# Input
input_indices = torch.tensor([1, 2, 3])

# Original weight norms
original_norm = embedding.weight.data.norm(dim=1)[:3]
print("Original weight norms:", original_norm.tolist())

# Output vector norms after lookup
output = embedding(input_indices)
output_norm = output.norm(dim=1)
print("Output vector norms:", output_norm.tolist())

Example 5: Complete Text Classification Model

Text classification using embedding layer:

Example

import torch
import torch.nn as nn

class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
        super(TextClassifier, self).__init__()
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        # LSTM
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        # Classifier
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        # LSTM: take last output
        _, (hidden, _) = self.lstm(embedded)
        hidden = hidden  # Hidden state of last layer
        # Classification
        logits = self.classifier(hidden)
        return logits

# Parameters
VOCAB_SIZE = 10000
EMBED_DIM = 128
HIDDEN_DIM = 256

model = TextClassifier(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM)

# Input: batch=4, sequence length=50
input_ids = torch.randint(1, VOCAB_SIZE, (4, 50))
output = model(input_ids)

print("Model structure:")
print(model)
print("Input shape:", input_ids.shape)
print("Output shape:", output.shape)

Difference Between Embedding and EmbeddingBag

Type Input Output Applicable Scenarios
nn.Embedding Word indices Sequence of word vectors Sequence models, LSTM, Transformer
nn.EmbeddingBag Word indices + offsets Aggregated vector Text classification, fast processing

Frequently Asked Questions

Q1: How to choose embedding dimension?

  • Small datasets: 50–100 dimensions
  • Medium datasets: 100–300 dimensions
  • Large datasets: 300–500 dimensions

Q2: What is the purpose of padding_idx?

Sets the embedding vector for the specified index to zero, and excludes its gradient from backpropagation.

Q3: When to freeze the embedding layer?

When using pretrained word vectors, it is common to freeze the layer for some time before fine-tuning.


Usage Scenarios

Main application scenarios of nn.Embedding include:

  • Word vector representation: Converting words into dense vectors
  • Text classification: As input layer for NLP models
  • Sequence models: Input for LSTM, GRU
  • Recommendation systems: Embedding representations for users and items

Note: embedding.weight is a learnable parameter and can be trained directly in the optimizer.


Image 2: PyTorch torch.nn Reference Manual PyTorch torch.nn Reference Manual

← Pytorch Torch Nn FlattenPytorch Torch Nn Dropout2D β†’