Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 28 additions & 38 deletions src/transformers/integrations/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

if is_torchao_available():
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
from torchao.prototype.safetensors.safetensors_support import (
unflatten_tensor_state_dict,
)
Expand Down Expand Up @@ -215,56 +215,46 @@ def convert(
missing_keys=None,
**kwargs,
) -> dict[str, torch.Tensor]:
if isinstance(self.hf_quantizer.quantization_config.quant_type, str):
is_int_4 = "int4" in self.hf_quantizer.quantization_config.quant_type
else:
config_name = self.hf_quantizer.quantization_config.quant_type.__class__.__name__
is_int_4 = fuzzy_match_size(config_name) == "4"

# Simple case if we gather layermsnorm weights, we can just return the value since they are not quantized
if "weight:_data" in input_dict.keys():
value = (
input_dict["weight:_data"][0]
if isinstance(input_dict["weight:_data"], list)
else input_dict["weight:_data"]
)
return {full_layer_name: value}

is_unsafe_serialization = ":" not in list(input_dict.keys())[0]
"""
Consolidates tensor subclass components before reconstructing the object

For example:
input_dict: {
"_weight_qdata": torch.Tensor,
"_weight_scale": torch.Tensor,
}
full_layer_name: "model.layers.0.self_attn.k_proj"

Comment on lines +227 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
full_layer_name: "model.layers.0.self_attn.k_proj"
full_layer_name: "model.layers.0.self_attn.k_proj.weight"

Given this, we reconstruct a Float8Tensor instance using the qdata and scale
and return it as a dictionary with the full_layer_name as the key and the recovered
Float8Tensor instance as the value.
"""
is_unsafe_serialization = "_weight_" not in list(input_dict.keys())[0]

param_data = {}
layer_name = ".".join(full_layer_name.split(".")[:-1])
if is_unsafe_serialization:
if isinstance(input_dict["weight"], list):
weight = input_dict["weight"][0]
else:
weight = input_dict["weight"]
else:
if isinstance(input_dict["weight:qdata"], list):
param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"][0]
else:
param_data[f"{full_layer_name}:qdata"] = input_dict["weight:qdata"]
for suffix in input_dict.keys():
assert len(input_dict[suffix]) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do an if/else and raise an error to follow the same pattern we use in transformers, but i'm not sure if it's necessary since if we have the suffix in input_dict it means we collected at least a tensor for that suffix

param_data[f"{layer_name}.{suffix}"] = input_dict[suffix][0]

if isinstance(input_dict["weight:scale"], list):
param_data[f"{full_layer_name}:scale"] = input_dict["weight:scale"][0]
else:
param_data[f"{full_layer_name}:scale"] = input_dict["weight:scale"]

if is_int_4:
if isinstance(input_dict["weight:zero_point"], list):
param_data[f"{full_layer_name}:zero_point"] = input_dict["weight:zero_point"][0]
else:
param_data[f"{full_layer_name}:zero_point"] = input_dict["weight:zero_point"]

# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
# If it's unsafe-serialized (i.e. not safetensors), no need for anything
if is_unsafe_serialization:
return {full_layer_name: weight}
# Sanity check for the new serialization format
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
# print("metadata", self.hf_quantizer.metadata)
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
elif not (TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.15.0` installed")

new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name]
unflattened_state_dict, leftover_state_dict = unflatten_tensor_state_dict(
param_data, self.hf_quantizer.metadata
)
assert not leftover_state_dict # there should be no unprocessed tensors
new_param = unflattened_state_dict[full_layer_name]

module, _ = get_module_from_name(model, full_layer_name)
# Add repr to the module
Expand Down
1 change: 0 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2104,7 +2104,6 @@ def set_decoder(self, decoder):
possible_module_names = ["language_model", "text_model", "decoder"]
for name in possible_module_names:
if hasattr(self, name):
print(name)
setattr(self, name, decoder)
Comment on lines -2107 to 2108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this

return

Expand Down
49 changes: 21 additions & 28 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
import torch.nn as nn

if is_torchao_available():
import torchao

if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
Expand Down Expand Up @@ -88,11 +86,6 @@ def _linear_extra_repr(self):


if is_torchao_available():
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
torchao.quantization.Float8WeightOnlyConfig,
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
]

TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))


Expand Down Expand Up @@ -171,12 +164,12 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool | None = F
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
the safetensors format.
"""
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
if TORCHAO_VERSION >= version.parse("0.14.0"):
if safe_serialization:
if TORCHAO_VERSION >= version.parse("0.15.0"):
return flatten_tensor_state_dict(model.state_dict())
else:
raise RuntimeError(
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
)
else:
return None, {}
Expand Down Expand Up @@ -241,6 +234,8 @@ def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]

def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
if "_weight_" in param_name:
return True
if self.pre_quantized:
return False
if self.quantization_config.quant_type == "autoquant":
Expand All @@ -249,8 +244,6 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
# check if the param_name is not in self.modules_to_not_convert
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
return False
elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this one, maybe just change it to (f"_{x} for x in self.full_ao_keys)" to be safe, at least with this we are confident it's correct. if there are no better general way to detect safetensors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is achieving the same thing as L237, i think we only need one check

return True

# we only quantize the weight of nn.Linear and nn.Embedding
module, tensor_name = get_module_from_name(model, param_name)
Expand Down Expand Up @@ -306,8 +299,8 @@ def create_quantized_param(
)
return
# Sanity check for the new serialization format
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
elif not (TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.metadata)):
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.15.0` installed")

# Save the states for later quantization when they are all gathered
if not hasattr(self, "ao_params"):
Expand Down Expand Up @@ -452,13 +445,10 @@ def _process_model_after_weight_loading(self, model, **kwargs):

def is_serializable(self, safe_serialization=None) -> bool:
if safe_serialization:
_is_torchao_serializable = type(
self.quantization_config.quant_type
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
if not _is_torchao_serializable:
_is_torchao_serializable = TORCHAO_VERSION >= version.parse("0.15.0")
if not TORCHAO_VERSION >= version.parse("0.15.0"):
logger.warning(
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
and torchao version >= 0.14.0, please set `safe_serialization` to False for \
f"torchao quantized model only supports safe serialization for torchao version >= 0.15.0, please set `safe_serialization` to False for \
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
)
return _is_torchao_serializable
Expand Down Expand Up @@ -548,15 +538,18 @@ def get_weight_conversions(self):
if self.pre_quantized:
return [
WeightConverter(
source_patterns=["weight:qdata", "weight:scale", "weight:zero_point"],
target_patterns="weight",
operations=[TorchAoDeserialize(self)],
),
WeightConverter(
source_patterns=["weight:_data"],
# TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
# note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
# thus, the order of source_patterns is intentional
source_patterns=[
"_weight_qdata",
"_weight_scale_and_zero",
"_weight_scale",
"_weight_zero_point",
"_weight_act_pre_scale",
],
Comment on lines +539 to +548
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is greedy but to match for example _weight_scale and not _weight_scale_and_zero you can just do something like _weight_scale$, but ordering the keys works as well 👍

target_patterns="weight",
operations=[TorchAoDeserialize(self)],
),
# used for unsafe serialization
]
return []
13 changes: 12 additions & 1 deletion tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ def check_serialization_expected_output(self, device, expected_output, safe_seri
dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto"
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization)

loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
tmpdirname, dtype=dtype, device_map=device, torch_dtype=dtype, use_safetensors=safe_serialization
)
Expand All @@ -738,7 +739,7 @@ def test_serialization_expected_output(self):


@require_torchao
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoSafeSerializationTest(TorchAoSerializationTest):
# called only once for all test in this class
@classmethod
Expand All @@ -763,6 +764,16 @@ def tearDown(self):
"What are we having for dinner?\n\nJess: (smiling) I",
),
(torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
(Int4WeightOnlyConfig(), "What are we having for dinner?"),
(
Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"),
"What are we having for dinner?\nRed, white, and green beans,",
),
(
torchao.quantization.Int8DynamicActivationIntxWeightConfig(),
"What are we having for dinner?\n\nJessica: (smiling)",
),
(torchao.quantization.IntxWeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
]
if is_torchao_available()
else []
Expand Down