Pytorch Torch Broadcast_To
## PyTorch torch.broadcast_to
The `torch.broadcast_to` function in PyTorch is used to broadcast an input tensor to a specified target shape. It returns a new read-only view of the original tensor, meaning it does not allocate new memory for the broadcasted elements. The broadcasting behavior follows standard NumPy broadcasting rules.
---
### Syntax
```python
torch.broadcast_to(input, shape)
```
### Parameters
| Parameter | Type | Description |
| :--- | :--- | :--- |
| `input` | `torch.Tensor` | The source tensor to be broadcasted. |
| `shape` | `tuple`, `list`, or `torch.Size` | The target shape to broadcast the tensor to. |
### Return Value
* Returns a **tensor view** with the specified target shape.
* Because it returns a view, the operation is highly efficient ($O(1)$ time and memory complexity), but the returned tensor is read-only.
---
## Broadcasting Rules
To broadcast a tensor to a target shape, PyTorch evaluates the dimensions starting from the trailing (rightmost) dimension:
1. **Dimension Compatibility**: Two dimensions are compatible if they are equal, or if one of them is `1`.
2. **Prepending Dimensions**: If the target shape has more dimensions than the input tensor, PyTorch prepends dimensions of size `1` to the input tensor's shape until the number of dimensions matches.
3. **Expansion**: Dimensions of size `1` are expanded to match the target dimension size by virtually duplicating the data without copying it in memory.
---
## Code Examples
### 1. Basic Usage: Broadcasting a 1D Tensor to 2D
In this example, a 1D tensor of shape `(3,)` is broadcasted to a 2D shape of `(3, 3)`.
```python
import torch
# Original 1D tensor
x = torch.tensor([1, 2, 3])
# Broadcast to a 2D shape (3, 3)
y = torch.broadcast_to(x, (3, 3))
print("Original Tensor:")
print(x)
print("Shape:", x.shape)
print("\nBroadcasted Tensor:")
print(y)
print("Shape:", y.shape)
# Output:
# Original Tensor:
# tensor([1, 2, 3])
# Shape: torch.Size()
#
# Broadcasted Tensor:
# tensor([[1, 2, 3],
# [1, 2, 3],
# [1, 2, 3]])
# Shape: torch.Size([3, 3])
```
---
### 2. Broadcasting a Scalar to a Multi-Dimensional Tensor
A scalar (0D tensor) can be broadcasted to any arbitrary shape.
```python
import torch
# Create a scalar tensor
x = torch.tensor(5)
# Broadcast the scalar to a 3D shape (2, 3, 4)
y = torch.broadcast_to(x, (2, 3, 4))
print("Scalar broadcasted to (2, 3, 4):")
print("Shape:", y.shape)
print("Value:\n", y)
```
---
### 3. Broadcasting a 2D Tensor to 3D
Here, a 2D tensor of shape `(2, 2)` is broadcasted to a 3D shape of `(3, 2, 2)`. PyTorch automatically prepends a dimension of size `1` to the input shape, making it `(1, 2, 2)`, and then expands it to `(3, 2, 2)`.
```python
import torch
# Original 2D tensor of shape (2, 2)
x = torch.tensor([[1, 2], [3, 4]])
# Broadcast to 3D shape (3, 2, 2)
y = torch.broadcast_to(x, (3, 2, 2))
print("Original Shape:", x.shape)
print("Broadcasted Shape:", y.shape)
print("\nBroadcasted Tensor:")
print(y)
# Output:
# Original Shape: torch.Size([2, 2])
# Broadcasted Shape: torch.Size([3, 2, 2])
#
# Broadcasted Tensor:
# tensor([[[1, 2],
# [3, 4]],
#
# [[1, 2],
# [3, 4]],
#
# [[1, 2],
# [3, 4]]])
```
---
## Important Considerations
### 1. Memory Efficiency (Views vs. Copies)
`torch.broadcast_to` returns a **view** of the original tensor. It does not copy the underlying data in memory. Instead, it sets the stride of the broadcasted dimensions to `0`.
Because it is a view, modifying the original tensor will affect the broadcasted tensor:
```python
import torch
x = torch.tensor([1, 2, 3])
y = torch.broadcast_to(x, (3, 3))
# Modify the original tensor
x = 99
print(y)
# Output:
# tensor([[99, 2, 3],
# [99, 2, 3],
# [99, 2, 3]])
```
### 2. Read-Only Views
The returned tensor is read-only. Attempting to write to a broadcasted tensor view will raise a `RuntimeError`:
```python
# This will raise a RuntimeError:
# y[0, 0] = 100
```
If you need a writable copy, chain the operation with `.clone()`:
```python
y_writable = torch.broadcast_to(x, (3, 3)).clone()
```
### 3. Invalid Broadcasting Shapes
If the target shape is incompatible with the input tensor's shape according to broadcasting rules, PyTorch will raise a `RuntimeError`. For example, you cannot broadcast a tensor of shape `(3,)` to `(4,)`.
YouTip