Pytorch Torch Triu
## PyTorch torch.triu
The `torch.triu` function in PyTorch is used to extract the upper triangular part of a matrix (2D tensor) or a batch of matrices. Any elements below the specified diagonal are set to zero, while the elements on and above the diagonal remain unchanged.
---
### Function Definition
```python
torch.triu(input, diagonal=0, *, out=None) -> Tensor
```
### Parameter Descriptions
* **`input`** (*Tensor*): The input tensor. It must be a tensor of shape `(*, m, n)`, where `*` represents zero or more batch dimensions.
* **`diagonal`** (*int, optional*): The diagonal to consider.
* `diagonal = 0` (default) refers to the main diagonal.
* `diagonal > 0` refers to diagonals above the main diagonal.
* `diagonal < 0` refers to diagonals below the main diagonal.
* **`out`** (*Tensor, optional*): The output tensor.
---
## How the Diagonal Parameter Works
The `diagonal` parameter controls which elements are retained:
$$\text{out}_{i,j} = \begin{cases} \text{input}_{i,j} & \text{if } j \ge i + \text{diagonal} \\ 0 & \text{otherwise} \end{cases}$$
* **`diagonal = 0`**: Retains the main diagonal and everything above it.
* **`diagonal = 1`**: Excludes the main diagonal, retaining only elements starting from the first diagonal above it.
* **`diagonal = -1`**: Retains the main diagonal, everything above it, and the first diagonal below it.
---
## Code Examples
### Example 1: Extracting the Standard Upper Triangular Part (Default)
By default, `diagonal=0` extracts the upper triangular part including the main diagonal.
```python
import torch
# Create a 3x3 matrix
a = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Extract the upper triangular part
y = torch.triu(a)
print(y)
```
**Output:**
```text
tensor([[1, 2, 3],
[0, 5, 6],
[0, 0, 9]])
```
---
### Example 2: Using a Positive Diagonal Offset (`diagonal > 0`)
Setting `diagonal=1` shifts the boundary up, zeroing out the main diagonal as well.
```python
import torch
# Create a 3x3 matrix
a = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Extract the upper triangular part starting above the main diagonal
y = torch.triu(a, diagonal=1)
print(y)
```
**Output:**
```text
tensor([[0, 2, 3],
[0, 0, 6],
[0, 0, 0]])
```
---
### Example 3: Using a Negative Diagonal Offset (`diagonal < 0`)
Setting `diagonal=-1` shifts the boundary down, retaining one diagonal below the main diagonal.
```python
import torch
# Create a 3x3 matrix
a = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Extract the upper triangular part including one diagonal below the main diagonal
y = torch.triu(a, diagonal=-1)
print(y)
```
**Output:**
```text
tensor([[1, 2, 3],
[4, 5, 6],
[0, 8, 9]])
```
---
### Example 4: Handling Non-Square Matrices
`torch.triu` works seamlessly on non-square (rectangular) matrices.
```python
import torch
# Create a 2x3 non-square matrix
a = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# Extract the upper triangular part
y = torch.triu(a)
print(y)
```
**Output:**
```text
tensor([[1, 2, 3],
[0, 5, 6]])
```
---
### Example 5: Batch Processing (Multi-dimensional Tensors)
`torch.triu` supports batched inputs. If the input tensor has more than two dimensions, the operation is applied to each of the innermost 2D matrices individually.
```python
import torch
# Create a batched tensor of shape (2, 3, 3)
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]]])
# Apply triu to the batch
y = torch.triu(a)
print(y)
```
**Output:**
```text
tensor([[[ 1, 2, 3],
[ 0, 5, 6],
[ 0, 0, 9]],
[[10, 11, 12],
[ 0, 14, 15],
[ 0, 0, 18]]])
```
---
## Key Considerations
1. **In-place Operations**: `torch.triu` returns a new tensor. If you want to perform the operation in-place on an existing tensor, use `input.triu_(diagonal)` instead.
2. **Memory Layout**: The returned tensor shares the same underlying storage layout and data type as the input tensor.
3. **Common Use Cases**:
* **Transformer Attention Masks**: Used to create causal masks (upper triangular masks) in self-attention mechanisms to prevent tokens from attending to future tokens.
* **Linear Algebra**: Used in matrix decompositions (such as LU or QR decomposition) where upper triangular matrices are required.
YouTip