We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0c59e9d commit 9880e78Copy full SHA for 9880e78
vllm/v1/spec_decode/eagle.py
@@ -657,7 +657,8 @@ def load_model(self, target_model: nn.Module) -> None:
657
self.hot_token_ids = load_draft_vocab_pruned(self.vllm_config.speculative_config.draft_vocab_pruned)
658
device = next(self.model.model.parameters()).device
659
self.hot_token_ids = self.hot_token_ids.to(device)
660
- head = self.model.model.embed_tokens.weight
+ # self.model.model.embed_tokens.weight is the model head
661
+ self.model.model.embed_tokens.weight.data = self.model.model.embed_tokens.weight.data[self.hot_token_id]
662
663
664
@torch.inference_mode()
0 commit comments