Skip to content

Commit 7271f0d

Browse files
authored
[Feat] MTP support DeepSeekV3.2 (#4465)
### What this PR does / why we need it? Currently, MTP does not support the DeepSeekV3.2 model. In this PR, we have enabled this feature. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: ZYang6263 <[email protected]>
1 parent 38bd952 commit 7271f0d

File tree

1 file changed

+31
-27
lines changed

1 file changed

+31
-27
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747

4848
_MTP_MODELS = {
4949
"DeepseekV3ForCausalLM":
50+
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
51+
"DeepseekV32ForCausalLM":
5052
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
5153
}
5254

@@ -813,26 +815,28 @@ def _propose(
813815
attn_metadata_i.slot_mapping.fill_(-1)
814816
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
815817
last_token_indices = self.arange[:batch_size]
816-
if attn_metadata_i.num_decode_tokens != 0:
818+
if getattr(attn_metadata_i, "num_decode_tokens", 0):
817819
attn_metadata_i.num_decode_tokens = batch_size
818820

819821
input_ids = draft_token_ids_list[-1].int()
820822
positions += 1
821823

824+
decode_metadata = getattr(attn_metadata_i, "decode", None)
825+
prefill_metadata = getattr(attn_metadata_i, "prefill", None)
822826
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
823-
if self.speculative_config.disable_padded_drafter_batch or \
824-
aclgraph_runtime_mode != CUDAGraphMode.FULL:
825-
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
827+
if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
828+
aclgraph_runtime_mode != CUDAGraphMode.FULL):
829+
decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
826830
1:batch_size + 1].tolist()
827831
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
828-
attn_metadata_i.decode.actual_seq_lengths_q = \
832+
decode_metadata.actual_seq_lengths_q = \
829833
builder.pad_actual_seq_len_q_mtp_disable_pad(
830834
graph_pad_size - batch_size,
831835
batch_size,
832-
attn_metadata_i.decode.actual_seq_lengths_q)
833-
attn_metadata_i.decode.cos = builder.cos_cache[
836+
decode_metadata.actual_seq_lengths_q)
837+
decode_metadata.cos = builder.cos_cache[
834838
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
835-
attn_metadata_i.decode.sin = builder.sin_cache[
839+
decode_metadata.sin = builder.sin_cache[
836840
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
837841
# NOTE(woosuk): We should handle the case where the draft model
838842
# generates tokens beyond the max model length. Since it is complex
@@ -870,32 +874,32 @@ def _propose(
870874
self.input_ids[batch_size:num_input_tokens] = 0
871875
self.hidden_states[batch_size:num_input_tokens].fill_(0)
872876

873-
if attn_metadata_i.prefill is not None:
874-
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
875-
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
877+
if prefill_metadata is not None:
878+
prefill_metadata.seq_lens = attn_metadata_i.seq_lens
879+
prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist(
876880
)
877-
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
878-
attn_metadata_i.prefill.input_positions = self.positions[:
879-
num_input_tokens]
880-
attn_metadata_i.prefill.max_seq_lens += 1
881-
attn_metadata_i.prefill.max_seq_lens = min(
882-
attn_metadata_i.prefill.max_seq_lens,
881+
prefill_metadata.context_lens = attn_metadata_i.seq_lens
882+
prefill_metadata.input_positions = self.positions[:
883+
num_input_tokens]
884+
prefill_metadata.max_seq_lens += 1
885+
prefill_metadata.max_seq_lens = min(
886+
prefill_metadata.max_seq_lens,
883887
self.runner.model_config.max_model_len)
884-
if attn_metadata_i.decode is not None:
885-
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
886-
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
888+
if decode_metadata is not None:
889+
decode_metadata.seq_lens = attn_metadata_i.seq_lens
890+
decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist(
887891
)
888-
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list
892+
decode_seq_lens_list = decode_metadata.seq_lens_list
889893
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
890894
self.speculative_config.disable_padded_drafter_batch:
891-
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [
895+
decode_metadata.seq_lens_list = decode_seq_lens_list + [
892896
0
893897
] * (graph_pad_size - len(decode_seq_lens_list))
894-
attn_metadata_i.decode.input_positions = self.positions[:
895-
num_input_tokens]
896-
attn_metadata_i.decode.max_seq_lens += 1
897-
attn_metadata_i.decode.max_seq_lens = min(
898-
attn_metadata_i.decode.max_seq_lens,
898+
decode_metadata.input_positions = self.positions[:
899+
num_input_tokens]
900+
decode_metadata.max_seq_lens += 1
901+
decode_metadata.max_seq_lens = min(
902+
decode_metadata.max_seq_lens,
899903
self.runner.model_config.max_model_len)
900904

901905
# mtp>1: [batch_size, k]

0 commit comments

Comments
 (0)