Skip to content

Commit 7fd26bc

Browse files
committed
up
1 parent ef6a483 commit 7fd26bc

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
flash_attn_3_func = None
7979
flash_attn_3_varlen_func = None
8080

81-
flash_attn_3_func_hub = None
81+
_BACKEND_HANDLES: Dict["AttentionBackendName", Callable] = {}
8282
_PREPARED_BACKENDS: Set["AttentionBackendName"] = set()
8383

8484
if _CAN_USE_SAGE_ATTN:
@@ -446,17 +446,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
446446

447447

448448
def _ensure_flash_attn_3_func_hub_loaded():
449-
global flash_attn_3_func_hub
450-
451-
if flash_attn_3_func_hub is not None:
452-
return flash_attn_3_func_hub
449+
cached = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB)
450+
if cached is not None:
451+
return cached
453452

454453
from ..utils.kernels_utils import _get_fa3_from_hub
455454

456455
flash_attn_interface_hub = _get_fa3_from_hub()
457-
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
456+
func = flash_attn_interface_hub.flash_attn_func
457+
_BACKEND_HANDLES[AttentionBackendName._FLASH_3_HUB] = func
458458

459-
return flash_attn_3_func_hub
459+
return func
460460

461461

462462
_BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = {
@@ -1348,7 +1348,9 @@ def _flash_attention_3_hub(
13481348
return_attn_probs: bool = False,
13491349
_parallel_config: Optional["ParallelConfig"] = None,
13501350
) -> torch.Tensor:
1351-
func = flash_attn_3_func_hub or _ensure_flash_attn_3_func_hub_loaded()
1351+
func = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB)
1352+
if func is None:
1353+
func = _ensure_flash_attn_3_func_hub_loaded()
13521354
out = func(
13531355
q=query,
13541356
k=key,

0 commit comments

Comments
 (0)