YouTip LogoYouTip

Pytorch Torch Kron

## PyTorch torch.kron The `torch.kron` function in PyTorch computes the **Kronecker product** of two tensors. The Kronecker product is an operation on two matrices of arbitrary size resulting in a block matrix. It multiplies each element of the first tensor by the entire second tensor and arranges the results in a structured grid. --- ### Mathematical Definition If $\mathbf{A}$ is an $m \times n$ matrix and $\mathbf{B}$ is a $p \times q$ matrix, then the Kronecker product $\mathbf{A} \otimes \mathbf{B}$ is an $mp \times nq$ block matrix: $$ \mathbf{A} \otimes \mathbf{B} = \begin{bmatrix} a_{11} \mathbf{B} & \cdots & a_{1n} \mathbf{B} \\ \vdots & \ddots & \vdots \\ a_{m1} \mathbf{B} & \cdots & a_{mn} \mathbf{B} \end{bmatrix} $$ --- ### Syntax ```python torch.kron(input, other, *, out=None) -> Tensor ``` #### Parameters * **`input`** *(Tensor)*: The first input tensor (often referred to as the left-hand side tensor, $\mathbf{A}$). * **`other`** *(Tensor)*: The second input tensor (often referred to as the right-hand side tensor, $\mathbf{B}$). * **`out`** *(Tensor, optional)*: The output tensor. --- ### Code Examples #### Example 1: 1D Tensors (Vectors) When applied to 1D tensors, `torch.kron` multiplies each element of the first vector by the entire second vector and concatenates the results. ```python import torch # Create two 1D tensors a = torch.tensor([1, 2]) b = torch.tensor([3, 4]) # Compute the Kronecker product y = torch.kron(a, b) print(y) ``` **Output:** ```text tensor([3, 4, 6, 8]) ``` *Explanation:* * The first element of `a` ($1$) is multiplied by `b` ($[3, 4]$) $\rightarrow [3, 4]$ * The second element of `a` ($2$) is multiplied by `b` ($[3, 4]$) $\rightarrow [6, 8]$ * Concatenating these yields $[3, 4, 6, 8]$. --- #### Example 2: 2D Tensors (Matrices) When applied to 2D matrices, the output is a block matrix where each block is the matrix `b` scaled by an element of matrix `a`. ```python import torch # Create two 2D tensors a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[5, 6], [7, 8]]) # Compute the Kronecker product y = torch.kron(a, b) print(y) ``` **Output:** ```text tensor([[ 5, 6, 10, 12], [ 7, 8, 14, 16], [15, 18, 20, 24], [21, 24, 28, 32]]) ``` *Explanation:* * Top-left block: $1 \times \mathbf{B} = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}$ * Top-right block: $2 \times \mathbf{B} = \begin{bmatrix} 10 & 12 \\ 14 & 16 \end{bmatrix}$ * Bottom-left block: $3 \times \mathbf{B} = \begin{bmatrix} 15 & 18 \\ 21 & 24 \end{bmatrix}$ * Bottom-right block: $4 \times \mathbf{B} = \begin{bmatrix} 20 & 24 \\ 28 & 32 \end{bmatrix}$ --- ### Important Considerations 1. **Dimensionality and Broadcasting**: If the input tensors have different numbers of dimensions, the tensor with fewer dimensions is prepended with $1$s until both tensors have the same number of dimensions. For example, if `input` has shape $(r_0, r_1)$ and `other` has shape $(s_0, s_1, s_2)$, `input` is unsqueezed to $(1, r_0, r_1)$ before computing the product. 2. **Output Shape**: If `input` has shape $(a_0, a_1, \dots, a_k)$ and `other` has shape $(b_0, b_1, \dots, b_k)$, the resulting tensor will have a shape of: $$(a_0 \times b_0, a_1 \times b_1, \dots, a_k \times b_k)$$ 3. **Data Types**: `torch.kron` supports standard PyTorch numeric data types (e.g., Float, Double, Int, Long). If the input tensors have different data types, PyTorch will automatically promote them to a common type.
← Pytorch Torch LcmPytorch Torch Istft β†’