Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
elif envs.VLLM_ATTENTION_BACKEND == "FLASHINFER":
kernel_block_alignment_size = 32
else:
kernel_block_alignment_size = 16

Expand Down
15 changes: 7 additions & 8 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ def get_supported_head_sizes(cls) -> list[int]:

@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
return [16, 32, 64]
# Note(Chen): FlashInfer backend supports other block_sizes. But as
# the backend doesn't know the block_size selected, we hardcode it as only
# supports 32 for now.
return [32]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Restricting FlashInfer to 32-token blocks breaks default configs

The backend now reports only 32 as a supported block size, but CUDA platforms still initialize cache_config.block_size to 16 by default. When a user runs any non-hybrid model with VLLM_ATTENTION_BACKEND=FLASHINFER, _find_compatible_block_sizes in the GPU model runner queries the backend and fails because 16 is not divisible by 32, raising `ValueError("No compatible block size for 16") before the model starts. This regression removes support for the common 16-token block size that previously worked. Either the backend needs to continue advertising 16 (and 64) or the default cache block size must be bumped to 32 when FlashInfer is selected.

Useful? React with 👍 / 👎.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any evidence that something wrong with 64?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem to 64. But we need to only allow one block_size here. Happy to change it to 64 if it is better.


@classmethod
def validate_head_size(cls, head_size: int) -> None:
Expand Down Expand Up @@ -291,6 +291,7 @@ def __init__(
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)
block_size = 32 # Note(Chen): Hardcode the block_size as 16 temporarily.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment here states that the block size is hardcoded to 16, but the value is set to 32. This seems to be a typo in the comment. To avoid confusion, the comment should be updated to reflect the actual value.

Suggested change
block_size = 32 # Note(Chen): Hardcode the block_size as 16 temporarily.
block_size = 32 # Note(Chen): Hardcode the block_size as 32 temporarily.


if vllm_is_batch_invariant():
self.decode_fixed_split_size = 2048
Expand All @@ -302,9 +303,7 @@ def __init__(
self.disable_split_kv = False

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(
self.model_config.max_model_len, self.kv_cache_spec.block_size
)
max_num_pages_per_req = cdiv(self.model_config.max_model_len, block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
speculative_config = vllm_config.speculative_config
Expand Down Expand Up @@ -333,7 +332,7 @@ def __init__(
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size
FlashInferBackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size
self.page_size = block_size

self.cache_dtype = self.cache_config.cache_dtype
if self.cache_dtype.startswith("fp8"):
Expand Down