YouTip LogoYouTip

Pytorch Torch Nn Rnn

The PyTorch `torch.nn.RNN` class implements a multi-layer Elman Recurrent Neural Network (RNN) with either $\tanh$ or ReLU non-linearities. Unlike feedforward neural networks, RNNs maintain an internal hidden state that acts as a memory, allowing them to process sequences of inputs of arbitrary length. This makes them highly effective for sequential data processing tasks such as time-series forecasting, natural language processing (NLP), and speech recognition. --- ## Introduction An Elman RNN applies a recurrence relation to a sequence of input vectors. For each element in an input sequence, the RNN layer computes the next hidden state $h_t$ using the current input $x_t$ and the previous hidden state $h_{t-1}$: $$h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{t-1} + b_{hh})$$ Where: * $h_t$ is the hidden state at time $t$. * $x_t$ is the input at time $t$. * $h_{t-1}$ is the hidden state of the previous time step at time $t-1$ (or the initial hidden state $h_0$ at $t=0$). * $W_{ih}$ and $W_{hh}$ are the learnable input-to-hidden and hidden-to-hidden weights. * $b_{ih}$ and $b_{hh}$ are the corresponding bias vectors. While more advanced architectures like LSTMs (`torch.nn.LSTM`) and GRUs (`torch.nn.GRU`) are often preferred in practice to mitigate the vanishing gradient problem, understanding `torch.nn.RNN` is fundamental to mastering sequence modeling in PyTorch. --- ## Syntax and Parameters ### Initialization Signature ```python class torch.nn.RNN(*args, **kwargs) ``` ### Key Constructor Parameters | Parameter | Type | Default | Description | | :--- | :--- | :--- | :--- | | `input_size` | `int` | *Required* | The number of expected features in the input $x$. | | `hidden_size` | `int` | *Required* | The number of features in the hidden state $h$. | | `num_layers` | `int` | `1` | Number of recurrent layers stacked on top of each other. | | `nonlinearity` | `str` | `'tanh'` | The non-linear activation function to use. Can be either `'tanh'` or `'relu'`. | | `bias` | `bool` | `True` | If `False`, the layer does not use bias weights $b_{ih}$ and $b_{hh}$. | | `batch_first` | `bool` | `False` | If `True`, the input and output tensors are provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. | | `dropout` | `float` | `0.0` | If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer. | | `bidirectional` | `bool` | `False` | If `True`, becomes a bidirectional RNN. | ### Input and Output Shapes Assuming `batch_first=False` (the default): * **Inputs**: * `input` of shape `(seq_len, batch_size, input_size)`: Tensor containing the features of the input sequence. * `h_0` of shape `(num_layers * num_directions, batch_size, hidden_size)`: Tensor containing the initial hidden state for each element in the batch. Defaults to zeros if not provided. * **Outputs**: * `output` of shape `(seq_len, batch_size, num_directions * hidden_size)`: Tensor containing the output features ($h_t$) from the last layer of the RNN, for each time step $t$. * `h_n` of shape `(num_layers * num_directions, batch_size, hidden_size)`: Tensor containing the final hidden state for $t = \text{seq\_len}$. --- ## Code Example Below is a complete, self-contained code example demonstrating how to initialize a `torch.nn.RNN` layer, pass data through it, and handle the output shapes. ```python import torch import torch.nn as nn # 1. Define Hyperparameters batch_size = 3 seq_len = 5 input_size = 10 # e.g., word embedding dimension hidden_size = 20 # Dimension of the hidden state num_layers = 2 # Stacked RNN layers # 2. Instantiate the RNN Layer # We set batch_first=True to align with standard data pipelines (Batch, Sequence, Feature) rnn = nn.RNN( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, nonlinearity='tanh' ) # 3. Create Dummy Input Data # Shape: (batch_size, seq_len, input_size) dummy_input = torch.randn(batch_size, seq_len, input_size) # 4. Initialize the Hidden State (Optional) # Shape: (num_layers * num_directions, batch_size, hidden_size) # Since bidirectional=False, num_directions is 1. h0 = torch.zeros(num_layers, batch_size, hidden_size) # 5. Forward Pass # If h0 is not explicitly passed, PyTorch automatically initializes it to zeros. output, hn = rnn(dummy_input, h0) # 6. Inspect Output Shapes print("--- Shape Verification ---") print(f"Input Shape: {dummy_input.shape}") print(f"Output Shape: {output.shape}") # Expected: (batch_size, seq_len, hidden_size) print(f"h_n Shape: {hn.shape}") # Expected: (num_layers, batch_size, hidden_size) # Verify that the last step of the output matches the last layer's hidden state # output[:, -1, :] should equal hn[-1, :, :] last_step_output = output[:, -1, :] last_layer_hidden = hn[-1, :, :] assert torch.allclose(last_step_output, last_layer_hidden, atol=1e-6) print("\nVerification successful: Last output step matches final hidden state.") ``` --- ## Best Practices and Common Pitfalls ### 1. Watch Out for the `batch_first` Flag By default, PyTorch recurrent layers expect inputs in the shape `(seq_len, batch_size, input_size)`. This is historically optimized for CUDA performance. However, most modern data loaders output tensors as `(batch_size, seq_len, input_size)`. * **Tip**: Always explicitly set `batch_first=True` if your data pipeline uses batch-first dimensions to avoid silent shape mismatches or unexpected tensor permutations. Note that `batch_first` **only** affects the input and output tensors; the hidden state $h_0$ and $h_n$ tensors always retain the shape `(num_layers, batch, hidden_size)`. ### 2. The Vanishing and Exploding Gradient Problem Standard RNNs struggle to learn long-term dependencies (sequences longer than 10–20 steps) due to vanishing or exploding gradients during backpropagation through time (BPTT). * **Tip**: If you are training on long sequences, use `torch.nn.LSTM` or `torch.nn.GRU` instead of `torch.nn.RNN`. If you must use standard RNNs, apply gradient clipping using `torch.nn.utils.clip_grad_norm_` to prevent exploding gradients. ### 3. Reusing Hidden States Across Batches When processing continuous sequences (like long text documents divided into chunks), you may want to pass the final hidden state `h_n` of the current batch as the initial hidden state `h_0` of the next batch. * **Pitfall**: If you pass `h_n` directly to the next iteration without detaching it, PyTorch will attempt to backpropagate through the entire history of all previous batches, leading to massive memory consumption and `RuntimeError: CUDA out of memory`. * **Solution**: Always call `.detach()` on the hidden state if you are carrying it over to a new batch: ```python # Carry over state without tracking history h0 = hn.detach() ```
← Pytorch Torch Nn SequentialPytorch Torch Nn Nllloss β†’