Skip to content

Commit f3f82da

Browse files
committed
[BugFix] Adapted Qwen3-Next to v0.11.2
Signed-off-by: drslark <[email protected]>
1 parent 84d7f5a commit f3f82da

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

tests/e2e/multicard/test_qwen3_next.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import os
2525
from unittest.mock import patch
2626

27-
import pytest
2827
from modelscope import snapshot_download # type: ignore
2928

3029
from tests.e2e.conftest import VllmRunner
@@ -64,7 +63,6 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
6463
del vllm_model
6564

6665

67-
@pytest.mark.skip(reason="Fix me, the accuracy is not correct")
6866
def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
6967
example_prompts = [
7068
"Hello, my name is",
@@ -74,11 +72,20 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
7472
]
7573
max_tokens = 20
7674

77-
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
78-
tensor_parallel_size=4,
79-
max_model_len=4096,
80-
gpu_memory_utilization=0.8,
81-
distributed_executor_backend="mp") as vllm_model:
75+
with VllmRunner(
76+
"Qwen/Qwen3-Next-80B-A3B-Instruct",
77+
tensor_parallel_size=4,
78+
max_model_len=4096,
79+
gpu_memory_utilization=0.8,
80+
distributed_executor_backend="mp",
81+
enforce_eager=True,
82+
additional_config={
83+
"ascend_scheduler_config": {
84+
"enabled": True,
85+
"enable_chunked_prefill": False
86+
}
87+
},
88+
) as vllm_model:
8289
ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
8390
del vllm_model
8491

@@ -87,6 +94,7 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
8794
max_model_len=4096,
8895
gpu_memory_utilization=0.8,
8996
distributed_executor_backend="mp",
97+
enforce_eager=True,
9098
additional_config={
9199
"ascend_scheduler_config": {
92100
"enabled": True,

vllm_ascend/models/qwen3_next.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ def _forward_core(
675675
initial_state[~has_initial_state, ...] = 0
676676

677677
batch_size = initial_state.shape[0]
678-
core_attn_out = []
678+
temp_core_attn_out = []
679679
last_recurrent_state = []
680680

681681
for b_idx in range(batch_size):
@@ -702,18 +702,18 @@ def _forward_core(
702702
use_qk_l2norm_in_kernel=True,
703703
)
704704

705-
core_attn_out.append(cur_core_attn_out_non_spec)
705+
temp_core_attn_out.append(cur_core_attn_out_non_spec)
706706
last_recurrent_state.append(cur_last_recurrent_state)
707707

708-
tar_dtype = core_attn_out[0].dtype
709-
tar_device = core_attn_out[0].device
710-
tar_shape = list(core_attn_out[0].shape)
708+
tar_dtype = temp_core_attn_out[0].dtype
709+
tar_device = temp_core_attn_out[0].device
710+
tar_shape = list(temp_core_attn_out[0].shape)
711711
tar_shape[1] = non_spec_query_start_loc[-1]
712712
core_attn_out_non_spec = torch.empty(tar_shape,
713713
dtype=tar_dtype,
714714
device=tar_device)
715715
for b_idx in range(batch_size):
716-
cur_core_attn_out = core_attn_out[b_idx]
716+
cur_core_attn_out = temp_core_attn_out[b_idx]
717717
start, end = non_spec_query_start_loc[
718718
b_idx], non_spec_query_start_loc[b_idx + 1]
719719
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out

0 commit comments

Comments
 (0)