Pytorch Torch Unflatten
## PyTorch torch.unflatten
`torch.unflatten` is a PyTorch function used to expand a specific dimension of a tensor into multiple dimensions. It is the inverse operation of `torch.flatten` and is highly useful for reshaping tensors, particularly when transitioning between convolutional layers and fully connected layers in neural networks.
---
### Function Definition
```python
torch.unflatten(input, dim, sizes) -> Tensor
```
#### Parameter Descriptions:
* **`input`** *(Tensor)*: The input tensor to be reshaped.
* **`dim`** *(int or str)*: The index or name of the dimension to expand.
* **`sizes`** *(Tuple or List or torch.Size)*: The target shape of the expanded dimension. The product of the elements in `sizes` must equal the size of the original dimension specified by `dim`.
---
## Code Examples
### Example 1: Unflattening a 1D Tensor into a 2D Matrix
In this example, we take a 1D tensor of size 12 and expand its single dimension (dimension 0) into a 2D shape of `(3, 4)`.
```python
import torch
# Create a 1D tensor
x = torch.arange(12)
# Unflatten dimension 0 into a 3x4 matrix
y = torch.unflatten(x, dim=0, sizes=(3, 4))
print("Shape of y:", y.shape)
print("Tensor y:\n", y)
```
**Output:**
```text
Shape of y: torch.Size([3, 4])
Tensor y:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
```
---
### Example 2: Unflattening into a 3D Tensor
You can also expand a single dimension into three or more dimensions, provided the total number of elements remains consistent.
```python
import torch
# Create a flat tensor with 24 elements
x = torch.randn(24)
# Unflatten dimension 0 into a 3D tensor of shape 2x3x4
y = torch.unflatten(x, dim=0, sizes=(2, 3, 4))
print("Shape of y:", y.shape)
```
**Output:**
```text
torch.Size([2, 3, 4])
```
---
### Example 3: Unflattening Named Dimensions
PyTorch supports named tensors. If your tensor has named dimensions, you can specify the target dimension using its name instead of its index.
```python
import torch
# Create a flat tensor
x = torch.randn(12)
# Assign a name 'N' to the dimension
x = x.rename('N')
# Unflatten the dimension by specifying its name
y = torch.unflatten(x, dim='N', sizes=(3, 4))
print("Shape of y:", y.shape)
```
**Output:**
```text
torch.Size([3, 4])
```
---
## Key Considerations
1. **Product Match Requirement**: The product of the dimensions specified in `sizes` must exactly match the size of the dimension being unflattened. For example, if `input.size(dim)` is `12`, the product of `sizes` must be `12` (e.g., `(3, 4)`, `(2, 6)`, or `(2, 3, 2)`). If they do not match, PyTorch will raise a `RuntimeError`.
2. **Memory Sharing**: Like `torch.view` and `torch.reshape`, `torch.unflatten` returns a view of the original tensor whenever possible. Modifying the returned tensor will modify the original tensor because they share the same underlying data storage.
3. **Alternative to `reshape` / `view`**: While you can achieve similar results using `tensor.view()` or `tensor.reshape()`, `torch.unflatten` is much safer and more readable when you only want to target and split a *specific* dimension without affecting or needing to calculate the rest of the tensor's dimensions.
YouTip