Pytorch Torch Cummax
## PyTorch torch.cummax
`torch.cummax` is a PyTorch function used to compute the cumulative maximum of elements of an input tensor along a specified dimension.
Unlike simple reduction functions (such as `torch.max`), which return a single maximum value, `torch.cummax` tracks the maximum value encountered up to each position along the chosen dimension. It returns a named tuple containing both the cumulative maximum values and their corresponding original indices.
---
## Syntax
```python
torch.cummax(input, dim, out=None) -> (Tensor, LongTensor)
```
### Parameters
| Parameter | Type | Description |
| :--- | :--- | :--- |
| `input` | `Tensor` | The input tensor. |
| `dim` | `int` | The dimension along which to compute the cumulative maximum. |
| `out` | `tuple`, optional | The output tuple of two tensors containing the cumulative maximum values and indices respectively. |
### Return Value
The function returns a named tuple `(values, indices)` where:
* **`values`**: A tensor of the same shape as `input` containing the cumulative maximum values.
* **`indices`**: A `LongTensor` containing the index location of each cumulative maximum value along the specified dimension `dim`.
---
## Code Examples
### Example 1: Cumulative Maximum of a 1D Tensor
This example demonstrates how `torch.cummax` processes a simple 1D tensor.
```python
import torch
# Create a 1D tensor
x = torch.tensor([3, 1, 4, 1, 5, 9, 2, 6])
# Compute cumulative maximum along the only dimension (dim=0)
values, indices = torch.cummax(x, dim=0)
print("Input Tensor:", x)
print("Cumulative Max Values:", values)
print("Indices of Max Values:", indices)
```
**Output:**
```text
Input Tensor: tensor([3, 1, 4, 1, 5, 9, 2, 6])
Cumulative Max Values: tensor([3, 3, 4, 4, 5, 9, 9, 9])
Indices of Max Values: tensor([0, 0, 2, 2, 4, 5, 5, 5])
```
*Explanation:*
* At index `0`, the max is `3` (index `0`).
* At index `1`, the max of `[3, 1]` is still `3` (index `0`).
* At index `2`, the max of `[3, 1, 4]` becomes `4` (index `2`).
* At index `5`, the max of `[3, 1, 4, 1, 5, 9]` becomes `9` (index `5`), which remains the maximum for the rest of the tensor.
---
### Example 2: Cumulative Maximum of a 2D Tensor
When working with multi-dimensional tensors, you must specify the dimension (`dim`) along which the operation is performed.
```python
import torch
# Create a 2D tensor (3x3 matrix)
x = torch.tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]], dtype=torch.float32)
print("Original 2D Tensor:")
print(x)
# 1. Cumulative maximum along columns (dim=0)
values_col, indices_col = torch.cummax(x, dim=0)
print("\n--- Cumulative Maximum along Columns (dim=0) ---")
print("Values:")
print(values_col)
print("Indices:")
print(indices_col)
# 2. Cumulative maximum along rows (dim=1)
values_row, indices_row = torch.cummax(x, dim=1)
print("\n--- Cumulative Maximum along Rows (dim=1) ---")
print("Values:")
print(values_row)
print("Indices:")
print(indices_row)
```
**Output:**
```text
Original 2D Tensor:
tensor([[3., 1., 4.],
[1., 5., 9.],
[2., 6., 5.]])
--- Cumulative Maximum along Columns (dim=0) ---
Values:
tensor([[3., 1., 4.],
[3., 5., 9.],
[3., 6., 9.]])
Indices:
tensor([[0, 0, 0],
[0, 1, 1],
[0, 2, 1]])
--- Cumulative Maximum along Rows (dim=1) ---
Values:
tensor([[3., 3., 4.],
[1., 5., 9.],
[2., 6., 6.]])
Indices:
tensor([[0, 0, 2],
[0, 1, 2],
[0, 1, 1]])
```
---
## Key Considerations
1. **Index Tracking**: If there are multiple occurrences of the maximum value, `torch.cummax` preserves the index of the **first** occurrence of that maximum value.
2. **Data Types**: The returned `indices` tensor always has the data type `torch.int64` (`LongTensor`), regardless of the input tensor's data type.
3. **Performance**: `torch.cummax` is highly optimized and runs on both CPU and CUDA devices. It is particularly useful in sequence processing, dynamic programming, and reinforcement learning algorithms (e.g., tracking historical peak values).
YouTip