1414import importlib
1515import re
1616import types
17- from collections import defaultdict
1817from typing import TYPE_CHECKING , Optional , Union
1918
2019from packaging import version
3837if is_torchao_available ():
3938 import torchao
4039
41- if version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.14 .0" ):
40+ if version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.15 .0" ):
4241 from torchao .prototype .safetensors .safetensors_support import (
4342 flatten_tensor_state_dict ,
4443 unflatten_tensor_state_dict ,
@@ -87,6 +86,9 @@ def _linear_extra_repr(self):
8786 SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
8887 torchao .quantization .Float8WeightOnlyConfig ,
8988 torchao .quantization .Float8DynamicActivationFloat8WeightConfig ,
89+ torchao .quantization .Int4WeightOnlyConfig ,
90+ torchao .quantization .IntxWeightOnlyConfig ,
91+ torchao .quantization .Int8DynamicActivationIntxWeightConfig ,
9092 ]
9193
9294 TORCHAO_VERSION = version .parse (importlib .metadata .version ("torchao" ))
@@ -104,20 +106,6 @@ class TorchAoHfQuantizer(HfQuantizer):
104106 def __init__ (self , quantization_config , ** kwargs ):
105107 super ().__init__ (quantization_config , ** kwargs )
106108
107- if isinstance (self .quantization_config .quant_type , str ):
108- is_int_4 = "int4" in self .quantization_config .quant_type
109- else :
110- config_name = self .quantization_config .quant_type .__class__ .__name__
111- is_int_4 = fuzzy_match_size (config_name ) == "4"
112-
113- # TODO: better way to get the serialized key names? Hard to read from torchao codebase
114- if is_int_4 :
115- self .weight_ao_keys = ["qdata" , "scale" , "zero_point" ]
116- else :
117- self .weight_ao_keys = ["qdata" , "scale" ]
118- # Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this
119- self .full_ao_keys = self .weight_ao_keys + ["_data" ]
120-
121109 def validate_environment (self , * args , ** kwargs ):
122110 if not is_torchao_available ():
123111 raise ImportError ("Loading an torchao quantized model requires torchao library (`pip install torchao`)" )
@@ -168,11 +156,11 @@ def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool]
168156 the safetensors format.
169157 """
170158 if type (self .quantization_config .quant_type ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization :
171- if TORCHAO_VERSION >= version .parse ("0.14 .0" ):
159+ if TORCHAO_VERSION >= version .parse ("0.15 .0" ):
172160 return flatten_tensor_state_dict (model .state_dict ())
173161 else :
174162 raise RuntimeError (
175- f"In order to use safetensors with torchao, please use torchao version >= 0.14 .0. Current version: { TORCHAO_VERSION } "
163+ f"In order to use safetensors with torchao, please use torchao version >= 0.15 .0. Current version: { TORCHAO_VERSION } "
176164 )
177165 else :
178166 return None , {}
@@ -234,7 +222,7 @@ def _process_model_before_weight_loading(
234222 return
235223
236224 def update_unexpected_keys (self , model , unexpected_keys : list [str ]) -> list [str ]:
237- return [k for k in unexpected_keys if not any ( k . endswith ( x ) for x in self . full_ao_keys ) ]
225+ return [k for k in unexpected_keys if "_weight_" not in k ]
238226
239227 def param_needs_quantization (self , model : "PreTrainedModel" , param_name : str , ** kwargs ) -> bool :
240228 if self .quantization_config .quant_type == "autoquant" :
@@ -243,7 +231,7 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
243231 # check if the param_name is not in self.modules_to_not_convert
244232 if any (key + "." in param_name or key == param_name for key in self .modules_to_not_convert ):
245233 return False
246- elif any ( param_name . endswith ( f": { x } " ) for x in self . full_ao_keys ) :
234+ elif "_weight_" in param_name :
247235 return True
248236 else :
249237 # we only quantize the weight of nn.Linear and nn.Embedding
@@ -267,42 +255,12 @@ def create_quantized_param(
267255 """
268256 from torchao .quantization import quantize_
269257
270- full_name = param_name
271- # Those are the pre quantized weights
272- if ":" in param_name :
273- param_name = param_name .rsplit (":" , 1 )[0 ]
274258 module , tensor_name = get_module_from_name (model , param_name )
275-
276259 if self .pre_quantized :
277- # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
278- # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
279- is_unsafe_serialization = ":" not in full_name
280- if tensor_name == "bias" or is_unsafe_serialization :
281- module ._parameters [tensor_name ] = torch .nn .Parameter (
282- param_value .to (target_device ), requires_grad = param_value .requires_grad
283- )
284- return
285- # Sanity check for the new serialization format
286- elif not (TORCHAO_VERSION >= version .parse ("0.14.0" ) and is_metadata_torchao (self .metadata )):
287- raise ValueError ("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed" )
288-
289- # Save the states for later quantization when they are all gathered
290- if not hasattr (self , "ao_params" ):
291- self .ao_params = defaultdict (dict )
292- self .ao_params [param_name ].update ({full_name : param_value })
293-
294- # We are ready for quantization in this case (we retrieved all the needed keys)
295- if len (self .ao_params [param_name ]) == len (self .weight_ao_keys ):
296- new_param = unflatten_tensor_state_dict (self .ao_params [param_name ], self .metadata )[param_name ]
297- # Set it
298- module ._parameters [tensor_name ] = torch .nn .Parameter (
299- new_param .to (target_device ), requires_grad = new_param .requires_grad
300- )
301-
302- # Free memory
303- del self .ao_params [param_name ]
260+ module ._parameters [tensor_name ] = torch .nn .Parameter (
261+ param_value .to (target_device ), requires_grad = param_value .requires_grad
262+ )
304263
305- # Add repr to the module
306264 if isinstance (module , nn .Linear ):
307265 module .extra_repr = types .MethodType (_linear_extra_repr , module )
308266 else :
@@ -368,6 +326,32 @@ def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpo
368326
369327 def _process_model_after_weight_loading (self , model , ** kwargs ):
370328 """No process required for torchao quantized model"""
329+ if TORCHAO_VERSION >= version .parse ("0.15.0" ) and is_metadata_torchao (self .metadata ):
330+ updated_state_dict = unflatten_tensor_state_dict (model .state_dict (), self .metadata )
331+
332+ weights_to_register = set (updated_state_dict .keys ())
333+
334+ for name , param in list (model .named_parameters ()):
335+ module_fqn , weight_name = name .rsplit ("." , 1 )
336+ module = model .get_submodule (module_fqn )
337+ weight = getattr (module , weight_name )
338+
339+ device = weight .device
340+ requires_grad = weight .requires_grad
341+
342+ if "_weight_" in weight_name :
343+ delattr (module , weight_name )
344+
345+ if name in weights_to_register :
346+ new_param_value = updated_state_dict [name ]
347+ new_param = torch .nn .Parameter (new_param_value .to (device ), requires_grad = requires_grad )
348+ module .register_parameter (weight_name , new_param )
349+
350+ weights_to_register .remove (name )
351+
352+ model .load_state_dict (updated_state_dict , strict = False )
353+ return
354+
371355 if self .quantization_config .quant_type == "autoquant" :
372356 from torchao import autoquant
373357 from torchao .quantization import ALL_AUTOQUANT_CLASS_LIST
@@ -386,11 +370,11 @@ def is_serializable(self, safe_serialization=None) -> bool:
386370 if safe_serialization :
387371 _is_torchao_serializable = type (
388372 self .quantization_config .quant_type
389- ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version .parse ("0.14 .0" )
373+ ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version .parse ("0.15 .0" )
390374 if not _is_torchao_serializable :
391375 logger .warning (
392376 f"torchao quantized model only supports safe serialization for { SUPPORTED_SAFE_SERIALIZATION_CONFIGS } , \
393- and torchao version >= 0.14 .0, please set `safe_serialization` to False for \
377+ and torchao version >= 0.15 .0, please set `safe_serialization` to False for \
394378 { type (self .quantization_config .quant_type )} and { TORCHAO_VERSION } ."
395379 )
396380 return _is_torchao_serializable
0 commit comments