@@ -115,6 +115,8 @@ def convert(
115115 source_keys : list [str ],
116116 target_keys : list [str ],
117117 full_layer_name : str ,
118+ model ,
119+ missing_keys ,
118120 config ,
119121 ** kwargs ,
120122 ) -> dict [str , list [torch .Tensor ]]:
@@ -138,6 +140,8 @@ def convert(
138140 source_keys : list [str ],
139141 target_keys : list [str ],
140142 full_layer_name : str ,
143+ model ,
144+ missing_keys ,
141145 config ,
142146 ) -> dict [str , list [torch .Tensor ]]:
143147 tensors = next (iter (value .values ()))
@@ -163,6 +167,8 @@ def convert(
163167 source_keys : list [str ],
164168 target_keys : list [str ],
165169 full_layer_name : str ,
170+ model ,
171+ missing_keys ,
166172 config ,
167173 ) -> dict [str , torch .Tensor ]:
168174 if len (target_keys ) != 1 :
@@ -191,6 +197,8 @@ def convert(
191197 source_keys : list [str ],
192198 target_keys : list [str ],
193199 full_layer_name : str ,
200+ model ,
201+ missing_keys ,
194202 config ,
195203 ) -> dict [str , torch .Tensor ]:
196204 merged : dict [str , torch .Tensor ] = {}
@@ -220,6 +228,8 @@ def convert(
220228 source_keys : list [str ],
221229 target_keys : list [str ],
222230 full_layer_name : str ,
231+ model ,
232+ missing_keys ,
223233 config ,
224234 ) -> dict [str , list [torch .Tensor ]]:
225235 if len (value ) != len (self .sizes ):
@@ -258,6 +268,8 @@ def convert(
258268 source_keys : list [str ],
259269 target_keys : list [str ],
260270 full_layer_name : str ,
271+ model ,
272+ missing_keys ,
261273 config ,
262274 ) -> dict [str , list [torch .Tensor ]]:
263275 self .config = config
@@ -298,21 +310,28 @@ def add_tensor(self, target_key: str, source_key: str, source_pattern: str, futu
298310class WeightRenaming (WeightTransform ):
299311 # Special case of WeightTransform that only renames keys without any conversion.
300312
301- def convert (self , layer_name : str , config = None , quantizer = None , missing_keys : Optional [MutableSet [str ]] = None ):
313+ def convert (
314+ self ,
315+ layer_name : str ,
316+ model = None ,
317+ config = None ,
318+ hf_quantizer = None ,
319+ missing_keys : Optional [MutableSet [str ]] = None ,
320+ ):
302321 misc = {}
303322 for pattern , futures in self .collected_tensors .items ():
304323 self .collected_tensors [pattern ] = [future .result () for future in futures ]
305324
306325 collected_tensors = self .collected_tensors
307- if quantizer is not None and self .quantization_operation is not None :
326+ if hf_quantizer is not None and self .quantization_operation is not None :
308327 with log_to_misc (layer_name , misc , (self .collected_tensors , layer_name ), self .quantization_operation ):
309328 collected_tensors = self .quantization_operation .convert (
310329 self .collected_tensors ,
311330 source_keys = self .source_keys ,
312331 target_keys = self .target_keys ,
313332 full_layer_name = layer_name ,
333+ model = model ,
314334 config = config ,
315- quant_config = quantizer .quantization_config ,
316335 missing_keys = missing_keys ,
317336 )
318337
@@ -332,7 +351,14 @@ def __post_init__(self):
332351 if not self .operations :
333352 raise ValueError ("WeightConverter requires at least one operation." )
334353
335- def convert (self , layer_name : str , config = None , quantizer = None , missing_keys : Optional [MutableSet [str ]] = None ):
354+ def convert (
355+ self ,
356+ layer_name : str ,
357+ model = None ,
358+ config = None ,
359+ hf_quantizer = None ,
360+ missing_keys : Optional [MutableSet [str ]] = None ,
361+ ):
336362 misc = {}
337363 for pattern , futures in self .collected_tensors .items ():
338364 self .collected_tensors [pattern ] = [future .result () for future in futures ]
@@ -345,17 +371,19 @@ def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Op
345371 source_keys = self .source_keys ,
346372 target_keys = self .target_keys ,
347373 full_layer_name = layer_name ,
374+ model = model ,
348375 config = config ,
376+ missing_keys = missing_keys ,
349377 )
350- if quantizer is not None and self .quantization_operation is not None :
378+ if hf_quantizer is not None and self .quantization_operation is not None :
351379 with log_to_misc (layer_name , misc , (collected_tensors , layer_name ), self .quantization_operation ):
352380 collected_tensors = self .quantization_operation .convert (
353381 collected_tensors ,
354382 source_keys = self .source_keys ,
355383 target_keys = self .target_keys ,
356384 full_layer_name = layer_name ,
357385 config = config ,
358- quant_config = quantizer . quantization_config ,
386+ model = model ,
359387 missing_keys = missing_keys ,
360388 )
361389 return collected_tensors , misc
@@ -626,7 +654,6 @@ def convert_and_load_state_dict_in_model(
626654 ```
627655
628656 """
629-
630657 prefix = model .base_model_prefix
631658 tp_plan = tp_plan or {}
632659 device_map = device_map or {"" : "cpu" }
@@ -750,7 +777,11 @@ def convert_and_load_state_dict_in_model(
750777 pbar .refresh ()
751778 try :
752779 realized_value , misc = mapping .convert (
753- first_param_name , config = model .config , quantizer = hf_quantizer , missing_keys = missing_keys
780+ first_param_name ,
781+ model = model ,
782+ config = model .config ,
783+ hf_quantizer = hf_quantizer ,
784+ missing_keys = missing_keys ,
754785 )
755786 for target_name , param in realized_value .items ():
756787 param = param [0 ] if isinstance (param , list ) else param
0 commit comments