YouTip LogoYouTip

Pytorch Torch Dsplit

## PyTorch torch.dsplit `torch.dsplit` is a PyTorch function used to split a tensor into multiple sub-tensors depth-wise (along the third dimension, which corresponds to `dim=2`). It is equivalent to calling `torch.tensor_split` with `dim=2`, but it provides a more intuitive and convenient API when working with multi-dimensional data (such as 3D or 4D tensors) where depth-wise partitioning is required. --- ## Syntax ```python torch.dsplit(input, indices_or_sections) ``` ### Parameters | Parameter | Type | Description | | :--- | :--- | :--- | | `input` | `Tensor` | The tensor to be split. Must be at least 3-dimensional. | | `indices_or_sections` | `int` or `list` / `tuple` of `ints` | **If an integer $N$:** Splits the tensor into $N$ equal parts along the third dimension. The size of the third dimension must be divisible by $N$.

**If a list/tuple of indices:** Splits the tensor at the specified indices along the third dimension. | ### Return Value * Returns a list of sub-tensors split from the original tensor. --- ## Code Examples ### Example 1: Splitting a 3D Tensor into Equal Parts In this example, we create a 3D tensor of shape `(2, 3, 4)` and split it into 2 equal parts along the depth dimension (`dim=2`). ```python import torch # Create a 3D tensor of shape (2, 3, 4) x = torch.arange(24).reshape(2, 3, 4) print("Original 3D Tensor:") print(x) # Split depth-wise into 2 equal parts result = torch.dsplit(x, 2) print("\nSplit into 2 parts along depth:") for i, t in enumerate(result): print(f" Chunk {i}:\n{t}") ``` **Output:** ```text Original 3D Tensor: tensor([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]) Split into 2 parts along depth: Chunk 0: tensor([[[ 0, 1], [ 4, 5], [ 8, 9]], [[12, 13], [16, 17], [20, 21]]]) Chunk 1: tensor([[[ 2, 3], [ 6, 7], [10, 11]], [[14, 15], [18, 19], [22, 23]]]) ``` --- ### Example 2: Splitting at Specific Indices You can also pass a list of indices to split the tensor at specific positions. For example, passing `[1, 3]` splits the tensor along the third dimension at indices `1` and `3`, resulting in three slices: `[:1]`, `[1:3]`, and `[3:]`. ```python import torch x = torch.arange(24).reshape(2, 3, 4) # Split at indices [1, 3] result = torch.dsplit(x, [1, 3]) print("Split at indices [1, 3]:") for i, t in enumerate(result): print(f" Chunk {i}:\n{t}") ``` **Output:** ```text Split at indices [1, 3]: Chunk 0: tensor([[, , ], [, , ]]) Chunk 1: tensor([[[ 1, 2], [ 5, 6], [ 9, 10]], [[13, 14], [17, 18], [21, 22]]]) Chunk 2: tensor([[, , ], [, , ]]) ``` --- ### Example 3: Splitting a 4D Tensor `torch.dsplit` works on any tensor with 3 or more dimensions. Here, we apply it to a 4D tensor of shape `(2, 2, 4, 2)`. The split still occurs along the third dimension (index `2`). ```python import torch # Create a 4D tensor of shape (2, 2, 4, 2) y = torch.arange(32).reshape(2, 2, 4, 2) print("Original 4D Tensor shape:", y.shape) # Split into 2 parts along the third dimension result = torch.dsplit(y, 2) print("Split into 2 parts along the third dimension:") for i, t in enumerate(result): print(f" Chunk {i} shape: {t.shape}") ``` **Output:** ```text Original 4D Tensor shape: torch.Size([2, 2, 4, 2]) Split into 2 parts along the third dimension: Chunk 0 shape: torch.Size([2, 2, 2, 2]) Chunk 1 shape: torch.Size([2, 2, 2, 2]) ``` --- ## Considerations 1. **Dimensionality Requirement**: The input tensor must be at least 3-dimensional. Passing a 1D or 2D tensor to `torch.dsplit` will raise a `RuntimeError`. For 1D and 2D tensors, use `torch.hsplit` or `torch.vsplit` instead. 2. **Divisibility**: If `indices_or_sections` is passed as an integer $N$, the size of the third dimension of the input tensor must be exactly divisible by $N$. If it is not divisible, PyTorch will throw a `RuntimeError`. If you need to split a tensor into unequal parts, pass a list of indices instead. 3. **Memory View**: Similar to other splitting functions in PyTorch, `torch.dsplit` returns views of the original tensor whenever possible, meaning that modifying a split chunk will also modify the original tensor.
← Pytorch Torch DtypePytorch Torch Divide β†’