@@ -201,7 +201,8 @@ def _forward(
201201 spec_query_start_loc = attn_metadata .spec_query_start_loc
202202 non_spec_query_start_loc = attn_metadata .non_spec_query_start_loc
203203 spec_sequence_masks = attn_metadata .spec_sequence_masks
204- spec_token_masks = attn_metadata .spec_token_masks
204+ spec_token_indx = attn_metadata .spec_token_indx
205+ non_spec_token_indx = attn_metadata .non_spec_token_indx
205206 spec_state_indices_tensor = attn_metadata .spec_state_indices_tensor # noqa: E501
206207 non_spec_state_indices_tensor = attn_metadata .non_spec_state_indices_tensor # noqa: E501
207208 self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
@@ -216,8 +217,6 @@ def _forward(
216217
217218 # 1. Set up dimensions for reshapes later
218219 projected_states , _ = self .in_proj (hidden_states [:num_actual_tokens ])
219- if spec_token_masks is not None :
220- spec_token_masks = spec_token_masks [:num_actual_tokens ]
221220 projected_states_qkvz , projected_states_ba = torch .split (
222221 projected_states ,
223222 [
@@ -242,8 +241,9 @@ def _forward(
242241 mixed_qkv_spec = mixed_qkv
243242 mixed_qkv_non_spec = None
244243 else :
245- mixed_qkv_spec = mixed_qkv [spec_token_masks ]
246- mixed_qkv_non_spec = mixed_qkv [~ spec_token_masks ]
244+ mixed_qkv_spec = mixed_qkv .index_select (0 , spec_token_indx )
245+ mixed_qkv_non_spec = mixed_qkv .index_select (
246+ 0 , non_spec_token_indx )
247247 else :
248248 mixed_qkv_spec = None
249249 mixed_qkv_non_spec = mixed_qkv
@@ -293,10 +293,10 @@ def _forward(
293293 g_non_spec = None
294294 beta_non_spec = None
295295 else :
296- g_spec = g [:, spec_token_masks ]
297- beta_spec = beta [:, spec_token_masks ]
298- g_non_spec = g [:, ~ spec_token_masks ]
299- beta_non_spec = beta [:, ~ spec_token_masks ]
296+ g_spec = g . index_select ( 1 , spec_token_indx )
297+ beta_spec = beta . index_select ( 1 , spec_token_indx )
298+ g_non_spec = g . index_select ( 1 , non_spec_token_indx )
299+ beta_non_spec = beta . index_select ( 1 , non_spec_token_indx )
300300 else :
301301 g_spec = None
302302 beta_spec = None
@@ -404,8 +404,9 @@ def _forward(
404404 dtype = core_attn_out_non_spec .dtype ,
405405 device = core_attn_out_non_spec .device ,
406406 )
407- core_attn_out [:, spec_token_masks ] = core_attn_out_spec
408- core_attn_out [:, ~ spec_token_masks ] = core_attn_out_non_spec
407+ core_attn_out .index_copy_ (1 , spec_token_indx , core_attn_out_spec )
408+ core_attn_out .index_copy_ (1 , non_spec_token_indx ,
409+ core_attn_out_non_spec )
409410 elif spec_sequence_masks is not None :
410411 core_attn_out = core_attn_out_spec
411412 else :
0 commit comments