YouTip LogoYouTip

Pytorch Torch Fake_Quantize_Per_Channel_Affine

## PyTorch torch.fake_quantize_per_channel_affine `torch.fake_quantize_per_channel_affine` is a PyTorch function designed to perform channel-wise fake quantization on an input tensor. It is widely used during Quantization-Aware Training (QAT) to simulate the effects of low-precision quantization (such as INT8) on both the forward and backward passes. --- ## What is Fake Quantization? Fake quantization simulates the precision loss caused by quantization and dequantization. The process rounds the floating-point values to the nearest quantized integer grid and then scales them back to floating-point representation. By performing this operation during training, the model can adapt to the quantization noise, resulting in significantly higher accuracy when the model is finally converted to a fully quantized format (such as INT8) for deployment. --- ## Mathematical Formulation For each channel $i$ along the specified `axis`, the fake quantization operation is defined as: $$X_{\text{quant}} = \text{clamp}\left(\text{round}\left(\frac{X}{\text{scale}}\right) + \text{zero\_point}, \text{quant\_min}, \text{quant\_max}\right)$$ $$Y = (X_{\text{quant}} - \text{zero\_point}) \times \text{scale}$$ Where: * **$\text{round}(\cdot)$** rounds the value to the nearest integer. * **$\text{clamp}(x, \min, \max)$** limits the values to the range $[\text{quant\_min}, \text{quant\_max}]$. --- ## Syntax ```python torch.fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) -> Tensor ``` ### Parameters | Parameter | Type | Description | | :--- | :--- | :--- | | **`input`** | `Tensor` | The input floating-point tensor to be fake-quantized. | | **`scale`** | `Tensor` | A 1D float tensor containing the scale factors for each channel. Its size must match the size of the dimension specified by `axis`. | | **`zero_point`** | `Tensor` | A 1D integer tensor containing the zero-point offsets for each channel. Its shape must match `scale`. | | **`axis`** | `int` | The dimension (channel axis) along which the quantization parameters (`scale` and `zero_point`) are applied. | | **`quant_min`** | `int` | The minimum bound of the target quantization range (e.g., `0` for unsigned 8-bit, `-128` for signed 8-bit). | | **`quant_max`** | `int` | The maximum bound of the target quantization range (e.g., `255` for unsigned 8-bit, `127` for signed 8-bit). | ### Returns * **`Tensor`**: A tensor of the same shape and data type as `input`, containing the fake-quantized values. --- ## Code Example The following example demonstrates how to apply channel-wise fake quantization to a 4D tensor (commonly representing `[Batch, Channels, Height, Width]` in computer vision models). ```python import torch # Create a 4D input tensor (e.g., Batch=2, Channels=3, Height=4, Width=5) x = torch.randn(2, 3, 4, 5) # Define scale factors and zero points for each of the 3 channels scale = torch.tensor([1.0, 1.2, 1.5]) zero_point = torch.tensor([0, 0, 0], dtype=torch.int32) # Perform channel-wise fake quantization along axis 1 (the channel dimension) # Target range: [0, 255] (representing unsigned 8-bit integer quantization) y = torch.fake_quantize_per_channel_affine( x, scale, zero_point, axis=1, quant_min=0, quant_max=255 ) print("Input shape: ", x.shape) print("Quantized shape: ", y.shape) print("Are shapes identical?", x.shape == y.shape) ``` ### Output ```text Input shape: torch.Size([2, 3, 4, 5]) Quantized shape: torch.Size([2, 3, 4, 5]) Are shapes identical? True ``` --- ## Key Considerations 1. **Gradient Propagation**: During the backward pass, `torch.fake_quantize_per_channel_affine` uses the **Straight-Through Estimator (STE)**. This means gradients pass through the rounding operation as if it were an identity function, allowing the model weights to be updated via backpropagation despite the non-differentiable rounding step. 2. **Channel Alignment**: The size of the `scale` and `zero_point` tensors must exactly match the size of the dimension specified by the `axis` parameter. Otherwise, a runtime error will be thrown. 3. **Hardware Acceleration**: This function is highly optimized and supports execution on both CPU and CUDA devices.
← Pytorch Torch FixPytorch Torch Expm1 β†’