@@ -418,20 +418,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
418418 rope_dim ,
419419 dtype = self .dtype ,
420420 device = self .device )
421- # For GQA models.
422- elif not self .vllm_config .model_config .use_mla :
423- self .cos = torch .ones (1 ,
424- self .max_num_tokens ,
425- 1 ,
426- 128 ,
427- dtype = self .dtype ,
428- device = self .device )
429- self .sin = torch .zeros (1 ,
430- self .max_num_tokens ,
431- 1 ,
432- 128 ,
433- dtype = self .dtype ,
434- device = self .device )
435421 else :
436422 self .cos = None
437423 self .sin = None
@@ -2530,22 +2516,6 @@ def execute_model(
25302516 aclgraph_runtime_mode , batch_descriptor = \
25312517 self .aclgraph_dispatcher .dispatch (num_tokens = num_input_tokens , uniform_decode = uniform_decode , has_lora = has_lora )
25322518
2533- # initialize rope
2534- cos_sin_cache = self .model .model .layers [
2535- self .model .model .
2536- start_layer ].self_attn .rotary_emb .cos_sin_cache .index_select (
2537- 0 , positions )
2538- last_dim = cos_sin_cache .size ()[- 1 ]
2539- cos , sin = cos_sin_cache .reshape (- 1 , 2 ,
2540- last_dim // 2 ).repeat (1 , 1 ,
2541- 2 ).chunk (2 ,
2542- dim = - 2 )
2543- # BSNH
2544- self .cos [:, :maybe_padded_num_tokens ] = cos .view (
2545- 1 , - 1 , 1 , last_dim ).contiguous ()
2546- self .sin [:, :maybe_padded_num_tokens ] = sin .view (
2547- 1 , - 1 , 1 , last_dim ).contiguous ()
2548-
25492519 # Run forward pass
25502520 with ProfileExecuteDuration ().capture_async ("forward" ):
25512521 with set_ascend_forward_context (
@@ -2562,11 +2532,7 @@ def execute_model(
25622532 total_num_scheduled_tokens ,
25632533 prefetch_stream = self .prefetch_stream ,
25642534 model_instance = self .model ,
2565- weight_prefetch_method = self .weight_prefetch_method ,
2566- cos = self .cos [:, :maybe_padded_num_tokens ]
2567- if self .cos is not None else None ,
2568- sin = self .sin [:, :maybe_padded_num_tokens ]
2569- if self .sin is not None else None ):
2535+ weight_prefetch_method = self .weight_prefetch_method ):
25702536 self .maybe_setup_kv_connector (scheduler_output )
25712537
25722538 hidden_states = self ._generate_process_reqs_hidden_states (
@@ -3274,20 +3240,6 @@ def dummy_drafter_compute_logits(hidden_states):
32743240 return self .drafter .model .compute_logits (
32753241 hidden_states [dummy_indices ])
32763242
3277- # initialize rope
3278- cos_sin_cache = self .model .model .layers [
3279- self .model .model .
3280- start_layer ].self_attn .rotary_emb .cos_sin_cache .index_select (
3281- 0 , positions )
3282- last_dim = cos_sin_cache .size ()[- 1 ]
3283- cos , sin = cos_sin_cache .reshape (- 1 , 2 , last_dim // 2 ).repeat (
3284- 1 , 1 , 2 ).chunk (2 , dim = - 2 )
3285- # BSNH
3286- self .cos [:, :num_tokens ] = cos .view (1 , - 1 , 1 ,
3287- last_dim ).contiguous ()
3288- self .sin [:, :num_tokens ] = sin .view (1 , - 1 , 1 ,
3289- last_dim ).contiguous ()
3290-
32913243 with set_ascend_forward_context (
32923244 attn_metadata ,
32933245 self .vllm_config ,
@@ -3302,11 +3254,7 @@ def dummy_drafter_compute_logits(hidden_states):
33023254 batch_descriptor = batch_descriptor ,
33033255 prefetch_stream = self .prefetch_stream ,
33043256 model_instance = self .model ,
3305- weight_prefetch_method = self .weight_prefetch_method ,
3306- cos = self .cos [:, :num_tokens ]
3307- if self .cos is not None else None ,
3308- sin = self .sin [:, :num_tokens ]
3309- if self .sin is not None else None ):
3257+ weight_prefetch_method = self .weight_prefetch_method ):
33103258 hidden_states = self ._generate_dummy_run_hidden_states (
33113259 with_prefill , input_ids , positions , attn_metadata ,
33123260 num_tokens_padded , intermediate_tensors , inputs_embeds )
0 commit comments