Pytorch Torch Nn Batchnorm2D
`torch.nn.BatchNorm2d` is a core PyTorch module used to apply **2D Batch Normalization** over a 4D input tensor (typically representing a batch of images or feature maps).
Batch Normalization is a highly effective technique used to accelerate the training of deep neural networks, stabilize optimization, and reduce sensitivity to network initialization.
---
## Introduction
In deep neural networks, the distribution of each layer's inputs changes during training as the parameters of the previous layers change. This phenomenon is known as **internal covariate shift**.
`BatchNorm2d` mitigates this issue by normalizing the activations of a convolutional layer across the mini-batch for each channel independently. For a given channel, it:
1. Calculates the mean and variance across the batch ($N$) and spatial dimensions ($H, W$).
2. Normalizes the features to have a mean of 0 and a variance of 1.
3. Applies a learnable affine transformation (scale $\gamma$ and shift $\beta$) to restore the representational power of the network.
### Mathematical Formula
For a 4D tensor of shape $(N, C, H, W)$, the normalization for each channel $c$ is calculated as:
$$y = \frac{x - \mathrm{E}}{\sqrt{\mathrm{Var} + \epsilon}} \times \gamma + \beta$$
Where:
* $\mathrm{E}$ and $\mathrm{Var}$ are computed per channel over the batch and spatial dimensions.
* $\epsilon$ (epsilon) is a small constant added to the denominator for numerical stability.
* $\gamma$ (gamma) and $\beta$ (beta) are learnable parameter vectors of size $C$.
During **training**, the layer keeps a running estimate of its computed mean and variance, which are then used for normalization during **evaluation** (inference).
---
## Syntax and Parameters
### Constructor Signature
```python
class torch.nn.BatchNorm2d(
num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
device=None,
dtype=None
)
```
### Parameters
| Parameter | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `num_features` | `int` | *Required* | The number of channels ($C$) in the input tensor. |
| `eps` | `float` | `1e-05` | A value added to the denominator for numerical stability. |
| `momentum` | `float` | `0.1` | The value used for the running mean and running variance computation. Can be set to `None` for cumulative moving average. |
| `affine` | `bool` | `True` | When set to `True`, this module has learnable affine parameters ($\gamma$ and $\beta$). |
| `track_running_stats` | `bool` | `True` | When set to `True`, this module tracks the running mean and variance. If `False`, it uses batch statistics during both training and evaluation. |
### Input and Output Shapes
* **Input:** $(N, C, H, W)$ or $(C, H, W)$
* **Output:** $(N, C, H, W)$ or $(C, H, W)$ (same shape as input)
Where:
* $N$ is the batch size.
* $C$ is the number of channels (must match `num_features`).
* $H$ is the height of the feature map.
* $W$ is the width of the feature map.
---
## Code Example
The following complete, executable script demonstrates how to initialize `BatchNorm2d`, pass a dummy image batch through it, and inspect its learnable parameters and running statistics.
```python
import torch
import torch.nn as nn
# Set seed for reproducibility
torch.manual_seed(42)
# 1. Define dimensions: Batch size (N)=2, Channels (C)=3, Height (H)=4, Width (W)=4
batch_size = 2
channels = 3
height = 4
width = 4
# 2. Create a dummy input tensor representing a batch of feature maps
# Shape: (N, C, H, W)
input_tensor = torch.randn(batch_size, channels, height, width) * 10 + 5
print("--- Input Tensor (First Channel of First Batch Item) ---")
print(input_tensor[0, 0])
# 3. Initialize BatchNorm2d
# num_features must match the number of input channels (C)
bn = nn.BatchNorm2d(num_features=channels)
# 4. Set the module to training mode (this is the default state)
bn.train()
# 5. Forward pass through BatchNorm2d
output_tensor = bn(input_tensor)
print("\n--- Output Tensor (First Channel of First Batch Item) ---")
print(output_tensor[0, 0])
# 6. Inspect Learnable Parameters (Gamma and Beta)
# Since affine=True by default, weight (gamma) is initialized to 1s, and bias (beta) to 0s
print("\n--- Learnable Parameters ---")
print(f"Weight (Gamma) shape: {bn.weight.shape} -> values: {bn.weight.data}")
print(f"Bias (Beta) shape: {bn.bias.shape} -> values: {bn.bias.data}")
# 7. Inspect Running Statistics
# These track the global dataset mean and variance for inference
print("\n--- Running Statistics (Updated after 1 forward pass) ---")
print(f"Running Mean: {bn.running_mean}")
print(f"Running Variance: {bn.running_var}")
# 8. Switch to Evaluation Mode
# In eval mode, the layer uses the running statistics instead of the batch statistics
bn.eval()
eval_output = bn(input_tensor)
print("\nEvaluation pass completed successfully.")
```
---
## Best Practices and Common Pitfalls
### 1. Always Toggle `train()` and `eval()` Modes
`BatchNorm2d` behaves fundamentally differently during training and evaluation:
* **During Training (`model.train()`)**: It calculates mean and variance from the current mini-batch and updates its running statistics.
* **During Evaluation (`model.eval()`)**: It freezes its running statistics and uses them to normalize validation/test data.
**Pitfall:** Forgetting to call `model.eval()` during validation or inference will cause the model to normalize test images using the test batch's statistics, leading to inconsistent, highly volatile, or incorrect predictions.
### 2. Disable Bias in Preceding Convolutional Layers
When a 2D convolutional layer is immediately followed by a `BatchNorm2d` layer, the convolution's bias term ($\mathbf{b}$) becomes redundant.
The first step of Batch Normalization is to subtract the mean of the batch. Any constant bias added by the convolutional layer is subtracted out during this step, making the bias parameters a waste of memory and computation.
```python
# INCORRECT (Redundant bias)
model_step = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, bias=True),
nn.BatchNorm2d(64)
)
# CORRECT (Saves memory and computation)
model_step = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, bias=False), # Set bias=False
nn.BatchNorm2d(64)
)
```
### 3. Batch Size Sensitivity
Batch Normalization relies on batch statistics to estimate the global dataset statistics. If your batch size is too small (e.g., $N < 4$ due to GPU memory constraints when training large models like segmentation networks), the batch mean and variance estimates will be highly inaccurate. This can degrade model performance.
* **Tip:** If you must train with very small batch sizes, consider alternative normalization techniques such as **Group Normalization** (`torch.nn.GroupNorm`) or **Layer Normalization** (`torch.nn.LayerNorm`).
YouTip