Skip to content

Commit a433f32

Browse files
[Op] DeepSeekV3.2 support bmm_transpose operator (vllm-project#4631)
### What this PR does / why we need it? DeepSeekV3.2 support bmm_transpose operator. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: ZYang6263 <[email protected]> Signed-off-by: ZYang6263 <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent 0b65ac6 commit a433f32

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,14 +484,15 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
484484
self._process_weights_for_fused_mlapo(act_dtype)
485485

486486
def _v_up_proj(self, x):
487-
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
488-
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
489-
x = torch_npu.npu_transpose_batchmatmul(x,
490-
self.W_UV,
491-
perm_x1=[1, 0, 2],
492-
perm_x2=[0, 1, 2],
493-
perm_y=[1, 0, 2])
494-
x = x.reshape(-1, self.local_num_heads * self.v_head_dim)
487+
if x.dtype in [torch.float16, torch.bfloat16] \
488+
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
489+
x = x.view(-1, self.num_heads, self.kv_lora_rank)
490+
b, _, _ = x.shape
491+
res = torch.empty((b, self.num_heads, self.v_head_dim),
492+
dtype=x.dtype,
493+
device=x.device)
494+
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
495+
x = res.reshape(-1, self.num_heads * self.v_head_dim)
495496
else:
496497
# Convert from (B, N, L) to (N, B, L)
497498
x = x.view(-1, self.local_num_heads,

0 commit comments

Comments
 (0)