|
10 | 10 | import torch |
11 | 11 |
|
12 | 12 | from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import ( |
| 13 | + PerBlockParamFakeQuantize, |
13 | 14 | PerBlockParamObserver, |
14 | 15 | ) |
15 | 16 | from torch import Tensor |
@@ -71,11 +72,18 @@ def _derive_bias_qparams_fn( |
71 | 72 | derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to( |
72 | 73 | torch.int32 |
73 | 74 | ) |
74 | | - if isinstance(weight_obs_or_fq, PerBlockParamObserver): |
| 75 | + |
| 76 | + # Handle per-block quantization for both observer and fake quantize |
| 77 | + weight_observer = weight_obs_or_fq |
| 78 | + if isinstance(weight_obs_or_fq, PerBlockParamFakeQuantize): |
| 79 | + # Extract the underlying observer from the fake quantize wrapper |
| 80 | + weight_observer = weight_obs_or_fq.activation_post_process |
| 81 | + |
| 82 | + if isinstance(weight_observer, PerBlockParamObserver): |
75 | 83 | # keep maximum scale of each channel for bias |
76 | 84 | derived_scale = ( |
77 | 85 | derived_scale.view(derived_scale.size(0), -1).amax(dim=-1) |
78 | | - / weight_obs_or_fq.num_steps |
| 86 | + / weight_observer.num_steps |
79 | 87 | ) |
80 | 88 | derived_zero = derived_zero.view(derived_zero.size(0), -1).amax(dim=-1) |
81 | 89 | return (derived_scale, derived_zero) |
@@ -468,6 +476,93 @@ def get_ptq_per_block_quant_config( |
468 | 476 | ) |
469 | 477 |
|
470 | 478 |
|
| 479 | +def get_qat_per_block_quant_config( |
| 480 | + act_dtype=torch.uint8, |
| 481 | + weight_dtype=torch.int8, |
| 482 | + act_observer=MovingAverageMinMaxObserver, |
| 483 | + act_symmetric: bool = False, |
| 484 | + ch_axis: int = 0, |
| 485 | +) -> QuantizationConfig: |
| 486 | + supported_act_types = { |
| 487 | + torch.uint8, |
| 488 | + torch.uint16, |
| 489 | + torch.int8, |
| 490 | + torch.int16, |
| 491 | + } |
| 492 | + supported_weight_dtypes = {torch.int4, torch.int8} |
| 493 | + assert ( |
| 494 | + act_dtype in supported_act_types |
| 495 | + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" |
| 496 | + |
| 497 | + assert ( |
| 498 | + weight_dtype in supported_weight_dtypes |
| 499 | + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" |
| 500 | + |
| 501 | + # torch does not support uint16 quantization, use int32 to bypass |
| 502 | + if act_symmetric: |
| 503 | + # If zero_point is 128, htp can do optimizations. |
| 504 | + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. |
| 505 | + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. |
| 506 | + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( |
| 507 | + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, |
| 508 | + qscheme=torch.per_tensor_symmetric, |
| 509 | + observer=act_observer, |
| 510 | + ) |
| 511 | + act_quantization_spec = QuantizationSpec( |
| 512 | + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, |
| 513 | + qscheme=torch.per_tensor_symmetric, |
| 514 | + ch_axis=0, |
| 515 | + observer_or_fake_quant_ctr=act_fake_quant_ctr, |
| 516 | + ) |
| 517 | + else: |
| 518 | + act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( |
| 519 | + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, |
| 520 | + quant_min=torch.iinfo(act_dtype).min, |
| 521 | + quant_max=torch.iinfo(act_dtype).max, |
| 522 | + qscheme=torch.per_tensor_affine, |
| 523 | + observer=act_observer, |
| 524 | + ) |
| 525 | + act_quantization_spec = QuantizationSpec( |
| 526 | + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, |
| 527 | + quant_min=torch.iinfo(act_dtype).min, |
| 528 | + quant_max=torch.iinfo(act_dtype).max, |
| 529 | + qscheme=torch.per_tensor_affine, |
| 530 | + observer_or_fake_quant_ctr=act_fake_quant_ctr, |
| 531 | + ) |
| 532 | + |
| 533 | + weight_fake_quant_ctr = PerBlockParamFakeQuantize.with_args( |
| 534 | + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, |
| 535 | + quant_min=( |
| 536 | + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 |
| 537 | + ), |
| 538 | + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, |
| 539 | + qscheme=torch.per_channel_symmetric, |
| 540 | + ch_axis=ch_axis, |
| 541 | + ) |
| 542 | + weight_quantization_spec = QuantizationSpec( |
| 543 | + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, |
| 544 | + quant_min=( |
| 545 | + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 |
| 546 | + ), |
| 547 | + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, |
| 548 | + qscheme=torch.per_channel_symmetric, |
| 549 | + ch_axis=ch_axis, |
| 550 | + observer_or_fake_quant_ctr=weight_fake_quant_ctr, |
| 551 | + ) |
| 552 | + |
| 553 | + bias_quantization_spec = _derived_bias_quant_spec |
| 554 | + |
| 555 | + quantization_config = QuantizationConfig( |
| 556 | + input_activation=act_quantization_spec, |
| 557 | + output_activation=act_quantization_spec, |
| 558 | + weight=weight_quantization_spec, |
| 559 | + bias=bias_quantization_spec, |
| 560 | + ) |
| 561 | + |
| 562 | + return quantization_config |
| 563 | + |
| 564 | + |
| 565 | + |
471 | 566 | # TODO merge qat and ptq to a function, and use a bool flag to control it |
472 | 567 | def get_8a8w_qnn_qat_config( |
473 | 568 | act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver |
|
0 commit comments