## PyTorch torch.vsplit
`torch.vsplit` is a PyTorch utility function used to split a tensor vertically (along the row dimension, which is the first dimension, `dim=0`). It is equivalent to calling `torch.tensor_split` with `dim=0`.
This function is highly useful when you need to slice datasets, partition batches, or divide multi-dimensional feature maps vertically.
---
### Syntax
```python
torch.vsplit(input, indices_or_sections) -> List
```
### Parameters
| Parameter | Type | Description |
| :--- | :--- | :--- |
| `input` | `Tensor` | The tensor to be split. Must be at least 2-dimensional. |
| `indices_or_sections` | `int` or `list` / `tuple` of `int` | **If an integer $N$:** The tensor is split into $N$ equal sections along the first dimension. If the dimension size is not divisible by $N$, an error will be raised.
**If a list/tuple of indices:** The tensor is split at the specified indices along the first dimension. For example, `[1, 3]` splits the tensor into slices `[:1]`, `[1:3]`, and `[3:]`. |
### Return Value
* Returns a list of views of the input tensor. Modifying the returned tensors will affect the original tensor.
---
### Code Examples
The following examples demonstrate how to split 2D and 3D tensors using both integer sections and index lists.
```python
import torch
# ---------------------------------------------------------
# Example 1: Splitting a 2D Tensor into Equal Sections
# ---------------------------------------------------------
x = torch.arange(12).reshape(4, 3)
print("Original 2D Tensor:")
print(x)
print("-" * 40)
# Split vertically into 2 equal parts
result_equal = torch.vsplit(x, 2)
print("Split vertically into 2 equal parts:")
for i, t in enumerate(result_equal):
print(f" Chunk {i}:\n{t}")
print("-" * 40)
# ---------------------------------------------------------
# Example 2: Splitting a 2D Tensor by Specific Indices
# ---------------------------------------------------------
# Splitting at indices [1, 3] produces three slices:
# Chunk 0: x[0:1]
# Chunk 1: x[1:3]
# Chunk 2: x[3:]
result_indices = torch.vsplit(x, [1, 3])
print("Split vertically at indices [1, 3]:")
for i, t in enumerate(result_indices):
print(f" Chunk {i}:\n{t}")
print("-" * 40)
# ---------------------------------------------------------
# Example 3: Splitting a 3D Tensor
# ---------------------------------------------------------
y = torch.arange(24).reshape(4, 3, 2)
print("Original 3D Tensor Shape:", y.shape)
# Split vertically (along dim=0) into 2 parts
result_3d = torch.vsplit(y, 2)
print("\nSplit 3D tensor along the first dimension into 2 parts:")
for i, t in enumerate(result_3d):
print(f" Chunk {i} Shape: {t.shape}")
```
### Output
```text
Original 2D Tensor:
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
----------------------------------------
Split vertically into 2 equal parts:
Chunk 0:
tensor([[0, 1, 2],
[3, 4, 5]])
Chunk 1:
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
----------------------------------------
Split vertically at indices [1, 3]:
Chunk 0:
tensor([[0, 1, 2]])
Chunk 1:
tensor([[3, 4, 5],
[6, 7, 8]])
Chunk 2:
tensor([[ 9, 10, 11]])
----------------------------------------
Original 3D Tensor Shape: torch.Size([4, 3, 2])
Split 3D tensor along the first dimension into 2 parts:
Chunk 0 Shape: torch.Size([2, 3, 2])
Chunk 1 Shape: torch.Size([2, 3, 2])
```
---
### Key Considerations
1. **Dimensionality Constraint**: `torch.vsplit` requires the input tensor to be at least **2-dimensional**. Attempting to use it on a 1D tensor will result in a `RuntimeError`.
2. **Divisibility**: When passing an integer $N$ for `indices_or_sections`, the size of the first dimension (`input.size(0)`) must be exactly divisible by $N$. If it is not, PyTorch will raise a `RuntimeError`. If you need to split a tensor into unequal parts, pass a list of indices instead.
3. **Memory Efficiency**: `torch.vsplit` returns **views** of the original tensor whenever possible. This means no data is copied in memory, making the operation highly efficient. However, modifying a split chunk will also modify the original tensor. If you need independent copies, call `.clone()` on the resulting chunks.
π Categories
- β‘ JavaScript (1589)
- π PHP (872)
- π Python3 (810)
- π HTML (691)
- βοΈ C# (650)
- π Python (594)
- β Java (552)
- βοΈ PyTorch (534)
- π§ Linux (472)
- βοΈ C (432)
- π¦ jQuery (406)
- π¨ CSS (377)
- π XML (259)
- π¦ jQuery UI (231)
- π― Bootstrap (220)
- βοΈ C++ (215)
- π °οΈ Angular (205)
- π HTML DOM (201)
- π΄ Redis (188)
- π Web Building (142)
- π Vue.js (141)
- π R (131)
- πΌ Pandas (124)
- ποΈ SQL (105)
- βοΈ Docker (86)
- βοΈ TypeScript (73)
- βοΈ Highcharts (70)
- π AI Agent (70)
- βοΈ React (68)
- π Node.js (65)
- βοΈ Machine Learning (60)
- π Git (59)
- π΅ Go (58)
- π Markdown (58)
- π’ NumPy (55)
- π§ͺ Flask (54)
- βοΈ Scala (53)
- ποΈ SQLite (52)
- π JSTL (52)
- βοΈ VS Code (51)
- π MongoDB (49)
- π Perl (48)
- π Ruby (47)
- π Matplotlib (47)
- βοΈ Uncategorized (46)
- π Swift (46)
- ποΈ PostgreSQL (46)
- βοΈ Data Structures (46)
- π Playwright (46)
- π iOS (45)
- ποΈ MySQL (44)
- βοΈ LangChain (43)
- π FastAPI (40)
- βοΈ Ionic (38)
- π Design Patterns (37)
- βοΈ Eclipse (37)
- π¨ CSS3 (34)
- π Lua (34)
- βοΈ Codex (34)
- πΈ Django (32)
- βοΈ OpenCV (32)
- π Rust (31)
- π JSP (31)
- βοΈ Claude Code (31)
- π Pillow (30)
- βοΈ OpenCode (28)
- π AI Skills (27)
- π Flutter (26)
- π Maven (26)
- π¨ Tailwind CSS (25)
- π§ TensorFlow (25)
- π Servlet (24)
- π Dart (23)
- π Assembly (23)
- βοΈ Memcached (22)
- βοΈ SVG (22)
- βοΈ Electron (22)
- π NLP (22)
- π Regex (21)
- π Android (20)
- π£ Kotlin (19)
- π Julia (19)
- π SOAP (17)
- π Selenium (17)
- π PowerShell (17)
- π Sass (16)
- π HTTP (16)
- π Zig (15)
- π AI (15)
- π AJAX (14)
- π Swagger (14)
- βοΈ Scikit-learn (13)
- βοΈ ECharts (13)
- βοΈ Chart.js (13)
- βοΈ Cursor (13)
- βοΈ SciPy (12)
- π RDF (12)
- π Ollama (12)
- π Next.js (12)
- π Plotly Dash (12)
- π JSON (11)
- π RESTful API (11)
- π WSDL (9)
- βοΈ CMake (8)
- π Firebug (7)
- π Nginx (6)
- βΈοΈ Kubernetes (6)
- π Jupyter (6)
- π LaTeX (4)
- π UniApp (4)
- ποΈ SQL Server (1)
YouTip