Skip to content

Commit a41460e

Browse files
committed
fix eagle spec decode
Signed-off-by: wxsIcey <[email protected]>
1 parent 59f15a7 commit a41460e

File tree

5 files changed

+99
-12
lines changed

5 files changed

+99
-12
lines changed

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __init__(self,
211211
self.enable_quantization_fusion = enable_quantization_fusion
212212
self.fuse_qknorm_rope = fuse_qknorm_rope
213213

214+
214215
class XliteGraphConfig:
215216
"""
216217
Configuration Object for xlite_graph_config from additional_config

vllm_ascend/ops/rotary_embedding.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,69 @@
2929
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
3030
get_ascend_device_type)
3131

32+
# Currently, rope ops used on npu requires detached cos && sin as inputs.
33+
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
34+
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
35+
# we shall implement a new rope ops which accept cos_sin_cache as inputs.
36+
_cos_sin_cache: Optional[torch.Tensor] = None
37+
_cos_cache: Optional[torch.Tensor] = None
38+
_sin_cache: Optional[torch.Tensor] = None
39+
_cos: Optional[torch.Tensor] = None
40+
_sin: Optional[torch.Tensor] = None
41+
42+
43+
def _record_cos_sin_cache(cos_sin_cache):
44+
global _cos_sin_cache
45+
if _cos_sin_cache is not None:
46+
return
47+
_cos_sin_cache = cos_sin_cache
48+
49+
50+
def initialize_cos_sin(vllm_config, dtype, device):
51+
global _cos_cache
52+
global _sin_cache
53+
54+
head_dim = vllm_config.model_config.get_head_size()
55+
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
56+
_cos_cache = torch.ones(1,
57+
max_num_batched_tokens,
58+
1,
59+
head_dim,
60+
dtype=dtype,
61+
device=device)
62+
_sin_cache = torch.zeros(1,
63+
max_num_batched_tokens,
64+
1,
65+
head_dim,
66+
dtype=dtype,
67+
device=device)
68+
69+
70+
def update_cos_sin(positions):
71+
global _cos_cache
72+
global _sin_cache
73+
global _cos
74+
global _sin
75+
76+
if _cos_sin_cache is None or \
77+
_cos_cache is None or \
78+
_sin_cache is None:
79+
return
80+
81+
num_tokens = positions.size(0)
82+
_cos_cache[:, :num_tokens] = _cos_sin_cache.index_select(
83+
0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2,
84+
dim=-2)[0]
85+
_sin_cache[:, :num_tokens] = _cos_sin_cache.index_select(
86+
0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2,
87+
dim=-2)[1]
88+
_cos = _cos_cache[:, :num_tokens]
89+
_sin = _sin_cache[:, :num_tokens]
90+
91+
92+
def get_cos_sin():
93+
return _cos, _sin
94+
3295

3396
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
3497
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
@@ -65,8 +128,9 @@ def _rope_forward_oot(
65128
raise NotImplementedError(
66129
"Batched rotary embedding is currently not supported on NPU.")
67130
else:
68-
if hasattr(self, "cos") and hasattr(self, "sin") and \
69-
self.cos is not None and self.sin is not None:
131+
cos, sin = get_cos_sin()
132+
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
133+
-1] == 128 and cos is not None and sin is not None:
70134
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
71135
# This method requires head_size and rotary_dim equal 128 and neox_style is True
72136
query = query.contiguous().view(1, query.shape[0], -1,
@@ -75,7 +139,7 @@ def _rope_forward_oot(
75139
# Although this function modifies in-place, please retain the function's return value.
76140
# Otherwise, the graph fusion operation may fail.
77141
query, key = torch_npu.npu_apply_rotary_pos_emb(
78-
query, key, forward_context.cos, forward_context.sin)
142+
query, key, cos, sin)
79143
elif self.rotary_dim < self.head_size:
80144
num_tokens = query.shape[0]
81145
query = query.view(num_tokens, -1, self.head_size)
@@ -125,10 +189,9 @@ def __init__(
125189
is_neox_style: bool,
126190
dtype: torch.dtype,
127191
) -> None:
128-
self.cos = None
129-
self.sin = None
130192
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
131193
is_neox_style, dtype)
194+
_record_cos_sin_cache(self.cos_sin_cache)
132195

133196
def forward_oot(
134197
self,
@@ -162,8 +225,6 @@ def __init__(
162225
beta_fast: int = 32,
163226
beta_slow: int = 1,
164227
) -> None:
165-
self.cos = None
166-
self.sin = None
167228
extra_kwargs = {
168229
"extrapolation_factor": extrapolation_factor,
169230
"attn_factor": attn_factor,
@@ -172,6 +233,7 @@ def __init__(
172233
}
173234
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
174235
is_neox_style, scaling_factor, dtype, **extra_kwargs)
236+
_record_cos_sin_cache(self.cos_sin_cache)
175237

176238
def forward_oot(
177239
self,

vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020
import triton
21-
import triton.language as tl # type: ignore
21+
import triton.language as tl # type: ignore
2222
import triton.runtime.driver as driver
2323
from vllm.utils.torch_utils import direct_register_custom_op
2424

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
2424
AscendMetadata)
2525
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
26+
from vllm_ascend.ops.rotary_embedding import update_cos_sin
2627
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
2728

2829
PADDING_SLOT_ID = -1
@@ -124,13 +125,16 @@ def dummy_run(self,
124125
batch_descriptor=None,
125126
dummy_compute_logits=lambda hidden_states: None):
126127
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
128+
positions = self.positions[:num_tokens]
129+
# update global cos, sin
130+
update_cos_sin(positions)
127131
with set_ascend_forward_context(None,
128132
self.vllm_config,
129133
moe_comm_type=moe_comm_type,
130134
num_tokens=num_tokens):
131135
self.model(
132136
input_ids=self.input_ids[:num_tokens],
133-
positions=self.positions[:num_tokens],
137+
positions=positions,
134138
hidden_states=self.hidden_states[:num_tokens],
135139
)
136140
dummy_compute_logits(self.hidden_states)
@@ -464,13 +468,18 @@ def _propose(
464468
self.positions[:num_tokens] = target_positions.to(device)
465469
self.hidden_states[:num_tokens] = target_hidden_states
466470
attn_metadata.block_tables = block_table.to(device)
471+
472+
positions = self.positions[:num_input_tokens]
473+
# update global cos, sin
474+
update_cos_sin(positions)
475+
467476
with set_ascend_forward_context(attn_metadata,
468477
self.vllm_config,
469478
moe_comm_type=moe_comm_type,
470479
num_tokens=num_input_tokens):
471480
last_hidden_states, hidden_states = self.model(
472481
input_ids=self.input_ids[:num_input_tokens],
473-
positions=self.positions[:num_input_tokens],
482+
positions=positions,
474483
hidden_states=self.hidden_states[:num_input_tokens],
475484
)
476485
sample_hidden_states = last_hidden_states[last_token_indices]
@@ -573,14 +582,19 @@ def _propose(
573582
attn_metadata.attn_mask = attn_mask
574583
attn_metadata.block_tables = block_table.to(device)
575584
# Run the model.
585+
586+
positions = self.positions[:input_batch_size]
587+
# update global cos, sin
588+
update_cos_sin(positions)
589+
576590
with set_ascend_forward_context(attn_metadata,
577591
self.vllm_config,
578592
moe_comm_type=moe_comm_type,
579593
num_tokens=input_batch_size):
580594

581595
last_hidden_states, hidden_states = self.model(
582596
input_ids=self.input_ids[:input_batch_size],
583-
positions=self.positions[:input_batch_size],
597+
positions=positions,
584598
hidden_states=self.hidden_states[:input_batch_size],
585599
)
586600
hidden_states = hidden_states[:batch_size]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
137137
from vllm_ascend.eplb.eplb_updator import EplbUpdator
138138
from vllm_ascend.eplb.utils import model_register
139+
from vllm_ascend.ops.rotary_embedding import initialize_cos_sin, update_cos_sin
139140
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
140141
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
141142
from vllm_ascend.platform import NPUPlatform
@@ -149,7 +150,7 @@
149150
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
150151
AscendDeviceType, ProfileExecuteDuration,
151152
enable_sp, get_ascend_device_type, is_enable_nz,
152-
is_moe_model, lmhead_tp_enable)
153+
is_moe_model, is_vl_model, lmhead_tp_enable)
153154
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
154155

155156
if TYPE_CHECKING:
@@ -434,6 +435,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
434435
else:
435436
self.cos = None
436437
self.sin = None
438+
if not is_vl_model(self.vllm_config
439+
) and not self.vllm_config.model_config.use_mla:
440+
initialize_cos_sin(self.vllm_config, self.dtype, self.device)
437441

438442
self.uses_mrope = self.model_config.uses_mrope
439443
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -2025,6 +2029,9 @@ def _prepare_inputs(
20252029
for layer_name in attn_group.layer_names:
20262030
attn_metadata[layer_name] = attn_metadata_i
20272031

2032+
# update global cos, sin
2033+
update_cos_sin(positions)
2034+
20282035
if lmhead_tp_enable():
20292036
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
20302037
logits_indices = nn.functional.pad(
@@ -3224,6 +3231,9 @@ def _dummy_run(
32243231
else:
32253232
positions = self.positions[:num_tokens_padded]
32263233

3234+
# update global cos, sin
3235+
update_cos_sin(positions)
3236+
32273237
if get_pp_group().is_first_rank:
32283238
intermediate_tensors = None
32293239
else:

0 commit comments

Comments
 (0)