Pytorch Torch Nn Gru
## PyTorch torch.nn.GRU Module
`torch.nn.GRU` is the Gated Recurrent Unit (GRU) module in PyTorch.
The GRU is a simplified variant of the Long Short-Term Memory (LSTM) network. It features fewer parameters, faster computation times, and achieves comparable performance on many sequence modeling tasks.
---
### Function Definition
```python
torch.nn.GRU(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False)
```
### Key Parameters
* **`input_size`**: The number of expected features in the input `x`.
* **`hidden_size`**: The number of features in the hidden state `h`.
* **`num_layers`**: Number of recurrent layers (default: `1`).
* **`bias`**: If `False`, the layer does not use bias weights (default: `True`).
* **`batch_first`**: If `True`, the input and output tensors are provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. (default: `False`).
* **`dropout`**: If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to `dropout` (default: `0`).
* **`bidirectional`**: If `True`, becomes a bidirectional GRU (default: `False`).
---
## Code Examples
### Example 1: Basic Usage
This example demonstrates how to initialize a basic 2-layer GRU and pass a batch of sequence data through it.
```python
import torch
import torch.nn as nn
# Initialize GRU: Input size = 256, Hidden size = 256, 2 Layers
gru = nn.GRU(input_size=256, hidden_size=256, num_layers=2, batch_first=True)
# Input tensor: batch_size=4, sequence_length=10, input_size=256
x = torch.randn(4, 10, 256)
# Forward pass
output, hidden = gru(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("Hidden state shape:", hidden.shape)
```
**Output:**
```text
Input shape: torch.Size([4, 10, 256])
Output shape: torch.Size([4, 10, 256])
Hidden state shape: torch.Size([2, 4, 256])
```
---
### Example 2: Performance Comparison (LSTM vs. GRU)
Because GRUs have fewer gates than LSTMs, they are computationally lighter and faster to train. The following script compares the execution speed of both modules.
```python
import torch
import torch.nn as nn
import time
# Configure LSTM and GRU with identical dimensions
lstm = nn.LSTM(256, 256, 1, batch_first=True)
gru = nn.GRU(256, 256, 1, batch_first=True)
# Input tensor: batch_size=32, sequence_length=100, input_size=256
x = torch.randn(32, 100, 256)
# Performance benchmark
for model, name in [(lstm, "LSTM"), (gru, "GRU")]:
start = time.time()
for _ in range(100):
_ = model(x)
print(f"{name} execution time: {time.time() - start:.3f}s")
```
---
### Example 3: Text Classification with Bidirectional GRU
This example shows how to build a text classifier using a bidirectional GRU. We concatenate the final hidden states from both directions to feed into a fully connected layer.
```python
import torch
import torch.nn as nn
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super(GRUClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
# Bidirectional GRU
self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
# The linear layer input dimension is doubled because of bidirectionality
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
embedded = self.embedding(x)
_, hidden = self.gru(embedded)
# Concatenate the final hidden state of the forward and backward passes
# hidden is the last forward hidden state, hidden is the last backward hidden state
hidden = torch.cat([hidden, hidden], dim=1)
return self.fc(hidden)
# Instantiate model
model = GRUClassifier(vocab_size=10000, embed_dim=128, hidden_dim=128, num_classes=2)
# Dummy input: batch_size=8, sequence_length=50
x = torch.randint(0, 10000, (8, 50))
output = model(x)
print("Input shape:", x.shape, "-> Output shape:", output.shape)
```
**Output:**
```text
Input shape: torch.Size([8, 50]) -> Output shape: torch.Size([8, 2])
```
---
## LSTM vs. GRU Comparison
| Feature | LSTM | GRU |
| :--- | :--- | :--- |
| **Parameter Count** | Higher | Lower |
| **Gating Mechanism** | 3 gates (Input, Forget, Output) | 2 gates (Reset, Update) |
| **Computation Speed** | Slower | Faster |
| **Memory Footprint** | Larger | Smaller |
---
## Common Use Cases
* **Sequence Modeling**: Natural Language Processing (NLP) tasks such as text classification, sentiment analysis, and speech/audio processing.
* **Rapid Prototyping**: Ideal when computational resources or memory are limited, and you need to train recurrent networks quickly.
* **Encoder-Decoder Architectures**: Frequently used as the encoder or decoder backbone in machine translation and sequence-to-sequence models.
YouTip