Skip to content

Commit 7c66385

Browse files
author
wangxiaoxin-sherie
committed
QWEN3-NEXT support FULL_DECODE_ONLY mode.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 0b9b6d7 commit 7c66385

File tree

4 files changed

+60
-8
lines changed

4 files changed

+60
-8
lines changed

tests/e2e/multicard/test_qwen3_next.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,21 @@ def test_models_distributed_Qwen3_NEXT_TP4():
3636
distributed_executor_backend="mp",
3737
enforce_eager=True) as vllm_model:
3838
vllm_model.generate_greedy(example_prompts, max_tokens)
39+
40+
41+
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
42+
example_prompts = [
43+
"Hello, my name is",
44+
] * 4
45+
max_tokens = 5
46+
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
47+
tensor_parallel_size=4,
48+
max_model_len=4096,
49+
gpu_memory_utilization=0.8,
50+
distributed_executor_backend="mp",
51+
enforce_eager=False,
52+
compilation_config={
53+
"cudagraph_mode": "FULL_DECODE_ONLY",
54+
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
55+
}) as vllm_model:
56+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/compilation/acl_graph.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,6 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
290290

291291
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
292292
graph_params = get_graph_params()
293-
# FIXME: Behold! We are using a temporary hack here to update the args
294-
# for each layer's attention op in the graph.
295293
with torch.npu.stream(update_stream):
296294
for key, param, handle, event in zip(
297295
forward_context.attn_metadata,

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.forward_context import get_forward_context
3232
from vllm.logger import logger
3333

34+
import numpy as np
3435
import vllm_ascend.envs as envs_ascend
3536
from vllm_ascend.ascend_config import get_ascend_config
3637
from vllm_ascend.platform import NPUPlatform
@@ -180,13 +181,14 @@ def _build_dummy_attn_metadata(
180181
max_query_len: int,
181182
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
182183
force_attention: bool = False,
184+
num_scheduled_tokens: np.array = None,
183185
) -> Optional[dict[str, Any]]:
184186
# NOTE: If torchair graph mode and not with_prefill,
185187
# we can't skip_attn, it will cause graph recompile.
186188
if with_prefill or self.enable_shared_expert_dp:
187189
attn_metadata = super()._build_dummy_attn_metadata(
188190
with_prefill, num_reqs, num_tokens, max_query_len,
189-
aclgraph_runtime_mode, force_attention)
191+
aclgraph_runtime_mode, force_attention, num_scheduled_tokens)
190192
else:
191193
common_attn_metadata = TorchairCommonAttentionMetadata(
192194
num_reqs=num_reqs,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@
7676
from vllm.utils.jsontree import json_map_leaves
7777
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
7878
from vllm.v1.attention.backends.utils import (
79-
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
79+
AttentionCGSupport, CommonAttentionMetadata,
80+
reorder_batch_to_split_decodes_and_prefills)
8081
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
8182
# yapf conflicts with isort for this block
8283
# yapf: disable
@@ -107,7 +108,8 @@
107108
from vllm_ascend.ascend_forward_context import (MoECommType,
108109
set_ascend_forward_context)
109110
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
110-
from vllm_ascend.attention.attention_v1 import AscendAttentionState
111+
from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder,
112+
AscendAttentionState)
111113
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
112114
AscendPrefillContextParallelMetadata)
113115
# yapf: disable
@@ -2653,6 +2655,7 @@ def _build_dummy_attn_metadata(
26532655
max_query_len: int,
26542656
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
26552657
force_attention: bool = False,
2658+
num_scheduled_tokens: np.array = None,
26562659
) -> Optional[dict[str, Any]]:
26572660
attn_metadata: Optional[dict[str, Any]] = None
26582661

@@ -2666,6 +2669,13 @@ def _build_dummy_attn_metadata(
26662669
self.seq_lens_np[:num_reqs] = seq_lens
26672670
self.seq_lens_np[num_reqs:] = 0
26682671

2672+
cu_num_tokens, arange = self._get_cumsum_and_arange(
2673+
num_scheduled_tokens)
2674+
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
2675+
self.device).to(torch.int32)
2676+
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
2677+
self.query_start_loc_cpu[1:num_reqs + 1] = torch.Tensor
2678+
26692679
num_computed_tokens_cpu = (
26702680
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
26712681

@@ -2722,12 +2732,35 @@ def _build_dummy_attn_metadata(
27222732
self.speculative_config.method == "deepseek_mtp":
27232733
attn_state = AscendAttentionState.SpecDecoding
27242734

2735+
common_metadata = CommonAttentionMetadata(
2736+
query_start_loc=self.query_start_loc[:num_reqs + 1],
2737+
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
2738+
1],
2739+
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
2740+
seq_lens=self.seq_lens_cpu[:num_reqs],
2741+
num_reqs=num_reqs,
2742+
num_actual_tokens=num_tokens,
2743+
block_table_tensor=block_table_tensor[:num_reqs],
2744+
slot_mapping=slot_mapping,
2745+
num_computed_tokens_cpu=num_computed_tokens_cpu,
2746+
max_query_len=max_query_len,
2747+
max_seq_len=seq_lens)
2748+
27252749
for attn_group in self.attn_groups[kv_cache_group_id]:
27262750
builder = attn_group.get_metadata_builder()
2727-
attn_metadata_i = builder.build_for_graph_capture(
2728-
common_attn_metadata, attn_state, self.get_model())
2751+
if isinstance(builder, AscendAttentionMetadataBuilder):
2752+
attn_metadata_full_attention = builder.build_for_graph_capture(
2753+
common_attn_metadata, attn_state, self.get_model())
2754+
elif isinstance(builder, GDNAttentionMetadataBuilder):
2755+
attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
2756+
common_metadata)
27292757
for layer_name in kv_cache_group_spec.layer_names:
2730-
attn_metadata[layer_name] = attn_metadata_i
2758+
if "linear_attn" in layer_name:
2759+
attn_metadata[
2760+
layer_name] = attn_metadata_gdn_attention
2761+
else:
2762+
attn_metadata[
2763+
layer_name] = attn_metadata_full_attention
27312764

27322765
return attn_metadata
27332766

@@ -2902,6 +2935,7 @@ def _dummy_run(
29022935
max_query_len=max_query_len,
29032936
aclgraph_runtime_mode=aclgraph_runtime_mode,
29042937
force_attention=force_attention,
2938+
num_scheduled_tokens=num_scheduled_tokens,
29052939
)
29062940

29072941
need_dummy_logits = (not self.in_profile_run

0 commit comments

Comments
 (0)