Skip to content

Commit 43efaaf

Browse files
committed
fix synchronize error
Signed-off-by: Ronald1995 <[email protected]>
1 parent 6058eef commit 43efaaf

File tree

2 files changed

+42
-35
lines changed

2 files changed

+42
-35
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def rejection_greedy_sample_pytorch(
334334

335335
start_indices = cu_num_draft_tokens - draft_tokens_per_req
336336
req_ids = torch.arange(batch_size, device=device)
337-
total_draft_tokens = torch.sum(draft_tokens_per_req_cpu).item()
337+
total_draft_tokens = sum(draft_tokens_per_req_cpu)
338338
token_req_ids = torch.repeat_interleave(
339339
req_ids, draft_tokens_per_req, output_size=total_draft_tokens
340340
)
@@ -363,8 +363,11 @@ def rejection_greedy_sample_pytorch(
363363
max_spec_len * 2)
364364
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
365365
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
366-
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
367-
no_mismatch_mask]
366+
first_mismatch_pos_per_req = torch.where(
367+
no_mismatch_mask,
368+
draft_tokens_per_req,
369+
first_mismatch_pos_per_req,
370+
)
368371

369372
# Copy matched target tokens into output.
370373
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
@@ -375,16 +378,19 @@ def rejection_greedy_sample_pytorch(
375378
greedy_mask = is_greedy.unsqueeze(1)
376379
final_copy_mask = copy_mask & greedy_mask
377380
global_idx = start_indices.unsqueeze(1) + copy_indices
378-
output_token_ids[final_copy_mask] = target_argmax[
379-
global_idx[final_copy_mask]].to(output_token_ids.dtype)
381+
output_token_ids_ = torch.where(
382+
final_copy_mask,
383+
target_argmax[global_idx].to(output_token_ids.dtype),
384+
output_token_ids
385+
)
386+
output_token_ids_.copy_(output_token_ids)
380387
# Fill bonus token.
381388
needs_bonus = is_greedy & (first_mismatch_pos_per_req
382389
>= draft_tokens_per_req)
383-
if torch.any(needs_bonus):
384-
bonus_rows = torch.where(needs_bonus)[0]
385-
bonus_cols = draft_tokens_per_req[bonus_rows]
386-
bonus_token_ids = bonus_token_ids.squeeze(1)
387-
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]
390+
bonus_rows = torch.where(needs_bonus)[0]
391+
bonus_cols = draft_tokens_per_req[bonus_rows]
392+
bonus_token_ids = bonus_token_ids.squeeze(1)
393+
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]
388394

389395

390396
def rejection_random_sample_pytorch(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,32 +2655,33 @@ def sample_tokens(
26552655
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
26562656
# the sampled tokens back, because there's no direct communication
26572657
# between the first-stage worker and the last-stage worker.
2658-
for req_idx in range(num_sampled_tokens):
2659-
sampled_ids: np.ndarray | None
2660-
if self.use_async_scheduling:
2661-
sampled_ids = (np.array([-1]) if req_idx
2662-
not in invalid_req_indices_set else None)
2663-
else:
2664-
sampled_ids = valid_sampled_token_ids[req_idx]
2665-
if sampled_ids is None or sampled_ids.shape[0] == 0:
2666-
continue
2658+
if not self.use_async_scheduling:
2659+
for req_idx in range(num_sampled_tokens):
2660+
sampled_ids: np.ndarray | None
2661+
if self.use_async_scheduling:
2662+
sampled_ids = (np.array([-1]) if req_idx
2663+
not in invalid_req_indices_set else None)
2664+
else:
2665+
sampled_ids = valid_sampled_token_ids[req_idx]
2666+
if sampled_ids is None or sampled_ids.shape[0] == 0:
2667+
continue
26672668

2668-
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
2669-
end_idx = start_idx + sampled_ids.shape[0]
2670-
assert end_idx <= self.model_config.max_model_len, (
2671-
"Sampled token IDs exceed the max model length. "
2672-
f"Total number of tokens: {end_idx} > max_model_len: "
2673-
f"{self.model_config.max_model_len}")
2674-
2675-
self.input_batch.token_ids_cpu[req_idx,
2676-
start_idx:end_idx] = sampled_ids
2677-
self.input_batch.is_token_ids[req_idx,
2678-
start_idx:end_idx] = True
2679-
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
2680-
self.input_batch.num_tokens[req_idx] = end_idx
2681-
req_id = self.input_batch.req_ids[req_idx]
2682-
req_state = self.requests[req_id]
2683-
req_state.output_token_ids.extend(sampled_ids.tolist())
2669+
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
2670+
end_idx = start_idx + sampled_ids.shape[0]
2671+
assert end_idx <= self.model_config.max_model_len, (
2672+
"Sampled token IDs exceed the max model length. "
2673+
f"Total number of tokens: {end_idx} > max_model_len: "
2674+
f"{self.model_config.max_model_len}")
2675+
2676+
self.input_batch.token_ids_cpu[req_idx,
2677+
start_idx:end_idx] = sampled_ids
2678+
self.input_batch.is_token_ids[req_idx,
2679+
start_idx:end_idx] = True
2680+
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
2681+
self.input_batch.num_tokens[req_idx] = end_idx
2682+
req_id = self.input_batch.req_ids[req_idx]
2683+
req_state = self.requests[req_id]
2684+
req_state.output_token_ids.extend(sampled_ids.tolist())
26842685
self.input_batch.prev_sampled_token_ids = None
26852686
def propose_draft_token_ids(sampled_token_ids):
26862687
assert self.spec_decode_common_attn_metadata is not None

0 commit comments

Comments
 (0)