Skip to content

Commit dbc38d7

Browse files
committed
set pin_memory=True
Signed-off-by: Ronald1995 <[email protected]>
1 parent f2c8426 commit dbc38d7

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,7 +2122,7 @@ def _calc_spec_decode_metadata(
21222122
cu_num_scheduled_tokens - num_sampled_tokens,
21232123
num_sampled_tokens)
21242124
logits_indices_pcp += arange
2125-
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).pin_memory().to(
2125+
logits_indices_pcp = torch.tensor(logits_indices_pcp, pin_memory=True).to(
21262126
self.device, non_blocking=True)
21272127

21282128
# Compute the bonus logits indices.
@@ -2145,28 +2145,23 @@ def _calc_spec_decode_metadata(
21452145

21462146
# TODO: Optimize the CPU -> NPU copy.
21472147
cu_num_draft_tokens = (
2148-
torch.from_numpy(cu_num_draft_tokens)
2149-
.pin_memory()
2148+
torch.tensor(cu_num_draft_tokens, pin_memory=True)
21502149
.to(self.device, non_blocking=True)
21512150
)
21522151
cu_num_sampled_tokens = (
2153-
torch.from_numpy(cu_num_sampled_tokens)
2154-
.pin_memory()
2152+
torch.tensor(cu_num_sampled_tokens, pin_memory=True)
21552153
.to(self.device, non_blocking=True)
21562154
)
21572155
logits_indices = (
2158-
torch.from_numpy(logits_indices)
2159-
.pin_memory()
2156+
torch.tensor(logits_indices, pin_memory=True)
21602157
.to(self.device, non_blocking=True)
21612158
)
21622159
target_logits_indices = (
2163-
torch.from_numpy(target_logits_indices)
2164-
.pin_memory()
2160+
torch.tensor(target_logits_indices, pin_memory=True)
21652161
.to(self.device, non_blocking=True)
21662162
)
21672163
bonus_logits_indices = (
2168-
torch.from_numpy(bonus_logits_indices)
2169-
.pin_memory()
2164+
torch.tensor(bonus_logits_indices, pin_memory=True)
21702165
.to(self.device, non_blocking=True)
21712166
)
21722167

0 commit comments

Comments
 (0)