File tree Expand file tree Collapse file tree 2 files changed +3
-6
lines changed
Expand file tree Collapse file tree 2 files changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -807,10 +807,7 @@ def _propose(
807807
808808 num_indices = last_token_indices .shape [0 ]
809809 if lmhead_tp_enable ():
810- if not self .runner .with_prefill :
811- max_num_reqs_across_dp = num_input_tokens
812- else :
813- max_num_reqs_across_dp = self .vllm_config .scheduler_config .max_num_seqs
810+ max_num_reqs_across_dp = self .vllm_config .scheduler_config .max_num_seqs * self .runner .uniform_decode_query_len
814811 last_token_indices = nn .functional .pad (
815812 last_token_indices ,
816813 (0 , max_num_reqs_across_dp - num_indices ))
Original file line number Diff line number Diff line change @@ -1970,7 +1970,7 @@ def _prepare_inputs(
19701970 attn_metadata [layer_name ] = attn_metadata_i
19711971
19721972 if lmhead_tp_enable ():
1973- max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self .max_num_reqs
1973+ max_num_reqs_across_dp = self . max_num_reqs * self .uniform_decode_query_len
19741974 logits_indices = nn .functional .pad (
19751975 logits_indices ,
19761976 (0 , max_num_reqs_across_dp - logits_indices .shape [0 ]))
@@ -3113,7 +3113,7 @@ def _dummy_run(
31133113
31143114 need_dummy_logits = (not self .in_profile_run
31153115 and lmhead_tp_enable ())
3116- max_num_reqs_across_dp = num_tokens_padded if not with_prefill else max_num_reqs
3116+ max_num_reqs_across_dp = max_num_reqs * self . uniform_decode_query_len
31173117 dummy_indices = torch .zeros (max_num_reqs_across_dp ,
31183118 dtype = torch .int32 )
31193119
You can’t perform that action at this time.
0 commit comments