|
16 | 16 | from vllm.utils import cdiv, round_down |
17 | 17 | from vllm.v1.attention.backends.utils import AttentionCGSupport |
18 | 18 |
|
| 19 | +from vllm_ascend import envs |
19 | 20 | from vllm_ascend.ascend_config import get_ascend_config |
20 | 21 | from vllm_ascend.attention.attention_v1 import AscendAttentionState |
21 | 22 | from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, |
22 | 23 | maybe_save_kv_layer_to_connector, |
23 | 24 | split_decodes_and_prefills, |
| 25 | + trans_rope_weight, transdata, |
24 | 26 | wait_for_kv_layer_from_connector) |
25 | 27 | from vllm_ascend.compilation.acl_graph import get_graph_params |
26 | 28 | from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig |
@@ -639,6 +641,87 @@ def get_and_maybe_dequant_weights(layer: LinearBase): |
639 | 641 | # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) |
640 | 642 | # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) |
641 | 643 |
|
| 644 | + if envs.VLLM_ASCEND_ENABLE_MLAPO: |
| 645 | + self._process_weights_for_fused_mlapo(act_dtype) |
| 646 | + |
| 647 | + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): |
| 648 | + kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data |
| 649 | + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() |
| 650 | + kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) |
| 651 | + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() |
| 652 | + wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data), dim=-1) |
| 653 | + wd_qkv = wd_qkv.t().contiguous() |
| 654 | + wd_qkv = transdata(wd_qkv, |
| 655 | + block_size=(16, 32)).unsqueeze(0).contiguous() |
| 656 | + self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) |
| 657 | + |
| 658 | + kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale |
| 659 | + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( |
| 660 | + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() |
| 661 | + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, |
| 662 | + self.qk_rope_head_dim) |
| 663 | + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( |
| 664 | + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() |
| 665 | + self.deq_scale_qkv = torch.cat( |
| 666 | + (kv_a_proj_deq_scl, self.q_a_proj.deq_scale), dim=-1).contiguous() |
| 667 | + |
| 668 | + kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias |
| 669 | + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( |
| 670 | + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() |
| 671 | + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, |
| 672 | + self.qk_rope_head_dim) |
| 673 | + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( |
| 674 | + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() |
| 675 | + self.quant_bias_qkv = torch.cat( |
| 676 | + (kv_a_proj_qt_bias, self.q_a_proj.quant_bias), |
| 677 | + dim=-1).contiguous() |
| 678 | + |
| 679 | + wu_q = self.q_proj.weight.data |
| 680 | + wu_q = wu_q.t().reshape(self.num_heads, |
| 681 | + self.qk_nope_head_dim + self.qk_rope_head_dim, |
| 682 | + -1) |
| 683 | + wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) |
| 684 | + wu_q = wu_q.reshape( |
| 685 | + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), |
| 686 | + -1) |
| 687 | + wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() |
| 688 | + self.wu_q = torch_npu.npu_format_cast(wu_q, 29) |
| 689 | + |
| 690 | + qb_deq_scl = self.q_proj.deq_scale.data |
| 691 | + qb_deq_scl = qb_deq_scl.reshape( |
| 692 | + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) |
| 693 | + qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) |
| 694 | + self.qb_deq_scl = qb_deq_scl.reshape( |
| 695 | + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) |
| 696 | + |
| 697 | + qb_qt_bias = self.q_proj.quant_bias.data |
| 698 | + qb_qt_bias = qb_qt_bias.reshape( |
| 699 | + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) |
| 700 | + qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) |
| 701 | + self.qb_qt_bias = qb_qt_bias.reshape( |
| 702 | + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) |
| 703 | + |
| 704 | + device = self.q_a_proj.weight.device |
| 705 | + self.gamma0 = torch.ones( |
| 706 | + [self.q_a_proj.weight.shape[-1]], |
| 707 | + dtype=act_dtype, |
| 708 | + device=device, |
| 709 | + ) |
| 710 | + self.beta0 = torch.zeros( |
| 711 | + [self.q_a_proj.weight.shape[-1]], |
| 712 | + dtype=act_dtype, |
| 713 | + device=device, |
| 714 | + ) |
| 715 | + self.gamma1 = self.q_a_layernorm.weight.data |
| 716 | + self.beta1 = self.q_a_layernorm.bias.data |
| 717 | + self.gamma2 = self.kv_a_layernorm.weight.data |
| 718 | + self.quant_scale0 = self.q_a_proj.input_scale.data |
| 719 | + self.quant_offset0 = self.q_a_proj.input_offset.data |
| 720 | + self.quant_scale1 = self.q_proj.input_scale.data |
| 721 | + self.quant_offset1 = self.q_proj.input_offset.data |
| 722 | + self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) |
| 723 | + self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) |
| 724 | + |
642 | 725 | def _compute_prefill_context( |
643 | 726 | self, |
644 | 727 | q_nope: torch.Tensor, |
@@ -961,6 +1044,68 @@ def _forward_decode( |
961 | 1044 | current_ms_metadata.before_comm_event.wait() |
962 | 1045 | return self._v_up_proj(attn_output) |
963 | 1046 |
|
| 1047 | + def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata): |
| 1048 | + bsz = attn_metadata.num_decode_tokens |
| 1049 | + hidden_states = hidden_states[:bsz] |
| 1050 | + |
| 1051 | + cos_shape = attn_metadata.decode.cos.shape |
| 1052 | + cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) |
| 1053 | + sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) |
| 1054 | + |
| 1055 | + decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] |
| 1056 | + decode_q_nope = torch.empty( |
| 1057 | + (hidden_states.shape[0], self.W_UK_T.shape[0], |
| 1058 | + decode_k_nope.shape[-1]), |
| 1059 | + dtype=hidden_states.dtype, |
| 1060 | + device=hidden_states.device, |
| 1061 | + ) |
| 1062 | + decode_q_pe = torch.empty( |
| 1063 | + (hidden_states.shape[0], self.W_UK_T.shape[0], |
| 1064 | + decode_k_pe.shape[-1]), |
| 1065 | + dtype=hidden_states.dtype, |
| 1066 | + device=hidden_states.device, |
| 1067 | + ) |
| 1068 | + |
| 1069 | + torch.ops._C_ascend.mla_preprocess( |
| 1070 | + hidden_states, |
| 1071 | + self.gamma0, |
| 1072 | + self.beta0, |
| 1073 | + self.wd_qkv, |
| 1074 | + self.deq_scale_qkv, |
| 1075 | + self.gamma1, |
| 1076 | + self.beta1, |
| 1077 | + self.wu_q, |
| 1078 | + self.qb_deq_scl, |
| 1079 | + self.gamma2, |
| 1080 | + cos, |
| 1081 | + sin, |
| 1082 | + self.W_UK_T, |
| 1083 | + decode_k_nope, |
| 1084 | + decode_k_pe, |
| 1085 | + attn_metadata.slot_mapping[:bsz].flatten(), |
| 1086 | + quant_scale0=self.quant_scale0, |
| 1087 | + quant_offset0=self.quant_offset0, |
| 1088 | + bias0=self.quant_bias_qkv, |
| 1089 | + quant_scale1=self.quant_scale1, |
| 1090 | + quant_offset1=self.quant_offset1, |
| 1091 | + bias1=self.qb_qt_bias, |
| 1092 | + ctkv_scale=self.ctkv_scale, |
| 1093 | + q_nope_scale=self.q_nope_scale, |
| 1094 | + cache_mode="krope_ctkv", |
| 1095 | + quant_mode="per_tensor_quant_asymm", |
| 1096 | + q_out0=decode_q_nope, |
| 1097 | + kv_cache_out0=decode_k_nope, |
| 1098 | + q_out1=decode_q_pe, |
| 1099 | + kv_cache_out1=decode_k_pe, |
| 1100 | + ) |
| 1101 | + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, |
| 1102 | + self.kv_lora_rank) |
| 1103 | + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) |
| 1104 | + |
| 1105 | + decode_preprocess_res = DecodeMLAPreprocessResult( |
| 1106 | + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) |
| 1107 | + return decode_preprocess_res, None |
| 1108 | + |
964 | 1109 | def _mla_preprocess(self, layer_name, hidden_states, kv_cache, |
965 | 1110 | attn_metadata, need_gather_q_kv): |
966 | 1111 | # MLA Preprocess: |
@@ -1065,9 +1210,15 @@ def forward( |
1065 | 1210 | device=hidden_states.device) |
1066 | 1211 |
|
1067 | 1212 | # MLA Preprocess |
1068 | | - decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( |
1069 | | - layer_name, hidden_states, kv_cache, attn_metadata, |
1070 | | - need_gather_q_kv) |
| 1213 | + forward_context = get_forward_context() |
| 1214 | + if (envs.VLLM_ASCEND_ENABLE_MLAPO and |
| 1215 | + (attn_metadata is None or not forward_context.with_prefill)): |
| 1216 | + decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( |
| 1217 | + hidden_states, kv_cache, attn_metadata) |
| 1218 | + else: |
| 1219 | + decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( |
| 1220 | + layer_name, hidden_states, kv_cache, attn_metadata, |
| 1221 | + need_gather_q_kv) |
1071 | 1222 |
|
1072 | 1223 | if decode_preprocess_res is not None: |
1073 | 1224 | # MLA Preprocess for decoding |
|
0 commit comments