Skip to content

Commit f0876b5

Browse files
[Bugfix] Fix Dcp dimension mismatch when enable Mlapo (#4687)
### What this PR does / why we need it? After enabling Mlapo and DCP, since Mlapo has its own mla_preprocess logic and does not perform additional all_gather operations on the DCP group, this will lead to dimension mismatch during the subsequent forward proces ### Does this PR introduce _any_ user-facing change? N/A - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: zengran <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent afe0050 commit f0876b5

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,13 @@ def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
14951495
self.kv_lora_rank)
14961496
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
14971497

1498+
if self.dcp_size > 1:
1499+
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
1500+
decode_q_no_split = get_dcp_group().all_gather(
1501+
decode_q_no_split, 1)
1502+
decode_q_nope, decode_q_pe = decode_q_no_split.split(
1503+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1504+
14981505
decode_preprocess_res = DecodeMLAPreprocessResult(
14991506
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
15001507
return decode_preprocess_res, None

0 commit comments

Comments
 (0)