Skip to content

Commit cb42564

Browse files
[BugFix] Fix eagle3 accuracy problem when enforce_eager=True (#4521)
### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? def main(): prompts = [ "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. llm = LLM( model="meta-llama/Llama-3.1-8B-Instruct", tensor_parallel_size=1, speculative_config={ "method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" "num_speculative_tokens": 3 }, enforce_eager=True, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) print(f"Outputs: {outputs}") for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: zhaomingyu <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent 3480094 commit cb42564

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_eagle_correctness(
110110
Compare the outputs of a original LLM and a speculative LLM
111111
should be the same when using eagle speculative decoding.
112112
'''
113-
pytest.skip("exist OOM error")
113+
pytest.skip("To be aligned with GPU")
114114
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
115115
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
116116
del ref_llm

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self,
7979
dtype=torch.int32)
8080
attn_mask_len = self.vllm_config.model_config.max_model_len
8181
self.attn_mask_builder = AttentionMaskBuilder(
82-
attn_mask_len, self.vllm_config.model_config.dtype)
82+
attn_mask_len, self.vllm_config.model_config.dtype, device=device)
8383

8484
def load_model(self, model: nn.Module) -> None:
8585
target_attn_layer_names = set(
@@ -430,9 +430,7 @@ def _propose(
430430

431431
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
432432
max_query_len = query_lens.max().item()
433-
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
434-
seq_lens, target_positions, self.vllm_config.model_config.dtype,
435-
self.device)
433+
attn_mask = self.runner.attn_mask
436434

437435
common_attn_metadata = AscendCommonAttentionMetadata(
438436
query_start_loc=cu_num_tokens.to(device),
@@ -507,9 +505,15 @@ def _propose(
507505
attn_metadata.num_actual_tokens = batch_size
508506
attn_metadata.max_query_len = 1
509507
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
508+
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
509+
1:].tolist()
510+
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
511+
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
510512
query_lens.fill_(1)
511513
attn_metadata.query_lens = query_lens
512514

515+
attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)]
516+
attn_metadata.seq_lens_list = seq_lens.tolist()
513517
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
514518
for now_speculative in range(
515519
self.vllm_config.speculative_config.num_speculative_tokens -
@@ -536,6 +540,9 @@ def _propose(
536540
# TODO: Increment the sequence lengths.
537541

538542
attn_metadata.seq_lens += 1
543+
attn_metadata.seq_lens_list = [
544+
_ + 1 for _ in attn_metadata.seq_lens_list
545+
]
539546
# TODO: Consider max model length.
540547
# attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
541548
# self.max_model_len)

0 commit comments

Comments
 (0)