YouTip LogoYouTip

Pytorch Torch Linalg Matmul

## PyTorch torch.linalg.matmul `torch.linalg.matmul` is a function in PyTorch's linear algebra module (`torch.linalg`) used to perform matrix multiplication. It is an alias for `torch.matmul`, providing a more unified and consistent interface within the linear algebra namespace. --- ### Function Definition ```python torch.linalg.matmul(input, other, *, out=None) -> Tensor ``` ### Parameters * **`input`** *(Tensor)*: The first input tensor to be multiplied. * **`other`** *(Tensor)*: The second input tensor to be multiplied. * **`out`** *(Tensor, optional)*: The output tensor where the result will be stored. Defaults to `None`. ### Returns * **`Tensor`**: A tensor representing the result of the matrix multiplication. --- ### Behavior and Dimensions The behavior of `torch.linalg.matmul` depends on the dimensionality of the input tensors: 1. **1D Tensors (Vectors)**: If both tensors are 1-dimensional, the dot product (scalar) is returned. 2. **2D Tensors (Matrices)**: If both tensors are 2-dimensional, the standard matrix-matrix product is returned. 3. **Mixed Dimensions**: * If the first argument is 1D and the second argument is 2D, a $1$ is temporarily prepended to its dimension for the purpose of the matrix multiply and removed after multiplication. * If the first argument is 2D and the second argument is 1D, a matrix-vector product is returned. 4. **High-Dimensional Tensors (Batched)**: If either argument is at least 3-dimensional, the operation is treated as a batch of matrices and is broadcasted accordingly. For example, if `input` is of shape $(J \times 1 \times N \times M)$ and `other` is of shape $(K \times M \times P)$, the out tensor will have the shape $(J \times K \times N \times P)$. --- ## Code Examples ### Example 1: Standard 2D Matrix Multiplication This example demonstrates basic matrix multiplication of two 2D tensors. ```python import torch # Create two 2D matrices A = torch.randn(3, 4) B = torch.randn(4, 5) # Perform matrix multiplication C = torch.linalg.matmul(A, B) print("Shape of A:", A.shape) print("Shape of B:", B.shape) print("Shape of C:", C.shape) ``` **Output:** ```text Shape of A: torch.Size([3, 4]) Shape of B: torch.Size([4, 5]) Shape of C: torch.Size([3, 5]) ``` ### Example 2: Batched Matrix Multiplication (3D Tensors) This example shows how `torch.linalg.matmul` handles batched inputs, which is highly useful in deep learning pipelines (e.g., attention mechanisms). ```python import torch # Create two batched matrices (Batch size = 10) A = torch.randn(10, 3, 4) B = torch.randn(10, 4, 5) # Perform batched matrix multiplication C = torch.linalg.matmul(A, B) print("Shape of batched A:", A.shape) print("Shape of batched B:", B.shape) print("Shape of batched C:", C.shape) ``` **Output:** ```text Shape of batched A: torch.Size([10, 3, 4]) Shape of batched B: torch.Size([10, 4, 5]) Shape of batched C: torch.Size([10, 3, 5]) ``` --- ## Considerations and Best Practices * **Alias Relationship**: Since `torch.linalg.matmul` is a direct alias of `torch.matmul`, they share identical performance characteristics and underlying implementations. Using `torch.linalg.matmul` is recommended when you want to write clean, explicit linear algebra code that aligns with standard libraries like NumPy's `np.linalg`. * **Broadcasting Rules**: Ensure that the inner dimensions of your matrices match. For a matrix multiplication of $A \times B$, the number of columns in $A$ must equal the number of rows in $B$. * **Data Types**: Both input tensors must have the same data type (e.g., `torch.float32`). If they do not match, you must cast them using `.to()` or `.float()` before performing the multiplication.
← Pytorch Torch Linalg Matrix_RaPytorch Torch Linalg Inv β†’