|
78 | 78 | flash_attn_3_func = None |
79 | 79 | flash_attn_3_varlen_func = None |
80 | 80 |
|
81 | | -flash_attn_3_func_hub = None |
| 81 | +_BACKEND_HANDLES: Dict["AttentionBackendName", Callable] = {} |
82 | 82 | _PREPARED_BACKENDS: Set["AttentionBackendName"] = set() |
83 | 83 |
|
84 | 84 | if _CAN_USE_SAGE_ATTN: |
@@ -446,17 +446,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None |
446 | 446 |
|
447 | 447 |
|
448 | 448 | def _ensure_flash_attn_3_func_hub_loaded(): |
449 | | - global flash_attn_3_func_hub |
450 | | - |
451 | | - if flash_attn_3_func_hub is not None: |
452 | | - return flash_attn_3_func_hub |
| 449 | + cached = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB) |
| 450 | + if cached is not None: |
| 451 | + return cached |
453 | 452 |
|
454 | 453 | from ..utils.kernels_utils import _get_fa3_from_hub |
455 | 454 |
|
456 | 455 | flash_attn_interface_hub = _get_fa3_from_hub() |
457 | | - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func |
| 456 | + func = flash_attn_interface_hub.flash_attn_func |
| 457 | + _BACKEND_HANDLES[AttentionBackendName._FLASH_3_HUB] = func |
458 | 458 |
|
459 | | - return flash_attn_3_func_hub |
| 459 | + return func |
460 | 460 |
|
461 | 461 |
|
462 | 462 | _BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = { |
@@ -1348,7 +1348,9 @@ def _flash_attention_3_hub( |
1348 | 1348 | return_attn_probs: bool = False, |
1349 | 1349 | _parallel_config: Optional["ParallelConfig"] = None, |
1350 | 1350 | ) -> torch.Tensor: |
1351 | | - func = flash_attn_3_func_hub or _ensure_flash_attn_3_func_hub_loaded() |
| 1351 | + func = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB) |
| 1352 | + if func is None: |
| 1353 | + func = _ensure_flash_attn_3_func_hub_loaded() |
1352 | 1354 | out = func( |
1353 | 1355 | q=query, |
1354 | 1356 | k=key, |
|
0 commit comments