Skip to content

Commit ace899e

Browse files
Prioritize fp16 compute when using allow_fp16_accumulation
1 parent aff1653 commit ace899e

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

comfy/model_management.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,12 @@ def is_amd():
256256
torch.backends.cuda.enable_flash_sdp(True)
257257
torch.backends.cuda.enable_mem_efficient_sdp(True)
258258

259+
260+
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
259261
try:
260262
if is_nvidia() and args.fast:
261263
torch.backends.cuda.matmul.allow_fp16_accumulation = True
264+
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
262265
except:
263266
pass
264267

@@ -681,6 +684,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
681684
if model_params * 2 > free_model_memory:
682685
return fp8_dtype
683686

687+
if PRIORITIZE_FP16:
688+
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
689+
return torch.float16
690+
684691
for dt in supported_dtypes:
685692
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
686693
if torch.float16 in supported_dtypes:

0 commit comments

Comments
 (0)