-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Open
Labels
enhancementNew feature or requestNew feature or request
Description
🌟 Feature Description
macos mps device support in benchmarks
Motivation
It would be great to add support for macOS's MPS device.
By default, PyTorch models use CUDA (if available) or CPU, which isn’t very user-friendly for macOS devices with MPS acceleration.
The current device configuration code is:
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu"
To enable MPS support, we could modify the code as follows:
USE_CUDA = torch.cuda.is_available() and GPU >= 0
USE_MPS = torch.backends.mps.is_available()
self.device = torch.device(f'cuda:{GPU}' if USE_CUDA else ('mps' if USE_MPS else 'cpu'))
Alternatively, should we let users configure the device directly via a YAML parameter for more flexibility?
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request