Skip to content

Commit c0317c9

Browse files
committed
fix pin_memory error
Signed-off-by: Ronald1995 <[email protected]>
1 parent 430d371 commit c0317c9

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,7 +2122,7 @@ def _calc_spec_decode_metadata(
21222122
cu_num_scheduled_tokens - num_sampled_tokens,
21232123
num_sampled_tokens)
21242124
logits_indices_pcp += arange
2125-
logits_indices_pcp = torch.tensor(logits_indices_pcp, pin_memory=True).to(
2125+
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).pin_memory().to(
21262126
self.device, non_blocking=True)
21272127

21282128
# Compute the bonus logits indices.
@@ -2145,25 +2145,27 @@ def _calc_spec_decode_metadata(
21452145

21462146
# TODO: Optimize the CPU -> NPU copy.
21472147
cu_num_draft_tokens = (
2148-
torch.tensor(cu_num_draft_tokens, pin_memory=True)
2148+
torch.from_numpy(cu_num_draft_tokens)
2149+
.pin_memory()
21492150
.to(self.device, non_blocking=True)
21502151
)
21512152
cu_num_sampled_tokens = (
2152-
torch.tensor(cu_num_sampled_tokens, pin_memory=True)
2153+
torch.from_numpy(cu_num_sampled_tokens)
2154+
.pin_memory()
21532155
.to(self.device, non_blocking=True)
21542156
)
21552157
logits_indices = (
2156-
torch.tensor(logits_indices, pin_memory=True)
2158+
torch.from_numpy(logits_indices)
2159+
.pin_memory()
21572160
.to(self.device, non_blocking=True)
21582161
)
21592162
target_logits_indices = (
2160-
torch.tensor(target_logits_indices, pin_memory=True)
2161-
.to(self.device, non_blocking=True)
2162-
)
2163-
bonus_logits_indices = (
2164-
torch.tensor(bonus_logits_indices, pin_memory=True)
2163+
torch.from_numpy(target_logits_indices)
2164+
.pin_memory()
21652165
.to(self.device, non_blocking=True)
21662166
)
2167+
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
2168+
self.device, non_blocking=True)
21672169

21682170
# Compute the draft token ids.
21692171
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]

0 commit comments

Comments
 (0)