Skip to content

Commit 43f5f0e

Browse files
committed
[Feat] Eagle adapt vllm's speculative_config.enforce_eager
Signed-off-by: anon189Ty <[email protected]>
1 parent 4bd1030 commit 43f5f0e

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

tests/ut/spec_decode/test_eagle_proposer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_initialization_eagle(self):
3131
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
3232
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
3333
self.vllm_config.model_config.enforce_eager = False
34+
self.vllm_config.speculative_config.enforce_eager = False
3435

3536
proposer = EagleProposer(vllm_config=self.vllm_config,
3637
device=self.device,

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ def __init__(self,
4646
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
4747
)
4848

49-
self.use_cuda_graph = (self.vllm_config.compilation_config.mode
50-
== CompilationMode.VLLM_COMPILE and
51-
not self.vllm_config.model_config.enforce_eager)
49+
self.use_cuda_graph = (
50+
self.vllm_config.compilation_config.mode
51+
== CompilationMode.VLLM_COMPILE
52+
and not self.vllm_config.model_config.enforce_eager
53+
and not self.vllm_config.speculative_config.enforce_eager)
5254

5355
self.cudagraph_batch_sizes = list(
5456
sorted(
@@ -125,6 +127,9 @@ def dummy_run(self,
125127
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
126128
batch_descriptor=None,
127129
dummy_compute_logits=lambda hidden_states: None):
130+
if not self.use_cuda_graph:
131+
# It is used to adapt the eagle-fullgraph pre-commit.
132+
aclgraph_runtime_mode = CUDAGraphMode.NONE # noqa
128133
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
129134
with set_ascend_forward_context(None,
130135
self.vllm_config,
@@ -454,6 +459,8 @@ def _propose(
454459
builder = self.runner.attn_groups[0][0].get_metadata_builder()
455460
attn_metadata = builder.build(0, common_attn_metadata,
456461
self.runner.get_model())
462+
aclgraph_runtime_mode = CUDAGraphMode.NONE
463+
batch_descriptor = None
457464
if self.use_cuda_graph and \
458465
num_tokens <= self.cudagraph_batch_sizes[-1]:
459466
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
@@ -466,10 +473,19 @@ def _propose(
466473
self.positions[:num_tokens] = target_positions.to(device)
467474
self.hidden_states[:num_tokens] = target_hidden_states
468475
attn_metadata.block_tables = block_table.to(device)
469-
with set_ascend_forward_context(attn_metadata,
470-
self.vllm_config,
471-
moe_comm_type=moe_comm_type,
472-
num_tokens=num_input_tokens):
476+
# Make sure the speculative_config.enforce_eager validated to be compatible with eagle full graph.
477+
# NOTE: set aclgraph_runtime_mode again to avoid an eagle-fullgraph pre-commit error.
478+
# Should remove the reassignment once eagle-fullgraph pr be merged.
479+
if not self.use_cuda_graph:
480+
aclgraph_runtime_mode = CUDAGraphMode.NONE
481+
batch_descriptor = None # noqa
482+
483+
with set_ascend_forward_context(
484+
attn_metadata,
485+
self.vllm_config,
486+
moe_comm_type=moe_comm_type,
487+
num_tokens=num_input_tokens,
488+
aclgraph_runtime_mode=aclgraph_runtime_mode):
473489
last_hidden_states, hidden_states = self.model(
474490
input_ids=self.input_ids[:num_input_tokens],
475491
positions=self.positions[:num_input_tokens],

0 commit comments

Comments
 (0)