2929from vllm_ascend .utils import (AscendDeviceType , enable_custom_op ,
3030 get_ascend_device_type )
3131
32+ # Currently, rope ops used on npu requires detached cos && sin as inputs.
33+ # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
34+ # So we have to preprocess cos_sin_cache int cos && sin. In the future,
35+ # we shall implement a new rope ops which accept cos_sin_cache as inputs.
36+ _cos_sin_cache : Optional [torch .Tensor ] = None
37+ _cos_cache : Optional [torch .Tensor ] = None
38+ _sin_cache : Optional [torch .Tensor ] = None
39+ _cos : Optional [torch .Tensor ] = None
40+ _sin : Optional [torch .Tensor ] = None
41+
42+
43+ def _record_cos_sin_cache (cos_sin_cache ):
44+ global _cos_sin_cache
45+ if _cos_sin_cache is not None :
46+ return
47+ _cos_sin_cache = cos_sin_cache
48+
49+
50+ def initialize_cos_sin (vllm_config , dtype , device ):
51+ global _cos_cache
52+ global _sin_cache
53+
54+ head_dim = vllm_config .model_config .get_head_size ()
55+ max_num_batched_tokens = vllm_config .scheduler_config .max_num_batched_tokens
56+ _cos_cache = torch .ones (1 ,
57+ max_num_batched_tokens ,
58+ 1 ,
59+ head_dim ,
60+ dtype = dtype ,
61+ device = device )
62+ _sin_cache = torch .zeros (1 ,
63+ max_num_batched_tokens ,
64+ 1 ,
65+ head_dim ,
66+ dtype = dtype ,
67+ device = device )
68+
69+
70+ def update_cos_sin (positions ):
71+ global _cos_cache
72+ global _sin_cache
73+ global _cos
74+ global _sin
75+
76+ if _cos_sin_cache is None or \
77+ _cos_cache is None or \
78+ _sin_cache is None :
79+ return
80+
81+ num_tokens = positions .size (0 )
82+ _cos_cache [:, :num_tokens ] = _cos_sin_cache .index_select (
83+ 0 , positions ).view (num_tokens , 2 , - 1 ).repeat (1 , 1 , 2 ).chunk (2 ,
84+ dim = - 2 )[0 ]
85+ _sin_cache [:, :num_tokens ] = _cos_sin_cache .index_select (
86+ 0 , positions ).view (num_tokens , 2 , - 1 ).repeat (1 , 1 , 2 ).chunk (2 ,
87+ dim = - 2 )[1 ]
88+ _cos = _cos_cache [:, :num_tokens ]
89+ _sin = _sin_cache [:, :num_tokens ]
90+
91+
92+ def get_cos_sin ():
93+ return _cos , _sin
94+
3295
3396def _custom_rotary_embedding_enabled (query , neox_style , head_size ):
3497 return query .dtype == torch .float16 and neox_style and head_size % 32 == 0 and enable_custom_op (
@@ -65,8 +128,9 @@ def _rope_forward_oot(
65128 raise NotImplementedError (
66129 "Batched rotary embedding is currently not supported on NPU." )
67130 else :
68- if hasattr (self , "cos" ) and hasattr (self , "sin" ) and \
69- self .cos is not None and self .sin is not None :
131+ cos , sin = get_cos_sin ()
132+ if is_neox_style and self .head_size == 128 and self .cos_sin_cache .shape [
133+ - 1 ] == 128 and cos is not None and sin is not None :
70134 # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
71135 # This method requires head_size and rotary_dim equal 128 and neox_style is True
72136 query = query .contiguous ().view (1 , query .shape [0 ], - 1 ,
@@ -75,7 +139,7 @@ def _rope_forward_oot(
75139 # Although this function modifies in-place, please retain the function's return value.
76140 # Otherwise, the graph fusion operation may fail.
77141 query , key = torch_npu .npu_apply_rotary_pos_emb (
78- query , key , forward_context . cos , forward_context . sin )
142+ query , key , cos , sin )
79143 elif self .rotary_dim < self .head_size :
80144 num_tokens = query .shape [0 ]
81145 query = query .view (num_tokens , - 1 , self .head_size )
@@ -125,10 +189,9 @@ def __init__(
125189 is_neox_style : bool ,
126190 dtype : torch .dtype ,
127191 ) -> None :
128- self .cos = None
129- self .sin = None
130192 super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
131193 is_neox_style , dtype )
194+ _record_cos_sin_cache (self .cos_sin_cache )
132195
133196 def forward_oot (
134197 self ,
@@ -162,8 +225,6 @@ def __init__(
162225 beta_fast : int = 32 ,
163226 beta_slow : int = 1 ,
164227 ) -> None :
165- self .cos = None
166- self .sin = None
167228 extra_kwargs = {
168229 "extrapolation_factor" : extrapolation_factor ,
169230 "attn_factor" : attn_factor ,
@@ -172,6 +233,7 @@ def __init__(
172233 }
173234 super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
174235 is_neox_style , scaling_factor , dtype , ** extra_kwargs )
236+ _record_cos_sin_cache (self .cos_sin_cache )
175237
176238 def forward_oot (
177239 self ,
0 commit comments