-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Open
Labels
Description
System Info
transformers version: 4.57.1
python version: 3.11
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
This example is using paligemma but really any model will do.
from transformers import AutoModelForImageTextToText
model = AutoModelForImageTextToText.from_pretrained("google/paligemma-3b-mix-224", device_map=torch.device("mps", index=0))If you specify an mps device with index=0, model loading will fail with
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[4], [line 1](vscode-notebook-cell:?execution_count=4&line=1)
----> [1](vscode-notebook-cell:?execution_count=4&line=1) model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=torch.device("mps", index=0))
File transformers/models/auto/auto_factory.py:604, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
602 if model_class.config_class == config.sub_configs.get("text_config", None):
603 config = config.get_text_config()
--> [604](transformers/models/auto/auto_factory.py:604) return model_class.from_pretrained(
605 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
606 )
607 raise ValueError(
608 f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
609 f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
610 )
File transformers/modeling_utils.py:277, in restore_default_dtype.<locals>._wrapper(*args, **kwargs)
275 old_dtype = torch.get_default_dtype()
276 try:
--> [277](https://file+.vscode-resource.vscode-cdn.net/Users/ericfang/sandbox/aiml/t1333/rhea/~/sandbox/aiml/t1333/rhea/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py:277) return func(*args, **kwargs)
278 finally:
279 torch.set_default_dtype(old_dtype)
File transformers/modeling_utils.py:5048, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
5038 if dtype_orig is not None:
5039 torch.set_default_dtype(dtype_orig)
5041 (
5042 model,
5043 missing_keys,
5044 unexpected_keys,
5045 mismatched_keys,
5046 offload_index,
5047 error_msgs,
-> [5048](transformers/modeling_utils.py:5048) ) = cls._load_pretrained_model(
5049 model,
5050 state_dict,
5051 checkpoint_files,
5052 pretrained_model_name_or_path,
5053 ignore_mismatched_sizes=ignore_mismatched_sizes,
5054 sharded_metadata=sharded_metadata,
5055 device_map=device_map,
5056 disk_offload_folder=offload_folder,
5057 dtype=dtype,
5058 hf_quantizer=hf_quantizer,
5059 keep_in_fp32_regex=keep_in_fp32_regex,
5060 device_mesh=device_mesh,
5061 key_mapping=key_mapping,
5062 weights_only=weights_only,
5063 )
5064 # make sure token embedding weights are still tied if needed
5065 model.tie_weights()
File transformers/modeling_utils.py:5468, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)
5465 args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
5467 for args in args_list:
-> [5468](https://file+.vscode-resource.vscode-cdn.net/Users/ericfang/sandbox/aiml/t1333/rhea/~/sandbox/aiml/t1333/rhea/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py:5468) _error_msgs, disk_offload_index = load_shard_file(args)
5469 error_msgs += _error_msgs
5471 # Save offloaded index if needed
File transformers/modeling_utils.py:843, in load_shard_file(args)
841 # Skip it with fsdp on ranks other than 0
842 elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
--> [843](transformers/modeling_utils.py:843) disk_offload_index = _load_state_dict_into_meta_model(
844 model,
845 state_dict,
846 shard_file,
847 reverse_key_renaming_mapping,
848 device_map=device_map,
849 disk_offload_folder=disk_offload_folder,
850 disk_offload_index=disk_offload_index,
851 hf_quantizer=hf_quantizer,
852 keep_in_fp32_regex=keep_in_fp32_regex,
853 device_mesh=device_mesh,
854 )
856 return error_msgs, disk_offload_index
File torch/utils/_contextlib.py:120, in context_decorator.<locals>.decorate_context(*args, **kwargs)
117 @functools.wraps(func)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> [120](torch/utils/_contextlib.py:120) return func(*args, **kwargs)
File ransformers/modeling_utils.py:748, in _load_state_dict_into_meta_model(model, state_dict, shard_file, reverse_renaming_mapping, device_map, disk_offload_folder, disk_offload_index, hf_quantizer, keep_in_fp32_regex, device_mesh)
740 hf_quantizer.create_quantized_param(
741 model,
742 param,
(...)
745 **sharding_kwargs,
746 )
747 else:
--> [748](transformers/modeling_utils.py:748) param = param[...]
749 if casting_dtype is not None:
750 param = param.to(casting_dtype)
File torch/cuda/__init__.py:403, in _lazy_init()
398 raise RuntimeError(
399 "Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
400 "multiprocessing, you must use the 'spawn' start method"
401 )
402 if not hasattr(torch._C, "_cuda_getDeviceCount"):
--> [403](torch/cuda/__init__.py:403) raise AssertionError("Torch not compiled with CUDA enabled")
404 if _cudart is None:
405 raise AssertionError(
406 "libcudart functions unavailable. It looks like you have a broken build?"
407 )
AssertionError: Torch not compiled with CUDA enabled
It appears _load_state_dict_into_meta_model does
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]which effectively soft assumes any torch device with an index is CUDA and moves on.
Expected behavior
This should complete without issue