YouTip LogoYouTip

Pytorch Torch Triu

## PyTorch torch.triu The `torch.triu` function in PyTorch is used to extract the upper triangular part of a matrix (2D tensor) or a batch of matrices. Any elements below the specified diagonal are set to zero, while the elements on and above the diagonal remain unchanged. --- ### Function Definition ```python torch.triu(input, diagonal=0, *, out=None) -> Tensor ``` ### Parameter Descriptions * **`input`** (*Tensor*): The input tensor. It must be a tensor of shape `(*, m, n)`, where `*` represents zero or more batch dimensions. * **`diagonal`** (*int, optional*): The diagonal to consider. * `diagonal = 0` (default) refers to the main diagonal. * `diagonal > 0` refers to diagonals above the main diagonal. * `diagonal < 0` refers to diagonals below the main diagonal. * **`out`** (*Tensor, optional*): The output tensor. --- ## How the Diagonal Parameter Works The `diagonal` parameter controls which elements are retained: $$\text{out}_{i,j} = \begin{cases} \text{input}_{i,j} & \text{if } j \ge i + \text{diagonal} \\ 0 & \text{otherwise} \end{cases}$$ * **`diagonal = 0`**: Retains the main diagonal and everything above it. * **`diagonal = 1`**: Excludes the main diagonal, retaining only elements starting from the first diagonal above it. * **`diagonal = -1`**: Retains the main diagonal, everything above it, and the first diagonal below it. --- ## Code Examples ### Example 1: Extracting the Standard Upper Triangular Part (Default) By default, `diagonal=0` extracts the upper triangular part including the main diagonal. ```python import torch # Create a 3x3 matrix a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # Extract the upper triangular part y = torch.triu(a) print(y) ``` **Output:** ```text tensor([[1, 2, 3], [0, 5, 6], [0, 0, 9]]) ``` --- ### Example 2: Using a Positive Diagonal Offset (`diagonal > 0`) Setting `diagonal=1` shifts the boundary up, zeroing out the main diagonal as well. ```python import torch # Create a 3x3 matrix a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # Extract the upper triangular part starting above the main diagonal y = torch.triu(a, diagonal=1) print(y) ``` **Output:** ```text tensor([[0, 2, 3], [0, 0, 6], [0, 0, 0]]) ``` --- ### Example 3: Using a Negative Diagonal Offset (`diagonal < 0`) Setting `diagonal=-1` shifts the boundary down, retaining one diagonal below the main diagonal. ```python import torch # Create a 3x3 matrix a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # Extract the upper triangular part including one diagonal below the main diagonal y = torch.triu(a, diagonal=-1) print(y) ``` **Output:** ```text tensor([[1, 2, 3], [4, 5, 6], [0, 8, 9]]) ``` --- ### Example 4: Handling Non-Square Matrices `torch.triu` works seamlessly on non-square (rectangular) matrices. ```python import torch # Create a 2x3 non-square matrix a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Extract the upper triangular part y = torch.triu(a) print(y) ``` **Output:** ```text tensor([[1, 2, 3], [0, 5, 6]]) ``` --- ### Example 5: Batch Processing (Multi-dimensional Tensors) `torch.triu` supports batched inputs. If the input tensor has more than two dimensions, the operation is applied to each of the innermost 2D matrices individually. ```python import torch # Create a batched tensor of shape (2, 3, 3) a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]]]) # Apply triu to the batch y = torch.triu(a) print(y) ``` **Output:** ```text tensor([[[ 1, 2, 3], [ 0, 5, 6], [ 0, 0, 9]], [[10, 11, 12], [ 0, 14, 15], [ 0, 0, 18]]]) ``` --- ## Key Considerations 1. **In-place Operations**: `torch.triu` returns a new tensor. If you want to perform the operation in-place on an existing tensor, use `input.triu_(diagonal)` instead. 2. **Memory Layout**: The returned tensor shares the same underlying storage layout and data type as the input tensor. 3. **Common Use Cases**: * **Transformer Attention Masks**: Used to create causal masks (upper triangular masks) in self-attention mechanisms to prevent tokens from attending to future tokens. * **Linear Algebra**: Used in matrix decompositions (such as LU or QR decomposition) where upper triangular matrices are required.
← Pytorch Torch True_DividePytorch Torch Tril β†’