Skip to content

Commit ef6a483

Browse files
committed
refactor how attention kernels from hub are used.
1 parent 8abc7ae commit ef6a483

File tree

2 files changed

+44
-17
lines changed

2 files changed

+44
-17
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import inspect
1818
import math
1919
from enum import Enum
20-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
20+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union
2121

2222
import torch
2323

@@ -78,17 +78,8 @@
7878
flash_attn_3_func = None
7979
flash_attn_3_varlen_func = None
8080

81-
if DIFFUSERS_ENABLE_HUB_KERNELS:
82-
if not is_kernels_available():
83-
raise ImportError(
84-
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
85-
)
86-
from ..utils.kernels_utils import _get_fa3_from_hub
87-
88-
flash_attn_interface_hub = _get_fa3_from_hub()
89-
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
90-
else:
91-
flash_attn_3_func_hub = None
81+
flash_attn_3_func_hub = None
82+
_PREPARED_BACKENDS: Set["AttentionBackendName"] = set()
9283

9384
if _CAN_USE_SAGE_ATTN:
9485
from sageattention import (
@@ -231,7 +222,9 @@ def decorator(func):
231222

232223
@classmethod
233224
def get_active_backend(cls):
234-
return cls._active_backend, cls._backends[cls._active_backend]
225+
backend = cls._active_backend
226+
_ensure_attention_backend_ready(backend)
227+
return backend, cls._backends[backend]
235228

236229
@classmethod
237230
def list_backends(cls):
@@ -258,7 +251,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
258251
raise ValueError(f"Backend {backend} is not registered.")
259252

260253
backend = AttentionBackendName(backend)
261-
_check_attention_backend_requirements(backend)
254+
_ensure_attention_backend_ready(backend)
262255

263256
old_backend = _AttentionBackendRegistry._active_backend
264257
_AttentionBackendRegistry._active_backend = backend
@@ -452,6 +445,39 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
452445
)
453446

454447

448+
def _ensure_flash_attn_3_func_hub_loaded():
449+
global flash_attn_3_func_hub
450+
451+
if flash_attn_3_func_hub is not None:
452+
return flash_attn_3_func_hub
453+
454+
from ..utils.kernels_utils import _get_fa3_from_hub
455+
456+
flash_attn_interface_hub = _get_fa3_from_hub()
457+
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
458+
459+
return flash_attn_3_func_hub
460+
461+
462+
_BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = {
463+
AttentionBackendName._FLASH_3_HUB: _ensure_flash_attn_3_func_hub_loaded,
464+
}
465+
466+
467+
def _prepare_attention_backend(backend: AttentionBackendName) -> None:
468+
preparer = _BACKEND_PREPARERS.get(backend)
469+
if preparer is not None:
470+
preparer()
471+
472+
473+
def _ensure_attention_backend_ready(backend: AttentionBackendName) -> None:
474+
if backend in _PREPARED_BACKENDS:
475+
return
476+
_check_attention_backend_requirements(backend)
477+
_prepare_attention_backend(backend)
478+
_PREPARED_BACKENDS.add(backend)
479+
480+
455481
@functools.lru_cache(maxsize=128)
456482
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
457483
batch_size: int,
@@ -1322,7 +1348,8 @@ def _flash_attention_3_hub(
13221348
return_attn_probs: bool = False,
13231349
_parallel_config: Optional["ParallelConfig"] = None,
13241350
) -> torch.Tensor:
1325-
out = flash_attn_3_func_hub(
1351+
func = flash_attn_3_func_hub or _ensure_flash_attn_3_func_hub_loaded()
1352+
out = func(
13261353
q=query,
13271354
k=key,
13281355
v=value,

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def set_attention_backend(self, backend: str) -> None:
594594
attention as backend.
595595
"""
596596
from .attention import AttentionModuleMixin
597-
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
597+
from .attention_dispatch import AttentionBackendName, _ensure_attention_backend_ready
598598

599599
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
600600
from .attention_processor import Attention, MochiAttention
@@ -606,7 +606,7 @@ def set_attention_backend(self, backend: str) -> None:
606606
if backend not in available_backends:
607607
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
608608
backend = AttentionBackendName(backend)
609-
_check_attention_backend_requirements(backend)
609+
_ensure_attention_backend_ready(backend)
610610

611611
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
612612
for module in self.modules():

0 commit comments

Comments
 (0)