Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from typing import Tuple

import torch
from torchao.quantization.pt2e import MappingType, PerBlock
from torchao.quantization.pt2e import FakeQuantize, MappingType, PerBlock
from torchao.quantization.pt2e._affine_quantization import (
_get_reduction_params,
AffineQuantizedMinMaxObserver,
choose_qparams_affine_with_min_max,
)
from torchao.quantization.quant_primitives import _fake_quantize_affine


class PerBlockParamObserver(AffineQuantizedMinMaxObserver):
Expand Down Expand Up @@ -89,3 +90,56 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
self.preserve_zero,
self.zero_point_domain,
)


class PerBlockParamFakeQuantize(FakeQuantize):
def __init__(
self,
dtype: torch.dtype = torch.int8,
block_size: torch.Size = None,
quant_min: int = None,
quant_max: int = None,
eps: float = torch.finfo(torch.float32).eps,
**kwargs,
):
super().__init__()
assert (
block_size is not None
), "block_size must be provided for per-block quantization"

self.activation_post_process = PerBlockParamObserver(
dtype=dtype,
block_size=block_size,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
**kwargs,
)
self.dtype = dtype
self.block_size = block_size
self.quant_min = quant_min if quant_min is not None else torch.iinfo(dtype).min
self.quant_max = quant_max if quant_max is not None else torch.iinfo(dtype).max
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be simpler if calling torchao.quantization.quant_primitives._fake_quantize_affine directly?

return x

self.activation_post_process(x)
scale, zero_point = self.activation_post_process.calculate_qparams()

return _fake_quantize_affine(
x,
self.block_size,
scale,
zero_point,
quant_dtype=self.dtype,
quant_min=self.quant_min,
quant_max=self.quant_max,
)

def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self.activation_post_process.calculate_qparams()

def convert(self, model, observer_node):
self.activation_post_process.convert(model, observer_node)
98 changes: 96 additions & 2 deletions backends/qualcomm/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
PerBlockParamFakeQuantize,
PerBlockParamObserver,
)
from torch import Tensor
Expand Down Expand Up @@ -71,11 +72,18 @@ def _derive_bias_qparams_fn(
derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to(
torch.int32
)
if isinstance(weight_obs_or_fq, PerBlockParamObserver):

# Handle per-block quantization for both observer and fake quantize
weight_observer = weight_obs_or_fq
if isinstance(weight_obs_or_fq, PerBlockParamFakeQuantize):
# Extract the underlying observer from the fake quantize wrapper
weight_observer = weight_obs_or_fq.activation_post_process

if isinstance(weight_observer, PerBlockParamObserver):
# keep maximum scale of each channel for bias
derived_scale = (
derived_scale.view(derived_scale.size(0), -1).amax(dim=-1)
/ weight_obs_or_fq.num_steps
/ weight_observer.num_steps
)
derived_zero = derived_zero.view(derived_zero.size(0), -1).amax(dim=-1)
return (derived_scale, derived_zero)
Expand Down Expand Up @@ -468,6 +476,92 @@ def get_ptq_per_block_quant_config(
)


def get_qat_per_block_quant_config(
act_dtype=torch.uint8,
weight_dtype=torch.int8,
act_observer=MovingAverageMinMaxObserver,
act_symmetric: bool = False,
ch_axis: int = 0,
) -> QuantizationConfig:
supported_act_types = {
torch.uint8,
torch.uint16,
torch.int8,
torch.int16,
}
supported_weight_dtypes = {torch.int4, torch.int8}
assert (
act_dtype in supported_act_types
), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"

assert (
weight_dtype in supported_weight_dtypes
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"

# torch does not support uint16 quantization, use int32 to bypass
if act_symmetric:
# If zero_point is 128, htp can do optimizations.
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
qscheme=torch.per_tensor_symmetric,
observer=act_observer,
)
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
qscheme=torch.per_tensor_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=act_fake_quant_ctr,
)
else:
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
observer=act_observer,
)
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_fake_quant_ctr,
)

weight_fake_quant_ctr = PerBlockParamFakeQuantize.with_args(
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
quant_min=(
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
),
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
qscheme=torch.per_channel_symmetric,
ch_axis=ch_axis,
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
quant_min=(
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
),
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
qscheme=torch.per_channel_symmetric,
ch_axis=ch_axis,
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
)

bias_quantization_spec = _derived_bias_quant_spec

quantization_config = QuantizationConfig(
input_activation=act_quantization_spec,
output_activation=act_quantization_spec,
weight=weight_quantization_spec,
bias=bias_quantization_spec,
)

return quantization_config


# TODO merge qat and ptq to a function, and use a bool flag to control it
def get_8a8w_qnn_qat_config(
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
Expand Down
14 changes: 14 additions & 0 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_8a8w_qnn_qat_config,
get_ptq_per_block_quant_config,
get_ptq_per_channel_quant_config,
get_qat_per_block_quant_config,
get_qat_per_channel_quant_config,
QuantizationConfig,
)
Expand Down Expand Up @@ -131,6 +132,19 @@ class QuantDtype(IntEnum):
),
None,
),
(QuantDtype.use_16a4w_block, True): (
get_16a4w_qnn_qat_config,
partial(
get_qat_per_channel_quant_config,
act_dtype=torch.uint16,
weight_dtype=torch.int4,
),
partial(
get_qat_per_block_quant_config,
act_dtype=torch.uint16,
weight_dtype=torch.int4,
),
),
(QuantDtype.use_8a8w, True): (
get_8a8w_qnn_qat_config,
partial(get_qat_per_channel_quant_config),
Expand Down
19 changes: 19 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,6 +2054,25 @@ def test_qnn_backend_16a4w_conv2d_qat(self):
)
self.lower_module_and_test_output(converted, sample_input)

def test_qnn_backend_16a4w_block_conv2d_qat(self):
modules = [
Conv2dSingle(in_channel=64, out_channel=64),
Conv2dSingle(bias=False),
] # noqa: F405
sample_input = (torch.randn([64, 64, 3, 3]),)
for i, module in enumerate(modules):
with self.subTest(i=i):
prepared = self.get_prepared_qat_module(
module,
sample_input,
quant_dtype=QuantDtype.use_16a4w_block,
block_size_map={"conv2d": (1, 32, 1, 1)},
)
converted = self.get_converted_sgd_trained_module(
module, prepared, sample_input
)
self.lower_module_and_test_output(converted, sample_input)

def test_qnn_backend_16a4w_layer_norm(self):
module = LayerNorm() # noqa: F405
sample_input = (torch.randn(196, 768),)
Expand Down
Loading