Skip to content

Commit 0686b32

Browse files
yiz-liujianzs
andauthored
[Fix] Fixes issues in MTP with async scheduling and ACL graph (#4963)
### What this PR does / why we need it? Corrects attention metadata size for MTP when both asynchronous scheduling and full ACL graph mode are enabled. This prevents potential size mismatches during execution. Additionally, improves the robustness of calculating token sample indices by explicitly aligning tensor shapes. Finally, prevents padding when the number of input tokens exceeds the maximum ACL graph batch size to avoid out-of-bounds errors. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Need to add corresponding test case ASAP. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: Yizhou Liu <[email protected]> Signed-off-by: Yizhou <[email protected]> Co-authored-by: Jade Zheng <[email protected]>
1 parent 42ceaf0 commit 0686b32

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,7 @@ def _propose(
748748
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
749749
aclgraph_runtime_mode, batch_descriptor = \
750750
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
751+
original_aclgraph_runtime_mode = aclgraph_runtime_mode
751752
if self.use_async_scheduling:
752753
# there is synchronization between mtp steps when enabling aclgraph,
753754
# disable aclgraph when use async scheduling to avoid the
@@ -802,6 +803,17 @@ def _propose(
802803
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
803804
hidden_states)
804805

806+
if original_aclgraph_runtime_mode == CUDAGraphMode.FULL and \
807+
self.use_async_scheduling and attn_metadata[layer_name].decode is not None:
808+
for layer_name in self.attn_layer_name:
809+
actual_size = len(attn_metadata[layer_name].decode.
810+
actual_seq_lengths_q)
811+
812+
attn_metadata[layer_name].decode.seq_lens_list = \
813+
attn_metadata[layer_name].decode.seq_lens_list[:actual_size]
814+
attn_metadata[layer_name].decode.block_table = \
815+
attn_metadata[layer_name].decode.block_table[:actual_size]
816+
805817
hidden_states = self.model(input_ids=input_ids,
806818
positions=positions,
807819
hidden_states=hidden_states)
@@ -1133,8 +1145,9 @@ def prepare_inputs_padded(
11331145
num_computed_tokens_cpu,
11341146
seq_lens=common_attn_metadata.seq_lens)
11351147

1136-
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
1137-
1 - num_rejected_tokens_gpu)
1148+
query_start_loc = common_attn_metadata.query_start_loc[
1149+
1:1 + num_rejected_tokens_gpu.shape[0]]
1150+
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
11381151

11391152
return spec_common_attn_metadata, token_indices, token_indices_to_sample
11401153

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def _prepare_inputs(
10191019
# TODO: We should make this official ASAP. Also note that if we pad here,
10201020
# the builders won’t need to add any extra padding.
10211021
if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
1022-
uniform_decode:
1022+
uniform_decode and num_input_tokens <= self.cudagraph_batch_sizes[-1]:
10231023
num_reqs_padded = num_input_tokens // self.uniform_decode_query_len
10241024
pad_size = num_reqs_padded - num_reqs
10251025
if pad_size > 0:

0 commit comments

Comments
 (0)