Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 89 additions & 102 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,28 +701,6 @@ def _forward_v1_style(
out=output)
return output

def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
lengths: List[int]) -> torch.Tensor:
max_len = max(lengths)
splits = torch.split(tensor_tnd, lengths, dim=0)

padded = []
for s in splits:
pad_len = max_len - s.shape[0]
s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len))
padded.append(s_pad)

tensor_bsnd = torch.stack(padded, dim=0)
return tensor_bsnd

def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor,
lengths: List[int]) -> torch.Tensor:
slices = []
for i, length in enumerate(lengths):
slices.append(tensor_bsnd[i, :length])
tensor_tnd = torch.cat(slices, dim=0)
return tensor_tnd

def _attention_with_nomask_and_mask(self, q: torch.Tensor,
q_seqlens: List[int],
k_nomask: torch.Tensor,
Expand All @@ -732,17 +710,15 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
v_mask: torch.Tensor,
kv_seqlens_mask: List[int],
mask: torch.Tensor) -> torch.Tensor:
q = self._pack_tnd_2_bsnd(q, q_seqlens)

# nomask Attention
if k_nomask is not None:
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
q,
self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask),
self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask),
k_nomask,
v_nomask,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
input_layout="TND",
atten_mask=None,
scale=self.scale,
sparse_mode=0,
Expand All @@ -751,38 +727,46 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_nomask,
actual_seq_lengths=q_seqlens)
attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask,
q_seqlens)
# (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
attn_lse_nomask = self._unpack_bsnd_2_tnd(
attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens)

# mask Attention
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
q,
self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask),
self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask),
k_mask,
v_mask,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
input_layout="TND",
atten_mask=mask,
scale=self.scale,
sparse_mode=0,
sparse_mode=3,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_mask,
actual_seq_lengths=q_seqlens)
attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens)
attn_lse_mask = self._unpack_bsnd_2_tnd(
attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens)

# update
output = attn_out_mask
if k_nomask is not None:
output, _ = self._update_out_and_lse(
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
T = attn_out_mask.shape[0]
N = attn_out_mask.shape[1]
D = attn_out_mask.shape[2]

attn_out_mask, attn_lse_mask = self._out_lse_reshape(
attn_out_mask, attn_lse_mask)
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
attn_out_nomask, attn_lse_nomask)
attn_out_mask = attn_out_mask.to(torch.float32)
attn_out_nomask = attn_out_nomask.to(torch.float32)
attn_lse_mask = attn_lse_mask.to(torch.float32)
attn_lse_nomask = attn_lse_nomask.to(torch.float32)

attn_output = [attn_out_nomask, attn_out_mask]
attn_lse = [attn_lse_nomask, attn_lse_mask]
update_type = 0
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
update_type)
output = output.view(T, N, D)

return output

Expand Down Expand Up @@ -838,20 +822,36 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
return output

def _update_out_and_lse(self, out_list: torch.Tensor,
lse_list: torch.Tensor) -> torch.Tensor:
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
Args:
out_list: shape = [N, batch_size, num_heads, head_size]
lse_list: shape = [N, batch_size, num_heads, 1]
Returns:
out_final: shape = [batch_size, num_heads, head_size]
lse_final: shape = [batch_size, num_heads, 1]
"""
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
dim=0)
return out_final, lse_final
def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view(
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse

def _npu_attention_update(
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
update_type = 0

batch = attn_out_lse_list[0].shape[0]
num_heads = attn_out_lse_list[0].shape[1]
head_dim = attn_out_lse_list[0].shape[2] - 1

attn_out_split_cp = []
attn_lse_split_cp = []

for i in attn_out_lse_list:
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
*torch.split(i, [self.head_size, 1], dim=-1))
attn_out_split_cp.append(attn_out_allgather)
attn_lse_split_cp.append(attn_lse_allgather)

attn_out, attn_lse = torch_npu.npu_attention_update(
attn_lse_split_cp, attn_out_split_cp, update_type)
attn_out = attn_out.view(batch, num_heads, head_dim)

return attn_out

def _forward_decode_pcp_dcp(self, query: torch.Tensor,
attn_metadata: AscendMetadata) -> torch.Tensor:
Expand All @@ -864,9 +864,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
else:
num_heads = self.num_heads

# 1. Compute out&lse by "npu_fused_infer_attention_score"
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
k_nope = self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1)
value = self.value_cache.view(self.key_cache.shape[0],
Expand All @@ -877,7 +874,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
'num_key_value_heads':
self.num_kv_heads,
'input_layout':
"BSND",
"TND",
'atten_mask':
None,
'scale':
Expand All @@ -892,9 +889,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
attn_metadata.block_tables,
'block_size':
self.key_cache.shape[1],
"actual_seq_lengths_kv":
attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
'actual_seq_lengths_kv':
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
'actual_seq_lengths':
attn_metadata.actual_seq_lengths_q[:attn_metadata.
num_decode_tokens]
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
Expand All @@ -910,26 +909,23 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, value, **common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_out = torch.empty_like(q_nope)
query, k_nope, value, **common_kwargs)
graph_params.workspaces[num_tokens] = workspace
attn_out = torch.empty_like(query)
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
dtype=torch.float,
device=q_nope.device)
device=query.device)

graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
(query, k_nope, value, self.num_heads, self.num_kv_heads,
self.scale, attn_metadata.block_tables,
self.key_cache.shape[1], attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank],
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
self.pcp_rank, self.dcp_rank, self.dcp_size))
self.key_cache.shape[1], attn_metadata.decode.
num_computed_tokens_of_cp_dcp[:, self.cp_rank, self.dcp_rank],
Comment on lines +922 to +923
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are several typos on these lines within the graph capture logic that will cause an AttributeError at runtime:

  1. attn_metadata.decode should be attn_metadata.decode_meta.
  2. num_computed_tokens_of_cp_dcp should be num_computed_tokens_of_pcp_dcp.
  3. self.cp_rank should be self.pcp_rank.

These seem to be typos introduced during refactoring, as the previous version of the code used the correct names.

Suggested change
self.key_cache.shape[1], attn_metadata.decode.
num_computed_tokens_of_cp_dcp[:, self.cp_rank, self.dcp_rank],
self.key_cache.shape[1], attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],

workspace, attn_out, attn_lse, self.cp_rank, self.dcp_rank,
self.dcp_size))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
query,
k_nope,
value,
**common_kwargs,
Expand All @@ -939,14 +935,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
graph_params.handles[num_tokens].append(handle)
else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, value, **common_kwargs)
query, k_nope, value, **common_kwargs)

attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
attn_out.shape[3])
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
attn_out_lse_list = []
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
Expand All @@ -955,35 +949,28 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
attn_out_lse_split_on_seq = list(
if self.pcp_size > 1:
attn_out_lse = attn_out_lse_all2all.contiguous()
attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))

attn_out_lse_split_dcp = torch.stack(
attn_out_lse_split_on_seq,
dim=0) # [dcp, batch_size, num_heads, head_size+1]
# Update out&lse
attn_out_split_dcp, attn_lse_split_dcp = torch.split(
attn_out_lse_split_dcp, [self.head_size, 1], dim=-1)
attn_out, attn_lse = self._update_out_and_lse(
attn_out_split_dcp, attn_lse_split_dcp)
if self.pcp_size > 1:
# 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
# 3. AllGather out&lse within CP group
# AllGather out&lse within CP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
# 4. Update out&lse
attn_out_lse_allgather = torch.stack(
attn_out_lse_list,
dim=0) # [pcp, batch_size, num_heads, head_size+1]
attn_out_allgather, attn_lse_allgather = torch.split(
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
attn_out, _ = self._update_out_and_lse(attn_out_allgather,
attn_lse_allgather)
if self.dcp_size > 1 and self.pcp_size > 1:
attn_out_lse_list_pcp_dcp = []
for s in attn_out_lse_list:
attn_out_lse_list_split = list(
torch.chunk(s, self.dcp_size, dim=1))
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
attn_out_lse_list = attn_out_lse_list_pcp_dcp
# Update out&lse
attn_out = self._npu_attention_update(attn_out_lse_list)
return attn_out

def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
Expand Down
4 changes: 1 addition & 3 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4118,7 +4118,6 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
num_prefills = num_reqs - num_decodes
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = AscendPrefillContextParallelMetadata(
Expand Down Expand Up @@ -4226,9 +4225,8 @@ def _list_to_tensor(lst, device, dtype=torch.int32):
device=self.device,
dtype=self.dtype), 1)
else:
max_seq_len = max(seq_lens, default=0)
pcp_prefill_mask = torch.triu(
torch.full((num_prefills, max_seq_len, max_seq_len),
torch.full((2048, 2048),
True,
device=self.device,
dtype=torch.bool), 1)
Expand Down
Loading