Skip to content

Commit ab71f3c

Browse files
sayakpaulDhruv NairDN6
authored
[core] Refactor hub attn kernels (#12475)
* refactor how attention kernels from hub are used. * up * refactor according to Dhruv's ideas. Co-authored-by: Dhruv Nair <[email protected]> * empty Co-authored-by: Dhruv Nair <[email protected]> * empty Co-authored-by: Dhruv Nair <[email protected]> * empty Co-authored-by: dn6 <[email protected]> * up --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: Dhruv Nair <[email protected]>
1 parent b7df4a5 commit ab71f3c

File tree

5 files changed

+54
-46
lines changed

5 files changed

+54
-46
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import functools
1717
import inspect
1818
import math
19+
from dataclasses import dataclass
1920
from enum import Enum
2021
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
2122

@@ -42,7 +43,7 @@
4243
is_xformers_available,
4344
is_xformers_version,
4445
)
45-
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
46+
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
4647

4748

4849
if TYPE_CHECKING:
@@ -82,24 +83,11 @@
8283
flash_attn_3_func = None
8384
flash_attn_3_varlen_func = None
8485

85-
8686
if _CAN_USE_AITER_ATTN:
8787
from aiter import flash_attn_func as aiter_flash_attn_func
8888
else:
8989
aiter_flash_attn_func = None
9090

91-
if DIFFUSERS_ENABLE_HUB_KERNELS:
92-
if not is_kernels_available():
93-
raise ImportError(
94-
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
95-
)
96-
from ..utils.kernels_utils import _get_fa3_from_hub
97-
98-
flash_attn_interface_hub = _get_fa3_from_hub()
99-
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
100-
else:
101-
flash_attn_3_func_hub = None
102-
10391
if _CAN_USE_SAGE_ATTN:
10492
from sageattention import (
10593
sageattn,
@@ -261,6 +249,25 @@ def _is_context_parallel_available(
261249
return supports_context_parallel
262250

263251

252+
@dataclass
253+
class _HubKernelConfig:
254+
"""Configuration for downloading and using a hub-based attention kernel."""
255+
256+
repo_id: str
257+
function_attr: str
258+
revision: Optional[str] = None
259+
kernel_fn: Optional[Callable] = None
260+
261+
262+
# Registry for hub-based attention kernels
263+
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
264+
# TODO: temporary revision for now. Remove when merged upstream into `main`.
265+
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
266+
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
267+
)
268+
}
269+
270+
264271
@contextlib.contextmanager
265272
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
266273
"""
@@ -415,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
415422

416423
# TODO: add support Hub variant of FA3 varlen later
417424
elif backend in [AttentionBackendName._FLASH_3_HUB]:
418-
if not DIFFUSERS_ENABLE_HUB_KERNELS:
419-
raise RuntimeError(
420-
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
421-
)
422425
if not is_kernels_available():
423426
raise RuntimeError(
424-
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
427+
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
425428
)
426429

427430
elif backend == AttentionBackendName.AITER:
@@ -571,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
571574
return q_idx >= kv_idx
572575

573576

577+
# ===== Helpers for downloading kernels =====
578+
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
579+
if backend not in _HUB_KERNELS_REGISTRY:
580+
return
581+
config = _HUB_KERNELS_REGISTRY[backend]
582+
583+
if config.kernel_fn is not None:
584+
return
585+
586+
try:
587+
from kernels import get_kernel
588+
589+
kernel_module = get_kernel(config.repo_id, revision=config.revision)
590+
kernel_func = getattr(kernel_module, config.function_attr)
591+
592+
# Cache the downloaded kernel function in the config object
593+
config.kernel_fn = kernel_func
594+
595+
except Exception as e:
596+
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
597+
raise
598+
599+
574600
# ===== torch op registrations =====
575601
# Registrations are required for fullgraph tracing compatibility
576602
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
@@ -1418,7 +1444,8 @@ def _flash_attention_3_hub(
14181444
return_attn_probs: bool = False,
14191445
_parallel_config: Optional["ParallelConfig"] = None,
14201446
) -> torch.Tensor:
1421-
out = flash_attn_3_func_hub(
1447+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
1448+
out = func(
14221449
q=query,
14231450
k=key,
14241451
v=value,

src/diffusers/models/modeling_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None:
595595
attention as backend.
596596
"""
597597
from .attention import AttentionModuleMixin
598-
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
598+
from .attention_dispatch import (
599+
AttentionBackendName,
600+
_check_attention_backend_requirements,
601+
_maybe_download_kernel_for_backend,
602+
)
599603

600604
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
601605
from .attention_processor import Attention, MochiAttention
@@ -606,8 +610,10 @@ def set_attention_backend(self, backend: str) -> None:
606610
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
607611
if backend not in available_backends:
608612
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
613+
609614
backend = AttentionBackendName(backend)
610615
_check_attention_backend_requirements(backend)
616+
_maybe_download_kernel_for_backend(backend)
611617

612618
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
613619
for module in self.modules():

src/diffusers/utils/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4747
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
4848
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
49-
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
5049

5150
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
5251
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

src/diffusers/utils/kernels_utils.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

tests/others/test_attention_backends.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
88
```bash
99
export RUN_ATTENTION_BACKEND_TESTS=yes
10-
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
1110
1211
pytest tests/others/test_attention_backends.py
1312
```

0 commit comments

Comments
 (0)