Skip to content

Commit 432b861

Browse files
Fix incorrect MLAPO weight release in PD mixex scenarios. (#4774)
### What this PR does / why we need it? Fix incorrect MLAPO weight release in PD mixex scenarios. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: ZYang6263 <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent b230e7e commit 432b861

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
470470
if self.fused_qkv_a_proj is None or not isinstance(
471471
quant_method, AscendW8A8LinearMethod):
472472
reasons.append(
473-
"Currently mlapo only supports W8A8 quantization in MLA scenario."
473+
"Currently mlapo only supports W8A8 quantization in SFA scenario."
474474
"Some layers in your model are not quantized with W8A8,"
475475
"thus mlapo is disabled for these layers.")
476476
if self.enable_sfa_cp:
@@ -597,8 +597,6 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
597597
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
598598
..., :self.q_lora_rank].contiguous()
599599

600-
self.fused_qkv_a_proj.weight = None
601-
602600
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
603601
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
604602
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
@@ -673,9 +671,12 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
673671
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
674672
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
675673

676-
if self.vllm_config.kv_transfer_config is not None:
674+
if self.vllm_config.kv_transfer_config is not None and \
675+
self.vllm_config.kv_transfer_config.is_kv_consumer:
676+
self.fused_qkv_a_proj.weight = None
677677
self.fused_qkv_a_proj.deq_scale = None
678678
self.fused_qkv_a_proj.quant_bias = None
679+
self.q_proj.weight = None
679680
self.q_proj.deq_scale = None
680681
self.q_proj.quant_bias = None
681682
torch.npu.empty_cache()

0 commit comments

Comments
 (0)