Skip to content

Commit 0a8fd5c

Browse files
sxufacebook-github-bot
authored andcommitted
Add 16a4w_block QAT config (#15878)
Summary: Introduce a FakeQuantizer subclass. It falls back to LPBQ observer's `convert`. `_derived_bias_quant_spec` also looks for it to correctly derive bias scale. Reviewed By: viveknayakatmeta Differential Revision: D87194388
1 parent b4d72f1 commit 0a8fd5c

File tree

4 files changed

+184
-3
lines changed

4 files changed

+184
-3
lines changed

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from typing import Tuple
88

99
import torch
10-
from torchao.quantization.pt2e import MappingType, PerBlock
10+
from torchao.quantization.pt2e import FakeQuantize, MappingType, PerBlock
1111
from torchao.quantization.pt2e._affine_quantization import (
1212
_get_reduction_params,
1313
AffineQuantizedMinMaxObserver,
1414
choose_qparams_affine_with_min_max,
1515
)
16+
from torchao.quantization.quant_primitives import _fake_quantize_affine
1617

1718

1819
class PerBlockParamObserver(AffineQuantizedMinMaxObserver):
@@ -89,3 +90,56 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
8990
self.preserve_zero,
9091
self.zero_point_domain,
9192
)
93+
94+
95+
class PerBlockParamFakeQuantize(FakeQuantize):
96+
def __init__(
97+
self,
98+
dtype: torch.dtype = torch.int8,
99+
block_size: torch.Size = None,
100+
quant_min: int = None,
101+
quant_max: int = None,
102+
eps: float = torch.finfo(torch.float32).eps,
103+
**kwargs,
104+
):
105+
super().__init__()
106+
assert (
107+
block_size is not None
108+
), "block_size must be provided for per-block quantization"
109+
110+
self.activation_post_process = PerBlockParamObserver(
111+
dtype=dtype,
112+
block_size=block_size,
113+
quant_min=quant_min,
114+
quant_max=quant_max,
115+
eps=eps,
116+
**kwargs,
117+
)
118+
self.dtype = dtype
119+
self.block_size = block_size
120+
self.quant_min = quant_min if quant_min is not None else torch.iinfo(dtype).min
121+
self.quant_max = quant_max if quant_max is not None else torch.iinfo(dtype).max
122+
self.eps = eps
123+
124+
def forward(self, x: torch.Tensor) -> torch.Tensor:
125+
if x.numel() == 0:
126+
return x
127+
128+
self.activation_post_process(x)
129+
scale, zero_point = self.activation_post_process.calculate_qparams()
130+
131+
return _fake_quantize_affine(
132+
x,
133+
self.block_size,
134+
scale,
135+
zero_point,
136+
quant_dtype=self.dtype,
137+
quant_min=self.quant_min,
138+
quant_max=self.quant_max,
139+
)
140+
141+
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
142+
return self.activation_post_process.calculate_qparams()
143+
144+
def convert(self, model, observer_node):
145+
self.activation_post_process.convert(model, observer_node)

backends/qualcomm/quantizer/qconfig.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111

1212
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
13+
PerBlockParamFakeQuantize,
1314
PerBlockParamObserver,
1415
)
1516
from torch import Tensor
@@ -71,11 +72,18 @@ def _derive_bias_qparams_fn(
7172
derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to(
7273
torch.int32
7374
)
74-
if isinstance(weight_obs_or_fq, PerBlockParamObserver):
75+
76+
# Handle per-block quantization for both observer and fake quantize
77+
weight_observer = weight_obs_or_fq
78+
if isinstance(weight_obs_or_fq, PerBlockParamFakeQuantize):
79+
# Extract the underlying observer from the fake quantize wrapper
80+
weight_observer = weight_obs_or_fq.activation_post_process
81+
82+
if isinstance(weight_observer, PerBlockParamObserver):
7583
# keep maximum scale of each channel for bias
7684
derived_scale = (
7785
derived_scale.view(derived_scale.size(0), -1).amax(dim=-1)
78-
/ weight_obs_or_fq.num_steps
86+
/ weight_observer.num_steps
7987
)
8088
derived_zero = derived_zero.view(derived_zero.size(0), -1).amax(dim=-1)
8189
return (derived_scale, derived_zero)
@@ -468,6 +476,92 @@ def get_ptq_per_block_quant_config(
468476
)
469477

470478

479+
def get_qat_per_block_quant_config(
480+
act_dtype=torch.uint8,
481+
weight_dtype=torch.int8,
482+
act_observer=MovingAverageMinMaxObserver,
483+
act_symmetric: bool = False,
484+
ch_axis: int = 0,
485+
) -> QuantizationConfig:
486+
supported_act_types = {
487+
torch.uint8,
488+
torch.uint16,
489+
torch.int8,
490+
torch.int16,
491+
}
492+
supported_weight_dtypes = {torch.int4, torch.int8}
493+
assert (
494+
act_dtype in supported_act_types
495+
), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
496+
497+
assert (
498+
weight_dtype in supported_weight_dtypes
499+
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
500+
501+
# torch does not support uint16 quantization, use int32 to bypass
502+
if act_symmetric:
503+
# If zero_point is 128, htp can do optimizations.
504+
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
505+
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
506+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
507+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
508+
qscheme=torch.per_tensor_symmetric,
509+
observer=act_observer,
510+
)
511+
act_quantization_spec = QuantizationSpec(
512+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
513+
qscheme=torch.per_tensor_symmetric,
514+
ch_axis=0,
515+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
516+
)
517+
else:
518+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
519+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
520+
quant_min=torch.iinfo(act_dtype).min,
521+
quant_max=torch.iinfo(act_dtype).max,
522+
qscheme=torch.per_tensor_affine,
523+
observer=act_observer,
524+
)
525+
act_quantization_spec = QuantizationSpec(
526+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
527+
quant_min=torch.iinfo(act_dtype).min,
528+
quant_max=torch.iinfo(act_dtype).max,
529+
qscheme=torch.per_tensor_affine,
530+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
531+
)
532+
533+
weight_fake_quant_ctr = PerBlockParamFakeQuantize.with_args(
534+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
535+
quant_min=(
536+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
537+
),
538+
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
539+
qscheme=torch.per_channel_symmetric,
540+
ch_axis=ch_axis,
541+
)
542+
weight_quantization_spec = QuantizationSpec(
543+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
544+
quant_min=(
545+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
546+
),
547+
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
548+
qscheme=torch.per_channel_symmetric,
549+
ch_axis=ch_axis,
550+
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
551+
)
552+
553+
bias_quantization_spec = _derived_bias_quant_spec
554+
555+
quantization_config = QuantizationConfig(
556+
input_activation=act_quantization_spec,
557+
output_activation=act_quantization_spec,
558+
weight=weight_quantization_spec,
559+
bias=bias_quantization_spec,
560+
)
561+
562+
return quantization_config
563+
564+
471565
# TODO merge qat and ptq to a function, and use a bool flag to control it
472566
def get_8a8w_qnn_qat_config(
473567
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver

backends/qualcomm/quantizer/quantizer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_8a8w_qnn_qat_config,
3030
get_ptq_per_block_quant_config,
3131
get_ptq_per_channel_quant_config,
32+
get_qat_per_block_quant_config,
3233
get_qat_per_channel_quant_config,
3334
QuantizationConfig,
3435
)
@@ -131,6 +132,19 @@ class QuantDtype(IntEnum):
131132
),
132133
None,
133134
),
135+
(QuantDtype.use_16a4w_block, True): (
136+
get_16a4w_qnn_qat_config,
137+
partial(
138+
get_qat_per_channel_quant_config,
139+
act_dtype=torch.uint16,
140+
weight_dtype=torch.int4,
141+
),
142+
partial(
143+
get_qat_per_block_quant_config,
144+
act_dtype=torch.uint16,
145+
weight_dtype=torch.int4,
146+
),
147+
),
134148
(QuantDtype.use_8a8w, True): (
135149
get_8a8w_qnn_qat_config,
136150
partial(get_qat_per_channel_quant_config),

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,6 +2054,25 @@ def test_qnn_backend_16a4w_conv2d_qat(self):
20542054
)
20552055
self.lower_module_and_test_output(converted, sample_input)
20562056

2057+
def test_qnn_backend_16a4w_block_conv2d_qat(self):
2058+
modules = [
2059+
Conv2dSingle(in_channel=64, out_channel=64),
2060+
Conv2dSingle(bias=False),
2061+
] # noqa: F405
2062+
sample_input = (torch.randn([64, 64, 3, 3]),)
2063+
for i, module in enumerate(modules):
2064+
with self.subTest(i=i):
2065+
prepared = self.get_prepared_qat_module(
2066+
module,
2067+
sample_input,
2068+
quant_dtype=QuantDtype.use_16a4w_block,
2069+
block_size_map={"conv2d": (1, 32, 1, 1)},
2070+
)
2071+
converted = self.get_converted_sgd_trained_module(
2072+
module, prepared, sample_input
2073+
)
2074+
self.lower_module_and_test_output(converted, sample_input)
2075+
20572076
def test_qnn_backend_16a4w_layer_norm(self):
20582077
module = LayerNorm() # noqa: F405
20592078
sample_input = (torch.randn(196, 768),)

0 commit comments

Comments
 (0)