7575from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
7676# yapf conflicts with isort for this block
7777# yapf: disable
78- from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
79- KVCacheConfig , KVCacheGroupSpec ,
80- KVCacheSpec , MambaSpec )
78+ from vllm .v1 .kv_cache_interface import (AttentionSpec ,
79+ EncoderOnlyAttentionSpec ,
80+ FullAttentionSpec , KVCacheConfig ,
81+ KVCacheGroupSpec , KVCacheSpec ,
82+ MambaSpec )
8183# yapf: enable
8284from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
8385 DraftTokenIds , LogprobsTensors , ModelRunnerOutput )
@@ -317,10 +319,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
317319 if torch .version .cann .startswith ("8.3" ):
318320 self .attn_mask_builder = AttentionMaskBuilder (
319321 self .scheduler_config .max_num_batched_tokens , self .dtype ,
320- self .device )
322+ self .device , self . model_config . runner_type == "generate" )
321323 else :
322324 self .attn_mask_builder = AttentionMaskBuilder (
323- self .model_config .max_model_len , self .dtype )
325+ self .model_config .max_model_len ,
326+ self .dtype ,
327+ tril = self .model_config .runner_type == "generate" )
324328
325329 # Set up speculative decoding.
326330 self .spec_attn_mask = None
@@ -1477,14 +1481,29 @@ def _prepare_inputs(
14771481 # in the same group share the same metadata.
14781482 for kv_cache_group_id , kv_cache_group_spec in enumerate (
14791483 self .kv_cache_config .kv_cache_groups ):
1480- blk_table = self .input_batch .block_table [kv_cache_group_id ]
1481- blk_table_tensor = blk_table .get_device_tensor ()
1482- slot_mapping = blk_table .slot_mapping_cpu [:
1483- total_num_scheduled_tokens ]
1484- self .slot_mapping [:total_num_scheduled_tokens ].copy_ (
1485- slot_mapping [:total_num_scheduled_tokens ],
1486- non_blocking = True ,
1487- )
1484+ if isinstance (kv_cache_group_spec .kv_cache_spec ,
1485+ EncoderOnlyAttentionSpec ):
1486+ # Encoder-only layers do not have KV cache, so we need to
1487+ # create a dummy block table and slot mapping for them.
1488+ blk_table_tensor = torch .zeros (
1489+ (num_reqs , 1 ),
1490+ dtype = torch .int32 ,
1491+ device = self .device ,
1492+ )
1493+ slot_mapping = torch .zeros (
1494+ (total_num_scheduled_tokens , ),
1495+ dtype = torch .int64 ,
1496+ device = self .device ,
1497+ )
1498+ else :
1499+ blk_table = self .input_batch .block_table [kv_cache_group_id ]
1500+ blk_table_tensor = blk_table .get_device_tensor ()
1501+ slot_mapping = blk_table .slot_mapping_cpu [:
1502+ total_num_scheduled_tokens ]
1503+ self .slot_mapping [:total_num_scheduled_tokens ].copy_ (
1504+ slot_mapping [:total_num_scheduled_tokens ],
1505+ non_blocking = True ,
1506+ )
14881507
14891508 # Make AscendCommonAttentionMetadata
14901509 common_attn_metadata = AscendCommonAttentionMetadata (
@@ -1533,6 +1552,11 @@ def _prepare_inputs(
15331552 common_prefix_len = common_prefix_len ,
15341553 common_attn_metadata = common_attn_metadata ,
15351554 ** extra_attn_metadata_args )
1555+ elif self .model_config .runner_type == "pooling" :
1556+ attn_metadata_i = builder .build (
1557+ common_prefix_len = common_prefix_len ,
1558+ common_attn_metadata = common_attn_metadata ,
1559+ ** extra_attn_metadata_args )
15361560 else :
15371561 attn_metadata_i = builder .build (
15381562 common_prefix_len = common_prefix_len ,
@@ -2639,6 +2663,33 @@ def _convert_torch_format(self, tensor):
26392663 tensor = torch_npu .npu_format_cast (tensor , ACL_FORMAT )
26402664 return tensor
26412665
2666+ def may_add_encoder_only_layers_to_kv_cache_config (self ) -> None :
2667+ """
2668+ Add encoder-only layers to the KV cache config.
2669+ """
2670+ block_size = self .vllm_config .cache_config .block_size
2671+ use_mla = self .vllm_config .model_config .use_mla
2672+ encoder_only_attn_specs : dict [AttentionSpec ,
2673+ list [str ]] = defaultdict (list )
2674+ attn_layers = get_layers_from_vllm_config (self .vllm_config , Attention )
2675+ for layer_name , attn_module in attn_layers .items ():
2676+ if attn_module .attn_type == AttentionType .ENCODER_ONLY :
2677+ attn_spec : AttentionSpec = EncoderOnlyAttentionSpec (
2678+ block_size = block_size ,
2679+ num_kv_heads = attn_module .num_kv_heads ,
2680+ head_size = attn_module .head_size ,
2681+ dtype = self .kv_cache_dtype ,
2682+ use_mla = use_mla )
2683+ encoder_only_attn_specs [attn_spec ].append (layer_name )
2684+ self .runner_only_attn_layers .add (layer_name )
2685+ if len (encoder_only_attn_specs ) > 0 :
2686+ assert len (
2687+ encoder_only_attn_specs
2688+ ) == 1 , "Only support one encoder-only attention spec now"
2689+ spec , layer_names = encoder_only_attn_specs .popitem ()
2690+ self .kv_cache_config .kv_cache_groups .append (
2691+ KVCacheGroupSpec (layer_names = layer_names , kv_cache_spec = spec ))
2692+
26422693 def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
26432694 """
26442695 Initialize KV cache based on `kv_cache_config`.
@@ -2648,9 +2699,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
26482699 """
26492700 kv_cache_config = deepcopy (kv_cache_config )
26502701 self .kv_cache_config = kv_cache_config
2702+ self .may_reinitialize_input_batch (kv_cache_config )
2703+ self .may_add_encoder_only_layers_to_kv_cache_config ()
26512704 self .initialize_attn_backend (kv_cache_config )
26522705 self .use_hybrid_blocks = (len (self .attn_groups ) > 1 )
2653- self .may_reinitialize_input_batch (kv_cache_config )
26542706
26552707 if self .model_config .is_deepseek_mla :
26562708 kv_caches = self .initialize_kv_cache_tensors_deepseek (
0 commit comments