|
16 | 16 | import functools |
17 | 17 | import inspect |
18 | 18 | import math |
| 19 | +from dataclasses import dataclass |
19 | 20 | from enum import Enum |
20 | 21 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
21 | 22 |
|
|
42 | 43 | is_xformers_available, |
43 | 44 | is_xformers_version, |
44 | 45 | ) |
45 | | -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS |
| 46 | +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS |
46 | 47 |
|
47 | 48 |
|
48 | 49 | if TYPE_CHECKING: |
|
82 | 83 | flash_attn_3_func = None |
83 | 84 | flash_attn_3_varlen_func = None |
84 | 85 |
|
85 | | - |
86 | 86 | if _CAN_USE_AITER_ATTN: |
87 | 87 | from aiter import flash_attn_func as aiter_flash_attn_func |
88 | 88 | else: |
89 | 89 | aiter_flash_attn_func = None |
90 | 90 |
|
91 | | -if DIFFUSERS_ENABLE_HUB_KERNELS: |
92 | | - if not is_kernels_available(): |
93 | | - raise ImportError( |
94 | | - "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." |
95 | | - ) |
96 | | - from ..utils.kernels_utils import _get_fa3_from_hub |
97 | | - |
98 | | - flash_attn_interface_hub = _get_fa3_from_hub() |
99 | | - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func |
100 | | -else: |
101 | | - flash_attn_3_func_hub = None |
102 | | - |
103 | 91 | if _CAN_USE_SAGE_ATTN: |
104 | 92 | from sageattention import ( |
105 | 93 | sageattn, |
@@ -261,6 +249,25 @@ def _is_context_parallel_available( |
261 | 249 | return supports_context_parallel |
262 | 250 |
|
263 | 251 |
|
| 252 | +@dataclass |
| 253 | +class _HubKernelConfig: |
| 254 | + """Configuration for downloading and using a hub-based attention kernel.""" |
| 255 | + |
| 256 | + repo_id: str |
| 257 | + function_attr: str |
| 258 | + revision: Optional[str] = None |
| 259 | + kernel_fn: Optional[Callable] = None |
| 260 | + |
| 261 | + |
| 262 | +# Registry for hub-based attention kernels |
| 263 | +_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = { |
| 264 | + # TODO: temporary revision for now. Remove when merged upstream into `main`. |
| 265 | + AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( |
| 266 | + repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" |
| 267 | + ) |
| 268 | +} |
| 269 | + |
| 270 | + |
264 | 271 | @contextlib.contextmanager |
265 | 272 | def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): |
266 | 273 | """ |
@@ -415,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None |
415 | 422 |
|
416 | 423 | # TODO: add support Hub variant of FA3 varlen later |
417 | 424 | elif backend in [AttentionBackendName._FLASH_3_HUB]: |
418 | | - if not DIFFUSERS_ENABLE_HUB_KERNELS: |
419 | | - raise RuntimeError( |
420 | | - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." |
421 | | - ) |
422 | 425 | if not is_kernels_available(): |
423 | 426 | raise RuntimeError( |
424 | | - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." |
| 427 | + f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." |
425 | 428 | ) |
426 | 429 |
|
427 | 430 | elif backend == AttentionBackendName.AITER: |
@@ -571,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
571 | 574 | return q_idx >= kv_idx |
572 | 575 |
|
573 | 576 |
|
| 577 | +# ===== Helpers for downloading kernels ===== |
| 578 | +def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: |
| 579 | + if backend not in _HUB_KERNELS_REGISTRY: |
| 580 | + return |
| 581 | + config = _HUB_KERNELS_REGISTRY[backend] |
| 582 | + |
| 583 | + if config.kernel_fn is not None: |
| 584 | + return |
| 585 | + |
| 586 | + try: |
| 587 | + from kernels import get_kernel |
| 588 | + |
| 589 | + kernel_module = get_kernel(config.repo_id, revision=config.revision) |
| 590 | + kernel_func = getattr(kernel_module, config.function_attr) |
| 591 | + |
| 592 | + # Cache the downloaded kernel function in the config object |
| 593 | + config.kernel_fn = kernel_func |
| 594 | + |
| 595 | + except Exception as e: |
| 596 | + logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") |
| 597 | + raise |
| 598 | + |
| 599 | + |
574 | 600 | # ===== torch op registrations ===== |
575 | 601 | # Registrations are required for fullgraph tracing compatibility |
576 | 602 | # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding |
@@ -1418,7 +1444,8 @@ def _flash_attention_3_hub( |
1418 | 1444 | return_attn_probs: bool = False, |
1419 | 1445 | _parallel_config: Optional["ParallelConfig"] = None, |
1420 | 1446 | ) -> torch.Tensor: |
1421 | | - out = flash_attn_3_func_hub( |
| 1447 | + func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn |
| 1448 | + out = func( |
1422 | 1449 | q=query, |
1423 | 1450 | k=key, |
1424 | 1451 | v=value, |
|
0 commit comments