Skip to content

Commit 023e494

Browse files
rewu93copybara-github
authored andcommitted
Refactor blockwise quantization granularity.
PiperOrigin-RevId: 813968719
1 parent 6e125c5 commit 023e494

File tree

10 files changed

+71
-90
lines changed

10 files changed

+71
-90
lines changed

ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@ def convert_stable_diffusion_to_tflite(
138138
if not os.path.exists(output_dir):
139139
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
140140

141-
quant_config = (
142-
quant_recipes.full_int8_weight_only_recipe() if quantize else None
143-
)
141+
quant_config = quant_recipes.full_weight_only_recipe() if quantize else None
144142

145143
# TODO(yichunk): convert to multi signature tflite model.
146144
# CLIP text encoder

ai_edge_torch/generative/quantize/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def main():
3333
kv = kv_utils.KVCache.from_model_config(config)
3434

3535
# Create a quantization recipe to be applied to the model
36-
quant_config = quant_recipes.full_int8_dynamic_recipe()
36+
quant_config = quant_recipes.full_dynamic_recipe()
3737
print(quant_config)
3838

3939
# Convert with quantization

ai_edge_torch/generative/quantize/quant_attrs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ class Granularity(enum.Enum):
6363
NONE: Granularity not applicable to this quantization scheme.
6464
CHANNELWISE: Or per-channel quantization. Each channel of relevant tensors
6565
is quantized independently of one another.
66+
BLOCKWISE_32: Blockwise quantization with block size 32.
67+
BLOCKWISE_64: Blockwise quantization with block size 64.
68+
BLOCKWISE_128: Blockwise quantization with block size 128.
69+
BLOCKWISE_256: Blockwise quantization with block size 256.
6670
"""
6771

6872
NONE = enum.auto()
6973
CHANNELWISE = enum.auto()
70-
BLOCKWISE = enum.auto()
74+
BLOCKWISE_32 = enum.auto()
75+
BLOCKWISE_64 = enum.auto()
76+
BLOCKWISE_128 = enum.auto()
77+
BLOCKWISE_256 = enum.auto()

ai_edge_torch/generative/quantize/quant_recipe.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,13 @@ class LayerQuantRecipe:
3939
mode: Type of quantization.
4040
algorithm: Algorithm for calculating quantization parameters.
4141
granularity: Granularity of quantization.
42-
block_size: Size of the block for blockwise quantization.
4342
"""
4443

4544
activation_dtype: quant_attrs.Dtype
4645
weight_dtype: quant_attrs.Dtype
4746
mode: quant_attrs.Mode
4847
algorithm: quant_attrs.Algorithm
4948
granularity: quant_attrs.Granularity
50-
block_size: int = 0
5149

5250
def __str__(self):
5351
base_str = (
@@ -56,7 +54,6 @@ def __str__(self):
5654
f'{self.mode.name}, '
5755
f'{self.algorithm.name}, '
5856
f'{self.granularity.name}, '
59-
f'{self.block_size}'
6057
)
6158
return f'{base_str})'
6259

@@ -77,16 +74,6 @@ def verify(self):
7774
and self.algorithm == supported[3]
7875
and self.granularity == supported[4]
7976
):
80-
if self.block_size > 0:
81-
if (
82-
self.block_size % 32 == 0
83-
and self.granularity == quant_attrs.Granularity.BLOCKWISE
84-
):
85-
is_valid = True
86-
break
87-
else:
88-
is_valid = False
89-
break
9077
is_valid = True
9178
break
9279

ai_edge_torch/generative/quantize/quant_recipe_utils.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,29 @@
3232
from ai_edge_torch.generative.quantize import quant_recipe
3333

3434

35-
def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
35+
def create_layer_quant_dynamic(
36+
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
37+
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
38+
) -> quant_recipe.LayerQuantRecipe:
3639
return quant_recipe.LayerQuantRecipe(
3740
activation_dtype=quant_attrs.Dtype.FP32,
38-
weight_dtype=quant_attrs.Dtype.INT8,
41+
weight_dtype=weight_dtype,
3942
mode=quant_attrs.Mode.DYNAMIC_RANGE,
4043
algorithm=quant_attrs.Algorithm.MIN_MAX,
41-
granularity=quant_attrs.Granularity.CHANNELWISE,
44+
granularity=granularity,
4245
)
4346

4447

45-
def create_layer_quant_int8_weight_only() -> quant_recipe.LayerQuantRecipe:
48+
def create_layer_quant_weight_only(
49+
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
50+
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
51+
) -> quant_recipe.LayerQuantRecipe:
4652
return quant_recipe.LayerQuantRecipe(
4753
activation_dtype=quant_attrs.Dtype.FP32,
48-
weight_dtype=quant_attrs.Dtype.INT8,
54+
weight_dtype=weight_dtype,
4955
mode=quant_attrs.Mode.WEIGHT_ONLY,
5056
algorithm=quant_attrs.Algorithm.MIN_MAX,
51-
granularity=quant_attrs.Granularity.CHANNELWISE,
57+
granularity=granularity,
5258
)
5359

5460

@@ -60,16 +66,3 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
6066
algorithm=quant_attrs.Algorithm.FLOAT_CAST,
6167
granularity=quant_attrs.Granularity.NONE,
6268
)
63-
64-
65-
def create_layer_quant_int4_dynamic_block(
66-
block_size: int,
67-
) -> quant_recipe.LayerQuantRecipe:
68-
return quant_recipe.LayerQuantRecipe(
69-
activation_dtype=quant_attrs.Dtype.FP32,
70-
weight_dtype=quant_attrs.Dtype.INT4,
71-
mode=quant_attrs.Mode.DYNAMIC_RANGE,
72-
algorithm=quant_attrs.Algorithm.MIN_MAX,
73-
granularity=quant_attrs.Granularity.BLOCKWISE,
74-
block_size=block_size,
75-
)

ai_edge_torch/generative/quantize/quant_recipes.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,37 @@
2929

3030
from typing import Optional
3131
from ai_edge_torch.generative.layers import model_config
32+
from ai_edge_torch.generative.quantize import quant_attrs
3233
from ai_edge_torch.generative.quantize import quant_recipe
3334
from ai_edge_torch.generative.quantize import quant_recipe_utils
3435
from ai_edge_torch.quantize import quant_config
3536

3637

37-
def full_int8_dynamic_recipe(
38+
def full_dynamic_recipe(
3839
mcfg: Optional[model_config.ModelConfig] = None,
40+
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
41+
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
3942
) -> quant_config.QuantConfig:
4043
return quant_config.QuantConfig(
4144
generative_recipe=quant_recipe.GenerativeQuantRecipe(
42-
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
45+
default=quant_recipe_utils.create_layer_quant_dynamic(
46+
weight_dtype, granularity
47+
),
4348
_model_config=mcfg,
4449
)
4550
)
4651

4752

48-
def full_int8_weight_only_recipe(
53+
def full_weight_only_recipe(
4954
mcfg: Optional[model_config.ModelConfig] = None,
55+
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
56+
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
5057
) -> quant_config.QuantConfig:
5158
return quant_config.QuantConfig(
5259
generative_recipe=quant_recipe.GenerativeQuantRecipe(
53-
default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
60+
default=quant_recipe_utils.create_layer_quant_weight_only(
61+
weight_dtype, granularity
62+
),
5463
_model_config=mcfg,
5564
)
5665
)
@@ -65,17 +74,3 @@ def full_fp16_recipe(
6574
_model_config=mcfg,
6675
)
6776
)
68-
69-
70-
def all_supported_int4_dynamic_block_recipe(
71-
block_size: int,
72-
mcfg: Optional[model_config.ModelConfig] = None,
73-
) -> quant_config.QuantConfig:
74-
return quant_config.QuantConfig(
75-
generative_recipe=quant_recipe.GenerativeQuantRecipe(
76-
default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
77-
block_size
78-
),
79-
_model_config=mcfg,
80-
)
81-
)

ai_edge_torch/generative/quantize/supported_schemes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,7 @@ def get_supported_layer_schemes():
2929
(_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
3030
(_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
3131
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
32-
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE),
32+
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_32),
33+
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_64),
34+
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_128),
3335
]

ai_edge_torch/generative/test/test_quantize.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@ def test_verify_invalid_recipes(
7979
Dtype.INT4,
8080
Mode.DYNAMIC_RANGE,
8181
Algorithm.MIN_MAX,
82-
Granularity.BLOCKWISE,
83-
32,
82+
Granularity.BLOCKWISE_32,
83+
),
84+
(
85+
Dtype.FP32,
86+
Dtype.INT4,
87+
Mode.DYNAMIC_RANGE,
88+
Algorithm.MIN_MAX,
89+
Granularity.BLOCKWISE_128,
8490
),
8591
])
8692
def test_verify_valid_recipes(
87-
self,
88-
activation,
89-
weight,
90-
mode,
91-
algo,
92-
granularity,
93-
block_size=None,
93+
self, activation, weight, mode, algo, granularity
9494
):
9595
quant_recipe.LayerQuantRecipe(
9696
activation, weight, mode, algo, granularity
@@ -108,21 +108,21 @@ def setUp(self):
108108
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
109109
return quant_config.QuantConfig(
110110
generative_recipe=quant_recipe.GenerativeQuantRecipe(
111-
attention=quant_recipe_utils.create_layer_quant_int8_dynamic(),
111+
attention=quant_recipe_utils.create_layer_quant_dynamic(),
112112
)
113113
)
114114

115115
def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
116116
return quant_config.QuantConfig(
117117
generative_recipe=quant_recipe.GenerativeQuantRecipe(
118-
feedforward=quant_recipe_utils.create_layer_quant_int8_dynamic(),
118+
feedforward=quant_recipe_utils.create_layer_quant_dynamic(),
119119
)
120120
)
121121

122122
@parameterized.parameters([
123123
(quant_recipes.full_fp16_recipe()),
124-
(quant_recipes.full_int8_dynamic_recipe()),
125-
(quant_recipes.full_int8_weight_only_recipe()),
124+
(quant_recipes.full_dynamic_recipe()),
125+
(quant_recipes.full_weight_only_recipe()),
126126
(_attention_int8_dynamic_recipe()),
127127
(_feedforward_int8_dynamic_recipe()),
128128
])
@@ -148,7 +148,7 @@ def test_quantize_convert_toy_weight_sharing(self):
148148
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
149149
input_pos = torch.arange(0, 100, dtype=torch.int)
150150

151-
quant_config = quant_recipes.full_int8_dynamic_recipe()
151+
quant_config = quant_recipes.full_dynamic_recipe()
152152
quantized_model = ai_edge_torch.convert(
153153
pytorch_model, (idx, input_pos), quant_config=quant_config
154154
)
@@ -164,7 +164,9 @@ def test_quantize_convert_toy_blockwise(self):
164164
pytorch_model = toy_model.ToySingleLayerModel(config)
165165
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
166166
input_pos = torch.arange(0, 100, dtype=torch.int)
167-
quant_config = quant_recipes.all_supported_int4_dynamic_block_recipe(32)
167+
quant_config = quant_recipes.full_dynamic_recipe(
168+
weight_dtype=Dtype.INT4, granularity=Granularity.BLOCKWISE_32
169+
)
168170
quantized_model = ai_edge_torch.convert(
169171
pytorch_model, (idx, input_pos), quant_config=quant_config
170172
)
@@ -175,17 +177,6 @@ def test_quantize_convert_toy_blockwise(self):
175177
"Quantized model isn't smaller than F32 model.",
176178
)
177179

178-
def test_unsupported_block_size(self):
179-
config = toy_model.get_model_config()
180-
pytorch_model = toy_model.ToySingleLayerModel(config)
181-
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
182-
input_pos = torch.arange(0, 100, dtype=torch.int)
183-
self.assertRaises(
184-
ValueError,
185-
quant_recipes.all_supported_int4_dynamic_block_recipe,
186-
36,
187-
)
188-
189180
def test_quantize_convert_compare_toy(self):
190181
self.skipTest("b/338288901")
191182
config = toy_model_with_kv_cache.get_model_config()

ai_edge_torch/generative/utilities/converter.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2626
from ai_edge_torch.generative.layers import lora as lora_utils
2727
import ai_edge_torch.generative.layers.model_config as cfg
28+
from ai_edge_torch.generative.quantize import quant_attrs
2829
from ai_edge_torch.generative.quantize import quant_recipes
2930
from ai_edge_torch.generative.utilities import export_config as export_config_lib
3031
from ai_edge_torch.generative.utilities import litertlm_builder
@@ -193,18 +194,22 @@ def get_quant_recipe_from_flag(
193194
case QuantizationName.NONE:
194195
return None
195196
case QuantizationName.DYNAMIC_INT8:
196-
return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config)
197+
return quant_recipes.full_dynamic_recipe(mcfg=model_config)
197198
case QuantizationName.WEIGHT_ONLY_INT8:
198-
return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config)
199+
return quant_recipes.full_weight_only_recipe(mcfg=model_config)
199200
case QuantizationName.FP16:
200201
return quant_recipes.full_fp16_recipe()
201202
case QuantizationName.DYNAMIC_INT4_BLOCK32:
202-
return quant_recipes.all_supported_int4_dynamic_block_recipe(
203-
32, mcfg=model_config
203+
return quant_recipes.full_dynamic_recipe(
204+
mcfg=model_config,
205+
weight_dtype=quant_attrs.Dtype.INT4,
206+
granularity=quant_attrs.Granularity.BLOCKWISE_32,
204207
)
205208
case QuantizationName.DYNAMIC_INT4_BLOCK128:
206209
return quant_recipes.all_supported_int4_dynamic_block_recipe(
207-
128, mcfg=model_config
210+
mcfg=model_config,
211+
weight_dtype=quant_attrs.Dtype.INT4,
212+
granularity=quant_attrs.Granularity.BLOCKWISE_128,
208213
)
209214
case _:
210215
raise ValueError(f'Unsupported quantization flag: {quantize}')

ai_edge_torch/lowertools/translate_recipe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,12 @@ def _get_granularity(
8080
return _QuantGranularity.CHANNELWISE
8181
if granularity == quant_attrs.Granularity.NONE:
8282
return _QuantGranularity.TENSORWISE
83-
if granularity == quant_attrs.Granularity.BLOCKWISE:
84-
return _QuantGranularity.BLOCKWISE
83+
if granularity == quant_attrs.Granularity.BLOCKWISE_32:
84+
return _QuantGranularity.BLOCKWISE_32
85+
if granularity == quant_attrs.Granularity.BLOCKWISE_64:
86+
return _QuantGranularity.BLOCKWISE_64
87+
if granularity == quant_attrs.Granularity.BLOCKWISE_128:
88+
return _QuantGranularity.BLOCKWISE_128
8589
raise ValueError('Unimplemented granularity')
8690

8791

@@ -108,7 +112,6 @@ def _set_quant_config(
108112
symmetric=True,
109113
granularity=_get_granularity(layer_recipe.granularity),
110114
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
111-
block_size=layer_recipe.block_size,
112115
),
113116
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
114117
explicit_dequantize=_get_explicit_dequant_from_mode(

0 commit comments

Comments
 (0)