Skip to content

Commit 848bff1

Browse files
BestJulyyanring
andauthored
Remove redundant reduce in aux_loss logging (#2095)
Signed-off-by: Li Tao <[email protected]> Signed-off-by: lit <[email protected]> Co-authored-by: Zijie Yan <[email protected]>
1 parent 7f4df2c commit 848bff1

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

megatron/core/transformer/moe/moe_utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def save_to_aux_losses_tracker(
723723
num_layers: int,
724724
reduce_group: torch.distributed.ProcessGroup = None,
725725
avg_group: torch.distributed.ProcessGroup = None,
726+
reduce_group_has_dp: bool = False,
726727
):
727728
"""Save the auxiliary loss for logging.
728729
Args:
@@ -731,7 +732,10 @@ def save_to_aux_losses_tracker(
731732
layer_number (int): Layer index of the loss.
732733
num_layers (int): The number of total layers.
733734
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
734-
mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
735+
avg_group (torch.distributed.ProcessGroup): The group for averaging the loss.
736+
reduce_group_has_dp (bool): Whether the reduce group has data parallel ranks.
737+
Set this to True if the reduce group has data parallel ranks. This flag is used to
738+
ensure the correct reduction in aux loss tracking.
735739
"""
736740
# Skip aux loss logging if layer_number is None.
737741
if layer_number is None:
@@ -744,6 +748,7 @@ def save_to_aux_losses_tracker(
744748
tracker[name]["values"][layer_number - 1] += loss.detach() # Aggregate the loss for the layer.
745749
tracker[name]["reduce_group"] = reduce_group
746750
tracker[name]["avg_group"] = avg_group
751+
tracker[name]["reduce_group_has_dp"] = reduce_group_has_dp
747752

748753

749754
def clear_aux_losses_tracker():
@@ -768,16 +773,18 @@ def reduce_aux_losses_tracker_across_ranks(track_names: Optional[List[str]] = No
768773
# Reduce aux losses across ranks.
769774
if tracker[name].get('reduce_group') is not None:
770775
torch.distributed.all_reduce(values, group=tracker[name].get('reduce_group'))
776+
# Need to conduct reduction across data parallel ranks. When the reduce_group
777+
# does not have 'dp' attribute, do it manually.
778+
if not tracker[name].get('reduce_group_has_dp', False):
779+
torch.distributed.all_reduce(
780+
values,
781+
group=parallel_state.get_data_parallel_group(with_context_parallel=False),
782+
op=torch.distributed.ReduceOp.AVG,
783+
)
771784
if tracker[name].get('avg_group') is not None:
772785
torch.distributed.all_reduce(
773786
values, group=tracker[name]['avg_group'], op=torch.distributed.ReduceOp.AVG
774787
)
775-
# This ensures proper loss averaging across all ranks including CP ranks
776-
torch.distributed.all_reduce(
777-
values,
778-
group=parallel_state.get_data_parallel_group(with_context_parallel=True),
779-
op=torch.distributed.ReduceOp.AVG,
780-
)
781788

782789

783790
def track_moe_metrics(
@@ -805,6 +812,7 @@ def track_moe_metrics(
805812
tracker[key]["values"] = torch.zeros(num_layers, device="cuda")
806813
tracker[key]["reduce_group"] = None
807814
tracker[key]["avg_group"] = None
815+
tracker[key]["reduce_group_has_dp"] = False
808816
reduce_aux_losses_tracker_across_ranks(track_names)
809817

810818
# Get number of MoE layers

megatron/core/transformer/moe/router.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def _apply_global_aux_loss(
374374
global_aux_loss,
375375
"global_load_balancing_loss",
376376
self.tp_dp_cp_group,
377+
reduce_group_has_dp=True,
377378
)
378379
return probs
379380

@@ -384,8 +385,20 @@ def attach_and_log_load_balancing_loss(
384385
aux_loss: torch.Tensor,
385386
aux_loss_name: str,
386387
reduce_group: torch.distributed.ProcessGroup,
388+
reduce_group_has_dp: bool = False,
387389
):
388-
"""Attach aux loss function to activation and add to logging."""
390+
"""Attach aux loss function to activation and add to logging.
391+
392+
Args:
393+
activation (torch.Tensor): The activation tensor to attach the loss to.
394+
aux_loss_coeff (float): The coefficient for the auxiliary loss.
395+
aux_loss (torch.Tensor): The auxiliary loss tensor.
396+
aux_loss_name (str): The name of the auxiliary loss for logging.
397+
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
398+
reduce_group_has_dp (bool): Whether the reduce group has data parallel ranks.
399+
Set this to True if the reduce group has data parallel ranks. This flag is used to
400+
ensure the correct reduction in aux loss tracking.
401+
"""
389402
# TODO (zijiey): fix the per_layer_logging for MTP, currently it will incorrectly
390403
# add the aux loss logging value to other layer's since it is difficult to get the
391404
# correct layer_number for MTP. It does not affect the correctness of the calculation
@@ -399,6 +412,7 @@ def attach_and_log_load_balancing_loss(
399412
self.layer_number,
400413
num_layers,
401414
reduce_group=reduce_group,
415+
reduce_group_has_dp=reduce_group_has_dp,
402416
)
403417
if self.calculate_per_token_loss:
404418
# Scale the aux_loss by the number of tokens.

0 commit comments

Comments
 (0)