Skip to content

Commit b8ff04d

Browse files
committed
fix synchronize error in _calc_spec_decode_metadata
Signed-off-by: Ronald1995 <[email protected]>
1 parent 43efaaf commit b8ff04d

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def rejection_greedy_sample_pytorch(
383383
target_argmax[global_idx].to(output_token_ids.dtype),
384384
output_token_ids
385385
)
386-
output_token_ids_.copy_(output_token_ids)
386+
output_token_ids.copy_(output_token_ids_)
387387
# Fill bonus token.
388388
needs_bonus = is_greedy & (first_mismatch_pos_per_req
389389
>= draft_tokens_per_req)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2144,16 +2144,31 @@ def _calc_spec_decode_metadata(
21442144
target_logits_indices += arange
21452145

21462146
# TODO: Optimize the CPU -> NPU copy.
2147-
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
2148-
self.device, non_blocking=True)
2149-
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
2150-
self.device, non_blocking=True)
2151-
logits_indices = torch.from_numpy(logits_indices).to(self.device,
2152-
non_blocking=True)
2153-
target_logits_indices = torch.from_numpy(target_logits_indices).to(
2154-
self.device, non_blocking=True)
2155-
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
2156-
self.device, non_blocking=True)
2147+
cu_num_draft_tokens = (
2148+
torch.from_numpy(cu_num_draft_tokens)
2149+
.pin_memory()
2150+
.to(self.device, non_blocking=True)
2151+
)
2152+
cu_num_sampled_tokens = (
2153+
torch.from_numpy(cu_num_sampled_tokens)
2154+
.pin_memory()
2155+
.to(self.device, non_blocking=True)
2156+
)
2157+
logits_indices = (
2158+
torch.from_numpy(logits_indices)
2159+
.pin_memory()
2160+
.to(self.device, non_blocking=True)
2161+
)
2162+
target_logits_indices = (
2163+
torch.from_numpy(target_logits_indices)
2164+
.pin_memory()
2165+
.to(self.device, non_blocking=True)
2166+
)
2167+
bonus_logits_indices = (
2168+
torch.from_numpy(bonus_logits_indices)
2169+
.pin_memory()
2170+
.to(self.device, non_blocking=True)
2171+
)
21572172

21582173
# Compute the draft token ids.
21592174
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]

0 commit comments

Comments
 (0)