Skip to content

Commit 688d59c

Browse files
committed
working
1 parent 945394e commit 688d59c

File tree

2 files changed

+1
-12
lines changed

2 files changed

+1
-12
lines changed

src/transformers/integrations/torchao.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def convert(
282282

283283
print("calling unflatten")
284284
print(param_data)
285+
print(self.hf_quantizer.metadata)
285286
unflattened_state_dict, _ = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)
286287
print(f"unflattened_state_dict: {unflattened_state_dict}")
287288
new_param = unflattened_state_dict[full_layer_name]

src/transformers/modeling_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4012,8 +4012,6 @@ def from_pretrained(
40124012
if dtype_orig is not None:
40134013
torch.set_default_dtype(dtype_orig)
40144014

4015-
print("calling load pretrained model")
4016-
40174015
# Finalize model weight initialization
40184016
model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
40194017
model,
@@ -4140,7 +4138,6 @@ def _load_pretrained_model(
41404138
merged_state_dict = {}
41414139
i = 0
41424140
for file in checkpoint_files:
4143-
print(f"getting file {i}")
41444141
file_pointer = safe_open(file, framework="pt", device="cpu")
41454142
all_pointer.add(file_pointer)
41464143
for k in file_pointer.keys():
@@ -4177,15 +4174,6 @@ def _load_pretrained_model(
41774174
for k in all_pointer:
41784175
k.__exit__(None, None, None)
41794176

4180-
# from torchao.prototype.safetensors.safetensors_support import (
4181-
# unflatten_tensor_state_dict,
4182-
# )
4183-
# from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
4184-
# if is_metadata_torchao(hf_quantizer.metadata):
4185-
# print('calling unflatten', model.state_dict().keys(), hf_quantizer.metadata)
4186-
# unflattened_state_dict, _ = unflatten_tensor_state_dict(model.state_dict(), hf_quantizer.metadata)
4187-
# model.load_state_dict(unflattened_state_dict, strict=False)
4188-
41894177
# Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
41904178
model.mark_tied_weights_as_initialized()
41914179

0 commit comments

Comments
 (0)