@@ -910,18 +910,20 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
910910 if workspace is None :
911911 workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
912912 query , k_nope , value , ** common_kwargs )
913- graph_params .workspaces [num_tokens ] = workspace
913+ graph_params .workspaces [num_tokens ] = weak_ref_tensors ( workspace )
914914 attn_out = torch .empty_like (query )
915915 attn_lse = torch .empty ((num_tokens , num_heads , 1 , 1 ),
916916 dtype = torch .float ,
917917 device = query .device )
918918
919919 graph_params .attn_params [num_tokens ].append (
920- (query , k_nope , value , self .num_heads , self .num_kv_heads ,
920+ (weak_ref_tensors (query ), weak_ref_tensors (k_nope ), weak_ref_tensors (value ),
921+ self .num_heads , self .num_kv_heads ,
921922 self .scale , attn_metadata .block_tables ,
922- self .key_cache .shape [1 ], attn_metadata .decode .
923+ self .key_cache .shape [1 ], attn_metadata .decode_meta .
923924 num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self .dcp_rank ],
924- workspace , attn_out , attn_lse , self .pcp_rank , self .dcp_rank ,
925+ weak_ref_tensors (workspace ), weak_ref_tensors (attn_out ),
926+ weak_ref_tensors (attn_lse ), self .pcp_rank , self .dcp_rank ,
925927 self .dcp_size ))
926928 torch .npu .graph_task_group_begin (stream )
927929 torch_npu .npu_fused_infer_attention_score .out (
0 commit comments