2020
2121import torch
2222import torch_npu
23+ from vllm .config import CUDAGraphMode
2324from vllm .model_executor .layers .rotary_embedding import (
2425 DeepseekScalingRotaryEmbedding , MRotaryEmbedding , RotaryEmbedding ,
2526 YaRNScalingRotaryEmbedding )
2627
2728from vllm_ascend .platform import NPUPlatform
2829from vllm_ascend .utils import (AscendDeviceType , enable_custom_op ,
29- get_ascend_device_type )
30+ get_ascend_device_type , is_vl_model )
3031
3132# Currently, rope ops used on npu requires detached cos && sin as inputs.
3233# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
3334# So we have to preprocess cos_sin_cache int cos && sin. In the future,
3435# we shall implement a new rope ops which accept cos_sin_cache as inputs.
36+ # NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin
37+ # to rope in AscendMLA(SFA)Impl. However, since rope is isolated from
38+ # AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
39+ # attn_metadata. This causes that rope in GQA models must pass cos && sin
40+ # by different approaches.
41+ _cos_mla : Optional [torch .Tensor ] = None
42+ _sin_mla : Optional [torch .Tensor ] = None
3543_cos_sin_cache : Optional [torch .Tensor ] = None
36- _cos_cache : Optional [torch .Tensor ] = None
37- _sin_cache : Optional [torch .Tensor ] = None
3844_cos : Optional [torch .Tensor ] = None
3945_sin : Optional [torch .Tensor ] = None
46+ _cos_slice : Optional [torch .Tensor ] = None
47+ _sin_slice : Optional [torch .Tensor ] = None
48+
49+
50+ def set_cos_and_sin (vllm_config , max_num_reqs , decode_token_per_req , dtype ,
51+ device ):
52+ global _cos_mla
53+ global _sin_mla
54+ global _cos
55+ global _sin
56+
57+ if _cos_mla is not None or \
58+ _sin_mla is not None or \
59+ _cos is not None or \
60+ _sin is not None :
61+ return
62+
63+ compilation_config = vllm_config .compilation_config
64+ model_config = vllm_config .model_config
65+ head_dim = model_config .get_head_size ()
66+ max_num_batched_tokens = vllm_config .scheduler_config .max_num_batched_tokens
67+
68+ if model_config .use_mla and compilation_config .cudagraph_mode == CUDAGraphMode .FULL_DECODE_ONLY :
69+ rope_dim = model_config .hf_text_config .qk_rope_head_dim
70+ _cos_mla = torch .ones (max_num_reqs * decode_token_per_req ,
71+ 1 ,
72+ 1 ,
73+ rope_dim ,
74+ dtype = dtype ,
75+ device = device )
76+ _sin_mla = torch .zeros (max_num_reqs * decode_token_per_req ,
77+ 1 ,
78+ 1 ,
79+ rope_dim ,
80+ dtype = dtype ,
81+ device = device )
82+ elif not is_vl_model (vllm_config ) and not vllm_config .model_config .use_mla :
83+ _cos = torch .ones (1 ,
84+ max_num_batched_tokens ,
85+ 1 ,
86+ head_dim ,
87+ dtype = dtype ,
88+ device = device )
89+ _sin = torch .zeros (1 ,
90+ max_num_batched_tokens ,
91+ 1 ,
92+ head_dim ,
93+ dtype = dtype ,
94+ device = device )
95+
96+
97+ def get_cos_and_sin_mla ():
98+ return _cos_mla , _sin_mla
4099
41100
42101def _record_cos_sin_cache (cos_sin_cache ):
@@ -46,50 +105,28 @@ def _record_cos_sin_cache(cos_sin_cache):
46105 _cos_sin_cache = cos_sin_cache
47106
48107
49- def initialize_cos_sin (vllm_config , dtype , device ):
50- global _cos_cache
51- global _sin_cache
52-
53- head_dim = vllm_config .model_config .get_head_size ()
54- max_num_batched_tokens = vllm_config .scheduler_config .max_num_batched_tokens
55- _cos_cache = torch .ones (1 ,
56- max_num_batched_tokens ,
57- 1 ,
58- head_dim ,
59- dtype = dtype ,
60- device = device )
61- _sin_cache = torch .zeros (1 ,
62- max_num_batched_tokens ,
63- 1 ,
64- head_dim ,
65- dtype = dtype ,
66- device = device )
67-
68-
69108def update_cos_sin (positions ):
70- global _cos_cache
71- global _sin_cache
72109 global _cos
73110 global _sin
111+ global _cos_slice
112+ global _sin_slice
74113
75114 if _cos_sin_cache is None or \
76- _cos_cache is None or \
77- _sin_cache is None :
115+ _cos is None or \
116+ _sin is None :
78117 return
79118
80119 num_tokens = positions .size (0 )
81- _cos_cache [:, :num_tokens ] = _cos_sin_cache .index_select (
82- 0 , positions ).view (num_tokens , 2 , - 1 ).repeat (1 , 1 , 2 ).chunk (2 ,
83- dim = - 2 )[0 ]
84- _sin_cache [:, :num_tokens ] = _cos_sin_cache .index_select (
85- 0 , positions ).view (num_tokens , 2 , - 1 ).repeat (1 , 1 , 2 ).chunk (2 ,
86- dim = - 2 )[1 ]
87- _cos = _cos_cache [:, :num_tokens ]
88- _sin = _sin_cache [:, :num_tokens ]
120+ _cos [:, :num_tokens ] = _cos_sin_cache .index_select (0 , positions ).view (
121+ num_tokens , 2 , - 1 ).repeat (1 , 1 , 2 ).chunk (2 , dim = - 2 )[0 ]
122+ _sin [:, :num_tokens ] = _cos_sin_cache .index_select (0 , positions ).view (
123+ num_tokens , 2 , - 1 ).repeat (1 , 1 , 2 ).chunk (2 , dim = - 2 )[1 ]
124+ _cos_slice = _cos [:, :num_tokens ]
125+ _sin_slice = _sin [:, :num_tokens ]
89126
90127
91- def get_cos_sin ():
92- return _cos , _sin
128+ def get_cos_and_sin_slice ():
129+ return _cos_slice , _sin_slice
93130
94131
95132def _custom_rotary_embedding_enabled (query , neox_style , head_size ):
@@ -127,7 +164,7 @@ def _rope_forward_oot(
127164 raise NotImplementedError (
128165 "Batched rotary embedding is currently not supported on NPU." )
129166 else :
130- cos , sin = get_cos_sin ()
167+ cos , sin = get_cos_and_sin_slice ()
131168 if is_neox_style and self .head_size == 128 and self .cos_sin_cache .shape [
132169 - 1 ] == 128 and cos is not None and sin is not None :
133170 # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
0 commit comments