Skip to content

Commit 916a9a1

Browse files
fix synchronize error of exceeds_max_model_len d2h copy (#4708)
### What this PR does / why we need it? there is d2h copy blocking cpu operations in mtp propose method, which make host bound issue. this pr refactor it and use cpu tensor to implement it. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? vllm main f5d3d93c40417c296c20dc301100e55708a17f3f - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: Ronald1995 <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent 2be0fe2 commit 916a9a1

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -885,10 +885,9 @@ def _propose(
885885
attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1
886886
# For the requests that exceed the max model length, we set the
887887
# sequence length to 1 to minimize their overheads in attention.
888-
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
889-
attn_metadata_i.seq_lens.device, non_blocking=False)
890-
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
891-
exceeds_max_model_len_cpu, 1)
888+
exceeds_mask = attn_metadata_i.seq_lens[:batch_size] > \
889+
self.runner.model_config.max_model_len
890+
attn_metadata_i.seq_lens[:batch_size].masked_fill_(exceeds_mask, 1)
892891
# Mask out the slot mappings that exceed the max model length.
893892
# Otherwise, the KV cache will be inadvertently updated with the
894893
# padding tokens.

0 commit comments

Comments
 (0)