Skip to content

Commit 6e44756

Browse files
committed
[Bugfix][Qwen3-Next] Fix Qwen3-Next with the latest maintained vllm commit id
Signed-off-by: MengqingCao <[email protected]>
1 parent 9e150e5 commit 6e44756

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

vllm_ascend/models/qwen3_next.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
from vllm.transformers_utils.configs import Qwen3NextConfig
5252
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
5353

54+
from vllm_ascend.utils import vllm_version_is
55+
5456
from vllm.model_executor.models.qwen3_next import ( # isort: skip
5557
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
5658
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
@@ -201,7 +203,13 @@ def _forward(
201203
spec_query_start_loc = attn_metadata.spec_query_start_loc
202204
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
203205
spec_sequence_masks = attn_metadata.spec_sequence_masks
204-
spec_token_masks = attn_metadata.spec_token_masks
206+
if vllm_version_is("0.11.0"):
207+
spec_token_masks = attn_metadata.spec_token_masks
208+
if spec_token_masks is not None:
209+
spec_token_masks = spec_token_masks[:num_actual_tokens]
210+
else:
211+
spec_token_indx = attn_metadata.spec_token_indx
212+
non_spec_token_indx = attn_metadata.non_spec_token_indx
205213
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
206214
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
207215
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
@@ -216,8 +224,6 @@ def _forward(
216224

217225
# 1. Set up dimensions for reshapes later
218226
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
219-
if spec_token_masks is not None:
220-
spec_token_masks = spec_token_masks[:num_actual_tokens]
221227
projected_states_qkvz, projected_states_ba = torch.split(
222228
projected_states,
223229
[
@@ -293,10 +299,16 @@ def _forward(
293299
g_non_spec = None
294300
beta_non_spec = None
295301
else:
296-
g_spec = g[:, spec_token_masks]
297-
beta_spec = beta[:, spec_token_masks]
298-
g_non_spec = g[:, ~spec_token_masks]
299-
beta_non_spec = beta[:, ~spec_token_masks]
302+
if vllm_version_is("0.11.0"):
303+
g_spec = g[:, spec_token_masks]
304+
beta_spec = beta[:, spec_token_masks]
305+
g_non_spec = g[:, ~spec_token_masks]
306+
beta_non_spec = beta[:, ~spec_token_masks]
307+
else:
308+
g_spec = g.index_select(1, spec_token_indx)
309+
beta_spec = beta.index_select(1, spec_token_indx)
310+
g_non_spec = g.index_select(1, non_spec_token_indx)
311+
beta_non_spec = beta.index_select(1, non_spec_token_indx)
300312
else:
301313
g_spec = None
302314
beta_spec = None
@@ -404,8 +416,15 @@ def _forward(
404416
dtype=core_attn_out_non_spec.dtype,
405417
device=core_attn_out_non_spec.device,
406418
)
407-
core_attn_out[:, spec_token_masks] = core_attn_out_spec
408-
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
419+
if vllm_version_is("0.11.0"):
420+
core_attn_out[:, spec_token_masks] = core_attn_out_spec
421+
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
422+
else:
423+
core_attn_out.index_copy_(1, spec_token_indx,
424+
core_attn_out_spec)
425+
core_attn_out.index_copy_(1, non_spec_token_indx,
426+
core_attn_out_non_spec)
427+
409428
elif spec_sequence_masks is not None:
410429
core_attn_out = core_attn_out_spec
411430
else:

0 commit comments

Comments
 (0)