Skip to content

Commit 41a63f6

Browse files
committed
fix Qwen3NextGatedDeltaNet
Signed-off-by: Icey <[email protected]>
1 parent 9757e24 commit 41a63f6

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

vllm_ascend/models/qwen3_next.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def _forward(
201201
spec_query_start_loc = attn_metadata.spec_query_start_loc
202202
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
203203
spec_sequence_masks = attn_metadata.spec_sequence_masks
204-
spec_token_masks = attn_metadata.spec_token_masks
204+
spec_token_indx = attn_metadata.spec_token_indx
205+
non_spec_token_indx = attn_metadata.non_spec_token_indx
205206
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
206207
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
207208
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
@@ -216,8 +217,6 @@ def _forward(
216217

217218
# 1. Set up dimensions for reshapes later
218219
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
219-
if spec_token_masks is not None:
220-
spec_token_masks = spec_token_masks[:num_actual_tokens]
221220
projected_states_qkvz, projected_states_ba = torch.split(
222221
projected_states,
223222
[
@@ -242,8 +241,9 @@ def _forward(
242241
mixed_qkv_spec = mixed_qkv
243242
mixed_qkv_non_spec = None
244243
else:
245-
mixed_qkv_spec = mixed_qkv[spec_token_masks]
246-
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
244+
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
245+
mixed_qkv_non_spec = mixed_qkv.index_select(
246+
0, non_spec_token_indx)
247247
else:
248248
mixed_qkv_spec = None
249249
mixed_qkv_non_spec = mixed_qkv
@@ -293,10 +293,10 @@ def _forward(
293293
g_non_spec = None
294294
beta_non_spec = None
295295
else:
296-
g_spec = g[:, spec_token_masks]
297-
beta_spec = beta[:, spec_token_masks]
298-
g_non_spec = g[:, ~spec_token_masks]
299-
beta_non_spec = beta[:, ~spec_token_masks]
296+
g_spec = g.index_select(1, spec_token_indx)
297+
beta_spec = beta.index_select(1, spec_token_indx)
298+
g_non_spec = g.index_select(1, non_spec_token_indx)
299+
beta_non_spec = beta.index_select(1, non_spec_token_indx)
300300
else:
301301
g_spec = None
302302
beta_spec = None
@@ -404,8 +404,9 @@ def _forward(
404404
dtype=core_attn_out_non_spec.dtype,
405405
device=core_attn_out_non_spec.device,
406406
)
407-
core_attn_out[:, spec_token_masks] = core_attn_out_spec
408-
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
407+
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
408+
core_attn_out.index_copy_(1, non_spec_token_indx,
409+
core_attn_out_non_spec)
409410
elif spec_sequence_masks is not None:
410411
core_attn_out = core_attn_out_spec
411412
else:

0 commit comments

Comments
 (0)