Skip to content

Commit 9a309b0

Browse files
committed
fix safetensors
1 parent 020e713 commit 9a309b0

File tree

4 files changed

+53
-61
lines changed

4 files changed

+53
-61
lines changed

src/transformers/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def _infer_parameter_dtype(
545545
QuantizationMethod.QUARK,
546546
QuantizationMethod.MXFP4,
547547
QuantizationMethod.BITS_AND_BYTES,
548+
QuantizationMethod.TORCHAO,
548549
}:
549550
return True, None
550551
else:

src/transformers/quantizers/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class HfQuantizer(ABC):
6767

6868
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
6969
self.quantization_config = quantization_config
70+
self.metadata = {}
7071

7172
# -- Handle extra kwargs below --
7273
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
@@ -344,10 +345,6 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False):
344345
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
345346
return None, {}
346347

347-
def update_state_dict_with_metadata(self, state_dict, metadata):
348-
"""Update state dict with metadata. Default behaviour returns state_dict"""
349-
return state_dict
350-
351348
@abstractmethod
352349
def is_serializable(self, safe_serialization=None): ...
353350

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import importlib
1515
import re
1616
import types
17-
from collections import defaultdict
1817
from typing import TYPE_CHECKING, Optional, Union
1918

2019
from packaging import version
@@ -38,7 +37,7 @@
3837
if is_torchao_available():
3938
import torchao
4039

41-
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
40+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
4241
from torchao.prototype.safetensors.safetensors_support import (
4342
flatten_tensor_state_dict,
4443
unflatten_tensor_state_dict,
@@ -87,6 +86,9 @@ def _linear_extra_repr(self):
8786
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
8887
torchao.quantization.Float8WeightOnlyConfig,
8988
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
89+
torchao.quantization.Int4WeightOnlyConfig,
90+
torchao.quantization.IntxWeightOnlyConfig,
91+
torchao.quantization.Int8DynamicActivationIntxWeightConfig,
9092
]
9193

9294
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
@@ -104,20 +106,6 @@ class TorchAoHfQuantizer(HfQuantizer):
104106
def __init__(self, quantization_config, **kwargs):
105107
super().__init__(quantization_config, **kwargs)
106108

107-
if isinstance(self.quantization_config.quant_type, str):
108-
is_int_4 = "int4" in self.quantization_config.quant_type
109-
else:
110-
config_name = self.quantization_config.quant_type.__class__.__name__
111-
is_int_4 = fuzzy_match_size(config_name) == "4"
112-
113-
# TODO: better way to get the serialized key names? Hard to read from torchao codebase
114-
if is_int_4:
115-
self.weight_ao_keys = ["qdata", "scale", "zero_point"]
116-
else:
117-
self.weight_ao_keys = ["qdata", "scale"]
118-
# Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this
119-
self.full_ao_keys = self.weight_ao_keys + ["_data"]
120-
121109
def validate_environment(self, *args, **kwargs):
122110
if not is_torchao_available():
123111
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
@@ -168,11 +156,11 @@ def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool]
168156
the safetensors format.
169157
"""
170158
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
171-
if TORCHAO_VERSION >= version.parse("0.14.0"):
159+
if TORCHAO_VERSION >= version.parse("0.15.0"):
172160
return flatten_tensor_state_dict(model.state_dict())
173161
else:
174162
raise RuntimeError(
175-
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
163+
f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
176164
)
177165
else:
178166
return None, {}
@@ -234,7 +222,7 @@ def _process_model_before_weight_loading(
234222
return
235223

236224
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
237-
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
225+
return [k for k in unexpected_keys if "_weight_" not in k]
238226

239227
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
240228
if self.quantization_config.quant_type == "autoquant":
@@ -243,7 +231,7 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
243231
# check if the param_name is not in self.modules_to_not_convert
244232
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
245233
return False
246-
elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
234+
elif "_weight_" in param_name:
247235
return True
248236
else:
249237
# we only quantize the weight of nn.Linear and nn.Embedding
@@ -267,42 +255,12 @@ def create_quantized_param(
267255
"""
268256
from torchao.quantization import quantize_
269257

270-
full_name = param_name
271-
# Those are the pre quantized weights
272-
if ":" in param_name:
273-
param_name = param_name.rsplit(":", 1)[0]
274258
module, tensor_name = get_module_from_name(model, param_name)
275-
276259
if self.pre_quantized:
277-
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
278-
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
279-
is_unsafe_serialization = ":" not in full_name
280-
if tensor_name == "bias" or is_unsafe_serialization:
281-
module._parameters[tensor_name] = torch.nn.Parameter(
282-
param_value.to(target_device), requires_grad=param_value.requires_grad
283-
)
284-
return
285-
# Sanity check for the new serialization format
286-
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
287-
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
288-
289-
# Save the states for later quantization when they are all gathered
290-
if not hasattr(self, "ao_params"):
291-
self.ao_params = defaultdict(dict)
292-
self.ao_params[param_name].update({full_name: param_value})
293-
294-
# We are ready for quantization in this case (we retrieved all the needed keys)
295-
if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
296-
new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
297-
# Set it
298-
module._parameters[tensor_name] = torch.nn.Parameter(
299-
new_param.to(target_device), requires_grad=new_param.requires_grad
300-
)
301-
302-
# Free memory
303-
del self.ao_params[param_name]
260+
module._parameters[tensor_name] = torch.nn.Parameter(
261+
param_value.to(target_device), requires_grad=param_value.requires_grad
262+
)
304263

305-
# Add repr to the module
306264
if isinstance(module, nn.Linear):
307265
module.extra_repr = types.MethodType(_linear_extra_repr, module)
308266
else:
@@ -368,6 +326,32 @@ def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpo
368326

369327
def _process_model_after_weight_loading(self, model, **kwargs):
370328
"""No process required for torchao quantized model"""
329+
if TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.metadata):
330+
updated_state_dict = unflatten_tensor_state_dict(model.state_dict(), self.metadata)
331+
332+
weights_to_register = set(updated_state_dict.keys())
333+
334+
for name, param in list(model.named_parameters()):
335+
module_fqn, weight_name = name.rsplit(".", 1)
336+
module = model.get_submodule(module_fqn)
337+
weight = getattr(module, weight_name)
338+
339+
device = weight.device
340+
requires_grad = weight.requires_grad
341+
342+
if "_weight_" in weight_name:
343+
delattr(module, weight_name)
344+
345+
if name in weights_to_register:
346+
new_param_value = updated_state_dict[name]
347+
new_param = torch.nn.Parameter(new_param_value.to(device), requires_grad=requires_grad)
348+
module.register_parameter(weight_name, new_param)
349+
350+
weights_to_register.remove(name)
351+
352+
model.load_state_dict(updated_state_dict, strict=False)
353+
return
354+
371355
if self.quantization_config.quant_type == "autoquant":
372356
from torchao import autoquant
373357
from torchao.quantization import ALL_AUTOQUANT_CLASS_LIST
@@ -386,11 +370,11 @@ def is_serializable(self, safe_serialization=None) -> bool:
386370
if safe_serialization:
387371
_is_torchao_serializable = type(
388372
self.quantization_config.quant_type
389-
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
373+
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.15.0")
390374
if not _is_torchao_serializable:
391375
logger.warning(
392376
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
393-
and torchao version >= 0.14.0, please set `safe_serialization` to False for \
377+
and torchao version >= 0.15.0, please set `safe_serialization` to False for \
394378
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
395379
)
396380
return _is_torchao_serializable

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,6 @@ def tearDown(self):
552552
def test_original_model_expected_output(self):
553553
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
554554
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
555-
556555
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
557556

558557
def check_serialization_expected_output(self, device, expected_output, safe_serialization=False):
@@ -573,11 +572,12 @@ def test_serialization_expected_output(self):
573572

574573

575574
@require_torchao
576-
@require_torchao_version_greater_or_equal("0.14.0")
575+
@require_torchao_version_greater_or_equal("0.15.0")
577576
class TorchAoSafeSerializationTest(TorchAoSerializationTest):
578577
# called only once for all test in this class
579578
@classmethod
580579
def setUpClass(cls):
580+
super().setUpClass()
581581
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
582582
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
583583

@@ -596,6 +596,16 @@ def tearDown(self):
596596
"What are we having for dinner?\n\nJess: (smiling) I",
597597
),
598598
(torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
599+
(Int4WeightOnlyConfig(), "What are we having for dinner?"),
600+
(
601+
Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"),
602+
"What are we having for dinner?\nRed, white, and green beans,",
603+
),
604+
(
605+
torchao.quantization.Int8DynamicActivationIntxWeightConfig(),
606+
"What are we having for dinner?\n\nJessica: (smiling)",
607+
),
608+
(torchao.quantization.IntxWeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
599609
]
600610
if is_torchao_available()
601611
else []

0 commit comments

Comments
 (0)