diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 62bca3098d9..e722881aea7 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -701,28 +701,6 @@ def _forward_v1_style( out=output) return output - def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor, - lengths: List[int]) -> torch.Tensor: - max_len = max(lengths) - splits = torch.split(tensor_tnd, lengths, dim=0) - - padded = [] - for s in splits: - pad_len = max_len - s.shape[0] - s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len)) - padded.append(s_pad) - - tensor_bsnd = torch.stack(padded, dim=0) - return tensor_bsnd - - def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor, - lengths: List[int]) -> torch.Tensor: - slices = [] - for i, length in enumerate(lengths): - slices.append(tensor_bsnd[i, :length]) - tensor_tnd = torch.cat(slices, dim=0) - return tensor_tnd - def _attention_with_nomask_and_mask(self, q: torch.Tensor, q_seqlens: List[int], k_nomask: torch.Tensor, @@ -732,17 +710,15 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor, v_mask: torch.Tensor, kv_seqlens_mask: List[int], mask: torch.Tensor) -> torch.Tensor: - q = self._pack_tnd_2_bsnd(q, q_seqlens) - # nomask Attention if k_nomask is not None: attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score( q, - self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask), - self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask), + k_nomask, + v_nomask, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, - input_layout="BSND", + input_layout="TND", atten_mask=None, scale=self.scale, sparse_mode=0, @@ -751,38 +727,46 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor, softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_nomask, actual_seq_lengths=q_seqlens) - attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask, - q_seqlens) - # (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1) - attn_lse_nomask = self._unpack_bsnd_2_tnd( - attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens) # mask Attention attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score( q, - self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask), - self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask), + k_mask, + v_mask, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, - input_layout="BSND", + input_layout="TND", atten_mask=mask, scale=self.scale, - sparse_mode=0, + sparse_mode=3, antiquant_mode=0, antiquant_scale=None, softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_mask, actual_seq_lengths=q_seqlens) - attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens) - attn_lse_mask = self._unpack_bsnd_2_tnd( - attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens) # update output = attn_out_mask if k_nomask is not None: - output, _ = self._update_out_and_lse( - torch.stack([attn_out_nomask, attn_out_mask], dim=0), - torch.stack([attn_lse_nomask, attn_lse_mask], dim=0)) + T = attn_out_mask.shape[0] + N = attn_out_mask.shape[1] + D = attn_out_mask.shape[2] + + attn_out_mask, attn_lse_mask = self._out_lse_reshape( + attn_out_mask, attn_lse_mask) + attn_out_nomask, attn_lse_nomask = self._out_lse_reshape( + attn_out_nomask, attn_lse_nomask) + attn_out_mask = attn_out_mask.to(torch.float32) + attn_out_nomask = attn_out_nomask.to(torch.float32) + attn_lse_mask = attn_lse_mask.to(torch.float32) + attn_lse_nomask = attn_lse_nomask.to(torch.float32) + + attn_output = [attn_out_nomask, attn_out_mask] + attn_lse = [attn_lse_nomask, attn_lse_mask] + update_type = 0 + output, _ = torch_npu.npu_attention_update(attn_lse, attn_output, + update_type) + output = output.view(T, N, D) return output @@ -838,20 +822,36 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) return output - def _update_out_and_lse(self, out_list: torch.Tensor, - lse_list: torch.Tensor) -> torch.Tensor: - """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i) - Args: - out_list: shape = [N, batch_size, num_heads, head_size] - lse_list: shape = [N, batch_size, num_heads, 1] - Returns: - out_final: shape = [batch_size, num_heads, head_size] - lse_final: shape = [batch_size, num_heads, 1] - """ - lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False) - out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list, - dim=0) - return out_final, lse_final + def _out_lse_reshape(self, attn_out: torch.Tensor, + attn_lse: torch.Tensor) -> torch.Tensor: + attn_out = attn_out.contiguous().view( + attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2]) + attn_lse = attn_lse.contiguous().view( + attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) + return attn_out, attn_lse + + def _npu_attention_update( + self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor: + update_type = 0 + + batch = attn_out_lse_list[0].shape[0] + num_heads = attn_out_lse_list[0].shape[1] + head_dim = attn_out_lse_list[0].shape[2] - 1 + + attn_out_split_cp = [] + attn_lse_split_cp = [] + + for i in attn_out_lse_list: + attn_out_allgather, attn_lse_allgather = self._out_lse_reshape( + *torch.split(i, [self.head_size, 1], dim=-1)) + attn_out_split_cp.append(attn_out_allgather) + attn_lse_split_cp.append(attn_lse_allgather) + + attn_out, attn_lse = torch_npu.npu_attention_update( + attn_lse_split_cp, attn_out_split_cp, update_type) + attn_out = attn_out.view(batch, num_heads, head_dim) + + return attn_out def _forward_decode_pcp_dcp(self, query: torch.Tensor, attn_metadata: AscendMetadata) -> torch.Tensor: @@ -864,9 +864,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, else: num_heads = self.num_heads - # 1. Compute out&lse by "npu_fused_infer_attention_score" - q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2]) - # [b,num_heads,head_size] -> [b,1,num_heads,head_size] k_nope = self.key_cache.view(self.key_cache.shape[0], self.key_cache.shape[1], -1) value = self.value_cache.view(self.key_cache.shape[0], @@ -877,7 +874,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, 'num_key_value_heads': self.num_kv_heads, 'input_layout': - "BSND", + "TND", 'atten_mask': None, 'scale': @@ -892,9 +889,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, attn_metadata.block_tables, 'block_size': self.key_cache.shape[1], - "actual_seq_lengths_kv": - attn_metadata.decode_meta. - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], + 'actual_seq_lengths_kv': + attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens], + 'actual_seq_lengths': + attn_metadata.actual_seq_lengths_q[:attn_metadata. + num_decode_tokens] } graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() @@ -910,26 +909,23 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - q_nope, k_nope, value, **common_kwargs) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) - attn_out = torch.empty_like(q_nope) + query, k_nope, value, **common_kwargs) + graph_params.workspaces[num_tokens] = workspace + attn_out = torch.empty_like(query) attn_lse = torch.empty((num_tokens, num_heads, 1, 1), dtype=torch.float, - device=q_nope.device) + device=query.device) graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), - weak_ref_tensors(value), self.num_heads, self.num_kv_heads, + (query, k_nope, value, self.num_heads, self.num_kv_heads, self.scale, attn_metadata.block_tables, - self.key_cache.shape[1], attn_metadata.decode_meta. - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, - self.dcp_rank], - weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse), - self.pcp_rank, self.dcp_rank, self.dcp_size)) + self.key_cache.shape[1], attn_metadata.decode. + num_computed_tokens_of_cp_dcp[:, self.cp_rank, self.dcp_rank], + workspace, attn_out, attn_lse, self.cp_rank, self.dcp_rank, + self.dcp_size)) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( - q_nope, + query, k_nope, value, **common_kwargs, @@ -939,14 +935,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, graph_params.handles[num_tokens].append(handle) else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( - q_nope, k_nope, value, **common_kwargs) + query, k_nope, value, **common_kwargs) - attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], - attn_out.shape[3]) - attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1) + attn_out_lse_list = [] + # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] + attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) if self.dcp_size > 1: - # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] - attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() attn_out_lse_all2all = torch.empty_like(attn_out_lse) @@ -955,35 +949,28 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, group=self.dcp_group) # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1] attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) - attn_out_lse_split_on_seq = list( + if self.pcp_size > 1: + attn_out_lse = attn_out_lse_all2all.contiguous() + attn_out_lse_list = list( torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) - attn_out_lse_split_dcp = torch.stack( - attn_out_lse_split_on_seq, - dim=0) # [dcp, batch_size, num_heads, head_size+1] - # Update out&lse - attn_out_split_dcp, attn_lse_split_dcp = torch.split( - attn_out_lse_split_dcp, [self.head_size, 1], dim=-1) - attn_out, attn_lse = self._update_out_and_lse( - attn_out_split_dcp, attn_lse_split_dcp) if self.pcp_size > 1: - # 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1] - attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) - # 3. AllGather out&lse within CP group + # AllGather out&lse within CP group attn_out_lse_list = [ torch.empty_like(attn_out_lse) for _ in range(self.pcp_size) ] dist.all_gather(attn_out_lse_list, attn_out_lse, group=self.pcp_group) - # 4. Update out&lse - attn_out_lse_allgather = torch.stack( - attn_out_lse_list, - dim=0) # [pcp, batch_size, num_heads, head_size+1] - attn_out_allgather, attn_lse_allgather = torch.split( - attn_out_lse_allgather, [self.head_size, 1], dim=-1) - attn_out, _ = self._update_out_and_lse(attn_out_allgather, - attn_lse_allgather) + if self.dcp_size > 1 and self.pcp_size > 1: + attn_out_lse_list_pcp_dcp = [] + for s in attn_out_lse_list: + attn_out_lse_list_split = list( + torch.chunk(s, self.dcp_size, dim=1)) + attn_out_lse_list_pcp_dcp += attn_out_lse_list_split + attn_out_lse_list = attn_out_lse_list_pcp_dcp + # Update out&lse + attn_out = self._npu_attention_update(attn_out_lse_list) return attn_out def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f30a9a39b44..a1de7ed07a5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -4118,7 +4118,6 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens): num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size - num_prefills = num_reqs - num_decodes long_seq_metadata = None if self.pcp_size * self.dcp_size > 1: long_seq_metadata = AscendPrefillContextParallelMetadata( @@ -4226,9 +4225,8 @@ def _list_to_tensor(lst, device, dtype=torch.int32): device=self.device, dtype=self.dtype), 1) else: - max_seq_len = max(seq_lens, default=0) pcp_prefill_mask = torch.triu( - torch.full((num_prefills, max_seq_len, max_seq_len), + torch.full((2048, 2048), True, device=self.device, dtype=torch.bool), 1)