Skip to content

Conversation

@wwl2755
Copy link
Contributor

@wwl2755 wwl2755 commented Sep 3, 2025

Purpose

Add qwen2-eagle (and qwen2.5-eagle) support.

According to Slack, Qwen2 is not supported yet because of different number of kv heads in the target and draft model. This PR addresses this by making the kv cache space aligns with the larger one (though it can cause some waste).

Another thing this PR addresses is that some eagle model places "eagle" at the end of its architecture (e.g. Qwen2ForCausalLMEagle), making it hard to be recognized by the registry.

This PR includes an important fix for test_spec_decode.py by @ekagra-ranjan in #23461. I can either wait that PR to be merged first or add @ekagra-ranjan as co-author in respect of his fixing. #23461 (#24257) is merged

cc: @WoosukKwon @LiuXiaoxuanPKU @mgoin

Test

python examples/offline_inference/spec_decode.py --method eagle --model-dir Qwen/Qwen2-7B-Instruct --eagle-dir yuhuili/EAGLE-Qwen2-7B-Instruct --num-prompts 10

python examples/offline_inference/spec_decode.py --method eagle --model-dir Qwen/Qwen2.5-14B-Instruct --eagle-dir Zjcxy-SmartAI/Eagle-Qwen2.5-14B-Instruct --num-prompts 10

VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/test_initialization.py -k EagleQwen2ForCausalLM

pytest -v -s tests/v1/e2e/test_spec_decode.py -k qwen2_eagle

Test Result

Detailed results
~/vllm$ python examples/offline_inference/spec_decode.py --method eagle --model-dir Qwen/Qwen2.5-14B-Instruct --eagle-dir Zjcxy-SmartAI/Eagle-Qwen2.5-14B-Instruct --num-prompts 10
INFO 09-03 06:03:45 [__init__.py:241] Automatically detected platform cuda.
INFO 09-03 06:03:47 [datasets.py:509] Sampling input_len from [1024, 1024] and output_len from [128, 128]
INFO 09-03 06:03:47 [utils.py:328] non-default args: {'trust_remote_code': True, 'max_model_len': 8192, 'gpu_memory_utilization': 0.8, 'enforce_eager': True, 'limit_mm_per_prompt': {'image': 5}, 'enable_chunked_prefill': True, 'disable_chunked_mm_input': True, 'speculative_config': {'method': 'eagle', 'model': 'Zjcxy-SmartAI/Eagle-Qwen2.5-14B-Instruct', 'num_speculative_tokens': 2}, 'model': 'Qwen/Qwen2.5-14B-Instruct'}
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
INFO 09-03 06:03:54 [__init__.py:771] Resolved architecture: Qwen2ForCausalLM
INFO 09-03 06:03:54 [__init__.py:1800] Using max model len 8192
INFO 09-03 06:03:55 [__init__.py:665] Normalized EAGLE architectures: ['Qwen2ForCausalLMEagle'] -> ['EagleQwen2ForCausalLM']
INFO 09-03 06:04:02 [__init__.py:771] Resolved architecture: EagleQwen2ForCausalLM
INFO 09-03 06:04:02 [__init__.py:1800] Using max model len 32768
INFO 09-03 06:04:02 [scheduler.py:222] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 09-03 06:04:02 [__init__.py:3699] Cudagraph is disabled under eager mode
WARNING 09-03 06:04:02 [__init__.py:2959] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reasons: CUDA is initialized
INFO 09-03 06:04:07 [__init__.py:241] Automatically detected platform cuda.
(EngineCore_0 pid=3559222) INFO 09-03 06:04:09 [core.py:648] Waiting for init message from front-end.
(EngineCore_0 pid=3559222) INFO 09-03 06:04:09 [core.py:75] Initializing a V1 LLM engine (v0.1.dev9107+g862f2ef89) with config: model='Qwen/Qwen2.5-14B-Instruct', speculative_config=SpeculativeConfig(method='eagle', model='Zjcxy-SmartAI/Eagle-Qwen2.5-14B-Instruct', num_spec_tokens=2), tokenizer='Qwen/Qwen2.5-14B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen2.5-14B-Instruct, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":null,"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":0,"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"pass_config":{},"max_capture_size":0,"local_cache_dir":null}
(EngineCore_0 pid=3559222) W0903 06:04:09.535000 3559222 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(EngineCore_0 pid=3559222) W0903 06:04:09.535000 3559222 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
(EngineCore_0 pid=3559222) 2025-09-03 06:04:09,537 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[W903 06:04:10.912594994 ProcessGroupNCCL.cpp:981] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_0 pid=3559222) INFO 09-03 06:04:11 [parallel_state.py:1134] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_0 pid=3559222) INFO 09-03 06:04:11 [topk_topp_sampler.py:58] Using FlashInfer for top-p & top-k sampling.
(EngineCore_0 pid=3559222) INFO 09-03 06:04:11 [gpu_model_runner.py:1928] Starting to load model Qwen/Qwen2.5-14B-Instruct...
(EngineCore_0 pid=3559222) INFO 09-03 06:04:11 [gpu_model_runner.py:1960] Loading model from scratch...
(EngineCore_0 pid=3559222) INFO 09-03 06:04:11 [cuda.py:328] Using Flash Attention backend on V1 engine.
(EngineCore_0 pid=3559222) INFO 09-03 06:04:11 [weight_utils.py:304] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/8 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  12% Completed | 1/8 [00:00<00:02,  2.73it/s]
Loading safetensors checkpoint shards:  25% Completed | 2/8 [00:00<00:02,  2.30it/s]
Loading safetensors checkpoint shards:  38% Completed | 3/8 [00:01<00:01,  2.89it/s]
Loading safetensors checkpoint shards:  50% Completed | 4/8 [00:01<00:01,  2.74it/s]
Loading safetensors checkpoint shards:  62% Completed | 5/8 [00:01<00:01,  2.53it/s]
Loading safetensors checkpoint shards:  75% Completed | 6/8 [00:02<00:00,  2.50it/s]
Loading safetensors checkpoint shards:  88% Completed | 7/8 [00:02<00:00,  2.45it/s]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [00:03<00:00,  2.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [00:03<00:00,  2.43it/s]
(EngineCore_0 pid=3559222) 
(EngineCore_0 pid=3559222) INFO 09-03 06:04:15 [default_loader.py:267] Loading weights took 3.37 seconds
(EngineCore_0 pid=3559222) INFO 09-03 06:04:15 [gpu_model_runner.py:1970] Loading drafter model...
(EngineCore_0 pid=3559222) INFO 09-03 06:04:15 [weight_utils.py:304] Using model weights format ['*.safetensors', '*.bin', '*.pt']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.55it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.55it/s]
(EngineCore_0 pid=3559222) 
(EngineCore_0 pid=3559222) INFO 09-03 06:04:16 [default_loader.py:267] Loading weights took 0.72 seconds
(EngineCore_0 pid=3559222) INFO 09-03 06:04:16 [eagle.py:634] Assuming the EAGLE head shares the same vocab embedding with the target model.
(EngineCore_0 pid=3559222) INFO 09-03 06:04:16 [eagle.py:650] Loading EAGLE LM head weights from the target model.
(EngineCore_0 pid=3559222) INFO 09-03 06:04:16 [gpu_model_runner.py:1982] Model loading took 28.1795 GiB and 4.783674 seconds
(EngineCore_0 pid=3559222) 2025-09-03 06:04:18,306 - INFO - flashinfer.jit: Loading JIT ops: sampling
(EngineCore_0 pid=3559222) [rank0]:W0903 06:04:18.306000 3559222 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(EngineCore_0 pid=3559222) [rank0]:W0903 06:04:18.306000 3559222 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
(EngineCore_0 pid=3559222) [rank0]:W0903 06:04:18.313000 3559222 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(EngineCore_0 pid=3559222) [rank0]:W0903 06:04:18.313000 3559222 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
(EngineCore_0 pid=3559222) 2025-09-03 06:04:18,328 - INFO - flashinfer.jit: Finished loading JIT ops: sampling
(EngineCore_0 pid=3559222) INFO 09-03 06:04:19 [gpu_worker.py:276] Available KV cache memory: 2.15 GiB
(EngineCore_0 pid=3559222) INFO 09-03 06:04:19 [kv_cache_utils.py:850] GPU KV cache size: 11,472 tokens
(EngineCore_0 pid=3559222) INFO 09-03 06:04:19 [kv_cache_utils.py:854] Maximum concurrency for 8,192 tokens per request: 1.40x
(EngineCore_0 pid=3559222) INFO 09-03 06:04:19 [core.py:217] init engine (profile, create kv cache, warmup model) took 2.60 seconds
(EngineCore_0 pid=3559222) INFO 09-03 06:04:19 [__init__.py:3699] Cudagraph is disabled under eager mode
INFO 09-03 06:04:20 [llm.py:285] Supported_tasks: ['generate']
INFO 09-03 06:04:20 [__init__.py:36] No IOProcessor plugins requested by the model
Adding requests:   0%|                                                                                                    | 0/10 [00:00<?, ?it/s](EngineCore_0 pid=3559222) WARNING 09-03 06:04:20 [cudagraph_dispatcher.py:102] cudagraph dispatching keys are not initialized. No cudagraph will be used.
Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 1652.54it/s]
Processed prompts: 100%|████████████████████████████████| 10/10 [00:08<00:00,  1.17it/s, est. speed input: 1191.23 toks/s, output: 299.59 toks/s]
--------------------------------------------------
total_num_output_tokens: 2560
num_drafts: 1082
num_draft_tokens: 2164
num_accepted_tokens: 1477
mean acceptance length: 2.37
--------------------------------------------------
acceptance at token 0: 0.77
acceptance at token 1: 0.59
[rank0]:[W903 06:04:29.866627928 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models performance Performance-related issues qwen Related to Qwen models speculative-decoding v1 labels Sep 3, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for qwen2-eagle models in speculative decoding. The key changes involve normalizing EAGLE model architecture names, accommodating different KV head counts between target and draft models by aligning the KV cache space, and implementing the EagleQwen2ForCausalLM model. The approach is sound and the implementation aligns well with the project's existing patterns. I have one high-severity suggestion to improve error handling, making it more specific to avoid masking potential bugs.

Comment on lines +2469 to +2465
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The except Exception: block is overly broad and can hide important errors, such as a misconfigured model path or typos in model configuration attributes. This makes debugging difficult because the root cause is suppressed, and the user only sees a generic warning. It is better to catch more specific exceptions that are expected during configuration loading (e.g., OSError, ValueError) and to log the actual exception for better diagnostics. This ensures that unexpected errors are not silenced.

Suggested change
except Exception:
# If we can't determine, assume they're different for safety
logger.warning(
"Could not determine KV heads compatibility for EAGLE")
return True
except (OSError, ValueError) as e:
# If we can't determine, assume they're different for safety
logger.warning(
"Could not determine KV heads compatibility for EAGLE due to "
"an error when loading the draft model config: %s. "
"Assuming they are different for safety.", e)
return True

@mergify
Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
max_model_len=8192,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can revert this right before merging. I just lowered it to avoid OOM in 40GB HBM (as I tested locally).

logger = init_logger(__name__)


class Qwen2DecoderLayer(Qwen2DecoderLayer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid confusion, we should use a different name for this class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I defined in this way to keep it consistent with the existing llama_eagle. Do you think we should also update them as well? https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama_eagle.py#L27

@DarkLight1337
Copy link
Member

@luccafong can you help review this in more detail? Thanks

@mergify
Copy link

mergify bot commented Sep 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Oct 8, 2025

Documentation preview: https://vllm--24187.org.readthedocs.build/en/24187/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation needs-rebase new-model Requests to new models performance Performance-related issues qwen Related to Qwen models speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants