Pytorch Torch Linalg Solve
## PyTorch torch.linalg.solve
The `torch.linalg.solve` function is a core utility in PyTorch's linear algebra module (`torch.linalg`) designed to solve systems of linear equations. Specifically, it computes the solution $X$ for the linear system:
$$AX = B$$
where $A$ is a square matrix (or a batch of square matrices) and $B$ is the right-hand side matrix or vector.
---
### Function Signature
```python
torch.linalg.solve(A, B, *, left=True, out=None)
```
### Parameters
* **`A` (Tensor)**: The coefficient tensor. It must have shape `(*, M, M)` where `*` represents zero or more batch dimensions, and the inner-most dimensions must form a square matrix.
* **`B` (Tensor)**: The right-hand side tensor.
* If `left=True`, its shape must be `(*, M, K)` or `(*, M)`.
* If `left=False`, its shape must be `(*, K, M)` or `(M,)` (when no batching is used).
* **`left` (bool, optional)**: Controls which system of equations to solve.
* If `True` (default), it solves $AX = B$.
* If `False`, it solves $XA = B$.
* **`out` (Tensor, optional)**: The output tensor. Ignored if `None`.
### Returns
* **`Tensor`**: The solution tensor $X$ matching the precision and device of the inputs.
---
## Code Examples
### Example 1: Solving a Basic 2D Linear System ($AX = B$)
This example demonstrates how to solve a simple system of two linear equations with two variables:
$$1.0x_1 + 2.0x_2 = 5.0$$
$$3.0x_1 + 4.0x_2 = 11.0$$
```python
import torch
# Create the coefficient matrix A and the right-hand side vector B
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
B = torch.tensor([5.0, 11.0])
# Solve the system AX = B
X = torch.linalg.solve(A, B)
print("Coefficient Matrix A:")
print(A)
print("\nRight-hand Side Vector B:")
print(B)
print("\nSolution X:")
print(X)
# Verify the solution: A @ X should equal B
print("\nVerification (A @ X):")
print(A @ X)
```
**Output:**
```text
Coefficient Matrix A:
tensor([[1., 2.],
[3., 4.]])
Right-hand Side Vector B:
tensor([ 5., 11.])
Solution X:
tensor([1., 2.])
Verification (A @ X):
tensor([ 5., 11.])
```
---
### Example 2: Solving Batched Systems of Equations
`torch.linalg.solve` supports batch operations out of the box. This is highly efficient for processing multiple independent linear systems simultaneously on GPU or CPU.
```python
import torch
# Create a batch of two 2x2 matrices (Batch size = 2)
A = torch.tensor([[[1.0, 2.0],
[3.0, 4.0]],
[[5.0, 6.0],
[7.0, 8.0]]])
# Create a batch of two right-hand side vectors
B = torch.tensor([[5.0, 11.0],
[17.0, 23.0]])
# Solve the batched system
X = torch.linalg.solve(A, B)
print("Batched Solution X:")
print(X)
# Verify batched multiplication using the @ operator
print("\nBatched Verification (A @ X):")
print(A @ X.unsqueeze(-1)) # Unsqueeze to match matrix dimensions for verification
```
**Output:**
```text
Batched Solution X:
tensor([[ 1.0000, 2.0000],
[-1.0000, 3.6667]])
Batched Verification (A @ X):
tensor([[[ 5.0000],
[11.0000]],
[[17.0000],
[23.0000]]])
```
---
### Example 3: Solving $XA = B$ (Using `left=False`)
When you need to solve for $X$ in the equation $XA = B$, you can set the `left` parameter to `False`.
```python
import torch
# Define matrices
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
B = torch.tensor([[5.0, 11.0], [7.0, 15.0]])
# Solve XA = B by setting left=False
X = torch.linalg.solve(A, B, left=False)
print("Solution X:")
print(X)
# Verify the solution: X @ A should equal B
print("\nVerification (X @ A):")
print(X @ A)
```
**Output:**
```text
Solution X:
tensor([[ 6.5000, -0.5000],
[ 8.5000, -0.5000]])
Verification (X @ A):
tensor([[ 5.0000, 11.0000],
[ 7.0000, 15.0000]])
```
---
## Important Considerations
1. **Matrix Invertibility**: The coefficient matrix $A$ must be square and non-singular (invertible). If $A$ is singular (i.e., its determinant is zero or it is not full rank), the function will raise a `RuntimeError`.
2. **Numerical Stability**: Using `torch.linalg.solve(A, B)` is numerically more stable and faster than explicitly calculating the matrix inverse and multiplying (`torch.linalg.inv(A) @ B`).
3. **Data Types**: This function supports float, double, cfloat, and cdouble data types. Input tensors will be automatically cast to a common compatible type if they differ.
4. **Hardware Acceleration**: When running on CUDA-enabled GPUs, PyTorch utilizes highly optimized cuSOLVER and MAGMA backends to accelerate computation for large-scale matrices.
YouTip