Skip to content

Commit a431b9a

Browse files
committed
fix safetensors
1 parent dd8f231 commit a431b9a

File tree

4 files changed

+63
-55
lines changed

4 files changed

+63
-55
lines changed

src/transformers/modeling_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
is_torch_mlu_available,
113113
is_torch_npu_available,
114114
is_torch_xla_available,
115+
is_torchao_available,
115116
logging,
116117
)
117118
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
@@ -148,6 +149,9 @@
148149
else:
149150
IS_SAGEMAKER_MP_POST_1_10 = False
150151

152+
if is_torchao_available():
153+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
154+
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
151155

152156
logger = logging.get_logger(__name__)
153157

@@ -545,6 +549,7 @@ def _infer_parameter_dtype(
545549
QuantizationMethod.QUARK,
546550
QuantizationMethod.MXFP4,
547551
QuantizationMethod.BITS_AND_BYTES,
552+
QuantizationMethod.TORCHAO,
548553
}:
549554
return True, None
550555
else:
@@ -659,7 +664,10 @@ def _load_state_dict_into_meta_model(
659664
if param_device == "disk":
660665
if not is_safetensors:
661666
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
662-
elif not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
667+
elif not is_quantized or (
668+
not is_metadata_torchao(hf_quantizer.metadata)
669+
and not hf_quantizer.param_needs_quantization(model, param_name)
670+
):
663671
if is_fsdp_enabled():
664672
param_device = "cpu" if is_local_dist_rank_0() else "meta"
665673

@@ -4808,6 +4816,9 @@ def _load_pretrained_model(
48084816
_error_msgs, disk_offload_index = load_shard_file(args)
48094817
error_msgs += _error_msgs
48104818

4819+
if hf_quantizer:
4820+
hf_quantizer.update_model_with_metadata(model, hf_quantizer.metadata)
4821+
48114822
# Save offloaded index if needed
48124823
if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
48134824
save_offload_index(disk_offload_index, disk_offload_folder)

src/transformers/quantizers/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class HfQuantizer(ABC):
6767

6868
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
6969
self.quantization_config = quantization_config
70+
self.metadata = {}
7071

7172
# -- Handle extra kwargs below --
7273
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
@@ -344,9 +345,9 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False):
344345
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
345346
return None, {}
346347

347-
def update_state_dict_with_metadata(self, state_dict, metadata):
348+
def update_model_with_metadata(self, model, metadata):
348349
"""Update state dict with metadata. Default behaviour returns state_dict"""
349-
return state_dict
350+
pass
350351

351352
@abstractmethod
352353
def is_serializable(self, safe_serialization=None): ...

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 36 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import importlib
1515
import re
1616
import types
17-
from collections import defaultdict
1817
from typing import TYPE_CHECKING, Optional, Union
1918

2019
from packaging import version
@@ -87,6 +86,9 @@ def _linear_extra_repr(self):
8786
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
8887
torchao.quantization.Float8WeightOnlyConfig,
8988
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
89+
torchao.quantization.Int4WeightOnlyConfig,
90+
torchao.quantization.IntxWeightOnlyConfig,
91+
torchao.quantization.Int8DynamicActivationIntxWeightConfig,
9092
]
9193

9294
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
@@ -104,20 +106,6 @@ class TorchAoHfQuantizer(HfQuantizer):
104106
def __init__(self, quantization_config, **kwargs):
105107
super().__init__(quantization_config, **kwargs)
106108

107-
if isinstance(self.quantization_config.quant_type, str):
108-
is_int_4 = "int4" in self.quantization_config.quant_type
109-
else:
110-
config_name = self.quantization_config.quant_type.__class__.__name__
111-
is_int_4 = fuzzy_match_size(config_name) == "4"
112-
113-
# TODO: better way to get the serialized key names? Hard to read from torchao codebase
114-
if is_int_4:
115-
self.weight_ao_keys = ["qdata", "scale", "zero_point"]
116-
else:
117-
self.weight_ao_keys = ["qdata", "scale"]
118-
# Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this
119-
self.full_ao_keys = self.weight_ao_keys + ["_data"]
120-
121109
def validate_environment(self, *args, **kwargs):
122110
if not is_torchao_available():
123111
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
@@ -234,7 +222,7 @@ def _process_model_before_weight_loading(
234222
return
235223

236224
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
237-
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
225+
return [k for k in unexpected_keys if "_" not in k]
238226

239227
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
240228
if self.quantization_config.quant_type == "autoquant":
@@ -243,7 +231,7 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
243231
# check if the param_name is not in self.modules_to_not_convert
244232
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
245233
return False
246-
elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
234+
elif "_" in param_name:
247235
return True
248236
else:
249237
# we only quantize the weight of nn.Linear and nn.Embedding
@@ -253,6 +241,34 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
253241
_QUANTIZABLE.append(torch.nn.Embedding)
254242
return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
255243

244+
def update_model_with_metadata(self, model, metadata):
245+
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata):
246+
updated_state_dict = unflatten_tensor_state_dict(model.state_dict(), metadata)
247+
248+
weights_to_register = set(updated_state_dict.keys())
249+
250+
for name, param in list(model.named_parameters()):
251+
module_fqn, weight_name = name.rsplit(".", 1)
252+
module = model.get_submodule(module_fqn)
253+
weight = getattr(module, weight_name)
254+
255+
device = weight.device
256+
requires_grad = weight.requires_grad
257+
258+
if "_weight_" in weight_name:
259+
delattr(module, weight_name)
260+
261+
if name in weights_to_register:
262+
new_param_value = updated_state_dict[name]
263+
new_param = torch.nn.Parameter(new_param_value.to(device), requires_grad=requires_grad)
264+
module.register_parameter(weight_name, new_param)
265+
266+
weights_to_register.remove(name)
267+
268+
model.load_state_dict(updated_state_dict, strict=False)
269+
else:
270+
return super().update_model_with_metadata(model, metadata)
271+
256272
def create_quantized_param(
257273
self,
258274
model: "PreTrainedModel",
@@ -267,42 +283,12 @@ def create_quantized_param(
267283
"""
268284
from torchao.quantization import quantize_
269285

270-
full_name = param_name
271-
# Those are the pre quantized weights
272-
if ":" in param_name:
273-
param_name = param_name.rsplit(":", 1)[0]
274286
module, tensor_name = get_module_from_name(model, param_name)
275-
276287
if self.pre_quantized:
277-
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
278-
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
279-
is_unsafe_serialization = ":" not in full_name
280-
if tensor_name == "bias" or is_unsafe_serialization:
281-
module._parameters[tensor_name] = torch.nn.Parameter(
282-
param_value.to(target_device), requires_grad=param_value.requires_grad
283-
)
284-
return
285-
# Sanity check for the new serialization format
286-
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
287-
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
288-
289-
# Save the states for later quantization when they are all gathered
290-
if not hasattr(self, "ao_params"):
291-
self.ao_params = defaultdict(dict)
292-
self.ao_params[param_name].update({full_name: param_value})
293-
294-
# We are ready for quantization in this case (we retrieved all the needed keys)
295-
if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
296-
new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
297-
# Set it
298-
module._parameters[tensor_name] = torch.nn.Parameter(
299-
new_param.to(target_device), requires_grad=new_param.requires_grad
300-
)
301-
302-
# Free memory
303-
del self.ao_params[param_name]
288+
module._parameters[tensor_name] = torch.nn.Parameter(
289+
param_value.to(target_device), requires_grad=param_value.requires_grad
290+
)
304291

305-
# Add repr to the module
306292
if isinstance(module, nn.Linear):
307293
module.extra_repr = types.MethodType(_linear_extra_repr, module)
308294
else:

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def setUpClass(cls):
536536

537537
def setUp(self):
538538
self.quant_config = TorchAoConfig(self.quant_scheme)
539-
dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto"
539+
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
540540
self.quantized_model = AutoModelForCausalLM.from_pretrained(
541541
self.model_name,
542542
dtype=dtype,
@@ -552,7 +552,6 @@ def tearDown(self):
552552
def test_original_model_expected_output(self):
553553
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
554554
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
555-
556555
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
557556

558557
def check_serialization_expected_output(self, device, expected_output, safe_serialization=False):
@@ -578,6 +577,7 @@ class TorchAoSafeSerializationTest(TorchAoSerializationTest):
578577
# called only once for all test in this class
579578
@classmethod
580579
def setUpClass(cls):
580+
super().setUpClass()
581581
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
582582
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
583583

@@ -596,6 +596,16 @@ def tearDown(self):
596596
"What are we having for dinner?\n\nJess: (smiling) I",
597597
),
598598
(torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
599+
(Int4WeightOnlyConfig(), "What are we having for dinner?"),
600+
(
601+
Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"),
602+
"What are we having for dinner?\nRed, white, and green beans,",
603+
),
604+
(
605+
torchao.quantization.Int8DynamicActivationIntxWeightConfig(),
606+
"What are we having for dinner?\n\nJessica: (smiling)",
607+
),
608+
(torchao.quantization.IntxWeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
599609
]
600610
if is_torchao_available()
601611
else []

0 commit comments

Comments
 (0)