File tree Expand file tree Collapse file tree 1 file changed +6
-7
lines changed Expand file tree Collapse file tree 1 file changed +6
-7
lines changed Original file line number Diff line number Diff line change @@ -976,26 +976,25 @@ def _sync_metadata_across_dp(
976976 # immediately once the other two flags are no longer needed.
977977 if self .dp_size == 1 :
978978 return num_tokens , None , with_prefill
979-
980979 # Sync num_tokens, with_prefill across dp ranks
981980 num_tokens_tensor = torch .tensor ([
982981 num_tokens if i == self .dp_rank else 0 for i in range (self .dp_size )
983982 ],
984- dtype = torch .int32 ,
985- device = "npu " )
983+ dtype = torch .int32 ,
984+ device = "cpu " )
986985
987986 flags_tensor = torch .tensor ([int (with_prefill )],
988987 dtype = torch .int32 ,
989- device = "npu " )
988+ device = "cpu " )
990989
991990 packed_tensor = torch .cat ([num_tokens_tensor , flags_tensor ])
992-
993- dist .all_reduce (packed_tensor , group = get_dp_group ().device_group )
991+ # use cpu_group to avoid cpu synchronization issue.
992+ # it can be overlaped with main moell execution on npu.
993+ dist .all_reduce (packed_tensor , group = get_dp_group ().cpu_group )
994994
995995 # Unpack the results
996996 num_tokens_across_dp = packed_tensor [:- 1 ]
997997 synced_flags = packed_tensor [- 1 :]
998-
999998 max_tokens_across_dp = torch .max (num_tokens_across_dp ).item ()
1000999 global_with_prefill = bool (synced_flags [0 ])
10011000
You can’t perform that action at this time.
0 commit comments