Skip to content

Commit 090bc49

Browse files
committed
BSND to TND and FA_UPDATE replacement
Signed-off-by: pichangping <[email protected]>
1 parent a16d2de commit 090bc49

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -910,18 +910,20 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
910910
if workspace is None:
911911
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
912912
query, k_nope, value, **common_kwargs)
913-
graph_params.workspaces[num_tokens] = workspace
913+
graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
914914
attn_out = torch.empty_like(query)
915915
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
916916
dtype=torch.float,
917917
device=query.device)
918918

919919
graph_params.attn_params[num_tokens].append(
920-
(query, k_nope, value, self.num_heads, self.num_kv_heads,
920+
(weak_ref_tensors(query), weak_ref_tensors(k_nope), weak_ref_tensors(value),
921+
self.num_heads, self.num_kv_heads,
921922
self.scale, attn_metadata.block_tables,
922-
self.key_cache.shape[1], attn_metadata.decode.
923+
self.key_cache.shape[1], attn_metadata.decode_meta.
923924
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
924-
workspace, attn_out, attn_lse, self.pcp_rank, self.dcp_rank,
925+
weak_ref_tensors(workspace), weak_ref_tensors(attn_out),
926+
weak_ref_tensors(attn_lse), self.pcp_rank, self.dcp_rank,
925927
self.dcp_size))
926928
torch.npu.graph_task_group_begin(stream)
927929
torch_npu.npu_fused_infer_attention_score.out(

0 commit comments

Comments
 (0)