Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,14 +2007,29 @@ def forward(
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded

decode_q_shape = (
decode_q.shape[0],
self.num_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
)
decode_q = torch.empty(
decode_q_shape,
device=decode_q_nope.device,
dtype=decode_q_nope.dtype,
)

if self.is_aiter_triton_fp8_bmm_enabled:
# Aiter path adopt strided write in bmm to prevent one
# additional index_copy or cat
decode_q[..., self.kv_lora_rank :].copy_(decode_q_pe)
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
decode_q_nope,
self.W_K,
self.W_K_scale,
group_size=128,
transpose_bm=True,
YQ=decode_q[..., : self.kv_lora_rank],
)
else:
# Pads the head_dim if necessary (for the underlying kernel)
Expand All @@ -2035,36 +2050,23 @@ def forward(
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)

if fp8_attention:
ql_nope_shape = decode_ql_nope.shape
q_pe_shape = decode_q_pe.shape
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
decode_q_shape = (
ql_nope_shape[0],
ql_nope_shape[1],
ql_nope_shape[2] + q_pe_shape[2],
)
# Using empty and copy since torch.cat introduces significant overhead.
decode_q0 = torch.empty(
decode_q_shape,
device=decode_ql_nope.device,
dtype=decode_ql_nope.dtype,
)
decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
if fp8_attention or self.dcp_world_size > 1:
# Make sure continuous tensor for fp8 or dcp case
decode_q[..., : decode_ql_nope.shape[2]].copy_(decode_ql_nope)
decode_q[..., decode_ql_nope.shape[2] :].copy_(decode_q_pe)
else:
# Otherwise, use tuple to avoid extra copy
decode_q = (decode_ql_nope, decode_q_pe)

if fp8_attention:
decode_q, _ = ops.scaled_fp8_quant(
decode_q0.view(decode_q_shape[0], -1),
decode_q.view(decode_q_shape[0], -1),
layer._q_scale,
)
decode_q = decode_q.view(decode_q_shape)
else:
decode_q = (decode_ql_nope, decode_q_pe)

if self.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
decode_q = torch.cat(decode_q, dim=-1)
# decode_q do allgather in head dim.
decode_q = get_dcp_group().all_gather(decode_q, dim=1)

Expand Down