Pytorch Torch Concat
# PyTorch Tensor Concatenation: A Complete Guide to `torch.cat`
In PyTorch, combining multiple tensors along a specific dimension is one of the most fundamental operations in data preprocessing, feature engineering, and deep learning model architecture design (such as skip connections in U-Net or concatenating channel features in CNNs).
While many developers look for a `torch.concat` function, PyTorch's primary implementation for this operation is **`torch.cat`**. (Note: `torch.concat` was introduced in newer PyTorch versions as an alias to `torch.cat` to match NumPy's naming convention).
This comprehensive guide covers the syntax, parameters, practical code examples, and key considerations for using `torch.cat` effectively.
---
## Syntax and Usage
The `torch.cat` function concatenates a sequence of tensors along a given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
### Syntax
```python
torch.cat(tensors, dim=0, *, out=None) -> Tensor
```
### Parameters
| Parameter | Type | Description |
| :--- | :--- | :--- |
| `tensors` | `Sequence` | A Python sequence (list or tuple) of PyTorch tensors to be concatenated. |
| `dim` | `int` (Optional) | The dimension (axis) along which the tensors will be concatenated. Defaults to `0`. |
| `out` | `Tensor` (Optional) | The output tensor to store the result. |
### Key Difference: `torch.cat` vs. `torch.stack`
* **`torch.cat`** joins the given tensors along an **existing** dimension. The dimensionality (number of dimensions) of the output tensor remains the same as the inputs.
* **`torch.stack`** joins the sequence of tensors along a **new** dimension. The output tensor will have one more dimension than the input tensors.
---
## Code Examples
Let's explore how `torch.cat` works in practice with 1D, 2D, and multi-dimensional tensors.
### 1. Concatenating 1D Tensors (Vectors)
For 1D tensors, concatenation simply appends one vector to the end of another.
```python
import torch
# Create two 1D tensors
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6, 7])
# Concatenate along the only dimension (dim=0)
result = torch.cat((tensor1, tensor2), dim=0)
print("Tensor 1:", tensor1)
print("Tensor 2:", tensor2)
print("Concatenated Result:", result)
print("Result Shape:", result.shape)
```
**Output:**
```text
Tensor 1: tensor([1, 2, 3])
Tensor 2: tensor([4, 5, 6, 7])
Concatenated Result: tensor([1, 2, 3, 4, 5, 6, 7])
Result Shape: torch.Size()
```
---
### 2. Concatenating 2D Tensors (Matrices)
When working with 2D tensors, you can concatenate them either vertically (`dim=0`) or horizontally (`dim=1`).
```python
import torch
# Create two 2D tensors of shape (2, 3)
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
y = torch.tensor([[7, 8, 9],
[10, 11, 12]])
# Concatenate vertically (along rows, dim=0)
# The resulting shape will be (2+2, 3) -> (4, 3)
result_dim0 = torch.cat((x, y), dim=0)
# Concatenate horizontally (along columns, dim=1)
# The resulting shape will be (2, 3+3) -> (2, 6)
result_dim1 = torch.cat((x, y), dim=1)
print("Concatenate along dim=0 (Vertical):\n", result_dim0)
print("Shape along dim=0:", result_dim0.shape)
print("\nConcatenate along dim=1 (Horizontal):\n", result_dim1)
print("Shape along dim=1:", result_dim1.shape)
```
**Output:**
```text
Concatenate along dim=0 (Vertical):
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
Shape along dim=0: torch.Size([4, 3])
Concatenate along dim=1 (Horizontal):
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
Shape along dim=1: torch.Size([2, 6])
```
---
### 3. Real-World Deep Learning Example: Concatenating Feature Maps
In convolutional neural networks (like DenseNet or U-Net), feature maps from different layers are often concatenated along the channel dimension (`dim=1` for `[Batch, Channel, Height, Width]` format).
```python
import torch
# Simulate two feature maps: [Batch Size, Channels, Height, Width]
# Feature map 1: Batch=2, Channels=16, H=28, W=28
features_1 = torch.randn(2, 16, 28, 28)
# Feature map 2: Batch=2, Channels=32, H=28, W=28
features_2 = torch.randn(2, 32, 28, 28)
# Concatenate along the channel dimension (dim=1)
combined_features = torch.cat((features_1, features_2), dim=1)
print("Features 1 Shape:", features_1.shape)
print("Features 2 Shape:", features_2.shape)
print("Combined Features Shape:", combined_features.shape)
```
**Output:**
```text
Features 1 Shape: torch.Size([2, 16, 28, 28])
Features 2 Shape: torch.Size([2, 32, 28, 28])
Combined Features Shape: torch.Size([2, 48, 28, 28])
```
---
## Considerations and Best Practices
To avoid runtime errors when using `torch.cat`, keep the following rules in mind:
### 1. Dimension Matching Rule
All dimensions of the input tensors **must match exactly**, except for the dimension along which you are concatenating.
* **Invalid Example:** Trying to concatenate a `(2, 3)` tensor and a `(3, 4)` tensor along `dim=0` will raise a `RuntimeError` because their second dimensions (3 and 4) do not match.
```python
import torch
t1 = torch.randn(2, 3)
t2 = torch.randn(3, 4)
# This will raise a RuntimeError:
# Sizes of tensors must match except on the concatenate dimension
try:
torch.cat((t1, t2), dim=0)
except RuntimeError as e:
print("Error:", e)
```
### 2. Device Matching (CPU vs. GPU)
All tensors passed to `torch.cat` must reside on the **same device** (either all on CPU or all on the same GPU device). Attempting to concatenate a CPU tensor with a CUDA tensor will result in a runtime error.
```python
# Incorrect:
# t1 = t1.to('cuda')
# t2 = t2.to('cpu')
# torch.cat((t1, t2)) -> Throws Device Mismatch Error
```
### 3. Memory Efficiency
`torch.cat` creates a **new tensor** in memory and copies the data from the source tensors. If you are concatenating very large tensors or performing concatenation inside a loop, it can lead to high memory consumption and slow performance.
* **Tip:** Instead of concatenating tensors incrementally inside a loop, append them to a standard Python list first, and call `torch.cat` once on the entire list outside the loop.
---
## Summary
* Use **`torch.cat`** (or its alias `torch.concat`) to join tensors along an existing dimension.
* Ensure all non-concatenated dimensions are identical in size.
* Ensure all tensors are on the same device (CPU/GPU) and share the same data type.
YouTip