diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py index 2f54b5b23db..9fda522021a 100644 --- a/tests/e2e/multicard/test_qwen3_next.py +++ b/tests/e2e/multicard/test_qwen3_next.py @@ -36,3 +36,21 @@ def test_models_distributed_Qwen3_NEXT_TP4(): distributed_executor_backend="mp", enforce_eager=True) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY(): + example_prompts = [ + "Hello, my name is", + ] * 4 + max_tokens = 5 + with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct", + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.8, + distributed_executor_backend="mp", + enforce_eager=False, + compilation_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [1, 8, 24, 48, 60] + }) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 41476ccc3d0..c8d1edee4bf 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -192,8 +192,10 @@ def __call__(self, *args, **kwargs): def update_attn_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() - # FIXME: Behold! We are using a temporary hack here to update the args - # for each layer's attention op in the graph. + # For Qwen3-next, since the kv_cache_config has already categorized + # linear_attn and self_attn, the attn_metadata is first arranged with + # self_attn followed by linear_attn. Therefore, using zip directly + # filters out the update operations for linear_attn. with torch.npu.stream(update_stream): for key, param, handle, event in zip( forward_context.attn_metadata, @@ -289,9 +291,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): - graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. + graph_params = get_graph_params() with torch.npu.stream(update_stream): for key, param, handle, event in zip( forward_context.attn_metadata, diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index fda33df1798..f92696c9cc6 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -31,6 +31,7 @@ from vllm.forward_context import get_forward_context from vllm.logger import logger +import numpy as np import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.platform import NPUPlatform @@ -178,6 +179,7 @@ def _build_dummy_attn_metadata( num_reqs: int, num_tokens: int, max_query_len: int, + num_scheduled_tokens: np.ndarray, aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, ) -> Optional[dict[str, Any]]: @@ -186,7 +188,7 @@ def _build_dummy_attn_metadata( if with_prefill or self.enable_shared_expert_dp: attn_metadata = super()._build_dummy_attn_metadata( with_prefill, num_reqs, num_tokens, max_query_len, - aclgraph_runtime_mode, force_attention) + num_scheduled_tokens, aclgraph_runtime_mode, force_attention) else: common_attn_metadata = TorchairCommonAttentionMetadata( num_reqs=num_reqs, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index faa222dd8ba..71b19d087f9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -76,7 +76,8 @@ from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills) + AttentionCGSupport, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # yapf conflicts with isort for this block # yapf: disable @@ -107,7 +108,8 @@ from vllm_ascend.ascend_forward_context import (MoECommType, set_ascend_forward_context) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder, + AscendAttentionState) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, AscendPrefillContextParallelMetadata) # yapf: disable @@ -2651,6 +2653,7 @@ def _build_dummy_attn_metadata( num_reqs: int, num_tokens: int, max_query_len: int, + num_scheduled_tokens: np.ndarray, aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, ) -> Optional[dict[str, Any]]: @@ -2666,6 +2669,14 @@ def _build_dummy_attn_metadata( self.seq_lens_np[:num_reqs] = seq_lens self.seq_lens_np[num_reqs:] = 0 + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + query_start_loc_tensor = torch.Tensor(cu_num_tokens).to( + self.device).to(torch.int32) + self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor + self.query_start_loc_cpu[1:num_reqs + + 1] = torch.Tensor(cu_num_tokens) + num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -2722,12 +2733,35 @@ def _build_dummy_attn_metadata( self.speculative_config.method == "deepseek_mtp": attn_state = AscendAttentionState.SpecDecoding + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + 1], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + max_query_len=max_query_len, + max_seq_len=seq_lens) + for attn_group in self.attn_groups[kv_cache_group_id]: builder = attn_group.get_metadata_builder() - attn_metadata_i = builder.build_for_graph_capture( - common_attn_metadata, attn_state, self.get_model()) + if isinstance(builder, AscendAttentionMetadataBuilder): + attn_metadata_full_attention = builder.build_for_graph_capture( + common_attn_metadata, attn_state, self.get_model()) + elif isinstance(builder, GDNAttentionMetadataBuilder): + attn_metadata_gdn_attention = builder.build_for_cudagraph_capture( + common_metadata) for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + if "linear_attn" in layer_name: + attn_metadata[ + layer_name] = attn_metadata_gdn_attention + else: + attn_metadata[ + layer_name] = attn_metadata_full_attention return attn_metadata @@ -2902,6 +2936,7 @@ def _dummy_run( max_query_len=max_query_len, aclgraph_runtime_mode=aclgraph_runtime_mode, force_attention=force_attention, + num_scheduled_tokens=num_scheduled_tokens, ) need_dummy_logits = (not self.in_profile_run