Skip to content

Commit a16d2de

Browse files
committed
BSND to TND and FA_UPDATE replacement
Signed-off-by: pichangping <[email protected]>
1 parent 4312a92 commit a16d2de

File tree

2 files changed

+90
-106
lines changed

2 files changed

+90
-106
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 89 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -701,28 +701,6 @@ def _forward_v1_style(
701701
out=output)
702702
return output
703703

704-
def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
705-
lengths: List[int]) -> torch.Tensor:
706-
max_len = max(lengths)
707-
splits = torch.split(tensor_tnd, lengths, dim=0)
708-
709-
padded = []
710-
for s in splits:
711-
pad_len = max_len - s.shape[0]
712-
s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len))
713-
padded.append(s_pad)
714-
715-
tensor_bsnd = torch.stack(padded, dim=0)
716-
return tensor_bsnd
717-
718-
def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor,
719-
lengths: List[int]) -> torch.Tensor:
720-
slices = []
721-
for i, length in enumerate(lengths):
722-
slices.append(tensor_bsnd[i, :length])
723-
tensor_tnd = torch.cat(slices, dim=0)
724-
return tensor_tnd
725-
726704
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
727705
q_seqlens: List[int],
728706
k_nomask: torch.Tensor,
@@ -732,17 +710,15 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
732710
v_mask: torch.Tensor,
733711
kv_seqlens_mask: List[int],
734712
mask: torch.Tensor) -> torch.Tensor:
735-
q = self._pack_tnd_2_bsnd(q, q_seqlens)
736-
737713
# nomask Attention
738714
if k_nomask is not None:
739715
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
740716
q,
741-
self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask),
742-
self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask),
717+
k_nomask,
718+
v_nomask,
743719
num_heads=self.num_heads,
744720
num_key_value_heads=self.num_kv_heads,
745-
input_layout="BSND",
721+
input_layout="TND",
746722
atten_mask=None,
747723
scale=self.scale,
748724
sparse_mode=0,
@@ -751,38 +727,46 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
751727
softmax_lse_flag=True,
752728
actual_seq_lengths_kv=kv_seqlens_nomask,
753729
actual_seq_lengths=q_seqlens)
754-
attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask,
755-
q_seqlens)
756-
# (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
757-
attn_lse_nomask = self._unpack_bsnd_2_tnd(
758-
attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens)
759730

760731
# mask Attention
761732
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
762733
q,
763-
self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask),
764-
self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask),
734+
k_mask,
735+
v_mask,
765736
num_heads=self.num_heads,
766737
num_key_value_heads=self.num_kv_heads,
767-
input_layout="BSND",
738+
input_layout="TND",
768739
atten_mask=mask,
769740
scale=self.scale,
770-
sparse_mode=0,
741+
sparse_mode=3,
771742
antiquant_mode=0,
772743
antiquant_scale=None,
773744
softmax_lse_flag=True,
774745
actual_seq_lengths_kv=kv_seqlens_mask,
775746
actual_seq_lengths=q_seqlens)
776-
attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens)
777-
attn_lse_mask = self._unpack_bsnd_2_tnd(
778-
attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens)
779747

780748
# update
781749
output = attn_out_mask
782750
if k_nomask is not None:
783-
output, _ = self._update_out_and_lse(
784-
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
785-
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
751+
T = attn_out_mask.shape[0]
752+
N = attn_out_mask.shape[1]
753+
D = attn_out_mask.shape[2]
754+
755+
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
756+
attn_out_mask, attn_lse_mask)
757+
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
758+
attn_out_nomask, attn_lse_nomask)
759+
attn_out_mask = attn_out_mask.to(torch.float32)
760+
attn_out_nomask = attn_out_nomask.to(torch.float32)
761+
attn_lse_mask = attn_lse_mask.to(torch.float32)
762+
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
763+
764+
attn_output = [attn_out_nomask, attn_out_mask]
765+
attn_lse = [attn_lse_nomask, attn_lse_mask]
766+
update_type = 0
767+
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
768+
update_type)
769+
output = output.view(T, N, D)
786770

787771
return output
788772

@@ -838,20 +822,36 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
838822
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
839823
return output
840824

841-
def _update_out_and_lse(self, out_list: torch.Tensor,
842-
lse_list: torch.Tensor) -> torch.Tensor:
843-
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
844-
Args:
845-
out_list: shape = [N, batch_size, num_heads, head_size]
846-
lse_list: shape = [N, batch_size, num_heads, 1]
847-
Returns:
848-
out_final: shape = [batch_size, num_heads, head_size]
849-
lse_final: shape = [batch_size, num_heads, 1]
850-
"""
851-
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
852-
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
853-
dim=0)
854-
return out_final, lse_final
825+
def _out_lse_reshape(self, attn_out: torch.Tensor,
826+
attn_lse: torch.Tensor) -> torch.Tensor:
827+
attn_out = attn_out.contiguous().view(
828+
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
829+
attn_lse = attn_lse.contiguous().view(
830+
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
831+
return attn_out, attn_lse
832+
833+
def _npu_attention_update(
834+
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
835+
update_type = 0
836+
837+
batch = attn_out_lse_list[0].shape[0]
838+
num_heads = attn_out_lse_list[0].shape[1]
839+
head_dim = attn_out_lse_list[0].shape[2] - 1
840+
841+
attn_out_split_cp = []
842+
attn_lse_split_cp = []
843+
844+
for i in attn_out_lse_list:
845+
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
846+
*torch.split(i, [self.head_size, 1], dim=-1))
847+
attn_out_split_cp.append(attn_out_allgather)
848+
attn_lse_split_cp.append(attn_lse_allgather)
849+
850+
attn_out, attn_lse = torch_npu.npu_attention_update(
851+
attn_lse_split_cp, attn_out_split_cp, update_type)
852+
attn_out = attn_out.view(batch, num_heads, head_dim)
853+
854+
return attn_out
855855

856856
def _forward_decode_pcp_dcp(self, query: torch.Tensor,
857857
attn_metadata: AscendMetadata) -> torch.Tensor:
@@ -864,9 +864,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
864864
else:
865865
num_heads = self.num_heads
866866

867-
# 1. Compute out&lse by "npu_fused_infer_attention_score"
868-
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
869-
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
870867
k_nope = self.key_cache.view(self.key_cache.shape[0],
871868
self.key_cache.shape[1], -1)
872869
value = self.value_cache.view(self.key_cache.shape[0],
@@ -877,7 +874,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
877874
'num_key_value_heads':
878875
self.num_kv_heads,
879876
'input_layout':
880-
"BSND",
877+
"TND",
881878
'atten_mask':
882879
None,
883880
'scale':
@@ -892,9 +889,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
892889
attn_metadata.block_tables,
893890
'block_size':
894891
self.key_cache.shape[1],
895-
"actual_seq_lengths_kv":
896-
attn_metadata.decode_meta.
897-
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
892+
'actual_seq_lengths_kv':
893+
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
894+
'actual_seq_lengths':
895+
attn_metadata.actual_seq_lengths_q[:attn_metadata.
896+
num_decode_tokens]
898897
}
899898
graph_params = get_graph_params()
900899
forward_context: ForwardContext = get_forward_context()
@@ -910,26 +909,23 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
910909
workspace = graph_params.workspaces.get(num_tokens)
911910
if workspace is None:
912911
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
913-
q_nope, k_nope, value, **common_kwargs)
914-
update_graph_params_workspaces(num_tokens,
915-
weak_ref_tensors(workspace))
916-
attn_out = torch.empty_like(q_nope)
912+
query, k_nope, value, **common_kwargs)
913+
graph_params.workspaces[num_tokens] = workspace
914+
attn_out = torch.empty_like(query)
917915
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
918916
dtype=torch.float,
919-
device=q_nope.device)
917+
device=query.device)
920918

921919
graph_params.attn_params[num_tokens].append(
922-
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
923-
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
920+
(query, k_nope, value, self.num_heads, self.num_kv_heads,
924921
self.scale, attn_metadata.block_tables,
925-
self.key_cache.shape[1], attn_metadata.decode_meta.
926-
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
927-
self.dcp_rank],
928-
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
929-
self.pcp_rank, self.dcp_rank, self.dcp_size))
922+
self.key_cache.shape[1], attn_metadata.decode.
923+
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
924+
workspace, attn_out, attn_lse, self.pcp_rank, self.dcp_rank,
925+
self.dcp_size))
930926
torch.npu.graph_task_group_begin(stream)
931927
torch_npu.npu_fused_infer_attention_score.out(
932-
q_nope,
928+
query,
933929
k_nope,
934930
value,
935931
**common_kwargs,
@@ -939,14 +935,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
939935
graph_params.handles[num_tokens].append(handle)
940936
else:
941937
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
942-
q_nope, k_nope, value, **common_kwargs)
938+
query, k_nope, value, **common_kwargs)
943939

944-
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
945-
attn_out.shape[3])
946-
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
940+
attn_out_lse_list = []
941+
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
942+
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
947943
if self.dcp_size > 1:
948-
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
949-
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
950944
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
951945
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
952946
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
@@ -955,37 +949,29 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
955949
group=self.dcp_group)
956950
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
957951
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
958-
attn_out_lse_split_on_seq = list(
952+
if self.pcp_size > 1:
953+
attn_out_lse = attn_out_lse_all2all.contiguous()
954+
attn_out_lse_list = list(
959955
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
960956

961-
attn_out_lse_split_dcp = torch.stack(
962-
attn_out_lse_split_on_seq,
963-
dim=0) # [dcp, batch_size, num_heads, head_size+1]
964-
# Update out&lse
965-
attn_out_split_dcp, attn_lse_split_dcp = torch.split(
966-
attn_out_lse_split_dcp, [self.head_size, 1], dim=-1)
967-
attn_out, attn_lse = self._update_out_and_lse(
968-
attn_out_split_dcp, attn_lse_split_dcp)
969957
if self.pcp_size > 1:
970-
# 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
971-
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
972-
# 3. AllGather out&lse within CP group
958+
# AllGather out&lse within CP group
973959
attn_out_lse_list = [
974960
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
975961
]
976962
dist.all_gather(attn_out_lse_list,
977963
attn_out_lse,
978964
group=self.pcp_group)
979-
# 4. Update out&lse
980-
attn_out_lse_allgather = torch.stack(
981-
attn_out_lse_list,
982-
dim=0) # [pcp, batch_size, num_heads, head_size+1]
983-
attn_out_allgather, attn_lse_allgather = torch.split(
984-
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
985-
attn_out, _ = self._update_out_and_lse(attn_out_allgather,
986-
attn_lse_allgather)
965+
if self.dcp_size > 1 and self.pcp_size > 1:
966+
attn_out_lse_list_pcp_dcp = []
967+
for s in attn_out_lse_list:
968+
attn_out_lse_list_split = list(
969+
torch.chunk(s, self.dcp_size, dim=1))
970+
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
971+
attn_out_lse_list = attn_out_lse_list_pcp_dcp
972+
# Update out&lse
973+
attn_out = self._npu_attention_update(attn_out_lse_list)
987974
return attn_out
988-
989975
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
990976
value: torch.Tensor, attn_metadata: AscendMetadata,
991977
output: torch.Tensor) -> torch.Tensor:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4118,7 +4118,6 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
41184118
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
41194119
>= self.input_batch.num_prompt_tokens[:num_reqs])
41204120
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
4121-
num_prefills = num_reqs - num_decodes
41224121
long_seq_metadata = None
41234122
if self.pcp_size * self.dcp_size > 1:
41244123
long_seq_metadata = AscendPrefillContextParallelMetadata(
@@ -4226,9 +4225,8 @@ def _list_to_tensor(lst, device, dtype=torch.int32):
42264225
device=self.device,
42274226
dtype=self.dtype), 1)
42284227
else:
4229-
max_seq_len = max(seq_lens, default=0)
42304228
pcp_prefill_mask = torch.triu(
4231-
torch.full((num_prefills, max_seq_len, max_seq_len),
4229+
torch.full((2048, 2048),
42324230
True,
42334231
device=self.device,
42344232
dtype=torch.bool), 1)

0 commit comments

Comments
 (0)