YouTip LogoYouTip

Pytorch Torch Unflatten

## PyTorch torch.unflatten `torch.unflatten` is a PyTorch function used to expand a specific dimension of a tensor into multiple dimensions. It is the inverse operation of `torch.flatten` and is highly useful for reshaping tensors, particularly when transitioning between convolutional layers and fully connected layers in neural networks. --- ### Function Definition ```python torch.unflatten(input, dim, sizes) -> Tensor ``` #### Parameter Descriptions: * **`input`** *(Tensor)*: The input tensor to be reshaped. * **`dim`** *(int or str)*: The index or name of the dimension to expand. * **`sizes`** *(Tuple or List or torch.Size)*: The target shape of the expanded dimension. The product of the elements in `sizes` must equal the size of the original dimension specified by `dim`. --- ## Code Examples ### Example 1: Unflattening a 1D Tensor into a 2D Matrix In this example, we take a 1D tensor of size 12 and expand its single dimension (dimension 0) into a 2D shape of `(3, 4)`. ```python import torch # Create a 1D tensor x = torch.arange(12) # Unflatten dimension 0 into a 3x4 matrix y = torch.unflatten(x, dim=0, sizes=(3, 4)) print("Shape of y:", y.shape) print("Tensor y:\n", y) ``` **Output:** ```text Shape of y: torch.Size([3, 4]) Tensor y: tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) ``` --- ### Example 2: Unflattening into a 3D Tensor You can also expand a single dimension into three or more dimensions, provided the total number of elements remains consistent. ```python import torch # Create a flat tensor with 24 elements x = torch.randn(24) # Unflatten dimension 0 into a 3D tensor of shape 2x3x4 y = torch.unflatten(x, dim=0, sizes=(2, 3, 4)) print("Shape of y:", y.shape) ``` **Output:** ```text torch.Size([2, 3, 4]) ``` --- ### Example 3: Unflattening Named Dimensions PyTorch supports named tensors. If your tensor has named dimensions, you can specify the target dimension using its name instead of its index. ```python import torch # Create a flat tensor x = torch.randn(12) # Assign a name 'N' to the dimension x = x.rename('N') # Unflatten the dimension by specifying its name y = torch.unflatten(x, dim='N', sizes=(3, 4)) print("Shape of y:", y.shape) ``` **Output:** ```text torch.Size([3, 4]) ``` --- ## Key Considerations 1. **Product Match Requirement**: The product of the dimensions specified in `sizes` must exactly match the size of the dimension being unflattened. For example, if `input.size(dim)` is `12`, the product of `sizes` must be `12` (e.g., `(3, 4)`, `(2, 6)`, or `(2, 3, 2)`). If they do not match, PyTorch will raise a `RuntimeError`. 2. **Memory Sharing**: Like `torch.view` and `torch.reshape`, `torch.unflatten` returns a view of the original tensor whenever possible. Modifying the returned tensor will modify the original tensor because they share the same underlying data storage. 3. **Alternative to `reshape` / `view`**: While you can achieve similar results using `tensor.view()` or `tensor.reshape()`, `torch.unflatten` is much safer and more readable when you only want to target and split a *specific* dimension without affecting or needing to calculate the rest of the tensor's dimensions.
← Pytorch Torch Unique_ConsecutiPytorch Torch Typename β†’