Skip to content

Commit 35265a3

Browse files
committed
fix lmhead_tp break
Signed-off-by: linfeng-yuan <[email protected]>
1 parent 776e14e commit 35265a3

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -807,10 +807,7 @@ def _propose(
807807

808808
num_indices = last_token_indices.shape[0]
809809
if lmhead_tp_enable():
810-
if not self.runner.with_prefill:
811-
max_num_reqs_across_dp = num_input_tokens
812-
else:
813-
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
810+
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
814811
last_token_indices = nn.functional.pad(
815812
last_token_indices,
816813
(0, max_num_reqs_across_dp - num_indices))

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,7 +1970,7 @@ def _prepare_inputs(
19701970
attn_metadata[layer_name] = attn_metadata_i
19711971

19721972
if lmhead_tp_enable():
1973-
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
1973+
max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len
19741974
logits_indices = nn.functional.pad(
19751975
logits_indices,
19761976
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
@@ -3113,7 +3113,7 @@ def _dummy_run(
31133113

31143114
need_dummy_logits = (not self.in_profile_run
31153115
and lmhead_tp_enable())
3116-
max_num_reqs_across_dp = num_tokens_padded if not with_prefill else max_num_reqs
3116+
max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len
31173117
dummy_indices = torch.zeros(max_num_reqs_across_dp,
31183118
dtype=torch.int32)
31193119

0 commit comments

Comments
 (0)