Skip to content

Commit d990316

Browse files
pichangpinghwhaokun
authored andcommitted
[long_seq_optim] BSND to TND and FA_UPDATE replacement (vllm-project#3778)
### What this PR does / why we need it? We have optimized the performance of long sequences:First,Modify the input data format for attention calculation. Instead of using the original BSND format, remove the logic for converting between TND and BSND, and directly adopt the TND format. The TND input format can be directly reused, which shortens the data flow path. Converting to BSND is an unnecessary processing step.Second, we switched the output update of the concatenated small operators to the npu_attention_update fusion operator to improve performance. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@c9461e0 --------- Signed-off-by: pichangping <[email protected]> Signed-off-by: hwhaokun <[email protected]>
1 parent 055bd9f commit d990316

File tree

2 files changed

+108
-112
lines changed

2 files changed

+108
-112
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 105 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import torch
2424
import torch.distributed as dist
2525
import torch.nn as nn
26-
import torch.nn.functional as F
2726
import torch_npu
2827
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2928
AttentionLayer, AttentionType)
@@ -318,6 +317,18 @@ def build(
318317
pcp_metadata = None
319318
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
320319
if common_long_seq_metadata is not None:
320+
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
321+
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
322+
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
323+
pcp_size = get_prefill_context_model_parallel_world_size(
324+
) if prefill_context_parallel_enable() else 1
325+
if pcp_size > 1:
326+
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0],
327+
dim=0).tolist()
328+
head_attn_nomask_seqlens = torch.cumsum(
329+
head_attn_nomask_seqlens[1], dim=0).tolist()
330+
tail_attn_nomask_seqlens = torch.cumsum(
331+
tail_attn_nomask_seqlens[1], dim=0).tolist()
321332
pcp_metadata = AscendPCPMetadata(
322333
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
323334
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
@@ -329,12 +340,9 @@ def build(
329340
kv_with_q_tail_nomask_idx_tensor,
330341
kv_with_q_tail_mask_idx=common_long_seq_metadata.
331342
kv_with_q_tail_mask_idx_tensor,
332-
attn_mask_seqlens=common_long_seq_metadata.
333-
attn_mask_seqlens,
334-
head_attn_nomask_seqlens=common_long_seq_metadata.
335-
head_attn_nomask_seqlens,
336-
tail_attn_nomask_seqlens=common_long_seq_metadata.
337-
tail_attn_nomask_seqlens,
343+
attn_mask_seqlens=attn_mask_seqlens,
344+
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
345+
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
338346
q_full_idx=common_long_seq_metadata.q_full_idx,
339347
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
340348
prefill_metadata = AscendMetadataForPrefill(
@@ -726,28 +734,6 @@ def _forward_v1_style(
726734
out=output)
727735
return output
728736

729-
def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
730-
lengths: List[int]) -> torch.Tensor:
731-
max_len = max(lengths)
732-
splits = torch.split(tensor_tnd, lengths, dim=0)
733-
734-
padded = []
735-
for s in splits:
736-
pad_len = max_len - s.shape[0]
737-
s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len))
738-
padded.append(s_pad)
739-
740-
tensor_bsnd = torch.stack(padded, dim=0)
741-
return tensor_bsnd
742-
743-
def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor,
744-
lengths: List[int]) -> torch.Tensor:
745-
slices = []
746-
for i, length in enumerate(lengths):
747-
slices.append(tensor_bsnd[i, :length])
748-
tensor_tnd = torch.cat(slices, dim=0)
749-
return tensor_tnd
750-
751737
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
752738
q_seqlens: List[int],
753739
k_nomask: torch.Tensor,
@@ -757,17 +743,15 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
757743
v_mask: torch.Tensor,
758744
kv_seqlens_mask: List[int],
759745
mask: torch.Tensor) -> torch.Tensor:
760-
q = self._pack_tnd_2_bsnd(q, q_seqlens)
761-
762746
# nomask Attention
763747
if k_nomask is not None:
764748
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
765749
q,
766-
self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask),
767-
self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask),
750+
k_nomask,
751+
v_nomask,
768752
num_heads=self.num_heads,
769753
num_key_value_heads=self.num_kv_heads,
770-
input_layout="BSND",
754+
input_layout="TND",
771755
atten_mask=None,
772756
scale=self.scale,
773757
sparse_mode=0,
@@ -776,38 +760,46 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
776760
softmax_lse_flag=True,
777761
actual_seq_lengths_kv=kv_seqlens_nomask,
778762
actual_seq_lengths=q_seqlens)
779-
attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask,
780-
q_seqlens)
781-
# (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
782-
attn_lse_nomask = self._unpack_bsnd_2_tnd(
783-
attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens)
784763

785764
# mask Attention
786765
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
787766
q,
788-
self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask),
789-
self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask),
767+
k_mask,
768+
v_mask,
790769
num_heads=self.num_heads,
791770
num_key_value_heads=self.num_kv_heads,
792-
input_layout="BSND",
771+
input_layout="TND",
793772
atten_mask=mask,
794773
scale=self.scale,
795-
sparse_mode=0,
774+
sparse_mode=3,
796775
antiquant_mode=0,
797776
antiquant_scale=None,
798777
softmax_lse_flag=True,
799778
actual_seq_lengths_kv=kv_seqlens_mask,
800779
actual_seq_lengths=q_seqlens)
801-
attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens)
802-
attn_lse_mask = self._unpack_bsnd_2_tnd(
803-
attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens)
804780

805781
# update
806782
output = attn_out_mask
807783
if k_nomask is not None:
808-
output, _ = self._update_out_and_lse(
809-
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
810-
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
784+
T = attn_out_mask.shape[0]
785+
N = attn_out_mask.shape[1]
786+
D = attn_out_mask.shape[2]
787+
788+
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
789+
attn_out_mask, attn_lse_mask)
790+
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
791+
attn_out_nomask, attn_lse_nomask)
792+
attn_out_mask = attn_out_mask.to(torch.float32)
793+
attn_out_nomask = attn_out_nomask.to(torch.float32)
794+
attn_lse_mask = attn_lse_mask.to(torch.float32)
795+
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
796+
797+
attn_output = [attn_out_nomask, attn_out_mask]
798+
attn_lse = [attn_lse_nomask, attn_lse_mask]
799+
update_type = 0
800+
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
801+
update_type)
802+
output = output.view(T, N, D)
811803

812804
return output
813805

@@ -832,29 +824,29 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
832824
# 1. Attention calculation in the first half of Q in load balancing
833825
output_head = self._attention_with_nomask_and_mask(
834826
q=torch.index_select(query, 0, q_head_idx),
835-
q_seqlens=attn_mask_seqlens[0].tolist(),
827+
q_seqlens=attn_mask_seqlens,
836828
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
837829
if self.pcp_rank > 0 else None,
838830
v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx)
839831
if self.pcp_rank > 0 else None,
840-
kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(),
832+
kv_seqlens_nomask=head_attn_nomask_seqlens,
841833
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
842834
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
843-
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
835+
kv_seqlens_mask=attn_mask_seqlens,
844836
mask=mask)
845837

846838
# 2. the Attention calculation in the latter half of Q in load balancing
847839
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
848840
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
849841
output_tail = self._attention_with_nomask_and_mask(
850842
q=torch.index_select(query, 0, q_tail_idx),
851-
q_seqlens=attn_mask_seqlens[0].tolist(),
843+
q_seqlens=attn_mask_seqlens,
852844
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
853845
v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx),
854-
kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(),
846+
kv_seqlens_nomask=tail_attn_nomask_seqlens,
855847
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
856848
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
857-
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
849+
kv_seqlens_mask=attn_mask_seqlens,
858850
mask=mask)
859851

860852
# 3. Combine the output of the first half and second half.
@@ -863,20 +855,36 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
863855
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
864856
return output
865857

866-
def _update_out_and_lse(self, out_list: torch.Tensor,
867-
lse_list: torch.Tensor) -> torch.Tensor:
868-
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
869-
Args:
870-
out_list: shape = [N, batch_size, num_heads, head_size]
871-
lse_list: shape = [N, batch_size, num_heads, 1]
872-
Returns:
873-
out_final: shape = [batch_size, num_heads, head_size]
874-
lse_final: shape = [batch_size, num_heads, 1]
875-
"""
876-
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
877-
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
878-
dim=0)
879-
return out_final, lse_final
858+
def _out_lse_reshape(self, attn_out: torch.Tensor,
859+
attn_lse: torch.Tensor) -> torch.Tensor:
860+
attn_out = attn_out.contiguous().view(
861+
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
862+
attn_lse = attn_lse.contiguous().view(
863+
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
864+
return attn_out, attn_lse
865+
866+
def _npu_attention_update(
867+
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
868+
update_type = 0
869+
870+
batch = attn_out_lse_list[0].shape[0]
871+
num_heads = attn_out_lse_list[0].shape[1]
872+
head_dim = attn_out_lse_list[0].shape[2] - 1
873+
874+
attn_out_split_cp = []
875+
attn_lse_split_cp = []
876+
877+
for i in attn_out_lse_list:
878+
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
879+
*torch.split(i, [self.head_size, 1], dim=-1))
880+
attn_out_split_cp.append(attn_out_allgather)
881+
attn_lse_split_cp.append(attn_lse_allgather)
882+
883+
attn_out, attn_lse = torch_npu.npu_attention_update(
884+
attn_lse_split_cp, attn_out_split_cp, update_type)
885+
attn_out = attn_out.view(batch, num_heads, head_dim)
886+
887+
return attn_out
880888

881889
def _forward_decode_pcp_dcp(self, query: torch.Tensor,
882890
attn_metadata: AscendMetadata) -> torch.Tensor:
@@ -889,9 +897,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
889897
else:
890898
num_heads = self.num_heads
891899

892-
# 1. Compute out&lse by "npu_fused_infer_attention_score"
893-
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
894-
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
895900
k_nope = self.key_cache.view(self.key_cache.shape[0],
896901
self.key_cache.shape[1], -1)
897902
value = self.value_cache.view(self.key_cache.shape[0],
@@ -902,7 +907,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
902907
'num_key_value_heads':
903908
self.num_kv_heads,
904909
'input_layout':
905-
"BSND",
910+
"TND",
906911
'atten_mask':
907912
None,
908913
'scale':
@@ -917,9 +922,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
917922
attn_metadata.block_tables,
918923
'block_size':
919924
self.key_cache.shape[1],
920-
"actual_seq_lengths_kv":
921-
attn_metadata.decode_meta.
922-
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
925+
'actual_seq_lengths_kv':
926+
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
927+
'actual_seq_lengths':
928+
attn_metadata.actual_seq_lengths_q[:attn_metadata.
929+
num_decode_tokens]
923930
}
924931
graph_params = get_graph_params()
925932
forward_context: ForwardContext = get_forward_context()
@@ -935,16 +942,16 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
935942
workspace = graph_params.workspaces.get(num_tokens)
936943
if workspace is None:
937944
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
938-
q_nope, k_nope, value, **common_kwargs)
945+
query, k_nope, value, **common_kwargs)
939946
update_graph_params_workspaces(num_tokens,
940947
weak_ref_tensors(workspace))
941-
attn_out = torch.empty_like(q_nope)
948+
attn_out = torch.empty_like(query)
942949
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
943950
dtype=torch.float,
944-
device=q_nope.device)
951+
device=query.device)
945952

946953
graph_params.attn_params[num_tokens].append(
947-
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
954+
(weak_ref_tensors(query), weak_ref_tensors(k_nope),
948955
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
949956
self.scale, attn_metadata.block_tables,
950957
self.key_cache.shape[1], attn_metadata.decode_meta.
@@ -954,7 +961,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
954961
self.pcp_rank, self.dcp_rank, self.dcp_size))
955962
torch.npu.graph_task_group_begin(stream)
956963
torch_npu.npu_fused_infer_attention_score.out(
957-
q_nope,
964+
query,
958965
k_nope,
959966
value,
960967
**common_kwargs,
@@ -964,14 +971,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
964971
graph_params.handles[num_tokens].append(handle)
965972
else:
966973
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
967-
q_nope, k_nope, value, **common_kwargs)
974+
query, k_nope, value, **common_kwargs)
968975

969-
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
970-
attn_out.shape[3])
971-
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
976+
attn_out_lse_list = []
977+
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
978+
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
972979
if self.dcp_size > 1:
973-
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
974-
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
975980
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
976981
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
977982
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
@@ -980,35 +985,28 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
980985
group=self.dcp_group)
981986
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
982987
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
983-
attn_out_lse_split_on_seq = list(
988+
if self.pcp_size > 1:
989+
attn_out_lse = attn_out_lse_all2all.contiguous()
990+
attn_out_lse_list = list(
984991
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
985992

986-
attn_out_lse_split_dcp = torch.stack(
987-
attn_out_lse_split_on_seq,
988-
dim=0) # [dcp, batch_size, num_heads, head_size+1]
989-
# Update out&lse
990-
attn_out_split_dcp, attn_lse_split_dcp = torch.split(
991-
attn_out_lse_split_dcp, [self.head_size, 1], dim=-1)
992-
attn_out, attn_lse = self._update_out_and_lse(
993-
attn_out_split_dcp, attn_lse_split_dcp)
994993
if self.pcp_size > 1:
995-
# 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
996-
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
997-
# 3. AllGather out&lse within CP group
994+
# AllGather out&lse within CP group
998995
attn_out_lse_list = [
999996
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
1000997
]
1001998
dist.all_gather(attn_out_lse_list,
1002999
attn_out_lse,
10031000
group=self.pcp_group)
1004-
# 4. Update out&lse
1005-
attn_out_lse_allgather = torch.stack(
1006-
attn_out_lse_list,
1007-
dim=0) # [pcp, batch_size, num_heads, head_size+1]
1008-
attn_out_allgather, attn_lse_allgather = torch.split(
1009-
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
1010-
attn_out, _ = self._update_out_and_lse(attn_out_allgather,
1011-
attn_lse_allgather)
1001+
if self.dcp_size > 1 and self.pcp_size > 1:
1002+
attn_out_lse_list_pcp_dcp = []
1003+
for s in attn_out_lse_list:
1004+
attn_out_lse_list_split = list(
1005+
torch.chunk(s, self.dcp_size, dim=1))
1006+
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
1007+
attn_out_lse_list = attn_out_lse_list_pcp_dcp
1008+
# Update out&lse
1009+
attn_out = self._npu_attention_update(attn_out_lse_list)
10121010
return attn_out
10131011

10141012
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,

0 commit comments

Comments
 (0)