YouTip LogoYouTip

Pytorch Torch Concat

# PyTorch Tensor Concatenation: A Complete Guide to `torch.cat` In PyTorch, combining multiple tensors along a specific dimension is one of the most fundamental operations in data preprocessing, feature engineering, and deep learning model architecture design (such as skip connections in U-Net or concatenating channel features in CNNs). While many developers look for a `torch.concat` function, PyTorch's primary implementation for this operation is **`torch.cat`**. (Note: `torch.concat` was introduced in newer PyTorch versions as an alias to `torch.cat` to match NumPy's naming convention). This comprehensive guide covers the syntax, parameters, practical code examples, and key considerations for using `torch.cat` effectively. --- ## Syntax and Usage The `torch.cat` function concatenates a sequence of tensors along a given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty. ### Syntax ```python torch.cat(tensors, dim=0, *, out=None) -> Tensor ``` ### Parameters | Parameter | Type | Description | | :--- | :--- | :--- | | `tensors` | `Sequence` | A Python sequence (list or tuple) of PyTorch tensors to be concatenated. | | `dim` | `int` (Optional) | The dimension (axis) along which the tensors will be concatenated. Defaults to `0`. | | `out` | `Tensor` (Optional) | The output tensor to store the result. | ### Key Difference: `torch.cat` vs. `torch.stack` * **`torch.cat`** joins the given tensors along an **existing** dimension. The dimensionality (number of dimensions) of the output tensor remains the same as the inputs. * **`torch.stack`** joins the sequence of tensors along a **new** dimension. The output tensor will have one more dimension than the input tensors. --- ## Code Examples Let's explore how `torch.cat` works in practice with 1D, 2D, and multi-dimensional tensors. ### 1. Concatenating 1D Tensors (Vectors) For 1D tensors, concatenation simply appends one vector to the end of another. ```python import torch # Create two 1D tensors tensor1 = torch.tensor([1, 2, 3]) tensor2 = torch.tensor([4, 5, 6, 7]) # Concatenate along the only dimension (dim=0) result = torch.cat((tensor1, tensor2), dim=0) print("Tensor 1:", tensor1) print("Tensor 2:", tensor2) print("Concatenated Result:", result) print("Result Shape:", result.shape) ``` **Output:** ```text Tensor 1: tensor([1, 2, 3]) Tensor 2: tensor([4, 5, 6, 7]) Concatenated Result: tensor([1, 2, 3, 4, 5, 6, 7]) Result Shape: torch.Size() ``` --- ### 2. Concatenating 2D Tensors (Matrices) When working with 2D tensors, you can concatenate them either vertically (`dim=0`) or horizontally (`dim=1`). ```python import torch # Create two 2D tensors of shape (2, 3) x = torch.tensor([[1, 2, 3], [4, 5, 6]]) y = torch.tensor([[7, 8, 9], [10, 11, 12]]) # Concatenate vertically (along rows, dim=0) # The resulting shape will be (2+2, 3) -> (4, 3) result_dim0 = torch.cat((x, y), dim=0) # Concatenate horizontally (along columns, dim=1) # The resulting shape will be (2, 3+3) -> (2, 6) result_dim1 = torch.cat((x, y), dim=1) print("Concatenate along dim=0 (Vertical):\n", result_dim0) print("Shape along dim=0:", result_dim0.shape) print("\nConcatenate along dim=1 (Horizontal):\n", result_dim1) print("Shape along dim=1:", result_dim1.shape) ``` **Output:** ```text Concatenate along dim=0 (Vertical): tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) Shape along dim=0: torch.Size([4, 3]) Concatenate along dim=1 (Horizontal): tensor([[ 1, 2, 3, 7, 8, 9], [ 4, 5, 6, 10, 11, 12]]) Shape along dim=1: torch.Size([2, 6]) ``` --- ### 3. Real-World Deep Learning Example: Concatenating Feature Maps In convolutional neural networks (like DenseNet or U-Net), feature maps from different layers are often concatenated along the channel dimension (`dim=1` for `[Batch, Channel, Height, Width]` format). ```python import torch # Simulate two feature maps: [Batch Size, Channels, Height, Width] # Feature map 1: Batch=2, Channels=16, H=28, W=28 features_1 = torch.randn(2, 16, 28, 28) # Feature map 2: Batch=2, Channels=32, H=28, W=28 features_2 = torch.randn(2, 32, 28, 28) # Concatenate along the channel dimension (dim=1) combined_features = torch.cat((features_1, features_2), dim=1) print("Features 1 Shape:", features_1.shape) print("Features 2 Shape:", features_2.shape) print("Combined Features Shape:", combined_features.shape) ``` **Output:** ```text Features 1 Shape: torch.Size([2, 16, 28, 28]) Features 2 Shape: torch.Size([2, 32, 28, 28]) Combined Features Shape: torch.Size([2, 48, 28, 28]) ``` --- ## Considerations and Best Practices To avoid runtime errors when using `torch.cat`, keep the following rules in mind: ### 1. Dimension Matching Rule All dimensions of the input tensors **must match exactly**, except for the dimension along which you are concatenating. * **Invalid Example:** Trying to concatenate a `(2, 3)` tensor and a `(3, 4)` tensor along `dim=0` will raise a `RuntimeError` because their second dimensions (3 and 4) do not match. ```python import torch t1 = torch.randn(2, 3) t2 = torch.randn(3, 4) # This will raise a RuntimeError: # Sizes of tensors must match except on the concatenate dimension try: torch.cat((t1, t2), dim=0) except RuntimeError as e: print("Error:", e) ``` ### 2. Device Matching (CPU vs. GPU) All tensors passed to `torch.cat` must reside on the **same device** (either all on CPU or all on the same GPU device). Attempting to concatenate a CPU tensor with a CUDA tensor will result in a runtime error. ```python # Incorrect: # t1 = t1.to('cuda') # t2 = t2.to('cpu') # torch.cat((t1, t2)) -> Throws Device Mismatch Error ``` ### 3. Memory Efficiency `torch.cat` creates a **new tensor** in memory and copies the data from the source tensors. If you are concatenating very large tensors or performing concatenation inside a loop, it can lead to high memory consumption and slow performance. * **Tip:** Instead of concatenating tensors incrementally inside a loop, append them to a standard Python list first, and call `torch.cat` once on the entire list outside the loop. --- ## Summary * Use **`torch.cat`** (or its alias `torch.concat`) to join tensors along an existing dimension. * Ensure all non-concatenated dimensions are identical in size. * Ensure all tensors are on the same device (CPU/GPU) and share the same data type.
← Pytorch Torch Conj_PhysicalPytorch Torch Compiled_With_Cx β†’