-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[torchao] safetensors #42529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[torchao] safetensors #42529
Changes from 5 commits
ad93162
c02b349
a213c68
f83b74c
0a4e634
d59bb71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ | |
|
|
||
| if is_torchao_available(): | ||
| TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao")) | ||
| if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"): | ||
| if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"): | ||
| from torchao.prototype.safetensors.safetensors_support import ( | ||
| unflatten_tensor_state_dict, | ||
| ) | ||
|
|
@@ -215,56 +215,46 @@ def convert( | |
| missing_keys=None, | ||
| **kwargs, | ||
| ) -> dict[str, torch.Tensor]: | ||
| if isinstance(self.hf_quantizer.quantization_config.quant_type, str): | ||
| is_int_4 = "int4" in self.hf_quantizer.quantization_config.quant_type | ||
| else: | ||
| config_name = self.hf_quantizer.quantization_config.quant_type.__class__.__name__ | ||
| is_int_4 = fuzzy_match_size(config_name) == "4" | ||
|
|
||
| # Simple case if we gather layermsnorm weights, we can just return the value since they are not quantized | ||
| if "weight:_data" in input_dict.keys(): | ||
| value = ( | ||
| input_dict["weight:_data"][0] | ||
| if isinstance(input_dict["weight:_data"], list) | ||
| else input_dict["weight:_data"] | ||
| ) | ||
| return {full_layer_name: value} | ||
|
|
||
| is_unsafe_serialization = ":" not in list(input_dict.keys())[0] | ||
| """ | ||
| Consolidates tensor subclass components before reconstructing the object | ||
|
|
||
| For example: | ||
| input_dict: { | ||
| "_weight_qdata": torch.Tensor, | ||
| "_weight_scale": torch.Tensor, | ||
| } | ||
| full_layer_name: "model.layers.0.self_attn.k_proj" | ||
|
|
||
| Given this, we reconstruct a Float8Tensor instance using the qdata and scale | ||
| and return it as a dictionary with the full_layer_name as the key and the recovered | ||
| Float8Tensor instance as the value. | ||
| """ | ||
| is_unsafe_serialization = "_weight_" not in list(input_dict.keys())[0] | ||
liangel-02 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| param_data = {} | ||
| layer_name = ".".join(full_layer_name.split(".")[:-1]) | ||
| if is_unsafe_serialization: | ||
| if isinstance(input_dict["weight"], list): | ||
| weight = input_dict["weight"][0] | ||
| else: | ||
| weight = input_dict["weight"] | ||
| else: | ||
| if isinstance(input_dict["weight:qdata"], list): | ||
| param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"][0] | ||
| else: | ||
| param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"] | ||
| for suffix in input_dict.keys(): | ||
| assert len(input_dict[suffix]) == 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's do an if/else and raise an error to follow the same pattern we use in transformers, but i'm not sure if it's necessary since if we have the |
||
| param_data[f"{layer_name}.{suffix}"] = input_dict[suffix][0] | ||
|
|
||
| if isinstance(input_dict["weight:scale"], list): | ||
| param_data[f"{full_layer_name}:scale"] = input_dict["weight:scale"][0] | ||
| else: | ||
| param_data[f"{full_layer_name}:scale"] = input_dict["weight:scale"] | ||
|
|
||
| if is_int_4: | ||
| if isinstance(input_dict["weight:zero_point"], list): | ||
| param_data[f"{full_layer_name}:zero_point"] = input_dict["weight:zero_point"][0] | ||
| else: | ||
| param_data[f"{full_layer_name}:zero_point"] = input_dict["weight:zero_point"] | ||
|
|
||
| # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was | ||
| # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either | ||
| # If it's unsafe-serialized (i.e. not safetensors), no need for anything | ||
| if is_unsafe_serialization: | ||
| return {full_layer_name: weight} | ||
| # Sanity check for the new serialization format | ||
| elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)): | ||
| # print("metadata", self.hf_quantizer.metadata) | ||
| raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") | ||
| elif not (TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.hf_quantizer.metadata)): | ||
| raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.15.0` installed") | ||
|
|
||
| new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name] | ||
| unflattened_state_dict, leftover_state_dict = unflatten_tensor_state_dict( | ||
| param_data, self.hf_quantizer.metadata | ||
| ) | ||
| assert not leftover_state_dict # there should be no unprocessed tensors | ||
| new_param = unflattened_state_dict[full_layer_name] | ||
|
|
||
| module, _ = get_module_from_name(model, full_layer_name) | ||
| # Add repr to the module | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2104,7 +2104,6 @@ def set_decoder(self, decoder): | |
| possible_module_names = ["language_model", "text_model", "decoder"] | ||
| for name in possible_module_names: | ||
| if hasattr(self, name): | ||
| print(name) | ||
| setattr(self, name, decoder) | ||
|
Comment on lines
-2107
to
2108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching this |
||
| return | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,9 +40,7 @@ | |
| import torch.nn as nn | ||
|
|
||
| if is_torchao_available(): | ||
| import torchao | ||
|
|
||
| if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"): | ||
| if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"): | ||
| from torchao.prototype.safetensors.safetensors_support import ( | ||
| flatten_tensor_state_dict, | ||
| unflatten_tensor_state_dict, | ||
|
|
@@ -88,11 +86,6 @@ def _linear_extra_repr(self): | |
|
|
||
|
|
||
| if is_torchao_available(): | ||
| SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [ | ||
| torchao.quantization.Float8WeightOnlyConfig, | ||
| torchao.quantization.Float8DynamicActivationFloat8WeightConfig, | ||
| ] | ||
|
|
||
| TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao")) | ||
|
|
||
|
|
||
|
|
@@ -171,12 +164,12 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool | None = F | |
| If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with | ||
| the safetensors format. | ||
| """ | ||
| if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization: | ||
| if TORCHAO_VERSION >= version.parse("0.14.0"): | ||
| if safe_serialization: | ||
| if TORCHAO_VERSION >= version.parse("0.15.0"): | ||
| return flatten_tensor_state_dict(model.state_dict()) | ||
| else: | ||
| raise RuntimeError( | ||
| f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}" | ||
| f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}" | ||
| ) | ||
| else: | ||
| return None, {} | ||
|
|
@@ -241,6 +234,8 @@ def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str] | |
| return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)] | ||
|
|
||
| def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: | ||
| if "_weight_" in param_name: | ||
| return True | ||
liangel-02 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.pre_quantized: | ||
| return False | ||
| if self.quantization_config.quant_type == "autoquant": | ||
|
|
@@ -249,8 +244,6 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, ** | |
| # check if the param_name is not in self.modules_to_not_convert | ||
| if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert): | ||
| return False | ||
| elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for this one, maybe just change it to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this is achieving the same thing as L237, i think we only need one check |
||
| return True | ||
|
|
||
| # we only quantize the weight of nn.Linear and nn.Embedding | ||
| module, tensor_name = get_module_from_name(model, param_name) | ||
|
|
@@ -306,8 +299,8 @@ def create_quantized_param( | |
| ) | ||
| return | ||
| # Sanity check for the new serialization format | ||
| elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)): | ||
| raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") | ||
| elif not (TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.metadata)): | ||
| raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.15.0` installed") | ||
|
|
||
| # Save the states for later quantization when they are all gathered | ||
| if not hasattr(self, "ao_params"): | ||
|
|
@@ -452,13 +445,10 @@ def _process_model_after_weight_loading(self, model, **kwargs): | |
|
|
||
| def is_serializable(self, safe_serialization=None) -> bool: | ||
| if safe_serialization: | ||
| _is_torchao_serializable = type( | ||
| self.quantization_config.quant_type | ||
| ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0") | ||
| if not _is_torchao_serializable: | ||
| _is_torchao_serializable = TORCHAO_VERSION >= version.parse("0.15.0") | ||
| if not TORCHAO_VERSION >= version.parse("0.15.0"): | ||
| logger.warning( | ||
| f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \ | ||
| and torchao version >= 0.14.0, please set `safe_serialization` to False for \ | ||
| f"torchao quantized model only supports safe serialization for torchao version >= 0.15.0, please set `safe_serialization` to False for \ | ||
| {type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}." | ||
| ) | ||
| return _is_torchao_serializable | ||
|
|
@@ -548,15 +538,18 @@ def get_weight_conversions(self): | |
| if self.pre_quantized: | ||
| return [ | ||
| WeightConverter( | ||
| source_patterns=["weight:qdata", "weight:scale", "weight:zero_point"], | ||
| target_patterns="weight", | ||
| operations=[TorchAoDeserialize(self)], | ||
| ), | ||
| WeightConverter( | ||
| source_patterns=["weight:_data"], | ||
| # TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_" | ||
| # note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect) | ||
| # thus, the order of source_patterns is intentional | ||
| source_patterns=[ | ||
| "_weight_qdata", | ||
| "_weight_scale_and_zero", | ||
| "_weight_scale", | ||
| "_weight_zero_point", | ||
| "_weight_act_pre_scale", | ||
| ], | ||
|
Comment on lines
+539
to
+548
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it is greedy but to match for example |
||
| target_patterns="weight", | ||
| operations=[TorchAoDeserialize(self)], | ||
| ), | ||
| # used for unsafe serialization | ||
| ] | ||
| return [] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.