Skip to content

Commit eb2e9f9

Browse files
sxufacebook-github-bot
authored andcommitted
Add 16a4w_block QAT config
Summary: Introduce a FakeQuantizer subclass. It falls back to LPBQ observer's `convert`. `_derived_bias_quant_spec` also looks for it to correctly derive bias scale. Differential Revision: D87194388
1 parent 529a265 commit eb2e9f9

File tree

3 files changed

+171
-3
lines changed

3 files changed

+171
-3
lines changed

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Tuple
88

99
import torch
10-
from torchao.quantization.pt2e import MappingType, PerBlock
10+
from torchao.quantization.pt2e import FakeQuantize, MappingType, PerBlock
1111
from torchao.quantization.pt2e._affine_quantization import (
1212
_get_reduction_params,
1313
AffineQuantizedMinMaxObserver,
@@ -89,3 +89,62 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
8989
self.preserve_zero,
9090
self.zero_point_domain,
9191
)
92+
93+
94+
class PerBlockParamFakeQuantize(FakeQuantize):
95+
def __init__(
96+
self,
97+
dtype: torch.dtype = torch.int8,
98+
block_size: torch.Size = None,
99+
quant_min: int = None,
100+
quant_max: int = None,
101+
eps: float = torch.finfo(torch.float32).eps,
102+
**kwargs,
103+
):
104+
super().__init__()
105+
assert block_size is not None, "block_size must be provided for per-block quantization"
106+
107+
self.activation_post_process = PerBlockParamObserver(
108+
dtype=dtype,
109+
block_size=block_size,
110+
quant_min=quant_min,
111+
quant_max=quant_max,
112+
eps=eps,
113+
**kwargs,
114+
)
115+
self.dtype = dtype
116+
self.block_size = block_size
117+
self.quant_min = quant_min if quant_min is not None else torch.iinfo(dtype).min
118+
self.quant_max = quant_max if quant_max is not None else torch.iinfo(dtype).max
119+
self.eps = eps
120+
121+
def forward(self, x: torch.Tensor) -> torch.Tensor:
122+
if x.numel() == 0:
123+
return x
124+
125+
self.activation_post_process(x)
126+
scale, zero_point = self.activation_post_process.calculate_qparams()
127+
128+
shape_for_reduction, reduction_dims = _get_reduction_params(
129+
self.block_size, x.size()
130+
)
131+
x_reshaped = x.view(shape_for_reduction)
132+
133+
# Unsqueeze scale and zero_point to match x_reshaped.
134+
for dim in reduction_dims:
135+
scale = scale.unsqueeze(dim)
136+
zero_point = zero_point.unsqueeze(dim)
137+
138+
x_quant = ((x_reshaped / scale).round() + zero_point).clamp(
139+
self.quant_min, self.quant_max
140+
)
141+
x_dequant = (x_quant - zero_point) * scale
142+
143+
x_fake_quant = x_dequant.view(x.size())
144+
return x_fake_quant
145+
146+
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
147+
return self.activation_post_process.calculate_qparams()
148+
149+
def convert(self, model, observer_node):
150+
self.activation_post_process.convert(model, observer_node)

backends/qualcomm/quantizer/qconfig.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111

1212
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
13+
PerBlockParamFakeQuantize,
1314
PerBlockParamObserver,
1415
)
1516
from torch import Tensor
@@ -71,11 +72,18 @@ def _derive_bias_qparams_fn(
7172
derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to(
7273
torch.int32
7374
)
74-
if isinstance(weight_obs_or_fq, PerBlockParamObserver):
75+
76+
# Handle per-block quantization for both observer and fake quantize
77+
weight_observer = weight_obs_or_fq
78+
if isinstance(weight_obs_or_fq, PerBlockParamFakeQuantize):
79+
# Extract the underlying observer from the fake quantize wrapper
80+
weight_observer = weight_obs_or_fq.activation_post_process
81+
82+
if isinstance(weight_observer, PerBlockParamObserver):
7583
# keep maximum scale of each channel for bias
7684
derived_scale = (
7785
derived_scale.view(derived_scale.size(0), -1).amax(dim=-1)
78-
/ weight_obs_or_fq.num_steps
86+
/ weight_observer.num_steps
7987
)
8088
derived_zero = derived_zero.view(derived_zero.size(0), -1).amax(dim=-1)
8189
return (derived_scale, derived_zero)
@@ -468,6 +476,93 @@ def get_ptq_per_block_quant_config(
468476
)
469477

470478

479+
def get_qat_per_block_quant_config(
480+
act_dtype=torch.uint8,
481+
weight_dtype=torch.int8,
482+
act_observer=MovingAverageMinMaxObserver,
483+
act_symmetric: bool = False,
484+
ch_axis: int = 0,
485+
) -> QuantizationConfig:
486+
supported_act_types = {
487+
torch.uint8,
488+
torch.uint16,
489+
torch.int8,
490+
torch.int16,
491+
}
492+
supported_weight_dtypes = {torch.int4, torch.int8}
493+
assert (
494+
act_dtype in supported_act_types
495+
), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
496+
497+
assert (
498+
weight_dtype in supported_weight_dtypes
499+
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
500+
501+
# torch does not support uint16 quantization, use int32 to bypass
502+
if act_symmetric:
503+
# If zero_point is 128, htp can do optimizations.
504+
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
505+
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
506+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
507+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
508+
qscheme=torch.per_tensor_symmetric,
509+
observer=act_observer,
510+
)
511+
act_quantization_spec = QuantizationSpec(
512+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
513+
qscheme=torch.per_tensor_symmetric,
514+
ch_axis=0,
515+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
516+
)
517+
else:
518+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
519+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
520+
quant_min=torch.iinfo(act_dtype).min,
521+
quant_max=torch.iinfo(act_dtype).max,
522+
qscheme=torch.per_tensor_affine,
523+
observer=act_observer,
524+
)
525+
act_quantization_spec = QuantizationSpec(
526+
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
527+
quant_min=torch.iinfo(act_dtype).min,
528+
quant_max=torch.iinfo(act_dtype).max,
529+
qscheme=torch.per_tensor_affine,
530+
observer_or_fake_quant_ctr=act_fake_quant_ctr,
531+
)
532+
533+
weight_fake_quant_ctr = PerBlockParamFakeQuantize.with_args(
534+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
535+
quant_min=(
536+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
537+
),
538+
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
539+
qscheme=torch.per_channel_symmetric,
540+
ch_axis=ch_axis,
541+
)
542+
weight_quantization_spec = QuantizationSpec(
543+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
544+
quant_min=(
545+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
546+
),
547+
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
548+
qscheme=torch.per_channel_symmetric,
549+
ch_axis=ch_axis,
550+
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
551+
)
552+
553+
bias_quantization_spec = _derived_bias_quant_spec
554+
555+
quantization_config = QuantizationConfig(
556+
input_activation=act_quantization_spec,
557+
output_activation=act_quantization_spec,
558+
weight=weight_quantization_spec,
559+
bias=bias_quantization_spec,
560+
)
561+
562+
return quantization_config
563+
564+
565+
471566
# TODO merge qat and ptq to a function, and use a bool flag to control it
472567
def get_8a8w_qnn_qat_config(
473568
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver

backends/qualcomm/quantizer/quantizer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
get_8a8w_qnn_ptq_config,
2929
get_8a8w_qnn_qat_config,
3030
get_ptq_per_block_quant_config,
31+
get_qat_per_block_quant_config,
3132
get_ptq_per_channel_quant_config,
3233
get_qat_per_channel_quant_config,
3334
QuantizationConfig,
@@ -131,6 +132,19 @@ class QuantDtype(IntEnum):
131132
),
132133
None,
133134
),
135+
(QuantDtype.use_16a4w_block, True): (
136+
get_16a4w_qnn_qat_config,
137+
partial(
138+
get_qat_per_channel_quant_config,
139+
act_dtype=torch.uint16,
140+
weight_dtype=torch.int4,
141+
),
142+
partial(
143+
get_qat_per_block_quant_config,
144+
act_dtype=torch.uint16,
145+
weight_dtype=torch.int4,
146+
),
147+
),
134148
(QuantDtype.use_8a8w, True): (
135149
get_8a8w_qnn_qat_config,
136150
partial(get_qat_per_channel_quant_config),

0 commit comments

Comments
 (0)