@@ -79,18 +79,18 @@ def test_verify_invalid_recipes(
7979 Dtype .INT4 ,
8080 Mode .DYNAMIC_RANGE ,
8181 Algorithm .MIN_MAX ,
82- Granularity .BLOCKWISE ,
83- 32 ,
82+ Granularity .BLOCKWISE_32 ,
83+ ),
84+ (
85+ Dtype .FP32 ,
86+ Dtype .INT4 ,
87+ Mode .DYNAMIC_RANGE ,
88+ Algorithm .MIN_MAX ,
89+ Granularity .BLOCKWISE_128 ,
8490 ),
8591 ])
8692 def test_verify_valid_recipes (
87- self ,
88- activation ,
89- weight ,
90- mode ,
91- algo ,
92- granularity ,
93- block_size = None ,
93+ self , activation , weight , mode , algo , granularity
9494 ):
9595 quant_recipe .LayerQuantRecipe (
9696 activation , weight , mode , algo , granularity
@@ -108,21 +108,21 @@ def setUp(self):
108108 def _attention_int8_dynamic_recipe () -> quant_config .QuantConfig :
109109 return quant_config .QuantConfig (
110110 generative_recipe = quant_recipe .GenerativeQuantRecipe (
111- attention = quant_recipe_utils .create_layer_quant_int8_dynamic (),
111+ attention = quant_recipe_utils .create_layer_quant_dynamic (),
112112 )
113113 )
114114
115115 def _feedforward_int8_dynamic_recipe () -> quant_config .QuantConfig :
116116 return quant_config .QuantConfig (
117117 generative_recipe = quant_recipe .GenerativeQuantRecipe (
118- feedforward = quant_recipe_utils .create_layer_quant_int8_dynamic (),
118+ feedforward = quant_recipe_utils .create_layer_quant_dynamic (),
119119 )
120120 )
121121
122122 @parameterized .parameters ([
123123 (quant_recipes .full_fp16_recipe ()),
124- (quant_recipes .full_int8_dynamic_recipe ()),
125- (quant_recipes .full_int8_weight_only_recipe ()),
124+ (quant_recipes .full_dynamic_recipe ()),
125+ (quant_recipes .full_weight_only_recipe ()),
126126 (_attention_int8_dynamic_recipe ()),
127127 (_feedforward_int8_dynamic_recipe ()),
128128 ])
@@ -148,7 +148,7 @@ def test_quantize_convert_toy_weight_sharing(self):
148148 idx = torch .unsqueeze (torch .arange (0 , 100 , dtype = torch .int ), 0 )
149149 input_pos = torch .arange (0 , 100 , dtype = torch .int )
150150
151- quant_config = quant_recipes .full_int8_dynamic_recipe ()
151+ quant_config = quant_recipes .full_dynamic_recipe ()
152152 quantized_model = ai_edge_torch .convert (
153153 pytorch_model , (idx , input_pos ), quant_config = quant_config
154154 )
@@ -164,7 +164,9 @@ def test_quantize_convert_toy_blockwise(self):
164164 pytorch_model = toy_model .ToySingleLayerModel (config )
165165 idx = torch .unsqueeze (torch .arange (0 , 100 , dtype = torch .int ), 0 )
166166 input_pos = torch .arange (0 , 100 , dtype = torch .int )
167- quant_config = quant_recipes .all_supported_int4_dynamic_block_recipe (32 )
167+ quant_config = quant_recipes .full_dynamic_recipe (
168+ weight_dtype = Dtype .INT4 , granularity = Granularity .BLOCKWISE_32
169+ )
168170 quantized_model = ai_edge_torch .convert (
169171 pytorch_model , (idx , input_pos ), quant_config = quant_config
170172 )
@@ -175,17 +177,6 @@ def test_quantize_convert_toy_blockwise(self):
175177 "Quantized model isn't smaller than F32 model." ,
176178 )
177179
178- def test_unsupported_block_size (self ):
179- config = toy_model .get_model_config ()
180- pytorch_model = toy_model .ToySingleLayerModel (config )
181- idx = torch .unsqueeze (torch .arange (0 , 100 , dtype = torch .int ), 0 )
182- input_pos = torch .arange (0 , 100 , dtype = torch .int )
183- self .assertRaises (
184- ValueError ,
185- quant_recipes .all_supported_int4_dynamic_block_recipe ,
186- 36 ,
187- )
188-
189180 def test_quantize_convert_compare_toy (self ):
190181 self .skipTest ("b/338288901" )
191182 config = toy_model_with_kv_cache .get_model_config ()
0 commit comments