Skip to content

Commit f2c8426

Browse files
committed
fix sync error of seq_lens tolist
Signed-off-by: Ronald1995 <[email protected]>
1 parent 64abb03 commit f2c8426

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def __init__(
144144
self.arange = torch.arange(max_num_slots_for_arange,
145145
device=device,
146146
dtype=torch.int32)
147+
self.arange_cpu = torch.arange(
148+
max_num_slots_for_arange, device="cpu", dtype=torch.int32
149+
)
147150

148151
self.inputs_embeds = torch.zeros(
149152
(self.max_num_tokens, self.hidden_size),
@@ -814,7 +817,7 @@ def _propose(
814817
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
815818
if self.speculative_config.disable_padded_drafter_batch or \
816819
aclgraph_runtime_mode != CUDAGraphMode.FULL:
817-
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
820+
attn_metadata_i.decode.actual_seq_lengths_q = self.arange_cpu[
818821
1:batch_size + 1].tolist()
819822
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
820823
attn_metadata_i.decode.actual_seq_lengths_q = \

0 commit comments

Comments
 (0)