Skip to content

Commit 9927a85

Browse files
authored
Add support for fake distributed process groups. (#2280)
1 parent e2bd0db commit 9927a85

File tree

4 files changed

+23
-3
lines changed

4 files changed

+23
-3
lines changed

megatron/training/arguments.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,14 @@ def validate_args(args, defaults={}):
11151115
assert not args.distrib_optim_fully_reshardable_mem_efficient, \
11161116
'--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups'
11171117

1118+
if args.fake_process_group:
1119+
assert args.moe_token_dispatcher_type != "flex", "Fake process group is not supported with flex token dispatcher."
1120+
# Disable nan check for fake process group
1121+
args.check_for_nan_in_loss_and_grad = False
1122+
warn_rank_0('check_for_nan_in_loss_and_grad is set to False for fake process group.')
1123+
# Disable gloo process groups for fake process group
1124+
args.enable_gloo_process_groups = False
1125+
warn_rank_0('enable_gloo_process_groups is set to False for fake process group.')
11181126

11191127
# Checkpointing
11201128
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
@@ -2770,6 +2778,10 @@ def _add_distributed_args(parser):
27702778
"and must be consistent across all ranks.")
27712779
group.add_argument('--replication-factor', default=2, type=int,
27722780
help="Number of machines storing the replica of a given rank's data.")
2781+
group.add_argument('--fake-process-group', action='store_true', default=False,
2782+
help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \
2783+
This is quite useful for profiling memory usage of distributed training with just one GPU. \
2784+
Setting WORLD_SIZE and RANK to the specific values for target distribtued scale.')
27732785
return parser
27742786

27752787

megatron/training/initialize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,12 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
346346
'rank': args.rank,
347347
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
348348
}
349+
if args.fake_process_group:
350+
assert is_torch_min_version("2.3.0"), "Fake process group is only supported with PyTorch 2.3.0 and above."
351+
from torch.testing._internal.distributed.fake_pg import FakeStore
352+
store = FakeStore()
353+
init_process_group_kwargs['backend'] = 'fake'
354+
init_process_group_kwargs['store'] = store
349355

350356
torch.distributed.init_process_group(**init_process_group_kwargs)
351357
inprocess_restart.maybe_force_nccl_backend_init(device_id)

megatron/training/training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,7 +1630,7 @@ def training_log(
16301630
mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
16311631
)
16321632
if iteration % args.log_interval == 0:
1633-
if args.record_memory_history and is_last_rank():
1633+
if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'):
16341634
snapshot = torch.cuda.memory._snapshot()
16351635
from pickle import dump
16361636

@@ -1711,7 +1711,9 @@ def training_log(
17111711
num_microbatches = get_num_microbatches()
17121712
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
17131713
report_memory(f'(after {iteration} iterations)')
1714-
report_memory_flag = False
1714+
if iteration > 1:
1715+
# Make sure the memory after the second iteration is reported to include optimizer state memory.
1716+
report_memory_flag = False
17151717
# Write timers to wandb, don't reset the counts
17161718
if args.log_timers_to_tensorboard:
17171719
timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False)

megatron/training/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def is_last_rank():
410410

411411
def print_rank_last(message):
412412
"""If distributed is initialized, print only on last rank."""
413-
if torch.distributed.is_initialized():
413+
if torch.distributed.is_initialized() and torch.distributed.get_backend() != 'fake':
414414
if is_last_rank():
415415
print(message, flush=True)
416416
else:

0 commit comments

Comments
 (0)