Skip to content

Commit 81e9f34

Browse files
committed
[Performance] Improve the inference performance of Eagle3.
vLLM version: v0.11.0 vLLM main: vllm-project/vllm Signed-off-by: liumail202512 <[email protected]>
1 parent 1b137d6 commit 81e9f34

File tree

1 file changed

+50
-80
lines changed

1 file changed

+50
-80
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 50 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020

2121
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
2222
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
23-
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
24-
AscendMetadata)
23+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
2524
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
2625
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
2726

@@ -68,9 +67,6 @@ def __init__(self,
6867
self.hidden_size),
6968
dtype=self.vllm_config.model_config.dtype,
7069
device=device)
71-
self.max_num_tokens = (
72-
vllm_config.scheduler_config.max_num_batched_tokens)
73-
self.token_arange_np = np.arange(self.max_num_tokens)
7470
# We need +1 here because the arange is used to set query_start_loc,
7571
# which has one more element than batch_size.
7672
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
@@ -189,8 +185,10 @@ def generate_token_ids(self,
189185
dtype=torch.int32,
190186
device=self.device,
191187
)
192-
cu_num_tokens, token_indices =\
193-
self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
188+
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
189+
cu_num_tokens, token_indices = self._prepare_inputs(
190+
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
191+
num_tokens)
194192
target_token_ids = self.runner.input_ids[token_indices]
195193
target_positions = positions[token_indices]
196194
if self.name == SpecDcodeType.EAGLE3:
@@ -590,88 +588,60 @@ def _propose(
590588

591589
def _prepare_inputs(
592590
self,
593-
eagle_attn_metadata: AscendMetadata,
591+
# [batch_size + 1]
592+
cu_target_query_lens: torch.Tensor,
594593
# [batch_size]
595594
num_rejected_tokens: torch.Tensor,
595+
num_tokens: int,
596596
) -> tuple[torch.Tensor, torch.Tensor]:
597-
"""
598-
This function is used to prepare the inputs for the spec decode.
599-
It updates to the common_attn_metadata to account for the rejected
600-
tokens (and newly sampled tokens). It also returns the token indices
601-
of the tokens that should be fed to the speculator.
602-
"""
603-
# E.g.
604-
# common_attn_metadata.query_start_loc{_cpu}:
605-
# [0, q1, q1 + q2, q1 + q2 + q3]
606-
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
607-
# num_rejected_tokens: [n1, n2, n3]
608-
# This function computes the intermediate values:
609-
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
610-
# And returns:
611-
# common_attn_metadata.query_start_loc{_cpu}:
612-
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
613-
# common_attn_metadata.seq_lens{_cpu}:
614-
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
615-
# token_indices: [0, 1, ..., q1 - n1 - 1,
616-
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
617-
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
618-
num_rejected_tokens_cpu = num_rejected_tokens.to("cpu")
619-
cu_target_query_lens = eagle_attn_metadata.query_start_loc
620-
device = eagle_attn_metadata.query_start_loc.device
621-
query_start_loc_cpu = cu_target_query_lens.to("cpu")
622-
623-
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
624-
new_query_len_per_req = (query_start_loc_cpu[1:] -
625-
query_start_loc_cpu[:-1])
626-
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
627-
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu
628-
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
629-
630-
# [q1 - n1, q2 - n2, q3 - n3] ->
631-
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
632-
new_query_start_loc_cpu = torch.zeros(
633-
query_start_loc_cpu.shape,
634-
dtype=torch.int32,
635-
pin_memory=is_pin_memory_available())
636-
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
637-
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
638-
639-
total_num_tokens = new_query_start_loc_np[-1]
640-
# Example assuming num_tokens_per_req_np = [2, 4, 3]
641-
# this implies that `new_query_start_locs` is:
642-
# [0, 2, 6, 9] ->
643-
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
644-
# _r1_ ____r2____ ___r3__
645-
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
646-
new_num_tokens_per_req_np)
647-
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
648-
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
649-
# _r1_ ____r2____ ___r3__
650-
token_offests = self.token_arange_np[:total_num_tokens] \
651-
- new_query_start_locs_expanded
652-
653-
# Expand starting positions to match token pattern
654-
# [0, q1, q1 + q2] ->
655-
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
656-
# _r1_ _____r2_______ ___________r3____________
657-
old_query_start_locs_expanded = np.repeat(
658-
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
659-
# Final token indices are:
660-
# [0, 1, // req 1
661-
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
662-
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
663-
token_indices_np = token_offests + old_query_start_locs_expanded
664-
token_indices = torch.from_numpy(token_indices_np).to(
665-
device, non_blocking=True)
666-
667-
# need use npu
597+
# cu_target_query_lens: [0, a, a + b, a + b + c]
598+
# num_rejected_tokens: [n1, n2, n3]
599+
# num_tokens_per_req: [a - n1, b - n2, c - n3]
600+
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
601+
# token_indices: [0, 1, ..., a - n1 - 1,
602+
# a, a + 1, ..., a + b - n2 - 1,
603+
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
604+
605+
# [0, a, a + b, a + b + c] -> [a, b, c]
668606
query_len_per_req = (cu_target_query_lens[1:] -
669607
cu_target_query_lens[:-1])
608+
# [a, b, c] -> [a - n1, b - n2, c - n3]
670609
num_tokens_per_req = query_len_per_req - num_rejected_tokens
671610

672611
# [a - n1, b - n2, c - n3] ->
673612
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
674613
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
675614
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
676-
615+
token_indices = torch.empty(
616+
num_tokens,
617+
dtype=torch.int32,
618+
device=cu_target_query_lens.device,
619+
)
620+
BLOCK_SIZE = 1024
621+
self._prepare_eagle_input_sequential(
622+
token_indices,
623+
cu_target_query_lens,
624+
cu_num_tokens,
625+
block_size=BLOCK_SIZE,
626+
)
677627
return cu_num_tokens, token_indices
628+
629+
def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor,
630+
cu_query_lens: torch.Tensor,
631+
cu_num_tokens: torch.Tensor, block_size: int):
632+
device = cu_query_lens.device
633+
dtype = out_tensor.dtype
634+
635+
offsets = torch.arange(block_size, device=device, dtype=dtype)
636+
start_pos = cu_num_tokens[:-1]
637+
end_pos = cu_num_tokens[1:]
638+
num_tokens = end_pos - start_pos
639+
640+
global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1))
641+
values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1))
642+
643+
mask = (offsets.view(1, -1) < num_tokens.view(-1, 1))
644+
645+
global_indices_flat = global_indices[mask]
646+
values_flat = values[mask]
647+
out_tensor[global_indices_flat] = values_flat

0 commit comments

Comments
 (0)