Pytorch Torch Dequantize
## PyTorch torch.dequantize
`torch.dequantize` is a PyTorch function used to dequantize a quantized tensor back into a standard floating-point tensor. This operation is essential in quantization workflows, especially when you need to convert quantized weights or activations back to high-precision representations for validation, debugging, or mixed-precision operations.
---
### Function Definition
```python
torch.dequantize(input) -> Tensor
```
#### Parameters
* **input** (*Tensor*): The input quantized tensor. This tensor must have a quantized data type (such as `torch.quint8`, `torch.qint8`, or `torch.qint32`).
#### Returns
* **Tensor**: A standard floating-point tensor (typically `torch.float32`) containing the dequantized values.
---
### Mathematical Background
Quantization maps floating-point values $x$ to integer values $q$ using a scale factor ($S$) and a zero-point ($Z$):
$$q = \text{round}\left(\frac{x}{S}\right) + Z$$
The `torch.dequantize` function reverses this process to reconstruct the floating-point approximation:
$$x_{\text{approx}} = (q - Z) \times S$$
---
### Code Example
The following example demonstrates how to quantize a floating-point tensor and then use `torch.dequantize` to convert it back.
```python
import torch
# 1. Create a standard floating-point input tensor
input_tensor = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
# 2. Quantize the tensor using per-tensor quantization
# Scale = 0.1, Zero-point = 10, Data type = 8-bit unsigned integer
quantized = torch.quantize_per_tensor(input_tensor, scale=0.1, zero_point=10, dtype=torch.quint8)
print("Quantized Tensor:")
print(quantized)
print("Quantized Integer Representation:")
print(quantized.int_repr()) # Displays the underlying integer values
print("-" * 40)
# 3. Dequantize back to a floating-point tensor
dequantized = torch.dequantize(quantized)
print("Dequantized Floating-Point Tensor:")
print(dequantized)
```
#### Output
```text
Quantized Tensor:
tensor([[-1.0000, 0.0000],
[ 1.0000, 2.0000]], size=(2, 2), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10)
Quantized Integer Representation:
tensor([[ 0, 10],
[20, 30]], dtype=torch.uint8)
----------------------------------------
Dequantized Floating-Point Tensor:
tensor([[-1.0000, 0.0000],
[ 1.0000, 2.0000]])
```
---
### Key Considerations
1. **Precision Loss**: Quantization is a lossy process. While `torch.dequantize` restores the tensor to a floating-point format, the precision lost during the initial quantization step cannot be recovered. The dequantized values will be approximations of the original values, constrained by the quantization scale and zero-point.
2. **In-Place Alternative**: You can also call `.dequantize()` directly as a method on a quantized tensor object (e.g., `quantized_tensor.dequantize()`).
3. **Supported Quantization Schemes**: `torch.dequantize` seamlessly handles both per-tensor and per-channel quantized tensors, automatically reading the scale and zero-point metadata embedded within the quantized tensor.
YouTip