@@ -415,20 +415,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
415415 rope_dim ,
416416 dtype = self .dtype ,
417417 device = self .device )
418- # For GQA models.
419- elif not self .vllm_config .model_config .use_mla :
420- self .cos = torch .ones (1 ,
421- self .max_num_tokens ,
422- 1 ,
423- 128 ,
424- dtype = self .dtype ,
425- device = self .device )
426- self .sin = torch .zeros (1 ,
427- self .max_num_tokens ,
428- 1 ,
429- 128 ,
430- dtype = self .dtype ,
431- device = self .device )
432418 else :
433419 self .cos = None
434420 self .sin = None
@@ -2508,22 +2494,6 @@ def execute_model(
25082494 aclgraph_runtime_mode , batch_descriptor = \
25092495 self .aclgraph_dispatcher .dispatch (num_tokens = num_input_tokens , uniform_decode = uniform_decode , has_lora = has_lora )
25102496
2511- # initialize rope
2512- cos_sin_cache = self .model .model .layers [
2513- self .model .model .
2514- start_layer ].self_attn .rotary_emb .cos_sin_cache .index_select (
2515- 0 , positions )
2516- last_dim = cos_sin_cache .size ()[- 1 ]
2517- cos , sin = cos_sin_cache .reshape (- 1 , 2 ,
2518- last_dim // 2 ).repeat (1 , 1 ,
2519- 2 ).chunk (2 ,
2520- dim = - 2 )
2521- # BSNH
2522- self .cos [:, :maybe_padded_num_tokens ] = cos .view (
2523- 1 , - 1 , 1 , last_dim ).contiguous ()
2524- self .sin [:, :maybe_padded_num_tokens ] = sin .view (
2525- 1 , - 1 , 1 , last_dim ).contiguous ()
2526-
25272497 # Run forward pass
25282498 with ProfileExecuteDuration ().capture_async ("forward" ):
25292499 with set_ascend_forward_context (
@@ -2540,11 +2510,7 @@ def execute_model(
25402510 total_num_scheduled_tokens ,
25412511 prefetch_stream = self .prefetch_stream ,
25422512 model_instance = self .model ,
2543- weight_prefetch_method = self .weight_prefetch_method ,
2544- cos = self .cos [:, :maybe_padded_num_tokens ]
2545- if self .cos is not None else None ,
2546- sin = self .sin [:, :maybe_padded_num_tokens ]
2547- if self .sin is not None else None ):
2513+ weight_prefetch_method = self .weight_prefetch_method ):
25482514 self .maybe_setup_kv_connector (scheduler_output )
25492515
25502516 hidden_states = self ._generate_process_reqs_hidden_states (
@@ -3252,20 +3218,6 @@ def dummy_drafter_compute_logits(hidden_states):
32523218 return self .drafter .model .compute_logits (
32533219 hidden_states [dummy_indices ])
32543220
3255- # initialize rope
3256- cos_sin_cache = self .model .model .layers [
3257- self .model .model .
3258- start_layer ].self_attn .rotary_emb .cos_sin_cache .index_select (
3259- 0 , positions )
3260- last_dim = cos_sin_cache .size ()[- 1 ]
3261- cos , sin = cos_sin_cache .reshape (- 1 , 2 , last_dim // 2 ).repeat (
3262- 1 , 1 , 2 ).chunk (2 , dim = - 2 )
3263- # BSNH
3264- self .cos [:, :num_tokens ] = cos .view (1 , - 1 , 1 ,
3265- last_dim ).contiguous ()
3266- self .sin [:, :num_tokens ] = sin .view (1 , - 1 , 1 ,
3267- last_dim ).contiguous ()
3268-
32693221 with set_ascend_forward_context (
32703222 attn_metadata ,
32713223 self .vllm_config ,
@@ -3280,11 +3232,7 @@ def dummy_drafter_compute_logits(hidden_states):
32803232 batch_descriptor = batch_descriptor ,
32813233 prefetch_stream = self .prefetch_stream ,
32823234 model_instance = self .model ,
3283- weight_prefetch_method = self .weight_prefetch_method ,
3284- cos = self .cos [:, :num_tokens ]
3285- if self .cos is not None else None ,
3286- sin = self .sin [:, :num_tokens ]
3287- if self .sin is not None else None ):
3235+ weight_prefetch_method = self .weight_prefetch_method ):
32883236 hidden_states = self ._generate_dummy_run_hidden_states (
32893237 input_ids , positions , num_tokens_padded ,
32903238 intermediate_tensors , inputs_embeds )
0 commit comments