@@ -723,6 +723,7 @@ def save_to_aux_losses_tracker(
723723 num_layers : int ,
724724 reduce_group : torch .distributed .ProcessGroup = None ,
725725 avg_group : torch .distributed .ProcessGroup = None ,
726+ reduce_group_has_dp : bool = False ,
726727):
727728 """Save the auxiliary loss for logging.
728729 Args:
@@ -731,7 +732,10 @@ def save_to_aux_losses_tracker(
731732 layer_number (int): Layer index of the loss.
732733 num_layers (int): The number of total layers.
733734 reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
734- mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
735+ avg_group (torch.distributed.ProcessGroup): The group for averaging the loss.
736+ reduce_group_has_dp (bool): Whether the reduce group has data parallel ranks.
737+ Set this to True if the reduce group has data parallel ranks. This flag is used to
738+ ensure the correct reduction in aux loss tracking.
735739 """
736740 # Skip aux loss logging if layer_number is None.
737741 if layer_number is None :
@@ -744,6 +748,7 @@ def save_to_aux_losses_tracker(
744748 tracker [name ]["values" ][layer_number - 1 ] += loss .detach () # Aggregate the loss for the layer.
745749 tracker [name ]["reduce_group" ] = reduce_group
746750 tracker [name ]["avg_group" ] = avg_group
751+ tracker [name ]["reduce_group_has_dp" ] = reduce_group_has_dp
747752
748753
749754def clear_aux_losses_tracker ():
@@ -768,16 +773,18 @@ def reduce_aux_losses_tracker_across_ranks(track_names: Optional[List[str]] = No
768773 # Reduce aux losses across ranks.
769774 if tracker [name ].get ('reduce_group' ) is not None :
770775 torch .distributed .all_reduce (values , group = tracker [name ].get ('reduce_group' ))
776+ # Need to conduct reduction across data parallel ranks. When the reduce_group
777+ # does not have 'dp' attribute, do it manually.
778+ if not tracker [name ].get ('reduce_group_has_dp' , False ):
779+ torch .distributed .all_reduce (
780+ values ,
781+ group = parallel_state .get_data_parallel_group (with_context_parallel = False ),
782+ op = torch .distributed .ReduceOp .AVG ,
783+ )
771784 if tracker [name ].get ('avg_group' ) is not None :
772785 torch .distributed .all_reduce (
773786 values , group = tracker [name ]['avg_group' ], op = torch .distributed .ReduceOp .AVG
774787 )
775- # This ensures proper loss averaging across all ranks including CP ranks
776- torch .distributed .all_reduce (
777- values ,
778- group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
779- op = torch .distributed .ReduceOp .AVG ,
780- )
781788
782789
783790def track_moe_metrics (
@@ -805,6 +812,7 @@ def track_moe_metrics(
805812 tracker [key ]["values" ] = torch .zeros (num_layers , device = "cuda" )
806813 tracker [key ]["reduce_group" ] = None
807814 tracker [key ]["avg_group" ] = None
815+ tracker [key ]["reduce_group_has_dp" ] = False
808816 reduce_aux_losses_tracker_across_ranks (track_names )
809817
810818 # Get number of MoE layers
0 commit comments