diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 7150977e9266..edcd2bb1b3bc 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -363,7 +363,10 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: dtype=kv_cache_dtype, ).page_size_bytes else: - kernel_block_alignment_size = 16 + if envs.VLLM_ATTENTION_BACKEND == "FLASHINFER": + kernel_block_alignment_size = 32 + else: + kernel_block_alignment_size = 16 attn_page_size_1_token = FullAttentionSpec( block_size=1, num_kv_heads=model_config.get_num_kv_heads(parallel_config), @@ -389,6 +392,17 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if mamba_page_size == 0: return + # Attention backend constraints: + # - FlashAttention (FA) requires block size to be multiple of 16 + # - MLA (Multi-head Latent Attention) requires larger alignment: + # * CUTLASS_MLA backend: 128-byte alignment + # * Other MLA backends: 64-byte alignment + if model_config.use_mla: + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + kernel_block_alignment_size = 128 if use_cutlass_mla else 64 + else: + kernel_block_alignment_size = 16 + if cache_config.enable_prefix_caching: # With prefix caching, select attention block size to # optimize for mamba kernel performance diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index e71d4ca4629d..35f4db452e3a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -156,6 +156,12 @@ def trtllm_prefill_attn_kvfp8_dequant( return mock_kv_cache, mock_block_table +# Note(Chen): FlashInfer backend supports other block_sizes. But as +# the backend doesn't know the block_size selected, we hardcode it as only +# supports 32 for now. +FLASH_INFER_BLOCK_SIZE = 32 + + class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True @@ -170,10 +176,7 @@ def get_supported_head_sizes(cls) -> list[int]: @staticmethod def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # Note: Not sure for all platforms, - # but on Blackwell, only support a page size of - # 16, 32, 64 - return [16, 32, 64] + return [FLASH_INFER_BLOCK_SIZE] @classmethod def validate_head_size(cls, head_size: int) -> None: @@ -291,6 +294,7 @@ def __init__( self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) + block_size = FLASH_INFER_BLOCK_SIZE # Note(Chen): temporary hardcode for now. if vllm_is_batch_invariant(): self.decode_fixed_split_size = 2048 @@ -302,9 +306,7 @@ def __init__( self.disable_split_kv = False self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv( - self.model_config.max_model_len, self.kv_cache_spec.block_size - ) + max_num_pages_per_req = cdiv(self.model_config.max_model_len, block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req speculative_config = vllm_config.speculative_config @@ -333,7 +335,7 @@ def __init__( self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size FlashInferBackend.validate_head_size(self.head_dim) - self.page_size = self.kv_cache_spec.block_size + self.page_size = block_size self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"):