Skip to content

Commit 211d4b9

Browse files
authored
[BugFix] Fix mlapo accuracy problem related with weight processing. (#3857)
This PR fixes a mlapo accuracy problem related with weight processing. Furthermore, modify mlapo related e2e test with quantized deepseek model to make it effective. Signed-off-by: whx-sjtu <[email protected]>
1 parent d9249c9 commit 211d4b9

File tree

2 files changed

+2
-18
lines changed

2 files changed

+2
-18
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -111,19 +111,3 @@ def test_mtp2_correctness_full_graph(
111111
model_name: str,
112112
):
113113
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
114-
115-
116-
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
117-
def test_mtp_correctness_piecewise_graph_with_mlapo_kernel(
118-
sampling_config: SamplingParams,
119-
model_name: str,
120-
):
121-
mtp_correctness(sampling_config, model_name, 1)
122-
123-
124-
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
125-
def test_mtp_correctness_full_graph_with_mlapo_kernel(
126-
sampling_config: SamplingParams,
127-
model_name: str,
128-
):
129-
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,9 +676,9 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
676676
..., self.q_lora_rank:].contiguous()
677677
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
678678
..., :self.q_lora_rank].contiguous()
679-
kv_a_proj_wt = kv_a_proj_wt.contiguous()
679+
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
680680
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
681-
kv_a_proj_wt = kv_a_proj_wt.contiguous()
681+
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
682682
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
683683
wd_qkv = wd_qkv.t().contiguous()
684684
wd_qkv = transdata(wd_qkv,

0 commit comments

Comments
 (0)