YouTip LogoYouTip

Pytorch Torch Cat

PyTorch torch.cat Function |

\\n\\n

Image 1: Pytorch torch Reference Manual Pytorch torch Reference Manual

\\n\\n

torch.cat is PyTorch Function used to concatenate multiple tensors along a specified dimension. It concatenates multiple tensors along the specified dimension into a single larger tensor.

\\n\\n

This is a very commonly used operation in deep learning, such as in scenarios like concatenating feature maps or stacking data batches.

\\n\\n

Function Definition

\\n\\n

torch.cat(tensors, dim=0, out=None)
\\nParameter:

\\n\\n
    \\n
  • tensors (Sequence of Tensor): Sequence of tensors to concatenate. All tensors must have the same shape in all dimensions except the concatenation dimension.
  • \\n
  • dim (int, optional): The dimension along which to concatenate, defaults to 0.
  • \\n
  • out (Tensor, optional): Output tensor.
  • \\n
\\n\\n

Return Value:

\\n\\n
    \\n
  • torch.Tensor: Returns the concatenated tensor.
  • \\n
\\n\\n
\\n\\n

Usage example

\\n\\n

Example 1: Concatenate along the first dimension

\\n\\n

Examples

\\n\\n
\\nimport torch

\\n# Create two tensors
\\na = torch.randn(2,3)
\\nb = torch.randn(2,3)

\\n# Concatenate along the first dimension (line)
\\nc = torch.cat([a, b], dim=0)

\\nprint("a shape of:", a.shape)
\\nprint("b shape of:", b.shape)
\\nprint("c shape of:", c.shape)
\\nprint(c)\\n
\\n\\n

The output result is:

\\n\\n

a shape of: torch.Size([2, 3]) b shape of: torch.Size([2, 3]) c shape of: torch.Size([4, 3]) tensor([[ 0.2532, 0.3643, 0.5341], [ 0.9578, 0.9086, -0.2847], [-0.7108, -0.0142, 0.7168], [-0.1542, -0.9841, -1.4945]])

\\n\\n

Example 2: Concatenate along the second dimension

\\n\\n

Examples

\\n\\n
\\nimport torch

\\n# Create two tensors
\\na = torch.randn(2,3)
\\nb = torch.randn(2,4)

\\n# Concatenate along the second dimension (column)
\\nc = torch.cat([a, b], dim=1)

\\nprint("a shape of:", a.shape)
\\nprint("b shape of:", b.shape)
\\nprint("c shape of:", c.shape)\\n
\\n\\n

The output result is:

\\n\\n

a shape of: torch.Size([2, 3]) b shape of: torch.Size([2, 4]) c shape of: torch.Size([2, 7])

\\n\\n

Example 3: Concatenate multiple tensors

\\n\\n

Examples

\\n\\n
\\nimport torch

\\n# Create multiple tensors
\\na = torch.tensor([1,2])
\\nb = torch.tensor([3,4])
\\nc = torch.tensor([5,6])

\\n# Concatenate multiple tensors
\\nresult = torch.cat([a, b, c])

\\nprint(result)\\n
\\n\\n

The output result is:

\\n\\n

tensor([1, 2, 3, 4, 5, 6])

\\n\\n

Example 4: Concatenating features in neural networks

\\n\\n
\\n

Examples

\\n
\\nimport torch

\\n# Simulating feature maps from different layers
\\nfeature1 = torch.randn(1, 64, 32, 32)  # Features from the first layer
\\nfeature2 = torch.randn(1, 128, 32, 32) # Features from the second layer

\\n# Along the channel dimension (dim=1οΌ‰Concatenate features
\\ncombined = torch.cat([feature1, feature2], dim=1)

\\nprint("Feature 1 shape:", feature1.shape)
\\nprint("Feature 2 shape:", feature2.shape)
\\nprint("Shape after concatenation:", combined.shape)
\\n
\\n
\\n\\n

The output result is:

\\n\\n
Feature 1 shape: torch.Size([1, 64, 32, 32])Feature 2 shape: torch.Size([1, 128, 32, 32])Shape after concatenation: torch.Size([1, 192, 32, 32])
\\n\\n

In neural networks,torch.cat Commonly used to fuse features from different layers in structures like Feature Pyramid Networks (FPN).

\\n\\n
\\n\\n

torch.cat and torch.stack difference between

\\n\\n
    \\n
  • torch.cat: Concatenates along an existing dimension, where tensor shapes are summed along the concatenation dimension.
  • \\n
  • torch.stack: Stacks along a new dimension, adding a new dimension.
  • \\n
\\n\\n

Examples

\\n\\n
\\nimport torch

\\na = torch.randn(2,3)
\\nb = torch.randn(2,3)

\\n# cat and stack difference between
\\ncat_result = torch.cat([a, b], dim=0)
\\nstack_result = torch.stack([a, b], dim=0)

\\nprint("cat Result Shape:", cat_result.shape)# (4, 3)
\\nprint("stack Result Shape:", stack_result.shape)# (2, 2, 3)\\n
\\n\\n
\\n\\n

Image 2: Pytorch torch Reference Manual Pytorch torch Reference Manual

← Pytorch Torch CeilPytorch Torch Block_Diag β†’