Pytorch Torch Linalg Matmul
## PyTorch torch.linalg.matmul
`torch.linalg.matmul` is a function in PyTorch's linear algebra module (`torch.linalg`) used to perform matrix multiplication. It is an alias for `torch.matmul`, providing a more unified and consistent interface within the linear algebra namespace.
---
### Function Definition
```python
torch.linalg.matmul(input, other, *, out=None) -> Tensor
```
### Parameters
* **`input`** *(Tensor)*: The first input tensor to be multiplied.
* **`other`** *(Tensor)*: The second input tensor to be multiplied.
* **`out`** *(Tensor, optional)*: The output tensor where the result will be stored. Defaults to `None`.
### Returns
* **`Tensor`**: A tensor representing the result of the matrix multiplication.
---
### Behavior and Dimensions
The behavior of `torch.linalg.matmul` depends on the dimensionality of the input tensors:
1. **1D Tensors (Vectors)**: If both tensors are 1-dimensional, the dot product (scalar) is returned.
2. **2D Tensors (Matrices)**: If both tensors are 2-dimensional, the standard matrix-matrix product is returned.
3. **Mixed Dimensions**:
* If the first argument is 1D and the second argument is 2D, a $1$ is temporarily prepended to its dimension for the purpose of the matrix multiply and removed after multiplication.
* If the first argument is 2D and the second argument is 1D, a matrix-vector product is returned.
4. **High-Dimensional Tensors (Batched)**: If either argument is at least 3-dimensional, the operation is treated as a batch of matrices and is broadcasted accordingly. For example, if `input` is of shape $(J \times 1 \times N \times M)$ and `other` is of shape $(K \times M \times P)$, the out tensor will have the shape $(J \times K \times N \times P)$.
---
## Code Examples
### Example 1: Standard 2D Matrix Multiplication
This example demonstrates basic matrix multiplication of two 2D tensors.
```python
import torch
# Create two 2D matrices
A = torch.randn(3, 4)
B = torch.randn(4, 5)
# Perform matrix multiplication
C = torch.linalg.matmul(A, B)
print("Shape of A:", A.shape)
print("Shape of B:", B.shape)
print("Shape of C:", C.shape)
```
**Output:**
```text
Shape of A: torch.Size([3, 4])
Shape of B: torch.Size([4, 5])
Shape of C: torch.Size([3, 5])
```
### Example 2: Batched Matrix Multiplication (3D Tensors)
This example shows how `torch.linalg.matmul` handles batched inputs, which is highly useful in deep learning pipelines (e.g., attention mechanisms).
```python
import torch
# Create two batched matrices (Batch size = 10)
A = torch.randn(10, 3, 4)
B = torch.randn(10, 4, 5)
# Perform batched matrix multiplication
C = torch.linalg.matmul(A, B)
print("Shape of batched A:", A.shape)
print("Shape of batched B:", B.shape)
print("Shape of batched C:", C.shape)
```
**Output:**
```text
Shape of batched A: torch.Size([10, 3, 4])
Shape of batched B: torch.Size([10, 4, 5])
Shape of batched C: torch.Size([10, 3, 5])
```
---
## Considerations and Best Practices
* **Alias Relationship**: Since `torch.linalg.matmul` is a direct alias of `torch.matmul`, they share identical performance characteristics and underlying implementations. Using `torch.linalg.matmul` is recommended when you want to write clean, explicit linear algebra code that aligns with standard libraries like NumPy's `np.linalg`.
* **Broadcasting Rules**: Ensure that the inner dimensions of your matrices match. For a matrix multiplication of $A \times B$, the number of columns in $A$ must equal the number of rows in $B$.
* **Data Types**: Both input tensors must have the same data type (e.g., `torch.float32`). If they do not match, you must cast them using `.to()` or `.float()` before performing the multiplication.
YouTip