YouTip LogoYouTip

Pytorch Torch Linalg Solve

## PyTorch torch.linalg.solve The `torch.linalg.solve` function is a core utility in PyTorch's linear algebra module (`torch.linalg`) designed to solve systems of linear equations. Specifically, it computes the solution $X$ for the linear system: $$AX = B$$ where $A$ is a square matrix (or a batch of square matrices) and $B$ is the right-hand side matrix or vector. --- ### Function Signature ```python torch.linalg.solve(A, B, *, left=True, out=None) ``` ### Parameters * **`A` (Tensor)**: The coefficient tensor. It must have shape `(*, M, M)` where `*` represents zero or more batch dimensions, and the inner-most dimensions must form a square matrix. * **`B` (Tensor)**: The right-hand side tensor. * If `left=True`, its shape must be `(*, M, K)` or `(*, M)`. * If `left=False`, its shape must be `(*, K, M)` or `(M,)` (when no batching is used). * **`left` (bool, optional)**: Controls which system of equations to solve. * If `True` (default), it solves $AX = B$. * If `False`, it solves $XA = B$. * **`out` (Tensor, optional)**: The output tensor. Ignored if `None`. ### Returns * **`Tensor`**: The solution tensor $X$ matching the precision and device of the inputs. --- ## Code Examples ### Example 1: Solving a Basic 2D Linear System ($AX = B$) This example demonstrates how to solve a simple system of two linear equations with two variables: $$1.0x_1 + 2.0x_2 = 5.0$$ $$3.0x_1 + 4.0x_2 = 11.0$$ ```python import torch # Create the coefficient matrix A and the right-hand side vector B A = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) B = torch.tensor([5.0, 11.0]) # Solve the system AX = B X = torch.linalg.solve(A, B) print("Coefficient Matrix A:") print(A) print("\nRight-hand Side Vector B:") print(B) print("\nSolution X:") print(X) # Verify the solution: A @ X should equal B print("\nVerification (A @ X):") print(A @ X) ``` **Output:** ```text Coefficient Matrix A: tensor([[1., 2.], [3., 4.]]) Right-hand Side Vector B: tensor([ 5., 11.]) Solution X: tensor([1., 2.]) Verification (A @ X): tensor([ 5., 11.]) ``` --- ### Example 2: Solving Batched Systems of Equations `torch.linalg.solve` supports batch operations out of the box. This is highly efficient for processing multiple independent linear systems simultaneously on GPU or CPU. ```python import torch # Create a batch of two 2x2 matrices (Batch size = 2) A = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) # Create a batch of two right-hand side vectors B = torch.tensor([[5.0, 11.0], [17.0, 23.0]]) # Solve the batched system X = torch.linalg.solve(A, B) print("Batched Solution X:") print(X) # Verify batched multiplication using the @ operator print("\nBatched Verification (A @ X):") print(A @ X.unsqueeze(-1)) # Unsqueeze to match matrix dimensions for verification ``` **Output:** ```text Batched Solution X: tensor([[ 1.0000, 2.0000], [-1.0000, 3.6667]]) Batched Verification (A @ X): tensor([[[ 5.0000], [11.0000]], [[17.0000], [23.0000]]]) ``` --- ### Example 3: Solving $XA = B$ (Using `left=False`) When you need to solve for $X$ in the equation $XA = B$, you can set the `left` parameter to `False`. ```python import torch # Define matrices A = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) B = torch.tensor([[5.0, 11.0], [7.0, 15.0]]) # Solve XA = B by setting left=False X = torch.linalg.solve(A, B, left=False) print("Solution X:") print(X) # Verify the solution: X @ A should equal B print("\nVerification (X @ A):") print(X @ A) ``` **Output:** ```text Solution X: tensor([[ 6.5000, -0.5000], [ 8.5000, -0.5000]]) Verification (X @ A): tensor([[ 5.0000, 11.0000], [ 7.0000, 15.0000]]) ``` --- ## Important Considerations 1. **Matrix Invertibility**: The coefficient matrix $A$ must be square and non-singular (invertible). If $A$ is singular (i.e., its determinant is zero or it is not full rank), the function will raise a `RuntimeError`. 2. **Numerical Stability**: Using `torch.linalg.solve(A, B)` is numerically more stable and faster than explicitly calculating the matrix inverse and multiplying (`torch.linalg.inv(A) @ B`). 3. **Data Types**: This function supports float, double, cfloat, and cdouble data types. Input tensors will be automatically cast to a common compatible type if they differ. 4. **Hardware Acceleration**: When running on CUDA-enabled GPUs, PyTorch utilizes highly optimized cuSOLVER and MAGMA backends to accelerate computation for large-scale matrices.
← Pytorch Torch Log1PPytorch Torch Linalg Pinv β†’