Skip to content

from_pretrained will fail if device_map is torch.device("mps", index=0) #41908

@oceanusxiv

Description

@oceanusxiv

System Info

transformers version: 4.57.1
python version: 3.11

Who can help?

@Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

This example is using paligemma but really any model will do.

from transformers import AutoModelForImageTextToText

model = AutoModelForImageTextToText.from_pretrained("google/paligemma-3b-mix-224", device_map=torch.device("mps", index=0))

If you specify an mps device with index=0, model loading will fail with

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[4], [line 1](vscode-notebook-cell:?execution_count=4&line=1)
----> [1](vscode-notebook-cell:?execution_count=4&line=1) model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=torch.device("mps", index=0))

File transformers/models/auto/auto_factory.py:604, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    602     if model_class.config_class == config.sub_configs.get("text_config", None):
    603         config = config.get_text_config()
--> [604](transformers/models/auto/auto_factory.py:604)     return model_class.from_pretrained(
    605         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    606     )
    607 raise ValueError(
    608     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    609     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
    610 )

File transformers/modeling_utils.py:277, in restore_default_dtype.<locals>._wrapper(*args, **kwargs)
    275 old_dtype = torch.get_default_dtype()
    276 try:
--> [277](https://file+.vscode-resource.vscode-cdn.net/Users/ericfang/sandbox/aiml/t1333/rhea/~/sandbox/aiml/t1333/rhea/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py:277)     return func(*args, **kwargs)
    278 finally:
    279     torch.set_default_dtype(old_dtype)

File transformers/modeling_utils.py:5048, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
   5038     if dtype_orig is not None:
   5039         torch.set_default_dtype(dtype_orig)
   5041     (
   5042         model,
   5043         missing_keys,
   5044         unexpected_keys,
   5045         mismatched_keys,
   5046         offload_index,
   5047         error_msgs,
-> [5048](transformers/modeling_utils.py:5048)     ) = cls._load_pretrained_model(
   5049         model,
   5050         state_dict,
   5051         checkpoint_files,
   5052         pretrained_model_name_or_path,
   5053         ignore_mismatched_sizes=ignore_mismatched_sizes,
   5054         sharded_metadata=sharded_metadata,
   5055         device_map=device_map,
   5056         disk_offload_folder=offload_folder,
   5057         dtype=dtype,
   5058         hf_quantizer=hf_quantizer,
   5059         keep_in_fp32_regex=keep_in_fp32_regex,
   5060         device_mesh=device_mesh,
   5061         key_mapping=key_mapping,
   5062         weights_only=weights_only,
   5063     )
   5064 # make sure token embedding weights are still tied if needed
   5065 model.tie_weights()

File transformers/modeling_utils.py:5468, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)
   5465         args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
   5467     for args in args_list:
-> [5468](https://file+.vscode-resource.vscode-cdn.net/Users/ericfang/sandbox/aiml/t1333/rhea/~/sandbox/aiml/t1333/rhea/.venv/lib/python3.11/site-packages/transformers/modeling_utils.py:5468)         _error_msgs, disk_offload_index = load_shard_file(args)
   5469         error_msgs += _error_msgs
   5471 # Save offloaded index if needed

File transformers/modeling_utils.py:843, in load_shard_file(args)
    841 # Skip it with fsdp on ranks other than 0
    842 elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
--> [843](transformers/modeling_utils.py:843)     disk_offload_index = _load_state_dict_into_meta_model(
    844         model,
    845         state_dict,
    846         shard_file,
    847         reverse_key_renaming_mapping,
    848         device_map=device_map,
    849         disk_offload_folder=disk_offload_folder,
    850         disk_offload_index=disk_offload_index,
    851         hf_quantizer=hf_quantizer,
    852         keep_in_fp32_regex=keep_in_fp32_regex,
    853         device_mesh=device_mesh,
    854     )
    856 return error_msgs, disk_offload_index

File torch/utils/_contextlib.py:120, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    117 @functools.wraps(func)
    118 def decorate_context(*args, **kwargs):
    119     with ctx_factory():
--> [120](torch/utils/_contextlib.py:120)         return func(*args, **kwargs)

File ransformers/modeling_utils.py:748, in _load_state_dict_into_meta_model(model, state_dict, shard_file, reverse_renaming_mapping, device_map, disk_offload_folder, disk_offload_index, hf_quantizer, keep_in_fp32_regex, device_mesh)
    740         hf_quantizer.create_quantized_param(
    741             model,
    742             param,
   (...)
    745             **sharding_kwargs,
    746         )
    747 else:
--> [748](transformers/modeling_utils.py:748)     param = param[...]
    749     if casting_dtype is not None:
    750         param = param.to(casting_dtype)

File torch/cuda/__init__.py:403, in _lazy_init()
    398     raise RuntimeError(
    399         "Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
    400         "multiprocessing, you must use the 'spawn' start method"
    401     )
    402 if not hasattr(torch._C, "_cuda_getDeviceCount"):
--> [403](torch/cuda/__init__.py:403)     raise AssertionError("Torch not compiled with CUDA enabled")
    404 if _cudart is None:
    405     raise AssertionError(
    406         "libcudart functions unavailable. It looks like you have a broken build?"
    407     )

AssertionError: Torch not compiled with CUDA enabled

It appears _load_state_dict_into_meta_model does

tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]

which effectively soft assumes any torch device with an index is CUDA and moves on.

Expected behavior

This should complete without issue

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions