Skip to content

Commit ec6b0df

Browse files
committed
fix
Signed-off-by: wxsIcey <[email protected]>
1 parent fff044b commit ec6b0df

File tree

5 files changed

+82
-78
lines changed

5 files changed

+82
-78
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -200,34 +200,3 @@ def set_mc2_mask(vllm_config, device):
200200

201201
def get_mc2_mask():
202202
return _reserved_mc2_mask
203-
204-
205-
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
206-
device):
207-
global _cos
208-
global _sin
209-
if _cos is not None:
210-
return
211-
compilation_config = vllm_config.compilation_config
212-
model_config = vllm_config.model_config
213-
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
214-
rope_dim = model_config.hf_text_config.qk_rope_head_dim
215-
_cos = torch.ones(max_num_reqs * decode_token_per_req,
216-
1,
217-
1,
218-
rope_dim,
219-
dtype=dtype,
220-
device=device)
221-
_sin = torch.zeros(max_num_reqs * decode_token_per_req,
222-
1,
223-
1,
224-
rope_dim,
225-
dtype=dtype,
226-
device=device)
227-
else:
228-
_cos = None
229-
_sin = None
230-
231-
232-
def get_cos_and_sin():
233-
return _cos, _sin

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from vllm_ascend import envs
2727
from vllm_ascend.ascend_config import get_ascend_config
28-
from vllm_ascend.ascend_forward_context import get_cos_and_sin
2928
from vllm_ascend.attention.attention_v1 import AscendAttentionState
3029
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
3130
maybe_save_kv_layer_to_connector,
@@ -35,6 +34,7 @@
3534
from vllm_ascend.compilation.acl_graph import (get_graph_params,
3635
get_mtp_graph_params,
3736
update_graph_params_workspaces)
37+
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
3838
from vllm_ascend.ops.shared_weight_layer import (
3939
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
4040
reach_layer_for_shared_weight_series,
@@ -626,7 +626,7 @@ def build(
626626

627627
decode_metadata = None
628628
if num_decodes > 0:
629-
cos, sin = get_cos_and_sin()
629+
cos, sin = get_cos_and_sin_mla()
630630
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
631631
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
632632
1].tolist()

vllm_ascend/attention/sfa_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
from vllm_ascend import envs
1818
from vllm_ascend.ascend_config import get_ascend_config
19-
from vllm_ascend.ascend_forward_context import get_cos_and_sin
2019
from vllm_ascend.attention.attention_v1 import AscendAttentionState
2120
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
2221
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
2322
trans_rope_weight, transdata,
2423
wait_for_kv_layer_from_connector)
24+
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
2525
from vllm_ascend.ops.shared_weight_layer import (
2626
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
2727
reach_layer_for_shared_weight_series,
@@ -187,7 +187,7 @@ def build(
187187
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
188188
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
189189

190-
cos, sin = get_cos_and_sin()
190+
cos, sin = get_cos_and_sin_mla()
191191

192192
assert self.cos_cache is not None and self.sin_cache is not None
193193
new_cos = self.cos_cache[input_positions][:, None, None]

vllm_ascend/ops/rotary_embedding.py

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,82 @@
2020

2121
import torch
2222
import torch_npu
23+
from vllm.config import CUDAGraphMode
2324
from vllm.model_executor.layers.rotary_embedding import (
2425
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
2526
YaRNScalingRotaryEmbedding)
2627

2728
from vllm_ascend.platform import NPUPlatform
2829
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
29-
get_ascend_device_type)
30+
get_ascend_device_type, is_vl_model)
3031

3132
# Currently, rope ops used on npu requires detached cos && sin as inputs.
3233
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
3334
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
3435
# we shall implement a new rope ops which accept cos_sin_cache as inputs.
36+
# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin
37+
# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from
38+
# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by
39+
# attn_metadata. This causes that rope in GQA models must pass cos && sin
40+
# by different approaches.
41+
_cos_mla: Optional[torch.Tensor] = None
42+
_sin_mla: Optional[torch.Tensor] = None
3543
_cos_sin_cache: Optional[torch.Tensor] = None
36-
_cos_cache: Optional[torch.Tensor] = None
37-
_sin_cache: Optional[torch.Tensor] = None
3844
_cos: Optional[torch.Tensor] = None
3945
_sin: Optional[torch.Tensor] = None
46+
_cos_slice: Optional[torch.Tensor] = None
47+
_sin_slice: Optional[torch.Tensor] = None
48+
49+
50+
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
51+
device):
52+
global _cos_mla
53+
global _sin_mla
54+
global _cos
55+
global _sin
56+
57+
if _cos_mla is not None or \
58+
_sin_mla is not None or \
59+
_cos is not None or \
60+
_sin is not None:
61+
return
62+
63+
compilation_config = vllm_config.compilation_config
64+
model_config = vllm_config.model_config
65+
head_dim = model_config.get_head_size()
66+
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
67+
68+
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
69+
rope_dim = model_config.hf_text_config.qk_rope_head_dim
70+
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
71+
1,
72+
1,
73+
rope_dim,
74+
dtype=dtype,
75+
device=device)
76+
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
77+
1,
78+
1,
79+
rope_dim,
80+
dtype=dtype,
81+
device=device)
82+
elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla:
83+
_cos = torch.ones(1,
84+
max_num_batched_tokens,
85+
1,
86+
head_dim,
87+
dtype=dtype,
88+
device=device)
89+
_sin = torch.zeros(1,
90+
max_num_batched_tokens,
91+
1,
92+
head_dim,
93+
dtype=dtype,
94+
device=device)
95+
96+
97+
def get_cos_and_sin_mla():
98+
return _cos_mla, _sin_mla
4099

41100

42101
def _record_cos_sin_cache(cos_sin_cache):
@@ -46,50 +105,28 @@ def _record_cos_sin_cache(cos_sin_cache):
46105
_cos_sin_cache = cos_sin_cache
47106

48107

49-
def initialize_cos_sin(vllm_config, dtype, device):
50-
global _cos_cache
51-
global _sin_cache
52-
53-
head_dim = vllm_config.model_config.get_head_size()
54-
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
55-
_cos_cache = torch.ones(1,
56-
max_num_batched_tokens,
57-
1,
58-
head_dim,
59-
dtype=dtype,
60-
device=device)
61-
_sin_cache = torch.zeros(1,
62-
max_num_batched_tokens,
63-
1,
64-
head_dim,
65-
dtype=dtype,
66-
device=device)
67-
68-
69108
def update_cos_sin(positions):
70-
global _cos_cache
71-
global _sin_cache
72109
global _cos
73110
global _sin
111+
global _cos_slice
112+
global _sin_slice
74113

75114
if _cos_sin_cache is None or \
76-
_cos_cache is None or \
77-
_sin_cache is None:
115+
_cos is None or \
116+
_sin is None:
78117
return
79118

80119
num_tokens = positions.size(0)
81-
_cos_cache[:, :num_tokens] = _cos_sin_cache.index_select(
82-
0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2,
83-
dim=-2)[0]
84-
_sin_cache[:, :num_tokens] = _cos_sin_cache.index_select(
85-
0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2,
86-
dim=-2)[1]
87-
_cos = _cos_cache[:, :num_tokens]
88-
_sin = _sin_cache[:, :num_tokens]
120+
_cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
121+
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0]
122+
_sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view(
123+
num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1]
124+
_cos_slice = _cos[:, :num_tokens]
125+
_sin_slice = _sin[:, :num_tokens]
89126

90127

91-
def get_cos_sin():
92-
return _cos, _sin
128+
def get_cos_and_sin_slice():
129+
return _cos_slice, _sin_slice
93130

94131

95132
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
@@ -127,7 +164,7 @@ def _rope_forward_oot(
127164
raise NotImplementedError(
128165
"Batched rotary embedding is currently not supported on NPU.")
129166
else:
130-
cos, sin = get_cos_sin()
167+
cos, sin = get_cos_and_sin_slice()
131168
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
132169
-1] == 128 and cos is not None and sin is not None:
133170
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
from vllm_ascend.ascend_forward_context import (MoECommType,
8686
get_mc2_tokens_capacity,
8787
set_ascend_forward_context,
88-
set_cos_and_sin, set_mc2_mask,
88+
set_mc2_mask,
8989
set_mc2_tokens_capacity)
9090
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
9191
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -108,7 +108,8 @@
108108
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
109109
from vllm_ascend.eplb.eplb_updator import EplbUpdator
110110
from vllm_ascend.eplb.utils import model_register
111-
from vllm_ascend.ops.rotary_embedding import initialize_cos_sin, update_cos_sin
111+
from vllm_ascend.ops.rotary_embedding import (initialize_cos_sin,
112+
set_cos_and_sin, update_cos_sin)
112113
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
113114
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
114115
from vllm_ascend.platform import NPUPlatform
@@ -270,9 +271,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
270271

271272
set_cos_and_sin(vllm_config, self.max_num_reqs,
272273
self.uniform_decode_query_len, self.dtype, self.device)
273-
if not is_vl_model(self.vllm_config
274-
) and not self.vllm_config.model_config.use_mla:
275-
initialize_cos_sin(self.vllm_config, self.dtype, self.device)
276274
set_mc2_tokens_capacity(vllm_config, self.max_num_reqs,
277275
self.uniform_decode_query_len)
278276
set_mc2_mask(vllm_config, self.device)

0 commit comments

Comments
 (0)