Skip to content

Commit 661e342

Browse files
committed
update torchao safetensors
1 parent 688d59c commit 661e342

File tree

5 files changed

+29
-125
lines changed

5 files changed

+29
-125
lines changed

src/transformers/core_model_loading.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,16 @@
4848
logger = logging.get_logger(__name__)
4949

5050

51-
def extract_concrete_key(key: str, pattern: str, pattern_regex: re.Pattern) -> str:
51+
def extract_concrete_key_from_regex_pattern(key: str, pattern: str, pattern_regex: re.Pattern) -> str:
5252
match = pattern_regex.match(key)
5353
if not match:
5454
return pattern
5555

5656
groups = match.groups()
57-
wildcard_count = pattern.count("*")
58-
59-
if wildcard_count == 0:
60-
return pattern
61-
elif wildcard_count == 1:
62-
return pattern.replace("*", groups[0])
63-
else:
64-
parts = pattern.split("*")
65-
result = "*".join(parts[1:])
66-
for i, captured in enumerate(groups[1:], start=0):
67-
result = result.replace("*", str(captured), 1)
68-
return result
57+
parts = pattern.split("*")
58+
result = "*".join(parts[1:])
59+
result = result.replace("*", groups[1], 1)
60+
return result
6961

7062

7163
def build_glob_alternation(
@@ -469,7 +461,6 @@ def convert(
469461
)
470462

471463
collected_tensors = self.collected_tensors
472-
473464
for op in self.operations:
474465
with log_to_misc(layer_name, misc, (collected_tensors, layer_name), op):
475466
collected_tensors = op.convert(
@@ -552,7 +543,7 @@ def log_to_misc(
552543
try:
553544
yield
554545
except Exception as e:
555-
print(f"error: {e}")
546+
556547
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
557548
if curr_op is None:
558549
return None
@@ -567,7 +558,6 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
567558
if isinstance(extras, tuple) and len(extras) == 2:
568559
values, target_keys = extras
569560
descriptor = f"{op_name} " if op_name else ""
570-
# print(values)
571561
misc[first_target_key] = (
572562
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}"
573563
)
@@ -616,7 +606,7 @@ def set_param_for_module(
616606
param_value = param_value.to_local()
617607
if param_name not in module_obj._buffers:
618608
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
619-
print(f"removing {target_name} from missing keys")
609+
620610
# Remove from missing keys (it's either mismatched, or all good)
621611
missing_keys.discard(target_name)
622612
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
@@ -789,9 +779,6 @@ def convert_and_load_state_dict_in_model(
789779
```
790780
791781
"""
792-
print('in convert and load state dict')
793-
print(f"model state_dict keys: {model.state_dict().keys()}")
794-
print(f"state_dict keys: {state_dict}")
795782
prefix = model.base_model_prefix
796783
tp_plan = tp_plan or {}
797784
device_map = device_map or {"": "cpu"}
@@ -833,7 +820,6 @@ def convert_and_load_state_dict_in_model(
833820

834821
# 2. finally, collect the tensor into the proper converter
835822
if renamed_key in missing_keys:
836-
print(f"orignal key in state_dict: {original_key}, renamed_key: {renamed_key}, matched_pattern: {matched_pattern}")
837823
empty_param = meta_model_state_dict.get(renamed_key)
838824
# If we enter here, we have a WeightConverter operation to perform
839825
if source_pattern is not None:
@@ -863,7 +849,6 @@ def convert_and_load_state_dict_in_model(
863849
if matched_dtype_pattern is not None:
864850
_dtype = dtype_plan[matched_dtype_pattern.group()]
865851
elif empty_param is not None and empty_param.dtype != _dtype:
866-
print("using empty param")
867852
_dtype = empty_param.dtype # usually correct when initializing
868853

869854
# 4. Handle TP sharding or device_map placement -> scheduled materialization
@@ -891,7 +876,7 @@ def convert_and_load_state_dict_in_model(
891876
# If disk, we need to materialize on cpu first
892877
param_device = "cpu" if param_device == "disk" else param_device
893878
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
894-
print("adding tensor")
879+
895880
mapping.add_tensor(renamed_key, original_key, source_pattern, future)
896881
elif source_pattern is not None: # add all target keys as unexpected
897882
mapping = pattern_to_converter[source_pattern]
@@ -903,7 +888,6 @@ def convert_and_load_state_dict_in_model(
903888
total_entries = len(param_name_to_load)
904889
with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
905890
for first_param_name, mapping in param_name_to_load.items():
906-
print(f"first_param_name: {first_param_name}")
907891
pbar.update(1)
908892
pbar.set_postfix({"Materializing param": first_param_name})
909893
pbar.refresh()
@@ -917,7 +901,6 @@ def convert_and_load_state_dict_in_model(
917901
misc=misc,
918902
)
919903
for target_name, param in realized_value.items():
920-
print(f"target_name: {target_name}")
921904
param = param[0] if isinstance(param, list) else param
922905
device_match = device_map_regex.match(target_name)
923906
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
@@ -942,7 +925,6 @@ def convert_and_load_state_dict_in_model(
942925
# Cleanup the tensors
943926
mapping.reset()
944927
except SkipLayer:
945-
print(f"skipping layer {first_param_name}")
946928
continue
947929

948930
# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)

src/transformers/integrations/torchao.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -215,29 +215,10 @@ def convert(
215215
missing_keys=None,
216216
**kwargs,
217217
) -> dict[str, torch.Tensor]:
218-
print(f"in deserialize: {input_dict.keys(), full_layer_name}")
219-
if isinstance(self.hf_quantizer.quantization_config.quant_type, str):
220-
is_int_4 = "int4" in self.hf_quantizer.quantization_config.quant_type
221-
else:
222-
config_name = self.hf_quantizer.quantization_config.quant_type.__class__.__name__
223-
is_int_4 = fuzzy_match_size(config_name) == "4"
224-
225-
# Simple case if we gather layermsnorm weights, we can just return the value since they are not quantized
226-
# if "._weight__data" in input_dict.keys():
227-
# value = (
228-
# input_dict["_weight__data"][0]
229-
# if isinstance(input_dict["._weight__data"], list)
230-
# else input_dict["_weight__data"]
231-
# )
232-
# return {full_layer_name: value}
233-
234-
print(list(input_dict.keys())[0])
235218
is_unsafe_serialization = "_weight_" not in list(input_dict.keys())[0]
236219

237220
param_data = {}
238221
layer_name = '.'.join(full_layer_name.split(".")[:-1])
239-
print(f"layer_name: {layer_name}")
240-
print(is_unsafe_serialization)
241222
if is_unsafe_serialization:
242223
if isinstance(input_dict["weight"], list):
243224
weight = input_dict["weight"][0]
@@ -250,41 +231,14 @@ def convert(
250231
else:
251232
param_data[f"{layer_name}.{suffix}"] = input_dict[suffix]
252233

253-
# print("processing qdata")
254-
# if isinstance(input_dict["_weight_qdata"], list):
255-
# param_data[f"{layer_name}._weight_qdata"] = input_dict["_weight_qdata"][0]
256-
# else:
257-
# param_data[f"{layer_name}._weight_qdata"] = input_dict["_weight_qdata"]
258-
259-
# print("processing scale")
260-
# if isinstance(input_dict["_weight_scale"], list):
261-
# param_data[f"{layer_name}._weight_scale"] = input_dict["_weight_scale"][0]
262-
# else:
263-
# param_data[f"{layer_name}._weight_scale"] = input_dict["_weight_scale"]
264-
265-
# if is_int_4:
266-
# if isinstance(input_dict["weight:zero_point"], list):
267-
# param_data[f"{layer_name}:zero_point"] = input_dict["weight:zero_point"][0]
268-
# else:
269-
# param_data[f"{layer_name}:zero_point"] = input_dict["weight:zero_point"]
270-
271-
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
272-
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
234+
# If it's unsafe-serialized (i.e. not safetensors), no need for anything
273235
if is_unsafe_serialization:
274-
print("returning")
275236
return {full_layer_name: weight}
276237
# Sanity check for the new serialization format
277-
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
278-
# print("metadata", self.hf_quantizer.metadata)
279-
print("here")
280-
print(is_metadata_torchao(self.hf_quantizer.metadata))
238+
elif not (TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
281239
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
282240

283-
print("calling unflatten")
284-
print(param_data)
285-
print(self.hf_quantizer.metadata)
286241
unflattened_state_dict, _ = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)
287-
print(f"unflattened_state_dict: {unflattened_state_dict}")
288242
new_param = unflattened_state_dict[full_layer_name]
289243

290244
module, _ = get_module_from_name(model, full_layer_name)

src/transformers/modeling_utils.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,7 +2104,6 @@ def set_decoder(self, decoder):
21042104
possible_module_names = ["language_model", "text_model", "decoder"]
21052105
for name in possible_module_names:
21062106
if hasattr(self, name):
2107-
print(name)
21082107
setattr(self, name, decoder)
21092108
return
21102109

@@ -3111,8 +3110,6 @@ def save_pretrained(
31113110
metadata = {}
31123111
if hf_quantizer is not None:
31133112
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
3114-
print("saving")
3115-
print(state_dict)
31163113
metadata["format"] = "pt"
31173114

31183115
# Only save the model itself if we are using distributed training
@@ -3949,6 +3946,19 @@ def from_pretrained(
39493946

39503947
is_quantized = hf_quantizer is not None
39513948

3949+
weight_conversions: Optional[list[WeightConverter | WeightRenaming]] = None
3950+
model_type = getattr(config, "model_type", None)
3951+
if model_type is not None:
3952+
weight_conversions = get_checkpoint_conversion_mapping(model_type)
3953+
if weight_conversions is None:
3954+
weight_conversions = get_checkpoint_conversion_mapping("legacy")
3955+
if key_mapping is not None:
3956+
weight_conversions.extend(
3957+
[WeightRenaming(source_keys=k, target_keys=v) for k, v in key_mapping.items()]
3958+
)
3959+
if hf_quantizer is not None:
3960+
weight_conversions.extend(hf_quantizer.get_weight_conversions())
3961+
39523962
if gguf_file:
39533963
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
39543964

@@ -3988,19 +3998,6 @@ def from_pretrained(
39883998
use_kernels=use_kernels,
39893999
)
39904000

3991-
weight_conversions: Optional[list[WeightConverter | WeightRenaming]] = None
3992-
model_type = getattr(config, "model_type", None)
3993-
if model_type is not None:
3994-
weight_conversions = get_checkpoint_conversion_mapping(model_type)
3995-
if weight_conversions is None:
3996-
weight_conversions = get_checkpoint_conversion_mapping("legacy")
3997-
if key_mapping is not None:
3998-
weight_conversions.extend(
3999-
[WeightRenaming(source_keys=k, target_keys=v) for k, v in key_mapping.items()]
4000-
)
4001-
if hf_quantizer is not None:
4002-
weight_conversions.extend(hf_quantizer.get_weight_conversions())
4003-
40044001
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
40054002
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
40064003

@@ -4136,13 +4133,11 @@ def _load_pretrained_model(
41364133
# Checkpoints are safetensors
41374134
if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"):
41384135
merged_state_dict = {}
4139-
i = 0
41404136
for file in checkpoint_files:
41414137
file_pointer = safe_open(file, framework="pt", device="cpu")
41424138
all_pointer.add(file_pointer)
41434139
for k in file_pointer.keys():
41444140
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
4145-
i += 1
41464141
# User passed an explicit state_dict
41474142
elif state_dict is not None:
41484143
merged_state_dict = state_dict
@@ -4466,14 +4461,13 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None:
44664461
self.initialize_weights()
44674462

44684463
def _adjust_missing_and_unexpected_keys(
4469-
self, missing_keys: set[str], unexpected_keys: set[str],
4464+
self, missing_keys: set[str], unexpected_keys: set[str]
44704465
) -> tuple[set[str], set[str]]:
44714466
"""Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
44724467
raising unneeded warnings/errors.
44734468
Also, set the `_is_hf_initialized` on tied weight keys, to avoid initializing them as they are going to
44744469
be tied anyway.
44754470
"""
4476-
44774471
# Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
44784472
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
44794473
# `_keys_to_ignore_on_load_unexpected` as it touches many models -> we add it manually to the existing patterns

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def create_quantized_param(
297297
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
298298
"""
299299
from torchao.quantization import quantize_
300-
print("in create quantized param")
300+
301301
full_name = param_name
302302
# Those are the pre quantized weights
303303
if ":" in param_name:
@@ -554,20 +554,11 @@ def get_weight_conversions(self):
554554
from ..integrations.torchao import TorchAoDeserialize
555555

556556
if self.pre_quantized:
557-
print("pre_quantized")
558-
print(self.metadata)
559557
return [
560558
WeightConverter(
561-
# source_keys=["_weight_qdata", "_weight_scale", "_weight_zero_point"],
562559
source_keys=["*_weight_*"],
563560
target_keys="*weight",
564561
operations=[TorchAoDeserialize(self)],
565562
),
566-
# WeightConverter(
567-
# source_keys=["._weight__data"],
568-
# target_keys=".weight",
569-
# operations=[TorchAoDeserialize(self)],
570-
# ),
571-
# used for unsafe serialization
572563
]
573564
return []

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -712,11 +712,11 @@ def tearDown(self):
712712
backend_empty_cache(torch_device)
713713
gc.collect()
714714

715-
# def test_original_model_expected_output(self):
716-
# input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
717-
# output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
715+
def test_original_model_expected_output(self):
716+
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
717+
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
718718

719-
# self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
719+
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
720720

721721
def check_serialization_expected_output(self, device, expected_output, safe_serialization=False):
722722
"""
@@ -725,26 +725,9 @@ def check_serialization_expected_output(self, device, expected_output, safe_seri
725725
dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto"
726726
with tempfile.TemporaryDirectory() as tmpdirname:
727727
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
728-
729-
original_state_dict = self.quantized_model.state_dict()
730-
print(original_state_dict)
731-
732728
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
733729
tmpdirname, dtype=dtype, device_map=device, torch_dtype=dtype, use_safetensors=safe_serialization
734730
)
735-
736-
loaded_state_dict = loaded_quantized_model.state_dict()
737-
for key in original_state_dict:
738-
if not hasattr(original_state_dict[key], "qdata"):
739-
print(torch.equal(original_state_dict[key], loaded_state_dict[key]))
740-
continue
741-
print(original_state_dict[key].qdata)
742-
print(loaded_state_dict[key].qdata)
743-
if not torch.equal(original_state_dict[key].qdata, loaded_state_dict[key].qdata):
744-
print("not equal")
745-
print(f"key: {key}, {original_state_dict[key]}, {loaded_state_dict[key]}")
746-
print("equal")
747-
748731
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
749732

750733
output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)

0 commit comments

Comments
 (0)