Skip to content
Open
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
eb34c8d
add support methods to abstract
MatthewBonanni Oct 8, 2025
87edf38
remove is_attn_backend_supported
MatthewBonanni Oct 8, 2025
fc493ae
all backends are V1 now
MatthewBonanni Oct 8, 2025
9618979
use backend_to_class_str
MatthewBonanni Oct 8, 2025
8aeb461
add MLA backend support details
MatthewBonanni Oct 8, 2025
eb8426f
use backend_to_class_str
MatthewBonanni Oct 8, 2025
aba576c
add support details for standard attention backends
MatthewBonanni Oct 8, 2025
ff18a9a
update cuda logic
MatthewBonanni Oct 8, 2025
eaed800
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 8, 2025
9687c99
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 8, 2025
df49484
fix pre-commit
MatthewBonanni Oct 8, 2025
ff5ad7c
fix argument mismatch
MatthewBonanni Oct 8, 2025
712ae59
fix pre-commit
MatthewBonanni Oct 8, 2025
97e1a2c
use block size literals
MatthewBonanni Oct 8, 2025
8f86714
replace backend_name_to_enum with direct calls
MatthewBonanni Oct 9, 2025
50596d8
use DeviceCapability objects
MatthewBonanni Oct 9, 2025
03f6963
update max
MatthewBonanni Oct 9, 2025
3bee84e
Fix block size adjustment
MatthewBonanni Oct 9, 2025
a716f3a
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 10, 2025
15234bb
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 13, 2025
2433669
split priorities by capability, update flashinfer min capability
MatthewBonanni Oct 14, 2025
a3617d7
change to typing imports
MatthewBonanni Oct 15, 2025
81d1b7b
backends specify their required kv cache layout
MatthewBonanni Oct 15, 2025
adaf53b
flashinfer supports up to 12.1
MatthewBonanni Oct 15, 2025
d1f1362
is_mla is false in base class
MatthewBonanni Oct 15, 2025
abb8375
triton supports fp8
MatthewBonanni Oct 15, 2025
85d8719
use CacheDType
MatthewBonanni Oct 15, 2025
1ef0417
add todo
MatthewBonanni Oct 15, 2025
a2c902f
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 15, 2025
16f9373
is_quantized_kv_cache use CacheDType
MatthewBonanni Oct 15, 2025
8474a14
fix supports_sink
MatthewBonanni Oct 15, 2025
62e6290
fix priority list
MatthewBonanni Oct 15, 2025
22dd1b8
fix FA block sizes
MatthewBonanni Oct 16, 2025
4bf076d
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 16, 2025
121d442
fix import failure
MatthewBonanni Oct 16, 2025
963cc9f
fix import error
MatthewBonanni Oct 16, 2025
778cd98
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 20, 2025
de3f302
fix import error
MatthewBonanni Oct 20, 2025
bc10bee
fix import
MatthewBonanni Oct 20, 2025
05aab3e
fix type error
MatthewBonanni Oct 21, 2025
7936c47
add flashmla support test
MatthewBonanni Oct 21, 2025
4f0f955
clean up head size validation
MatthewBonanni Oct 21, 2025
feded36
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 21, 2025
d8b8043
use KVCacheLayoutType
MatthewBonanni Oct 21, 2025
a3ccbba
move selector layout change to same place as block size change
MatthewBonanni Oct 21, 2025
3285c2c
MLA only supports head size 576
MatthewBonanni Oct 21, 2025
6eab504
fix kv_cache_dtype support logic
MatthewBonanni Oct 21, 2025
5523dac
fix test
MatthewBonanni Oct 21, 2025
58fc888
skip FA MLA if test is run on hardware where it's not supported
MatthewBonanni Oct 21, 2025
17fd954
fix test
MatthewBonanni Oct 21, 2025
2b23712
fix pre-commit
MatthewBonanni Oct 21, 2025
68a63b7
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 22, 2025
fc1d3f3
fix head size
MatthewBonanni Oct 22, 2025
ecdef49
fix pre-commit
MatthewBonanni Oct 22, 2025
9008e56
flashinfer_mla only support blackwell (only uses TRTLLM kernels)
MatthewBonanni Oct 22, 2025
b756ceb
compute capability checks
MatthewBonanni Oct 22, 2025
afccece
remove reference to backend_name_to_enum
MatthewBonanni Oct 22, 2025
33cb1ef
fix default block size
MatthewBonanni Oct 22, 2025
3f5439e
improve logs
MatthewBonanni Oct 23, 2025
75fce85
fix block size support
MatthewBonanni Oct 23, 2025
ba51339
fix getting priority list
MatthewBonanni Oct 23, 2025
d49fbf9
remove redundant block size methods
MatthewBonanni Oct 24, 2025
dd31329
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 24, 2025
b18a193
fix import
MatthewBonanni Oct 24, 2025
0e0cb6d
raise error instead of implicitly changing backend
MatthewBonanni Oct 24, 2025
1a7b366
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 27, 2025
f147663
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 27, 2025
1eefe90
don't ignore block size
MatthewBonanni Oct 27, 2025
97bee04
move block_size update back to check_and_update_config
MatthewBonanni Oct 27, 2025
0812fac
fix import
MatthewBonanni Oct 27, 2025
ec39247
address missing case
MatthewBonanni Oct 27, 2025
e6497dd
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 28, 2025
860bfdb
fix flashmla_sparse support
MatthewBonanni Oct 28, 2025
df1cd64
fix hybrid models
MatthewBonanni Oct 28, 2025
758b3a5
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 29, 2025
01b43ff
return only mla or non-mla priorities
MatthewBonanni Oct 29, 2025
ee894ea
cleanup
MatthewBonanni Oct 29, 2025
842e89b
skip test on hopper
MatthewBonanni Oct 29, 2025
bd190e7
temp: apply fixes for test
MatthewBonanni Oct 29, 2025
5bf94f6
Revert "skip test on hopper"
MatthewBonanni Oct 29, 2025
7e34939
revert to old check_and_update_config block_size logic
MatthewBonanni Oct 29, 2025
3b1e92f
Revert "temp: apply fixes for test"
MatthewBonanni Oct 29, 2025
54dffe2
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 30, 2025
db6cc0f
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 30, 2025
d34eb77
add test_attention_selector to Blackwell Tests
MatthewBonanni Oct 31, 2025
48290ee
rename _Backend to AttentionBackendEnum, add class methods
MatthewBonanni Oct 31, 2025
1c71eab
get rid of get_min_compute_capability and get_max_compute_capability
MatthewBonanni Oct 31, 2025
6e9d1f1
fix pre-commit
MatthewBonanni Oct 31, 2025
d3cdda7
change methods to properties
MatthewBonanni Oct 31, 2025
925069c
device_capability not None
MatthewBonanni Oct 31, 2025
a0b56c5
query device_capability inside get_required_kv_cache_layout
MatthewBonanni Oct 31, 2025
fff453a
Update vllm/attention/backends/abstract.py
MatthewBonanni Oct 31, 2025
95aae78
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 31, 2025
530f356
class_path always None in decorator
MatthewBonanni Oct 31, 2025
933ee5f
type hint for value
MatthewBonanni Oct 31, 2025
255edc9
restore comment
MatthewBonanni Oct 31, 2025
6af36aa
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 31, 2025
c9d62f8
fix docs
MatthewBonanni Oct 31, 2025
93a0770
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 31, 2025
f6a5a32
add FLASHMLA_SPARSE to priority list
MatthewBonanni Oct 31, 2025
bc91050
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 31, 2025
0435eca
fix test
MatthewBonanni Nov 2, 2025
a098d82
fix flashmla_sparse
MatthewBonanni Nov 2, 2025
d8215e0
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Nov 2, 2025
4452f5f
fix pre-commit
MatthewBonanni Nov 2, 2025
5d9627c
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Nov 3, 2025
bc2ba2d
fix pre-commit
MatthewBonanni Nov 3, 2025
367d673
mock SM90 for initialization test in CI
MatthewBonanni Nov 3, 2025
cd24063
mock supports_compute_capability instead of get_device_capability
MatthewBonanni Nov 3, 2025
da3a9cd
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Nov 3, 2025
4d1102a
skip test
MatthewBonanni Nov 4, 2025
09058fc
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Nov 4, 2025
15a19d3
fix path
MatthewBonanni Nov 4, 2025
6515e4c
_get_backend_priorities return sorted list instead of dict
MatthewBonanni Nov 4, 2025
0b841ff
fix merge miss
MatthewBonanni Nov 4, 2025
69bb887
remove get_default_block_size
MatthewBonanni Nov 4, 2025
2f956f9
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Nov 4, 2025
556fc3d
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def test_attention_quant_pattern(
custom_ops_list = custom_ops.split(",") if custom_ops else []

device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.manual_seed(42)

vllm_config = VllmConfig(
Expand Down
75 changes: 45 additions & 30 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,13 @@ def test_env(

elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
capability = torch.cuda.get_device_capability()
if use_mla:
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# and Blackwell GPUs (SM 10.x), V1 only
# - FLASHINFER_MLA: only supported on Blackwell GPUs
# (SM 10.0+), V1 only
# (SM 10.x), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases
Expand All @@ -141,58 +142,72 @@ def test_env(
if block_size != 128:
# CUTLASS_MLA only supports block_size == 128
pytest.skip("CUTLASS_MLA only supports block_size 128")
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
if capability[0] != 10:
pytest.skip("CUTLASS MLA is not supported on this platform")
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER_MLA":
if capability[0] != 10:
pytest.skip(
"FlashInfer MLA is not supported on this platform"
)
if block_size not in [32, 64]:
# FlashInfer MLA only supports block_size 32 or 64
pytest.skip(
"FlashInfer MLA only supports block_size 32 or 64"
)
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
elif name == "FLASHMLA":
if block_size != 64:
# FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.v1.attention.backends.mla.flashmla import (
is_flashmla_dense_supported,
)
from vllm.v1.attention.backends.mla.flashmla import (
is_flashmla_dense_supported,
)

is_supported, _ = is_flashmla_dense_supported()
if not is_supported:
pytest.skip("FlashMLA not supported on this platform")
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
is_supported, _ = is_flashmla_dense_supported()
if not is_supported:
pytest.skip("FlashMLA not supported on this platform")
backend = get_attn_backend(
576,
torch.float16,
None,
block_size,
use_mla=use_mla,
)
expected = name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
from vllm.attention.utils.fa_utils import (
flash_attn_supports_mla,
)

if not flash_attn_supports_mla():
pytest.skip(
"FlashAttention MLA not supported on this platform"
)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "TRITON_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER":
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
64, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER"
assert backend.get_name() == expected
Expand Down
21 changes: 14 additions & 7 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# This test checks if GPUModelRunner initializes correctly when an attention
# backend enforces a non-default KV cache stride order.
n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config)
expected_kv_cache_shape = [
2,
NUM_BLOCKS,
BLOCK_SIZE,
n_heads,
model_runner.model_config.get_head_size(),
]
head_size = model_runner.model_config.get_head_size()

# Get the expected shape from the backend's get_kv_cache_shape method
# to ensure compatibility with different backends (triton vs flexattention)
attn_backend = None
for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
break

assert attn_backend is not None, "No attention backend found"
expected_kv_cache_shape = list(
attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size)
)

# TODO mla test
default_stride = tuple(range(5))
# Permutation that gets you back to expected kv shape
Expand Down
172 changes: 169 additions & 3 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from typing import Generic, Protocol, TypeVar
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast

import torch

from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey

if TYPE_CHECKING:
from vllm.config.cache import BlockSize, CacheDType
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType


class AttentionType:
"""
Expand Down Expand Up @@ -88,6 +93,167 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []

@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
supported_dtypes = cls.get_supported_dtypes()
return (not supported_dtypes) or dtype in supported_dtypes

@classmethod
def get_supported_kv_cache_dtypes(cls) -> list["CacheDType"]:
return ["auto"]

@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
if kv_cache_dtype is None:
return True
supported_kv_cache_dtypes = cls.get_supported_kv_cache_dtypes()
return (not supported_kv_cache_dtypes) or (
kv_cache_dtype in supported_kv_cache_dtypes
)

@classmethod
def get_supported_block_sizes(cls) -> list["BlockSize"]:
return []

@classmethod
def supports_block_size(cls, block_size: "BlockSize | None") -> bool:
from vllm.config.cache import BlockSize

if block_size is None:
return True
try:
block_size_literal = cast(BlockSize, block_size)
except ValueError:
return False
supported_block_sizes = cls.get_supported_block_sizes()
return (
not supported_block_sizes
) or block_size_literal in supported_block_sizes

@classmethod
def get_default_block_size(cls) -> "BlockSize":
supported_block_sizes = cls.get_supported_block_sizes()
if not supported_block_sizes:
raise ValueError(
f"Fallback failed, no explicitly supported block sizes for "
f"backend {cls.get_name()}"
)
return supported_block_sizes[0]

@classmethod
def is_mla(cls) -> bool:
return False

@classmethod
def supports_sink(cls) -> bool:
return False

@classmethod
def is_sparse(cls) -> bool:
return False

@classmethod
def get_min_compute_capability(cls) -> "DeviceCapability | None":
return None

@classmethod
def get_max_compute_capability(cls) -> "DeviceCapability | None":
return None

@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
min_capability = cls.get_min_compute_capability()
max_capability = cls.get_max_compute_capability()
return ((min_capability is None) or (capability >= min_capability)) and (
(max_capability is None) or (capability <= max_capability)
)

@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> str | None:
return None

@classmethod
def validate_configuration(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
invalid_reasons.append("head_size not supported")
if not cls.supports_dtype(dtype):
invalid_reasons.append("dtype not supported")
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
else:
invalid_reasons.append("non-MLA not supported")
if has_sink and not cls.supports_sink():
invalid_reasons.append("sink setting not supported")
if use_sparse != cls.is_sparse():
if use_sparse:
invalid_reasons.append("sparse not supported")
else:
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
device_capability,
)
if combination_reason is not None:
invalid_reasons.append(combination_reason)
return invalid_reasons

@classmethod
def get_required_kv_cache_layout(
cls, capability: "DeviceCapability"
) -> "KVCacheLayoutType | None":
"""
Some backends require a specific kv cache layout.
This function returns the required layout if any.
"""
return None


class AttentionMetadata:
pass
Expand Down Expand Up @@ -160,8 +326,8 @@ def __init__(
) -> None:
raise NotImplementedError

@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
# TODO: implement this function for all backends.
return [MultipleOf(1)]

Expand Down
19 changes: 5 additions & 14 deletions vllm/attention/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
"""Attention backend registry"""

import enum
from typing import TYPE_CHECKING

from vllm.utils.import_utils import resolve_obj_by_qualname

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend


class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
Expand Down Expand Up @@ -84,7 +88,7 @@ def backend_to_class_str(backend: _Backend) -> str:
return BACKEND_MAP[backend]


def backend_to_class(backend: _Backend) -> type:
def backend_to_class(backend: _Backend) -> "type[AttentionBackend]":
"""Get the backend class.

Args:
Expand All @@ -95,16 +99,3 @@ def backend_to_class(backend: _Backend) -> type:
"""
backend_class_name = backend_to_class_str(backend)
return resolve_obj_by_qualname(backend_class_name)


def backend_name_to_enum(backend_name: str) -> _Backend | None:
"""
Convert a string backend name to a _Backend enum value.

Returns:
_Backend: enum value if backend_name is a valid in-tree type
None: otherwise it's an invalid in-tree type or an out-of-tree platform
is loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else None
Loading