1717import inspect
1818import math
1919from enum import Enum
20- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Literal , Optional , Tuple , Union
20+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Literal , Optional , Set , Tuple , Union
2121
2222import torch
2323
7878 flash_attn_3_func = None
7979 flash_attn_3_varlen_func = None
8080
81- if DIFFUSERS_ENABLE_HUB_KERNELS :
82- if not is_kernels_available ():
83- raise ImportError (
84- "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
85- )
86- from ..utils .kernels_utils import _get_fa3_from_hub
87-
88- flash_attn_interface_hub = _get_fa3_from_hub ()
89- flash_attn_3_func_hub = flash_attn_interface_hub .flash_attn_func
90- else :
91- flash_attn_3_func_hub = None
81+ flash_attn_3_func_hub = None
82+ _PREPARED_BACKENDS : Set ["AttentionBackendName" ] = set ()
9283
9384if _CAN_USE_SAGE_ATTN :
9485 from sageattention import (
@@ -231,7 +222,9 @@ def decorator(func):
231222
232223 @classmethod
233224 def get_active_backend (cls ):
234- return cls ._active_backend , cls ._backends [cls ._active_backend ]
225+ backend = cls ._active_backend
226+ _ensure_attention_backend_ready (backend )
227+ return backend , cls ._backends [backend ]
235228
236229 @classmethod
237230 def list_backends (cls ):
@@ -258,7 +251,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
258251 raise ValueError (f"Backend { backend } is not registered." )
259252
260253 backend = AttentionBackendName (backend )
261- _check_attention_backend_requirements (backend )
254+ _ensure_attention_backend_ready (backend )
262255
263256 old_backend = _AttentionBackendRegistry ._active_backend
264257 _AttentionBackendRegistry ._active_backend = backend
@@ -452,6 +445,39 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
452445 )
453446
454447
448+ def _ensure_flash_attn_3_func_hub_loaded ():
449+ global flash_attn_3_func_hub
450+
451+ if flash_attn_3_func_hub is not None :
452+ return flash_attn_3_func_hub
453+
454+ from ..utils .kernels_utils import _get_fa3_from_hub
455+
456+ flash_attn_interface_hub = _get_fa3_from_hub ()
457+ flash_attn_3_func_hub = flash_attn_interface_hub .flash_attn_func
458+
459+ return flash_attn_3_func_hub
460+
461+
462+ _BACKEND_PREPARERS : Dict [AttentionBackendName , Callable [[], None ]] = {
463+ AttentionBackendName ._FLASH_3_HUB : _ensure_flash_attn_3_func_hub_loaded ,
464+ }
465+
466+
467+ def _prepare_attention_backend (backend : AttentionBackendName ) -> None :
468+ preparer = _BACKEND_PREPARERS .get (backend )
469+ if preparer is not None :
470+ preparer ()
471+
472+
473+ def _ensure_attention_backend_ready (backend : AttentionBackendName ) -> None :
474+ if backend in _PREPARED_BACKENDS :
475+ return
476+ _check_attention_backend_requirements (backend )
477+ _prepare_attention_backend (backend )
478+ _PREPARED_BACKENDS .add (backend )
479+
480+
455481@functools .lru_cache (maxsize = 128 )
456482def _prepare_for_flash_attn_or_sage_varlen_without_mask (
457483 batch_size : int ,
@@ -1322,7 +1348,8 @@ def _flash_attention_3_hub(
13221348 return_attn_probs : bool = False ,
13231349 _parallel_config : Optional ["ParallelConfig" ] = None ,
13241350) -> torch .Tensor :
1325- out = flash_attn_3_func_hub (
1351+ func = flash_attn_3_func_hub or _ensure_flash_attn_3_func_hub_loaded ()
1352+ out = func (
13261353 q = query ,
13271354 k = key ,
13281355 v = value ,
0 commit comments