Skip to content

Commit 0005760

Browse files
committed
Blackwell compatibility
1 parent 021154d commit 0005760

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

modules/launch_utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,42 @@ def requirements_met(requirements_file):
313313
return True
314314

315315

316+
def get_cuda_comp_cap():
317+
"""
318+
Returns float of CUDA Compute Capability using nvidia-smi
319+
Returns 0.0 on error
320+
CUDA Compute Capability
321+
ref https://developer.nvidia.com/cuda-gpus
322+
ref https://en.wikipedia.org/wiki/CUDA
323+
Blackwell consumer GPUs should return 12.0 data-center GPUs should return 10.0
324+
"""
325+
try:
326+
return float(subprocess.check_output(['nvidia-smi', '--query-gpu=compute_cap', '--format=noheader,csv'], text=True))
327+
except Exception as _:
328+
return 0.0
329+
330+
331+
def early_access_blackwell_wheels():
332+
"""For Blackwell GPUs, use Early Access PyTorch Wheels provided by Nvidia"""
333+
if all([
334+
os.environ.get('TORCH_INDEX_URL') is None,
335+
sys.version_info.major == 3,
336+
sys.version_info.minor in (10, 11, 12),
337+
platform.system() == "Windows",
338+
get_cuda_comp_cap() >= 10, # Blackwell
339+
]):
340+
base_repo = 'https://huggingface.co/w-e-w/torch-2.6.0-cu128.nv/resolve/main/'
341+
ea_whl = {
342+
10: f'{base_repo}torch-2.6.0+cu128.nv-cp310-cp310-win_amd64.whl#sha256=fef3de7ce8f4642e405576008f384304ad0e44f7b06cc1aa45e0ab4b6e70490d {base_repo}torchvision-0.20.0a0+cu128.nv-cp310-cp310-win_amd64.whl#sha256=50841254f59f1db750e7348b90a8f4cd6befec217ab53cbb03780490b225abef',
343+
11: f'{base_repo}torch-2.6.0+cu128.nv-cp311-cp311-win_amd64.whl#sha256=6665c36e6a7e79e7a2cb42bec190d376be9ca2859732ed29dd5b7b5a612d0d26 {base_repo}torchvision-0.20.0a0+cu128.nv-cp311-cp311-win_amd64.whl#sha256=bbc0ee4938e35fe5a30de3613bfcd2d8ef4eae334cf8d49db860668f0bb47083',
344+
12: f'{base_repo}torch-2.6.0+cu128.nv-cp312-cp312-win_amd64.whl#sha256=a3197f72379d34b08c4a4bcf49ea262544a484e8702b8c46cbcd66356c89def6 {base_repo}torchvision-0.20.0a0+cu128.nv-cp312-cp312-win_amd64.whl#sha256=235e7be71ac4e75b0f8e817bae4796d7bac8a67146d2037ab96394f2bdc63e6c'
345+
}
346+
return f'pip install {ea_whl.get(sys.version_info.minor)}'
347+
348+
316349
def prepare_environment():
317350
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
318-
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
351+
torch_command = os.environ.get('TORCH_COMMAND', early_access_blackwell_wheels() or f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
319352
if args.use_ipex:
320353
if platform.system() == "Windows":
321354
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main

0 commit comments

Comments
 (0)