@@ -39,9 +39,9 @@ def enum_display_devices():
3939 else :
4040 gpu_names = set ()
4141 out = subprocess .check_output (['nvidia-smi' , '-L' ])
42- for l in out .split (b'\n ' ):
43- if len (l ) > 0 :
44- gpu_names .add (l .decode ('utf-8' ).split (' (UUID' )[0 ])
42+ for line in out .split (b'\n ' ):
43+ if len (line ) > 0 :
44+ gpu_names .add (line .decode ('utf-8' ).split (' (UUID' )[0 ])
4545 return gpu_names
4646
4747blacklist = {"GeForce GTX TITAN X" , "GeForce GTX 980" , "GeForce GTX 970" , "GeForce GTX 960" , "GeForce GTX 950" , "GeForce 945M" ,
@@ -55,7 +55,7 @@ def enum_display_devices():
5555def cuda_malloc_supported ():
5656 try :
5757 names = get_gpu_names ()
58- except :
58+ except Exception :
5959 names = set ()
6060 for x in names :
6161 if "NVIDIA" in x :
@@ -82,16 +82,16 @@ def cuda_malloc_supported():
8282 version = module .__version__
8383 if int (version [0 ]) >= 2 : #enable by default for torch version 2.0 and up
8484 args .cuda_malloc = cuda_malloc_supported ()
85- except :
85+ except Exception :
8686 pass
8787
88+ def init_cuda_malloc ():
89+ if args .cuda_malloc and not args .disable_cuda_malloc :
90+ env_var = os .environ .get ('PYTORCH_CUDA_ALLOC_CONF' , None )
91+ if env_var is None :
92+ env_var = "backend:cudaMallocAsync"
93+ else :
94+ env_var += ",backend:cudaMallocAsync"
8895
89- if args .cuda_malloc and not args .disable_cuda_malloc :
90- env_var = os .environ .get ('PYTORCH_CUDA_ALLOC_CONF' , None )
91- if env_var is None :
92- env_var = "backend:cudaMallocAsync"
93- else :
94- env_var += ",backend:cudaMallocAsync"
95-
96- os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = env_var
97- print (f"Setup environment PYTORCH_CUDA_ALLOC_CONF={ env_var } " )
96+ os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = env_var
97+ print (f"Setup environment PYTORCH_CUDA_ALLOC_CONF={ env_var } " )
0 commit comments