1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from typing import Optional
16+
1517from ..activations import ACT2FN
18+ from ..core_model_loading import ConversionOps
19+ from ..quantizers .quantizers_utils import get_module_from_name , should_convert_module
1620from ..utils import is_accelerate_available , is_fbgemm_gpu_available , is_torch_available , logging
1721
1822
2933logger = logging .get_logger (__name__ )
3034
3135
36+ class FbgemmFp8Quantize (ConversionOps ):
37+ def __init__ (self , hf_quantizer ):
38+ self .hf_quantizer = hf_quantizer
39+
40+ def convert (
41+ self ,
42+ input_dict : dict [str , torch .Tensor | list [torch .Tensor ]],
43+ model : Optional [torch .nn .Module ] = None ,
44+ ** kwargs ,
45+ ) -> dict [str , torch .Tensor ]:
46+ target_key , value = tuple (input_dict .items ())[0 ]
47+ value = value [0 ]
48+
49+ from ..integrations import FbgemmFp8Llama4TextExperts
50+
51+ module , tensor_name = get_module_from_name (model , target_key )
52+
53+ if isinstance (module , FbgemmFp8Llama4TextExperts ):
54+ if tensor_name == "gate_up_proj" :
55+ # Process each expert separately
56+ # Transpose the second and third dimension
57+ transposed_param = value .transpose (1 , 2 )
58+
59+ # Reshape to 2D for quantization
60+ original_shape = transposed_param .shape
61+ flattened_param = transposed_param .reshape (- 1 , original_shape [- 1 ])
62+
63+ # Quantize using per row instead of per column
64+ new_value_flat , weight_scale_flat = torch .ops .fbgemm .quantize_fp8_per_row (flattened_param )
65+
66+ # Reshape back to original dimensions
67+ new_value = new_value_flat .reshape (original_shape )
68+ new_value = new_value .transpose (1 , 2 )
69+ weight_scale = weight_scale_flat .reshape (original_shape [0 ], 1 , original_shape [1 ])
70+ elif tensor_name == "down_proj" :
71+ # Process each expert separately
72+ # Transpose the weights for proper quantization
73+ transposed_param = value .transpose (1 , 2 )
74+
75+ # Reshape to 2D for quantization
76+ original_shape = transposed_param .shape
77+ flattened_param = transposed_param .reshape (- 1 , original_shape [- 1 ])
78+
79+ # Quantize using per column
80+ new_value_flat , weight_scale_flat = torch .ops .fbgemm .quantize_fp8_per_row (flattened_param )
81+
82+ # Reshape back to original dimensions
83+ new_value = new_value_flat .reshape (original_shape )
84+ new_value = new_value .transpose (1 , 2 )
85+ weight_scale = weight_scale_flat .reshape (original_shape [0 ], original_shape [1 ], 1 )
86+ else :
87+ new_value , weight_scale = torch .ops .fbgemm .quantize_fp8_per_row (value )
88+ weight_scale = torch .nn .Parameter (weight_scale .view (weight_scale .shape [0 ], 1 ))
89+
90+ return {target_key : torch .nn .Parameter (new_value ), f"{ target_key } _scale" : weight_scale }
91+
92+
3293class FbgemmFp8Linear (torch .nn .Linear ):
33- def __init__ (self , in_features , out_features , bias , weight_dtype = torch .float32 ):
94+ def __init__ (self , in_features , out_features , bias , dtype = torch .float8_e4m3fn ):
3495 super ().__init__ (in_features , out_features , bias )
3596 self .in_features = in_features
3697 self .out_features = out_features
3798
38- self .weight = torch .nn .Parameter (torch .zeros ((out_features , in_features ), dtype = torch . float8_e4m3fn ))
39- self .weight_scale = torch .nn .Parameter (torch .zeros ((out_features , 1 ), dtype = weight_dtype ))
99+ self .weight = torch .nn .Parameter (torch .zeros ((out_features , in_features ), dtype = dtype ))
100+ self .weight_scale = torch .nn .Parameter (torch .zeros ((out_features , 1 ), dtype = torch . float32 ))
40101 self .register_buffer ("input_scale_ub" , torch .zeros ([1 ], dtype = torch .float ), persistent = False )
41102
42103 if bias :
43- self .bias = torch .nn .Parameter (torch .zeros ((self .out_features ), dtype = weight_dtype ))
104+ self .bias = torch .nn .Parameter (torch .zeros ((self .out_features ), dtype = torch . float32 ))
44105 else :
45106 self .bias = None
46107
@@ -154,90 +215,11 @@ def forward(self, hidden_states):
154215 return next_states .view (- 1 , self .hidden_size )
155216
156217
157- def _replace_with_fbgemm_fp8_linear (
158- model ,
159- modules_to_not_convert = None ,
160- current_key_name = None ,
161- quantization_config = None ,
162- has_been_replaced = False ,
163- pre_quantized = False ,
164- config = None ,
165- tp_plan = None ,
166- ):
167- """
168- Private method that wraps the recursion for module replacement.
169-
170- Returns the converted model and a boolean that indicates if the conversion has been successful or not.
171- """
172-
173- import re
174-
175- if current_key_name is None :
176- current_key_name = []
177-
178- for name , module in model .named_children ():
179- current_key_name .append (name )
180-
181- if (isinstance (module , nn .Linear )) and name not in modules_to_not_convert :
182- # Check if the current key is not in the `modules_to_not_convert`
183- current_key_name_str = "." .join (current_key_name )
184- if not any (
185- (key + "." in current_key_name_str ) or (key == current_key_name_str ) for key in modules_to_not_convert
186- ):
187- with init_empty_weights (include_buffers = True ):
188- in_features = module .in_features
189- out_features = module .out_features
190- model ._modules [name ] = FbgemmFp8Linear (
191- in_features ,
192- out_features ,
193- module .bias is not None ,
194- )
195- has_been_replaced = True
196-
197- # Force requires grad to False to avoid unexpected errors
198- model ._modules [name ].requires_grad_ (False )
199- # set non persistent buffer outside of init_empty_weights
200- model ._modules [name ].input_scale_ub = torch .tensor (
201- [quantization_config .activation_scale_ub ],
202- dtype = torch .float ,
203- )
204- if module .__class__ .__name__ == "Llama4TextExperts" and name not in modules_to_not_convert :
205- current_key_name_str = "." .join (current_key_name )
206- if not any (
207- (key + "." in current_key_name_str ) or (key == current_key_name_str ) for key in modules_to_not_convert
208- ):
209- with init_empty_weights (include_buffers = True ):
210- tp_plan [re .sub (r"\d+" , "*" , current_key_name_str + ".down_proj_scale" )] = None
211- model ._modules [name ] = FbgemmFp8Llama4TextExperts (
212- config .text_config ,
213- )
214- model ._modules [name ].input_scale_ub = torch .tensor (
215- [quantization_config .activation_scale_ub ], dtype = torch .float
216- )
217-
218- if len (list (module .children ())) > 0 :
219- _ , has_been_replaced = _replace_with_fbgemm_fp8_linear (
220- module ,
221- modules_to_not_convert ,
222- current_key_name ,
223- quantization_config ,
224- has_been_replaced = has_been_replaced ,
225- pre_quantized = pre_quantized ,
226- config = config ,
227- tp_plan = tp_plan ,
228- )
229- # Remove the last key for recursion
230- current_key_name .pop (- 1 )
231- return model , has_been_replaced
232-
233-
234218def replace_with_fbgemm_fp8_linear (
235219 model ,
236220 modules_to_not_convert = None ,
237- current_key_name = None ,
238221 quantization_config = None ,
239222 pre_quantized = False ,
240- config = None ,
241223 tp_plan = None ,
242224):
243225 """
@@ -254,26 +236,45 @@ def replace_with_fbgemm_fp8_linear(
254236 modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
255237 Names of the modules to not convert in `FP8Linear`. In practice we keep the `lm_head` in full precision
256238 for numerical stability reasons.
257- current_key_name (`list[`str`]`, *optional*):
258- An array to track the current key of the recursion. This is used to check whether the current key (part of
259- it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
260- `disk`).
261239 """
262240
263- modules_to_not_convert = ["lm_head" ] if modules_to_not_convert is None else modules_to_not_convert
264-
265- if quantization_config .modules_to_not_convert is not None :
266- modules_to_not_convert .extend (quantization_config .modules_to_not_convert )
267- modules_to_not_convert = list (set (modules_to_not_convert ))
268- model , has_been_replaced = _replace_with_fbgemm_fp8_linear (
269- model ,
270- modules_to_not_convert ,
271- current_key_name ,
272- quantization_config ,
273- pre_quantized = pre_quantized ,
274- config = config ,
275- tp_plan = tp_plan ,
276- )
241+ has_been_replaced = False
242+ module_kwargs = {} if pre_quantized else {"dtype" : None }
243+
244+ for module_name , module in model .named_modules ():
245+ if not should_convert_module (module_name , modules_to_not_convert ):
246+ continue
247+
248+ new_module = None
249+ with init_empty_weights (include_buffers = True ):
250+ if module .__class__ .__name__ == "Llama4TextExperts" :
251+ # TODO: make sure tp works later
252+ # if tp_plan is not None:
253+ # tp_key = re.sub(r"\d+", "*", f"{module_name}.down_proj_scale")
254+ # tp_plan[tp_key] = None
255+ text_config = getattr (model .config , "text_config" , model .config )
256+ new_module = FbgemmFp8Llama4TextExperts (text_config or model .config )
257+ elif isinstance (module , nn .Linear ):
258+ new_module = FbgemmFp8Linear (
259+ module .in_features ,
260+ module .out_features ,
261+ module .bias is not None ,
262+ ** module_kwargs ,
263+ )
264+ new_module .requires_grad_ (False )
265+
266+ if new_module is None :
267+ continue
268+
269+ if hasattr (new_module , "input_scale_ub" ):
270+ new_module .input_scale_ub = torch .tensor (
271+ [quantization_config .activation_scale_ub ],
272+ dtype = torch .float ,
273+ )
274+
275+ model .set_submodule (module_name , new_module )
276+ has_been_replaced = True
277+
277278 if not has_been_replaced :
278279 logger .warning (
279280 "You are loading your model using FP8 quantization but no linear modules were found in your model."
0 commit comments