Pytorch Torch Fake_Quantize_Per_Channel_Affine
## PyTorch torch.fake_quantize_per_channel_affine
`torch.fake_quantize_per_channel_affine` is a PyTorch function designed to perform channel-wise fake quantization on an input tensor. It is widely used during Quantization-Aware Training (QAT) to simulate the effects of low-precision quantization (such as INT8) on both the forward and backward passes.
---
## What is Fake Quantization?
Fake quantization simulates the precision loss caused by quantization and dequantization. The process rounds the floating-point values to the nearest quantized integer grid and then scales them back to floating-point representation.
By performing this operation during training, the model can adapt to the quantization noise, resulting in significantly higher accuracy when the model is finally converted to a fully quantized format (such as INT8) for deployment.
---
## Mathematical Formulation
For each channel $i$ along the specified `axis`, the fake quantization operation is defined as:
$$X_{\text{quant}} = \text{clamp}\left(\text{round}\left(\frac{X}{\text{scale}}\right) + \text{zero\_point}, \text{quant\_min}, \text{quant\_max}\right)$$
$$Y = (X_{\text{quant}} - \text{zero\_point}) \times \text{scale}$$
Where:
* **$\text{round}(\cdot)$** rounds the value to the nearest integer.
* **$\text{clamp}(x, \min, \max)$** limits the values to the range $[\text{quant\_min}, \text{quant\_max}]$.
---
## Syntax
```python
torch.fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) -> Tensor
```
### Parameters
| Parameter | Type | Description |
| :--- | :--- | :--- |
| **`input`** | `Tensor` | The input floating-point tensor to be fake-quantized. |
| **`scale`** | `Tensor` | A 1D float tensor containing the scale factors for each channel. Its size must match the size of the dimension specified by `axis`. |
| **`zero_point`** | `Tensor` | A 1D integer tensor containing the zero-point offsets for each channel. Its shape must match `scale`. |
| **`axis`** | `int` | The dimension (channel axis) along which the quantization parameters (`scale` and `zero_point`) are applied. |
| **`quant_min`** | `int` | The minimum bound of the target quantization range (e.g., `0` for unsigned 8-bit, `-128` for signed 8-bit). |
| **`quant_max`** | `int` | The maximum bound of the target quantization range (e.g., `255` for unsigned 8-bit, `127` for signed 8-bit). |
### Returns
* **`Tensor`**: A tensor of the same shape and data type as `input`, containing the fake-quantized values.
---
## Code Example
The following example demonstrates how to apply channel-wise fake quantization to a 4D tensor (commonly representing `[Batch, Channels, Height, Width]` in computer vision models).
```python
import torch
# Create a 4D input tensor (e.g., Batch=2, Channels=3, Height=4, Width=5)
x = torch.randn(2, 3, 4, 5)
# Define scale factors and zero points for each of the 3 channels
scale = torch.tensor([1.0, 1.2, 1.5])
zero_point = torch.tensor([0, 0, 0], dtype=torch.int32)
# Perform channel-wise fake quantization along axis 1 (the channel dimension)
# Target range: [0, 255] (representing unsigned 8-bit integer quantization)
y = torch.fake_quantize_per_channel_affine(
x,
scale,
zero_point,
axis=1,
quant_min=0,
quant_max=255
)
print("Input shape: ", x.shape)
print("Quantized shape: ", y.shape)
print("Are shapes identical?", x.shape == y.shape)
```
### Output
```text
Input shape: torch.Size([2, 3, 4, 5])
Quantized shape: torch.Size([2, 3, 4, 5])
Are shapes identical? True
```
---
## Key Considerations
1. **Gradient Propagation**: During the backward pass, `torch.fake_quantize_per_channel_affine` uses the **Straight-Through Estimator (STE)**. This means gradients pass through the rounding operation as if it were an identity function, allowing the model weights to be updated via backpropagation despite the non-differentiable rounding step.
2. **Channel Alignment**: The size of the `scale` and `zero_point` tensors must exactly match the size of the dimension specified by the `axis` parameter. Otherwise, a runtime error will be thrown.
3. **Hardware Acceleration**: This function is highly optimized and supports execution on both CPU and CUDA devices.
YouTip