|
47 | 47 |
|
48 | 48 | _MTP_MODELS = { |
49 | 49 | "DeepseekV3ForCausalLM": |
| 50 | + ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), |
| 51 | + "DeepseekV32ForCausalLM": |
50 | 52 | ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP") |
51 | 53 | } |
52 | 54 |
|
@@ -813,26 +815,28 @@ def _propose( |
813 | 815 | attn_metadata_i.slot_mapping.fill_(-1) |
814 | 816 | attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] |
815 | 817 | last_token_indices = self.arange[:batch_size] |
816 | | - if attn_metadata_i.num_decode_tokens != 0: |
| 818 | + if getattr(attn_metadata_i, "num_decode_tokens", 0): |
817 | 819 | attn_metadata_i.num_decode_tokens = batch_size |
818 | 820 |
|
819 | 821 | input_ids = draft_token_ids_list[-1].int() |
820 | 822 | positions += 1 |
821 | 823 |
|
| 824 | + decode_metadata = getattr(attn_metadata_i, "decode", None) |
| 825 | + prefill_metadata = getattr(attn_metadata_i, "prefill", None) |
822 | 826 | # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. |
823 | | - if self.speculative_config.disable_padded_drafter_batch or \ |
824 | | - aclgraph_runtime_mode != CUDAGraphMode.FULL: |
825 | | - attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ |
| 827 | + if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \ |
| 828 | + aclgraph_runtime_mode != CUDAGraphMode.FULL): |
| 829 | + decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ |
826 | 830 | 1:batch_size + 1].tolist() |
827 | 831 | if aclgraph_runtime_mode == CUDAGraphMode.FULL: |
828 | | - attn_metadata_i.decode.actual_seq_lengths_q = \ |
| 832 | + decode_metadata.actual_seq_lengths_q = \ |
829 | 833 | builder.pad_actual_seq_len_q_mtp_disable_pad( |
830 | 834 | graph_pad_size - batch_size, |
831 | 835 | batch_size, |
832 | | - attn_metadata_i.decode.actual_seq_lengths_q) |
833 | | - attn_metadata_i.decode.cos = builder.cos_cache[ |
| 836 | + decode_metadata.actual_seq_lengths_q) |
| 837 | + decode_metadata.cos = builder.cos_cache[ |
834 | 838 | positions[:batch_size]].unsqueeze(1).unsqueeze(2) |
835 | | - attn_metadata_i.decode.sin = builder.sin_cache[ |
| 839 | + decode_metadata.sin = builder.sin_cache[ |
836 | 840 | positions[:batch_size]].unsqueeze(1).unsqueeze(2) |
837 | 841 | # NOTE(woosuk): We should handle the case where the draft model |
838 | 842 | # generates tokens beyond the max model length. Since it is complex |
@@ -870,32 +874,32 @@ def _propose( |
870 | 874 | self.input_ids[batch_size:num_input_tokens] = 0 |
871 | 875 | self.hidden_states[batch_size:num_input_tokens].fill_(0) |
872 | 876 |
|
873 | | - if attn_metadata_i.prefill is not None: |
874 | | - attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens |
875 | | - attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist( |
| 877 | + if prefill_metadata is not None: |
| 878 | + prefill_metadata.seq_lens = attn_metadata_i.seq_lens |
| 879 | + prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist( |
876 | 880 | ) |
877 | | - attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens |
878 | | - attn_metadata_i.prefill.input_positions = self.positions[: |
879 | | - num_input_tokens] |
880 | | - attn_metadata_i.prefill.max_seq_lens += 1 |
881 | | - attn_metadata_i.prefill.max_seq_lens = min( |
882 | | - attn_metadata_i.prefill.max_seq_lens, |
| 881 | + prefill_metadata.context_lens = attn_metadata_i.seq_lens |
| 882 | + prefill_metadata.input_positions = self.positions[: |
| 883 | + num_input_tokens] |
| 884 | + prefill_metadata.max_seq_lens += 1 |
| 885 | + prefill_metadata.max_seq_lens = min( |
| 886 | + prefill_metadata.max_seq_lens, |
883 | 887 | self.runner.model_config.max_model_len) |
884 | | - if attn_metadata_i.decode is not None: |
885 | | - attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens |
886 | | - attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( |
| 888 | + if decode_metadata is not None: |
| 889 | + decode_metadata.seq_lens = attn_metadata_i.seq_lens |
| 890 | + decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist( |
887 | 891 | ) |
888 | | - decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list |
| 892 | + decode_seq_lens_list = decode_metadata.seq_lens_list |
889 | 893 | if aclgraph_runtime_mode == CUDAGraphMode.FULL and \ |
890 | 894 | self.speculative_config.disable_padded_drafter_batch: |
891 | | - attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [ |
| 895 | + decode_metadata.seq_lens_list = decode_seq_lens_list + [ |
892 | 896 | 0 |
893 | 897 | ] * (graph_pad_size - len(decode_seq_lens_list)) |
894 | | - attn_metadata_i.decode.input_positions = self.positions[: |
895 | | - num_input_tokens] |
896 | | - attn_metadata_i.decode.max_seq_lens += 1 |
897 | | - attn_metadata_i.decode.max_seq_lens = min( |
898 | | - attn_metadata_i.decode.max_seq_lens, |
| 898 | + decode_metadata.input_positions = self.positions[: |
| 899 | + num_input_tokens] |
| 900 | + decode_metadata.max_seq_lens += 1 |
| 901 | + decode_metadata.max_seq_lens = min( |
| 902 | + decode_metadata.max_seq_lens, |
899 | 903 | self.runner.model_config.max_model_len) |
900 | 904 |
|
901 | 905 | # mtp>1: [batch_size, k] |
|
0 commit comments