Skip to content

Commit 5ebb9bd

Browse files
authored
【Bugfix】bugfix_for_bmm_transpose (#4899)
The bmm_transpose operator in version 3.2 is only used in the decoding stage due to shape limitations. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: ChrisGelhLan <[email protected]>
1 parent 78bf211 commit 5ebb9bd

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,9 +490,11 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
490490
self._process_weights_for_fused_mlapo(act_dtype)
491491

492492
def _v_up_proj(self, x):
493+
forward_context = get_forward_context()
493494
if x.dtype in [torch.float16, torch.bfloat16] \
494495
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
495-
and not self.enable_sfa_cp:
496+
and not self.enable_sfa_cp \
497+
and not forward_context.with_prefill:
496498
x = x.view(-1, self.num_heads, self.kv_lora_rank)
497499
b, _, _ = x.shape
498500
res = torch.empty((b, self.num_heads, self.v_head_dim),

0 commit comments

Comments
 (0)