Skip to content

Commit 0cc8146

Browse files
committed
lint, add init_cuda_malloc()
1 parent e78be27 commit 0cc8146

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

cuda_malloc.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def enum_display_devices():
3939
else:
4040
gpu_names = set()
4141
out = subprocess.check_output(['nvidia-smi', '-L'])
42-
for l in out.split(b'\n'):
43-
if len(l) > 0:
44-
gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
42+
for line in out.split(b'\n'):
43+
if len(line) > 0:
44+
gpu_names.add(line.decode('utf-8').split(' (UUID')[0])
4545
return gpu_names
4646

4747
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
@@ -55,7 +55,7 @@ def enum_display_devices():
5555
def cuda_malloc_supported():
5656
try:
5757
names = get_gpu_names()
58-
except:
58+
except Exception:
5959
names = set()
6060
for x in names:
6161
if "NVIDIA" in x:
@@ -82,16 +82,16 @@ def cuda_malloc_supported():
8282
version = module.__version__
8383
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
8484
args.cuda_malloc = cuda_malloc_supported()
85-
except:
85+
except Exception:
8686
pass
8787

88+
def init_cuda_malloc():
89+
if args.cuda_malloc and not args.disable_cuda_malloc:
90+
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
91+
if env_var is None:
92+
env_var = "backend:cudaMallocAsync"
93+
else:
94+
env_var += ",backend:cudaMallocAsync"
8895

89-
if args.cuda_malloc and not args.disable_cuda_malloc:
90-
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
91-
if env_var is None:
92-
env_var = "backend:cudaMallocAsync"
93-
else:
94-
env_var += ",backend:cudaMallocAsync"
95-
96-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
97-
print(f"Setup environment PYTORCH_CUDA_ALLOC_CONF={env_var}")
96+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
97+
print(f"Setup environment PYTORCH_CUDA_ALLOC_CONF={env_var}")

webui.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
startup_timer = timer.startup_timer
1111
startup_timer.record("launcher")
1212

13-
import cuda_malloc
13+
from cuda_malloc import init_cuda_malloc
14+
init_cuda_malloc()
1415
startup_timer.record("cuda_malloc")
1516

1617
initialize.imports()

0 commit comments

Comments
 (0)