Pytorch Torch Select_Scatter
## PyTorch torch.select_scatter
`torch.select_scatter` is a PyTorch function used to scatter a source tensor (`src`) into a destination tensor (`input`) at a specified index along a given dimension.
This function is the out-of-place equivalent of selecting a slice of a tensor using `tensor[..., index, ...] = src`. It allows you to update specific slices of a tensor without modifying the original tensor in place, which is highly beneficial for maintaining clean computation graphs in autograd.
---
### Syntax
```python
torch.select_scatter(input, src, dim, index) -> Tensor
```
### Parameters
* **`input`** (*Tensor*): The input (destination) tensor that you want to copy values into.
* **`src`** (*Tensor*): The source tensor containing the values to be scattered into the `input` tensor.
* **`dim`** (*int*): The dimension along which to index.
* **`index`** (*int*): The index in the specified dimension where the `src` tensor will be scattered.
### Return Value
* **`Tensor`**: A new tensor with the values of `src` copied into `input` at the specified `dim` and `index`. The original `input` tensor remains unchanged.
---
## Code Examples
### Example 1: Basic 2D Tensor Scattering
In this example, we scatter a 1D tensor of ones into the second row (index `1` along dimension `0`) of a 4x4 zero matrix.
```python
import torch
# Create input and source tensors
input_tensor = torch.zeros(4, 4)
src = torch.ones(4)
# Scatter src into input at dimension 0, index 1 (the second row)
output = torch.select_scatter(input_tensor, src, dim=0, index=1)
print("Input Tensor:")
print(input_tensor)
print("\nSource Tensor:")
print(src)
print("\nResult after scattering at index 1:")
print(output)
```
**Output:**
```text
Input Tensor:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
Source Tensor:
tensor([1., 1., 1., 1.])
Result after scattering at index 1:
tensor([[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
```
---
### Example 2: Scattering in a 3D Tensor
This example demonstrates how to scatter values along a specific dimension of a higher-dimensional (3D) tensor.
```python
import torch
# Create a 3D tensor of zeros
input_tensor = torch.zeros(3, 4, 5)
src = torch.ones(5)
# Scatter along the second dimension (dim=1) at index 2
output = torch.select_scatter(input_tensor, src, dim=1, index=2)
print("Input Shape:", input_tensor.shape)
print("Source Shape:", src.shape)
print("Output Shape:", output.shape)
print("\nOutput slice at dim=1, index=2 (for the first batch element):")
print(output)
```
**Output:**
```text
Input Shape: torch.Size([3, 4, 5])
Source Shape: torch.Size()
Output Shape: torch.Size([3, 4, 5])
Output slice at dim=1, index=2 (for the first batch element):
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.]])
```
---
### Example 3: Using Negative Indexing
Like standard Python and PyTorch indexing, `torch.select_scatter` supports negative integers to index relative to the end of a dimension.
```python
import torch
# Create a 4x4 sequential tensor
input_tensor = torch.arange(16).reshape(4, 4).float()
src = torch.tensor([100.0, 100.0, 100.0, 100.0])
# Scatter into the last row (index -1)
output = torch.select_scatter(input_tensor, src, dim=0, index=-1)
print("Original Tensor:")
print(input_tensor)
print("\nResult after scattering at index -1:")
print(output)
```
**Output:**
```text
Original Tensor:
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]])
Result after scattering at index -1:
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[100., 100., 100., 100.]])
```
---
## Important Considerations
1. **Out-of-Place Operation**: `torch.select_scatter` does **not** modify the original `input` tensor. It allocates and returns a new tensor. If you need to perform an in-place update, use standard tensor slicing assignment (e.g., `input[dim, index] = src`).
2. **Shape Matching**: The shape of the `src` tensor must match the shape of the slice being replaced. Specifically, the shape of `src` must be equal to the shape of `input.select(dim, index)`. If the shapes do not match, PyTorch will throw a runtime error.
3. **Gradient Tracking**: Because this is a differentiable out-of-place operation, gradients will flow correctly through both `input` and `src` during backpropagation, making it ideal for use inside custom neural network layers.
YouTip