7676from vllm .utils .jsontree import json_map_leaves
7777from vllm .v1 .attention .backends .gdn_attn import GDNAttentionMetadataBuilder
7878from vllm .v1 .attention .backends .utils import (
79- AttentionCGSupport , reorder_batch_to_split_decodes_and_prefills )
79+ AttentionCGSupport , CommonAttentionMetadata ,
80+ reorder_batch_to_split_decodes_and_prefills )
8081from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
8182# yapf conflicts with isort for this block
8283# yapf: disable
107108from vllm_ascend .ascend_forward_context import (MoECommType ,
108109 set_ascend_forward_context )
109110from vllm_ascend .attention .attention_mask import AttentionMaskBuilder
110- from vllm_ascend .attention .attention_v1 import AscendAttentionState
111+ from vllm_ascend .attention .attention_v1 import (AscendAttentionMetadataBuilder ,
112+ AscendAttentionState )
111113from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
112114 AscendPrefillContextParallelMetadata )
113115# yapf: disable
@@ -2653,6 +2655,7 @@ def _build_dummy_attn_metadata(
26532655 max_query_len : int ,
26542656 aclgraph_runtime_mode : Optional [CUDAGraphMode ] = None ,
26552657 force_attention : bool = False ,
2658+ num_scheduled_tokens : np .array = None ,
26562659 ) -> Optional [dict [str , Any ]]:
26572660 attn_metadata : Optional [dict [str , Any ]] = None
26582661
@@ -2666,6 +2669,13 @@ def _build_dummy_attn_metadata(
26662669 self .seq_lens_np [:num_reqs ] = seq_lens
26672670 self .seq_lens_np [num_reqs :] = 0
26682671
2672+ cu_num_tokens , arange = self ._get_cumsum_and_arange (
2673+ num_scheduled_tokens )
2674+ query_start_loc_tensor = torch .Tensor (cu_num_tokens ).to (
2675+ self .device ).to (torch .int32 )
2676+ self .query_start_loc [1 :num_reqs + 1 ] = query_start_loc_tensor
2677+ self .query_start_loc_cpu [1 :num_reqs + 1 ] = torch .Tensor
2678+
26692679 num_computed_tokens_cpu = (
26702680 self .input_batch .num_computed_tokens_cpu_tensor [:num_reqs ])
26712681
@@ -2722,12 +2732,35 @@ def _build_dummy_attn_metadata(
27222732 self .speculative_config .method == "deepseek_mtp" :
27232733 attn_state = AscendAttentionState .SpecDecoding
27242734
2735+ common_metadata = CommonAttentionMetadata (
2736+ query_start_loc = self .query_start_loc [:num_reqs + 1 ],
2737+ query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs +
2738+ 1 ],
2739+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ],
2740+ seq_lens = self .seq_lens_cpu [:num_reqs ],
2741+ num_reqs = num_reqs ,
2742+ num_actual_tokens = num_tokens ,
2743+ block_table_tensor = block_table_tensor [:num_reqs ],
2744+ slot_mapping = slot_mapping ,
2745+ num_computed_tokens_cpu = num_computed_tokens_cpu ,
2746+ max_query_len = max_query_len ,
2747+ max_seq_len = seq_lens )
2748+
27252749 for attn_group in self .attn_groups [kv_cache_group_id ]:
27262750 builder = attn_group .get_metadata_builder ()
2727- attn_metadata_i = builder .build_for_graph_capture (
2728- common_attn_metadata , attn_state , self .get_model ())
2751+ if isinstance (builder , AscendAttentionMetadataBuilder ):
2752+ attn_metadata_full_attention = builder .build_for_graph_capture (
2753+ common_attn_metadata , attn_state , self .get_model ())
2754+ elif isinstance (builder , GDNAttentionMetadataBuilder ):
2755+ attn_metadata_gdn_attention = builder .build_for_cudagraph_capture (
2756+ common_metadata )
27292757 for layer_name in kv_cache_group_spec .layer_names :
2730- attn_metadata [layer_name ] = attn_metadata_i
2758+ if "linear_attn" in layer_name :
2759+ attn_metadata [
2760+ layer_name ] = attn_metadata_gdn_attention
2761+ else :
2762+ attn_metadata [
2763+ layer_name ] = attn_metadata_full_attention
27312764
27322765 return attn_metadata
27332766
@@ -2902,6 +2935,7 @@ def _dummy_run(
29022935 max_query_len = max_query_len ,
29032936 aclgraph_runtime_mode = aclgraph_runtime_mode ,
29042937 force_attention = force_attention ,
2938+ num_scheduled_tokens = num_scheduled_tokens ,
29052939 )
29062940
29072941 need_dummy_logits = (not self .in_profile_run
0 commit comments