7676from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
7777# yapf conflicts with isort for this block
7878# yapf: disable
79- from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
80- KVCacheConfig , KVCacheGroupSpec ,
81- KVCacheSpec , MambaSpec )
79+ from vllm .v1 .kv_cache_interface import (AttentionSpec ,
80+ EncoderOnlyAttentionSpec ,
81+ FullAttentionSpec , KVCacheConfig ,
82+ KVCacheGroupSpec , KVCacheSpec ,
83+ MambaSpec )
8284# yapf: enable
8385from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
8486 DraftTokenIds , LogprobsTensors , ModelRunnerOutput )
@@ -324,13 +326,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
324326 self .block_size ,
325327 use_mla = self .model_config .use_mla ,
326328 )
329+ pooler_config = self .model_config .pooler_config
330+ tril = self .model_config .runner_type == "generate" or (
331+ pooler_config is not None
332+ and pooler_config .pooling_type .lower () == "last" )
327333 if torch .version .cann .startswith ("8.3" ):
328334 self .attn_mask_builder = AttentionMaskBuilder (
329335 self .scheduler_config .max_num_batched_tokens , self .dtype ,
330- self .device )
336+ self .device , tril )
331337 else :
332338 self .attn_mask_builder = AttentionMaskBuilder (
333- self .model_config .max_model_len , self .dtype )
339+ self .model_config .max_model_len , self .dtype , tril = tril )
334340
335341 # Set up speculative decoding.
336342 self .spec_attn_mask = None
@@ -1487,14 +1493,29 @@ def _prepare_inputs(
14871493 # in the same group share the same metadata.
14881494 for kv_cache_group_id , kv_cache_group_spec in enumerate (
14891495 self .kv_cache_config .kv_cache_groups ):
1490- blk_table = self .input_batch .block_table [kv_cache_group_id ]
1491- blk_table_tensor = blk_table .get_device_tensor ()
1492- slot_mapping = blk_table .slot_mapping_cpu [:
1493- total_num_scheduled_tokens ]
1494- self .slot_mapping [:total_num_scheduled_tokens ].copy_ (
1495- slot_mapping [:total_num_scheduled_tokens ],
1496- non_blocking = True ,
1497- )
1496+ if isinstance (kv_cache_group_spec .kv_cache_spec ,
1497+ EncoderOnlyAttentionSpec ):
1498+ # Encoder-only layers do not have KV cache, so we need to
1499+ # create a dummy block table and slot mapping for them.
1500+ blk_table_tensor = torch .zeros (
1501+ (num_reqs , 1 ),
1502+ dtype = torch .int32 ,
1503+ device = self .device ,
1504+ )
1505+ slot_mapping = torch .zeros (
1506+ (total_num_scheduled_tokens , ),
1507+ dtype = torch .int64 ,
1508+ device = self .device ,
1509+ )
1510+ else :
1511+ blk_table = self .input_batch .block_table [kv_cache_group_id ]
1512+ blk_table_tensor = blk_table .get_device_tensor ()
1513+ slot_mapping = blk_table .slot_mapping_cpu [:
1514+ total_num_scheduled_tokens ]
1515+ self .slot_mapping [:total_num_scheduled_tokens ].copy_ (
1516+ slot_mapping [:total_num_scheduled_tokens ],
1517+ non_blocking = True ,
1518+ )
14981519
14991520 # Make AscendCommonAttentionMetadata
15001521 common_attn_metadata = AscendCommonAttentionMetadata (
@@ -1543,6 +1564,11 @@ def _prepare_inputs(
15431564 common_prefix_len = common_prefix_len ,
15441565 common_attn_metadata = common_attn_metadata ,
15451566 ** extra_attn_metadata_args )
1567+ elif self .model_config .runner_type == "pooling" :
1568+ attn_metadata_i = builder .build (
1569+ common_prefix_len = common_prefix_len ,
1570+ common_attn_metadata = common_attn_metadata ,
1571+ ** extra_attn_metadata_args )
15461572 else :
15471573 attn_metadata_i = builder .build (
15481574 common_prefix_len = common_prefix_len ,
@@ -2672,6 +2698,33 @@ def _convert_torch_format(self, tensor):
26722698 tensor = torch_npu .npu_format_cast (tensor , ACL_FORMAT )
26732699 return tensor
26742700
2701+ def may_add_encoder_only_layers_to_kv_cache_config (self ) -> None :
2702+ """
2703+ Add encoder-only layers to the KV cache config.
2704+ """
2705+ block_size = self .vllm_config .cache_config .block_size
2706+ use_mla = self .vllm_config .model_config .use_mla
2707+ encoder_only_attn_specs : dict [AttentionSpec ,
2708+ list [str ]] = defaultdict (list )
2709+ attn_layers = get_layers_from_vllm_config (self .vllm_config , Attention )
2710+ for layer_name , attn_module in attn_layers .items ():
2711+ if attn_module .attn_type == AttentionType .ENCODER_ONLY :
2712+ attn_spec : AttentionSpec = EncoderOnlyAttentionSpec (
2713+ block_size = block_size ,
2714+ num_kv_heads = attn_module .num_kv_heads ,
2715+ head_size = attn_module .head_size ,
2716+ dtype = self .kv_cache_dtype ,
2717+ use_mla = use_mla )
2718+ encoder_only_attn_specs [attn_spec ].append (layer_name )
2719+ self .runner_only_attn_layers .add (layer_name )
2720+ if len (encoder_only_attn_specs ) > 0 :
2721+ assert len (
2722+ encoder_only_attn_specs
2723+ ) == 1 , "Only support one encoder-only attention spec now"
2724+ spec , layer_names = encoder_only_attn_specs .popitem ()
2725+ self .kv_cache_config .kv_cache_groups .append (
2726+ KVCacheGroupSpec (layer_names = layer_names , kv_cache_spec = spec ))
2727+
26752728 def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
26762729 """
26772730 Initialize KV cache based on `kv_cache_config`.
@@ -2681,9 +2734,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
26812734 """
26822735 kv_cache_config = deepcopy (kv_cache_config )
26832736 self .kv_cache_config = kv_cache_config
2737+ self .may_reinitialize_input_batch (kv_cache_config )
2738+ self .may_add_encoder_only_layers_to_kv_cache_config ()
26842739 self .initialize_attn_backend (kv_cache_config )
26852740 self .use_hybrid_blocks = (len (self .attn_groups ) > 1 )
2686- self .may_reinitialize_input_batch (kv_cache_config )
26872741
26882742 if self .model_config .is_deepseek_mla :
26892743 kv_caches = self .initialize_kv_cache_tensors_deepseek (
0 commit comments