5151from vllm .transformers_utils .configs import Qwen3NextConfig
5252from vllm .v1 .attention .backends .gdn_attn import GDNAttentionMetadata
5353
54+ from vllm_ascend .utils import vllm_version_is
55+
5456from vllm .model_executor .models .qwen3_next import ( # isort: skip
5557 Qwen3NextAttention , Qwen3NextDecoderLayer , Qwen3NextForCausalLM ,
5658 Qwen3NextGatedDeltaNet , Qwen3NextModel , Qwen3NextSparseMoeBlock ,
@@ -201,7 +203,13 @@ def _forward(
201203 spec_query_start_loc = attn_metadata .spec_query_start_loc
202204 non_spec_query_start_loc = attn_metadata .non_spec_query_start_loc
203205 spec_sequence_masks = attn_metadata .spec_sequence_masks
204- spec_token_masks = attn_metadata .spec_token_masks
206+ if vllm_version_is ("0.11.0" ):
207+ spec_token_masks = attn_metadata .spec_token_masks
208+ if spec_token_masks is not None :
209+ spec_token_masks = spec_token_masks [:num_actual_tokens ]
210+ else :
211+ spec_token_indx = attn_metadata .spec_token_indx
212+ non_spec_token_indx = attn_metadata .non_spec_token_indx
205213 spec_state_indices_tensor = attn_metadata .spec_state_indices_tensor # noqa: E501
206214 non_spec_state_indices_tensor = attn_metadata .non_spec_state_indices_tensor # noqa: E501
207215 self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
@@ -216,8 +224,6 @@ def _forward(
216224
217225 # 1. Set up dimensions for reshapes later
218226 projected_states , _ = self .in_proj (hidden_states [:num_actual_tokens ])
219- if spec_token_masks is not None :
220- spec_token_masks = spec_token_masks [:num_actual_tokens ]
221227 projected_states_qkvz , projected_states_ba = torch .split (
222228 projected_states ,
223229 [
@@ -293,10 +299,16 @@ def _forward(
293299 g_non_spec = None
294300 beta_non_spec = None
295301 else :
296- g_spec = g [:, spec_token_masks ]
297- beta_spec = beta [:, spec_token_masks ]
298- g_non_spec = g [:, ~ spec_token_masks ]
299- beta_non_spec = beta [:, ~ spec_token_masks ]
302+ if vllm_version_is ("0.11.0" ):
303+ g_spec = g [:, spec_token_masks ]
304+ beta_spec = beta [:, spec_token_masks ]
305+ g_non_spec = g [:, ~ spec_token_masks ]
306+ beta_non_spec = beta [:, ~ spec_token_masks ]
307+ else :
308+ g_spec = g .index_select (1 , spec_token_indx )
309+ beta_spec = beta .index_select (1 , spec_token_indx )
310+ g_non_spec = g .index_select (1 , non_spec_token_indx )
311+ beta_non_spec = beta .index_select (1 , non_spec_token_indx )
300312 else :
301313 g_spec = None
302314 beta_spec = None
@@ -404,8 +416,15 @@ def _forward(
404416 dtype = core_attn_out_non_spec .dtype ,
405417 device = core_attn_out_non_spec .device ,
406418 )
407- core_attn_out [:, spec_token_masks ] = core_attn_out_spec
408- core_attn_out [:, ~ spec_token_masks ] = core_attn_out_non_spec
419+ if vllm_version_is ("0.11.0" ):
420+ core_attn_out [:, spec_token_masks ] = core_attn_out_spec
421+ core_attn_out [:, ~ spec_token_masks ] = core_attn_out_non_spec
422+ else :
423+ core_attn_out .index_copy_ (1 , spec_token_indx ,
424+ core_attn_out_spec )
425+ core_attn_out .index_copy_ (1 , non_spec_token_indx ,
426+ core_attn_out_non_spec )
427+
409428 elif spec_sequence_masks is not None :
410429 core_attn_out = core_attn_out_spec
411430 else :
0 commit comments