@@ -2655,32 +2655,33 @@ def sample_tokens(
26552655 # NOTE(woosuk): As an exception, when using PP, the scheduler sends
26562656 # the sampled tokens back, because there's no direct communication
26572657 # between the first-stage worker and the last-stage worker.
2658- for req_idx in range (num_sampled_tokens ):
2659- sampled_ids : np .ndarray | None
2660- if self .use_async_scheduling :
2661- sampled_ids = (np .array ([- 1 ]) if req_idx
2662- not in invalid_req_indices_set else None )
2663- else :
2664- sampled_ids = valid_sampled_token_ids [req_idx ]
2665- if sampled_ids is None or sampled_ids .shape [0 ] == 0 :
2666- continue
2658+ if not self .use_async_scheduling :
2659+ for req_idx in range (num_sampled_tokens ):
2660+ sampled_ids : np .ndarray | None
2661+ if self .use_async_scheduling :
2662+ sampled_ids = (np .array ([- 1 ]) if req_idx
2663+ not in invalid_req_indices_set else None )
2664+ else :
2665+ sampled_ids = valid_sampled_token_ids [req_idx ]
2666+ if sampled_ids is None or sampled_ids .shape [0 ] == 0 :
2667+ continue
26672668
2668- start_idx = self .input_batch .num_tokens_no_spec [req_idx ]
2669- end_idx = start_idx + sampled_ids .shape [0 ]
2670- assert end_idx <= self .model_config .max_model_len , (
2671- "Sampled token IDs exceed the max model length. "
2672- f"Total number of tokens: { end_idx } > max_model_len: "
2673- f"{ self .model_config .max_model_len } " )
2674-
2675- self .input_batch .token_ids_cpu [req_idx ,
2676- start_idx :end_idx ] = sampled_ids
2677- self .input_batch .is_token_ids [req_idx ,
2678- start_idx :end_idx ] = True
2679- self .input_batch .num_tokens_no_spec [req_idx ] = end_idx
2680- self .input_batch .num_tokens [req_idx ] = end_idx
2681- req_id = self .input_batch .req_ids [req_idx ]
2682- req_state = self .requests [req_id ]
2683- req_state .output_token_ids .extend (sampled_ids .tolist ())
2669+ start_idx = self .input_batch .num_tokens_no_spec [req_idx ]
2670+ end_idx = start_idx + sampled_ids .shape [0 ]
2671+ assert end_idx <= self .model_config .max_model_len , (
2672+ "Sampled token IDs exceed the max model length. "
2673+ f"Total number of tokens: { end_idx } > max_model_len: "
2674+ f"{ self .model_config .max_model_len } " )
2675+
2676+ self .input_batch .token_ids_cpu [req_idx ,
2677+ start_idx :end_idx ] = sampled_ids
2678+ self .input_batch .is_token_ids [req_idx ,
2679+ start_idx :end_idx ] = True
2680+ self .input_batch .num_tokens_no_spec [req_idx ] = end_idx
2681+ self .input_batch .num_tokens [req_idx ] = end_idx
2682+ req_id = self .input_batch .req_ids [req_idx ]
2683+ req_state = self .requests [req_id ]
2684+ req_state .output_token_ids .extend (sampled_ids .tolist ())
26842685 self .input_batch .prev_sampled_token_ids = None
26852686 def propose_draft_token_ids (sampled_token_ids ):
26862687 assert self .spec_decode_common_attn_metadata is not None
0 commit comments