diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0a5257a1d87d..f4a06a96b360 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -2007,7 +2007,21 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded + decode_q_shape = ( + decode_q.shape[0], + self.num_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + decode_q = torch.empty( + decode_q_shape, + device=decode_q_nope.device, + dtype=decode_q_nope.dtype, + ) + if self.is_aiter_triton_fp8_bmm_enabled: + # Aiter path adopt strided write in bmm to prevent one + # additional index_copy or cat + decode_q[..., self.kv_lora_rank :].copy_(decode_q_pe) # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( decode_q_nope, @@ -2015,6 +2029,7 @@ def forward( self.W_K_scale, group_size=128, transpose_bm=True, + YQ=decode_q[..., : self.kv_lora_rank], ) else: # Pads the head_dim if necessary (for the underlying kernel) @@ -2035,36 +2050,23 @@ def forward( # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) - if fp8_attention: - ql_nope_shape = decode_ql_nope.shape - q_pe_shape = decode_q_pe.shape - assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] - assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] - decode_q_shape = ( - ql_nope_shape[0], - ql_nope_shape[1], - ql_nope_shape[2] + q_pe_shape[2], - ) - # Using empty and copy since torch.cat introduces significant overhead. - decode_q0 = torch.empty( - decode_q_shape, - device=decode_ql_nope.device, - dtype=decode_ql_nope.dtype, - ) - decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope) - decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe) + if fp8_attention or self.dcp_world_size > 1: + # Make sure continuous tensor for fp8 or dcp case + decode_q[..., : decode_ql_nope.shape[2]].copy_(decode_ql_nope) + decode_q[..., decode_ql_nope.shape[2] :].copy_(decode_q_pe) + else: + # Otherwise, use tuple to avoid extra copy + decode_q = (decode_ql_nope, decode_q_pe) + if fp8_attention: decode_q, _ = ops.scaled_fp8_quant( - decode_q0.view(decode_q_shape[0], -1), + decode_q.view(decode_q_shape[0], -1), layer._q_scale, ) decode_q = decode_q.view(decode_q_shape) - else: - decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: assert not fp8_attention, "DCP not support fp8 kvcache now." - # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) - decode_q = torch.cat(decode_q, dim=-1) # decode_q do allgather in head dim. decode_q = get_dcp_group().all_gather(decode_q, dim=1)