@@ -917,14 +917,14 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
917917 device = query .device )
918918
919919 graph_params .attn_params [num_tokens ].append (
920- (weak_ref_tensors (query ), weak_ref_tensors (k_nope ), weak_ref_tensors ( value ),
921- self .num_heads , self .num_kv_heads ,
920+ (weak_ref_tensors (query ), weak_ref_tensors (k_nope ),
921+ weak_ref_tensors ( value ), self .num_heads , self .num_kv_heads ,
922922 self .scale , attn_metadata .block_tables ,
923923 self .key_cache .shape [1 ], attn_metadata .decode_meta .
924- num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self . dcp_rank ],
925- weak_ref_tensors ( workspace ), weak_ref_tensors ( attn_out ),
926- weak_ref_tensors (attn_lse ), self . pcp_rank , self . dcp_rank ,
927- self .dcp_size ))
924+ num_computed_tokens_of_pcp_dcp [:, self .pcp_rank ,
925+ self . dcp_rank ],
926+ weak_ref_tensors (attn_out ), weak_ref_tensors ( attn_lse ) ,
927+ self .pcp_rank , self . dcp_rank , self . dcp_size ))
928928 torch .npu .graph_task_group_begin (stream )
929929 torch_npu .npu_fused_infer_attention_score .out (
930930 query ,
0 commit comments