@@ -4012,8 +4012,6 @@ def from_pretrained(
40124012 if dtype_orig is not None :
40134013 torch .set_default_dtype (dtype_orig )
40144014
4015- print ("calling load pretrained model" )
4016-
40174015 # Finalize model weight initialization
40184016 model , missing_keys , unexpected_keys , mismatched_keys , offload_index , error_msgs = cls ._load_pretrained_model (
40194017 model ,
@@ -4140,7 +4138,6 @@ def _load_pretrained_model(
41404138 merged_state_dict = {}
41414139 i = 0
41424140 for file in checkpoint_files :
4143- print (f"getting file { i } " )
41444141 file_pointer = safe_open (file , framework = "pt" , device = "cpu" )
41454142 all_pointer .add (file_pointer )
41464143 for k in file_pointer .keys ():
@@ -4177,15 +4174,6 @@ def _load_pretrained_model(
41774174 for k in all_pointer :
41784175 k .__exit__ (None , None , None )
41794176
4180- # from torchao.prototype.safetensors.safetensors_support import (
4181- # unflatten_tensor_state_dict,
4182- # )
4183- # from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
4184- # if is_metadata_torchao(hf_quantizer.metadata):
4185- # print('calling unflatten', model.state_dict().keys(), hf_quantizer.metadata)
4186- # unflattened_state_dict, _ = unflatten_tensor_state_dict(model.state_dict(), hf_quantizer.metadata)
4187- # model.load_state_dict(unflattened_state_dict, strict=False)
4188-
41894177 # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
41904178 model .mark_tied_weights_as_initialized ()
41914179
0 commit comments