Skip to content

Commit 9880e78

Browse files
committed
override the lm head
1 parent 0c59e9d commit 9880e78

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,8 @@ def load_model(self, target_model: nn.Module) -> None:
657657
self.hot_token_ids = load_draft_vocab_pruned(self.vllm_config.speculative_config.draft_vocab_pruned)
658658
device = next(self.model.model.parameters()).device
659659
self.hot_token_ids = self.hot_token_ids.to(device)
660-
head = self.model.model.embed_tokens.weight
660+
# self.model.model.embed_tokens.weight is the model head
661+
self.model.model.embed_tokens.weight.data = self.model.model.embed_tokens.weight.data[self.hot_token_id]
661662

662663

663664
@torch.inference_mode()

0 commit comments

Comments
 (0)