@@ -313,9 +313,42 @@ def requirements_met(requirements_file):
313313 return True
314314
315315
316+ def get_cuda_comp_cap ():
317+ """
318+ Returns float of CUDA Compute Capability using nvidia-smi
319+ Returns 0.0 on error
320+ CUDA Compute Capability
321+ ref https://developer.nvidia.com/cuda-gpus
322+ ref https://en.wikipedia.org/wiki/CUDA
323+ Blackwell consumer GPUs should return 12.0 data-center GPUs should return 10.0
324+ """
325+ try :
326+ return float (subprocess .check_output (['nvidia-smi' , '--query-gpu=compute_cap' , '--format=noheader,csv' ], text = True ))
327+ except Exception as _ :
328+ return 0.0
329+
330+
331+ def early_access_blackwell_wheels ():
332+ """For Blackwell GPUs, use Early Access PyTorch Wheels provided by Nvidia"""
333+ if all ([
334+ os .environ .get ('TORCH_INDEX_URL' ) is None ,
335+ sys .version_info .major == 3 ,
336+ sys .version_info .minor in (10 , 11 , 12 ),
337+ platform .system () == "Windows" ,
338+ get_cuda_comp_cap () >= 10 , # Blackwell
339+ ]):
340+ base_repo = 'https://huggingface.co/w-e-w/torch-2.6.0-cu128.nv/resolve/main/'
341+ ea_whl = {
342+ 10 : f'{ base_repo } torch-2.6.0+cu128.nv-cp310-cp310-win_amd64.whl#sha256=fef3de7ce8f4642e405576008f384304ad0e44f7b06cc1aa45e0ab4b6e70490d { base_repo } torchvision-0.20.0a0+cu128.nv-cp310-cp310-win_amd64.whl#sha256=50841254f59f1db750e7348b90a8f4cd6befec217ab53cbb03780490b225abef' ,
343+ 11 : f'{ base_repo } torch-2.6.0+cu128.nv-cp311-cp311-win_amd64.whl#sha256=6665c36e6a7e79e7a2cb42bec190d376be9ca2859732ed29dd5b7b5a612d0d26 { base_repo } torchvision-0.20.0a0+cu128.nv-cp311-cp311-win_amd64.whl#sha256=bbc0ee4938e35fe5a30de3613bfcd2d8ef4eae334cf8d49db860668f0bb47083' ,
344+ 12 : f'{ base_repo } torch-2.6.0+cu128.nv-cp312-cp312-win_amd64.whl#sha256=a3197f72379d34b08c4a4bcf49ea262544a484e8702b8c46cbcd66356c89def6 { base_repo } torchvision-0.20.0a0+cu128.nv-cp312-cp312-win_amd64.whl#sha256=235e7be71ac4e75b0f8e817bae4796d7bac8a67146d2037ab96394f2bdc63e6c'
345+ }
346+ return f'pip install { ea_whl .get (sys .version_info .minor )} '
347+
348+
316349def prepare_environment ():
317350 torch_index_url = os .environ .get ('TORCH_INDEX_URL' , "https://download.pytorch.org/whl/cu121" )
318- torch_command = os .environ .get ('TORCH_COMMAND' , f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url { torch_index_url } " )
351+ torch_command = os .environ .get ('TORCH_COMMAND' , early_access_blackwell_wheels () or f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url { torch_index_url } " )
319352 if args .use_ipex :
320353 if platform .system () == "Windows" :
321354 # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
0 commit comments