5151from vllm .transformers_utils .configs import Qwen3NextConfig
5252from vllm .v1 .attention .backends .gdn_attn import GDNAttentionMetadata
5353
54+ from vllm_ascend .utils import vllm_version_is
55+
5456from vllm .model_executor .models .qwen3_next import ( # isort: skip
5557 Qwen3NextAttention , Qwen3NextDecoderLayer , Qwen3NextForCausalLM ,
5658 Qwen3NextGatedDeltaNet , Qwen3NextModel , Qwen3NextSparseMoeBlock ,
@@ -201,7 +203,11 @@ def _forward(
201203 spec_query_start_loc = attn_metadata .spec_query_start_loc
202204 non_spec_query_start_loc = attn_metadata .non_spec_query_start_loc
203205 spec_sequence_masks = attn_metadata .spec_sequence_masks
204- spec_token_masks = attn_metadata .spec_token_masks
206+ if vllm_version_is ("0.11.0" ):
207+ spec_token_masks = attn_metadata .spec_token_masks
208+ else :
209+ spec_token_indx = attn_metadata .spec_token_indx
210+ non_spec_token_indx = attn_metadata .non_spec_token_indx
205211 spec_state_indices_tensor = attn_metadata .spec_state_indices_tensor # noqa: E501
206212 non_spec_state_indices_tensor = attn_metadata .non_spec_state_indices_tensor # noqa: E501
207213 self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
@@ -216,8 +222,9 @@ def _forward(
216222
217223 # 1. Set up dimensions for reshapes later
218224 projected_states , _ = self .in_proj (hidden_states [:num_actual_tokens ])
219- if spec_token_masks is not None :
220- spec_token_masks = spec_token_masks [:num_actual_tokens ]
225+ if vllm_version_is ("0.11.0" ):
226+ if spec_token_masks is not None :
227+ spec_token_masks = spec_token_masks [:num_actual_tokens ]
221228 projected_states_qkvz , projected_states_ba = torch .split (
222229 projected_states ,
223230 [
@@ -242,8 +249,13 @@ def _forward(
242249 mixed_qkv_spec = mixed_qkv
243250 mixed_qkv_non_spec = None
244251 else :
245- mixed_qkv_spec = mixed_qkv [spec_token_masks ]
246- mixed_qkv_non_spec = mixed_qkv [~ spec_token_masks ]
252+ if vllm_version_is ("0.11.0" ):
253+ mixed_qkv_spec = mixed_qkv [spec_token_masks ]
254+ mixed_qkv_non_spec = mixed_qkv [~ spec_token_masks ]
255+ else :
256+ mixed_qkv_spec = mixed_qkv .index_select (0 , spec_token_indx )
257+ mixed_qkv_non_spec = mixed_qkv .index_select (
258+ 0 , non_spec_token_indx )
247259 else :
248260 mixed_qkv_spec = None
249261 mixed_qkv_non_spec = mixed_qkv
@@ -293,10 +305,16 @@ def _forward(
293305 g_non_spec = None
294306 beta_non_spec = None
295307 else :
296- g_spec = g [:, spec_token_masks ]
297- beta_spec = beta [:, spec_token_masks ]
298- g_non_spec = g [:, ~ spec_token_masks ]
299- beta_non_spec = beta [:, ~ spec_token_masks ]
308+ if vllm_version_is ("0.11.0" ):
309+ g_spec = g [:, spec_token_masks ]
310+ beta_spec = beta [:, spec_token_masks ]
311+ g_non_spec = g [:, ~ spec_token_masks ]
312+ beta_non_spec = beta [:, ~ spec_token_masks ]
313+ else :
314+ g_spec = g .index_select (1 , spec_token_indx )
315+ beta_spec = beta .index_select (1 , spec_token_indx )
316+ g_non_spec = g .index_select (1 , non_spec_token_indx )
317+ beta_non_spec = beta .index_select (1 , non_spec_token_indx )
300318 else :
301319 g_spec = None
302320 beta_spec = None
@@ -404,8 +422,14 @@ def _forward(
404422 dtype = core_attn_out_non_spec .dtype ,
405423 device = core_attn_out_non_spec .device ,
406424 )
407- core_attn_out [:, spec_token_masks ] = core_attn_out_spec
408- core_attn_out [:, ~ spec_token_masks ] = core_attn_out_non_spec
425+ if vllm_version_is ("0.11.0" ):
426+ core_attn_out [:, spec_token_masks ] = core_attn_out_spec
427+ core_attn_out [:, ~ spec_token_masks ] = core_attn_out_non_spec
428+ else :
429+ core_attn_out .index_copy_ (1 , spec_token_indx ,
430+ core_attn_out_spec )
431+ core_attn_out .index_copy_ (1 , non_spec_token_indx ,
432+ core_attn_out_non_spec )
409433 elif spec_sequence_masks is not None :
410434 core_attn_out = core_attn_out_spec
411435 else :
@@ -673,4 +697,4 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
673697 self .num_physical_experts = example_layer .n_physical_experts
674698 self .num_local_physical_experts = example_layer .n_local_physical_experts
675699 self .num_routed_experts = example_layer .n_routed_experts
676- self .num_redundant_experts = example_layer .n_redundant_experts
700+ self .num_redundant_experts = example_layer .n_redundant_experts
0 commit comments