Pytorch Torch Kron
## PyTorch torch.kron
The `torch.kron` function in PyTorch computes the **Kronecker product** of two tensors. The Kronecker product is an operation on two matrices of arbitrary size resulting in a block matrix. It multiplies each element of the first tensor by the entire second tensor and arranges the results in a structured grid.
---
### Mathematical Definition
If $\mathbf{A}$ is an $m \times n$ matrix and $\mathbf{B}$ is a $p \times q$ matrix, then the Kronecker product $\mathbf{A} \otimes \mathbf{B}$ is an $mp \times nq$ block matrix:
$$
\mathbf{A} \otimes \mathbf{B} = \begin{bmatrix} a_{11} \mathbf{B} & \cdots & a_{1n} \mathbf{B} \\ \vdots & \ddots & \vdots \\ a_{m1} \mathbf{B} & \cdots & a_{mn} \mathbf{B} \end{bmatrix}
$$
---
### Syntax
```python
torch.kron(input, other, *, out=None) -> Tensor
```
#### Parameters
* **`input`** *(Tensor)*: The first input tensor (often referred to as the left-hand side tensor, $\mathbf{A}$).
* **`other`** *(Tensor)*: The second input tensor (often referred to as the right-hand side tensor, $\mathbf{B}$).
* **`out`** *(Tensor, optional)*: The output tensor.
---
### Code Examples
#### Example 1: 1D Tensors (Vectors)
When applied to 1D tensors, `torch.kron` multiplies each element of the first vector by the entire second vector and concatenates the results.
```python
import torch
# Create two 1D tensors
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
# Compute the Kronecker product
y = torch.kron(a, b)
print(y)
```
**Output:**
```text
tensor([3, 4, 6, 8])
```
*Explanation:*
* The first element of `a` ($1$) is multiplied by `b` ($[3, 4]$) $\rightarrow [3, 4]$
* The second element of `a` ($2$) is multiplied by `b` ($[3, 4]$) $\rightarrow [6, 8]$
* Concatenating these yields $[3, 4, 6, 8]$.
---
#### Example 2: 2D Tensors (Matrices)
When applied to 2D matrices, the output is a block matrix where each block is the matrix `b` scaled by an element of matrix `a`.
```python
import torch
# Create two 2D tensors
a = torch.tensor([[1, 2],
[3, 4]])
b = torch.tensor([[5, 6],
[7, 8]])
# Compute the Kronecker product
y = torch.kron(a, b)
print(y)
```
**Output:**
```text
tensor([[ 5, 6, 10, 12],
[ 7, 8, 14, 16],
[15, 18, 20, 24],
[21, 24, 28, 32]])
```
*Explanation:*
* Top-left block: $1 \times \mathbf{B} = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}$
* Top-right block: $2 \times \mathbf{B} = \begin{bmatrix} 10 & 12 \\ 14 & 16 \end{bmatrix}$
* Bottom-left block: $3 \times \mathbf{B} = \begin{bmatrix} 15 & 18 \\ 21 & 24 \end{bmatrix}$
* Bottom-right block: $4 \times \mathbf{B} = \begin{bmatrix} 20 & 24 \\ 28 & 32 \end{bmatrix}$
---
### Important Considerations
1. **Dimensionality and Broadcasting**:
If the input tensors have different numbers of dimensions, the tensor with fewer dimensions is prepended with $1$s until both tensors have the same number of dimensions. For example, if `input` has shape $(r_0, r_1)$ and `other` has shape $(s_0, s_1, s_2)$, `input` is unsqueezed to $(1, r_0, r_1)$ before computing the product.
2. **Output Shape**:
If `input` has shape $(a_0, a_1, \dots, a_k)$ and `other` has shape $(b_0, b_1, \dots, b_k)$, the resulting tensor will have a shape of:
$$(a_0 \times b_0, a_1 \times b_1, \dots, a_k \times b_k)$$
3. **Data Types**:
`torch.kron` supports standard PyTorch numeric data types (e.g., Float, Double, Int, Long). If the input tensors have different data types, PyTorch will automatically promote them to a common type.
YouTip