From 0a8fd5c8ebd04f78aff0bf95826e51cdfcfe757e Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Thu, 20 Nov 2025 17:24:08 -0800 Subject: [PATCH] Add 16a4w_block QAT config (#15878) Summary: Introduce a FakeQuantizer subclass. It falls back to LPBQ observer's `convert`. `_derived_bias_quant_spec` also looks for it to correctly derive bias scale. Reviewed By: viveknayakatmeta Differential Revision: D87194388 --- .../observers/per_block_param_observer.py | 56 ++++++++++- backends/qualcomm/quantizer/qconfig.py | 98 ++++++++++++++++++- backends/qualcomm/quantizer/quantizer.py | 14 +++ backends/qualcomm/tests/test_qnn_delegate.py | 19 ++++ 4 files changed, 184 insertions(+), 3 deletions(-) diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index b3f854db527..4c4dd1c911e 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -7,12 +7,13 @@ from typing import Tuple import torch -from torchao.quantization.pt2e import MappingType, PerBlock +from torchao.quantization.pt2e import FakeQuantize, MappingType, PerBlock from torchao.quantization.pt2e._affine_quantization import ( _get_reduction_params, AffineQuantizedMinMaxObserver, choose_qparams_affine_with_min_max, ) +from torchao.quantization.quant_primitives import _fake_quantize_affine class PerBlockParamObserver(AffineQuantizedMinMaxObserver): @@ -89,3 +90,56 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: self.preserve_zero, self.zero_point_domain, ) + + +class PerBlockParamFakeQuantize(FakeQuantize): + def __init__( + self, + dtype: torch.dtype = torch.int8, + block_size: torch.Size = None, + quant_min: int = None, + quant_max: int = None, + eps: float = torch.finfo(torch.float32).eps, + **kwargs, + ): + super().__init__() + assert ( + block_size is not None + ), "block_size must be provided for per-block quantization" + + self.activation_post_process = PerBlockParamObserver( + dtype=dtype, + block_size=block_size, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + **kwargs, + ) + self.dtype = dtype + self.block_size = block_size + self.quant_min = quant_min if quant_min is not None else torch.iinfo(dtype).min + self.quant_max = quant_max if quant_max is not None else torch.iinfo(dtype).max + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + + self.activation_post_process(x) + scale, zero_point = self.activation_post_process.calculate_qparams() + + return _fake_quantize_affine( + x, + self.block_size, + scale, + zero_point, + quant_dtype=self.dtype, + quant_min=self.quant_min, + quant_max=self.quant_max, + ) + + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.activation_post_process.calculate_qparams() + + def convert(self, model, observer_node): + self.activation_post_process.convert(model, observer_node) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 593eb77961a..b5121226ca8 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -10,6 +10,7 @@ import torch from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import ( + PerBlockParamFakeQuantize, PerBlockParamObserver, ) from torch import Tensor @@ -71,11 +72,18 @@ def _derive_bias_qparams_fn( derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to( torch.int32 ) - if isinstance(weight_obs_or_fq, PerBlockParamObserver): + + # Handle per-block quantization for both observer and fake quantize + weight_observer = weight_obs_or_fq + if isinstance(weight_obs_or_fq, PerBlockParamFakeQuantize): + # Extract the underlying observer from the fake quantize wrapper + weight_observer = weight_obs_or_fq.activation_post_process + + if isinstance(weight_observer, PerBlockParamObserver): # keep maximum scale of each channel for bias derived_scale = ( derived_scale.view(derived_scale.size(0), -1).amax(dim=-1) - / weight_obs_or_fq.num_steps + / weight_observer.num_steps ) derived_zero = derived_zero.view(derived_zero.size(0), -1).amax(dim=-1) return (derived_scale, derived_zero) @@ -468,6 +476,92 @@ def get_ptq_per_block_quant_config( ) +def get_qat_per_block_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int8, + act_observer=MovingAverageMinMaxObserver, + act_symmetric: bool = False, + ch_axis: int = 0, +) -> QuantizationConfig: + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + supported_weight_dtypes = {torch.int4, torch.int8} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch does not support uint16 quantization, use int32 to bypass + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + qscheme=torch.per_tensor_symmetric, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + else: + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + + weight_fake_quant_ctr = PerBlockParamFakeQuantize.with_args( + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=ch_axis, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=ch_axis, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + # TODO merge qat and ptq to a function, and use a bool flag to control it def get_8a8w_qnn_qat_config( act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 9ca9a7dad6c..0d54b250bfd 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -29,6 +29,7 @@ get_8a8w_qnn_qat_config, get_ptq_per_block_quant_config, get_ptq_per_channel_quant_config, + get_qat_per_block_quant_config, get_qat_per_channel_quant_config, QuantizationConfig, ) @@ -131,6 +132,19 @@ class QuantDtype(IntEnum): ), None, ), + (QuantDtype.use_16a4w_block, True): ( + get_16a4w_qnn_qat_config, + partial( + get_qat_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int4, + ), + partial( + get_qat_per_block_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int4, + ), + ), (QuantDtype.use_8a8w, True): ( get_8a8w_qnn_qat_config, partial(get_qat_per_channel_quant_config), diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index a9403f98b17..68d16eaeb8c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2054,6 +2054,25 @@ def test_qnn_backend_16a4w_conv2d_qat(self): ) self.lower_module_and_test_output(converted, sample_input) + def test_qnn_backend_16a4w_block_conv2d_qat(self): + modules = [ + Conv2dSingle(in_channel=64, out_channel=64), + Conv2dSingle(bias=False), + ] # noqa: F405 + sample_input = (torch.randn([64, 64, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + prepared = self.get_prepared_qat_module( + module, + sample_input, + quant_dtype=QuantDtype.use_16a4w_block, + block_size_map={"conv2d": (1, 32, 1, 1)}, + ) + converted = self.get_converted_sgd_trained_module( + module, prepared, sample_input + ) + self.lower_module_and_test_output(converted, sample_input) + def test_qnn_backend_16a4w_layer_norm(self): module = LayerNorm() # noqa: F405 sample_input = (torch.randn(196, 768),)