Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/e2e/multicard/test_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,21 @@ def test_models_distributed_Qwen3_NEXT_TP4():
distributed_executor_backend="mp",
enforce_eager=True) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
example_prompts = [
"Hello, my name is",
] * 4
max_tokens = 5
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
enforce_eager=False,
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
}) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
8 changes: 5 additions & 3 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ def __call__(self, *args, **kwargs):

def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
Expand Down Expand Up @@ -289,9 +291,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,


def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.forward_context import get_forward_context
from vllm.logger import logger

import numpy as np
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
Expand Down Expand Up @@ -178,6 +179,7 @@ def _build_dummy_attn_metadata(
num_reqs: int,
num_tokens: int,
max_query_len: int,
num_scheduled_tokens: np.ndarray,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
Expand All @@ -186,7 +188,7 @@ def _build_dummy_attn_metadata(
if with_prefill or self.enable_shared_expert_dp:
attn_metadata = super()._build_dummy_attn_metadata(
with_prefill, num_reqs, num_tokens, max_query_len,
aclgraph_runtime_mode, force_attention)
num_scheduled_tokens, aclgraph_runtime_mode, force_attention)
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,
Expand Down
45 changes: 40 additions & 5 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
AttentionCGSupport, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -107,7 +108,8 @@
from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context)
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder,
AscendAttentionState)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
# yapf: disable
Expand Down Expand Up @@ -2651,6 +2653,7 @@ def _build_dummy_attn_metadata(
num_reqs: int,
num_tokens: int,
max_query_len: int,
num_scheduled_tokens: np.ndarray,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
Expand All @@ -2666,6 +2669,14 @@ def _build_dummy_attn_metadata(
self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0

cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
self.device).to(torch.int32)
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)

num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])

Expand Down Expand Up @@ -2722,12 +2733,35 @@ def _build_dummy_attn_metadata(
self.speculative_config.method == "deepseek_mtp":
attn_state = AscendAttentionState.SpecDecoding

common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
max_seq_len=seq_lens)

for attn_group in self.attn_groups[kv_cache_group_id]:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata, attn_state, self.get_model())
if isinstance(builder, AscendAttentionMetadataBuilder):
attn_metadata_full_attention = builder.build_for_graph_capture(
common_attn_metadata, attn_state, self.get_model())
elif isinstance(builder, GDNAttentionMetadataBuilder):
attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
common_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if "linear_attn" in layer_name:
attn_metadata[
layer_name] = attn_metadata_gdn_attention
else:
attn_metadata[
layer_name] = attn_metadata_full_attention

return attn_metadata

Expand Down Expand Up @@ -2902,6 +2936,7 @@ def _dummy_run(
max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
num_scheduled_tokens=num_scheduled_tokens,
)

need_dummy_logits = (not self.in_profile_run
Expand Down
Loading