Skip to content

Commit 0c59e9d

Browse files
committed
get model device
1 parent b046119 commit 0c59e9d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,9 @@ def load_model(self, target_model: nn.Module) -> None:
654654
self.hot_token_ids = None
655655
if self.vllm_config.speculative_config.draft_vocab_pruned:
656656
logger.info(f"Loading pruned draft model vocabulary from {self.vllm_config.speculative_config.draft_vocab_pruned}")
657-
self.hot_token_ids = load_draft_vocab_pruned(self.vllm_config.speculative_config.draft_vocab_pruned).to(self.model.device)
657+
self.hot_token_ids = load_draft_vocab_pruned(self.vllm_config.speculative_config.draft_vocab_pruned)
658+
device = next(self.model.model.parameters()).device
659+
self.hot_token_ids = self.hot_token_ids.to(device)
658660
head = self.model.model.embed_tokens.weight
659661

660662

0 commit comments

Comments
 (0)