diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index b5f777a30c6..3aeade771c6 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1103,6 +1103,14 @@ def validate_args(args, defaults={}): assert not args.distrib_optim_fully_reshardable_mem_efficient, \ '--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups' + if args.fake_process_group: + assert args.moe_token_dispatcher_type != "flex", "Fake process group is not supported with flex token dispatcher." + # Disable nan check for fake process group + args.check_for_nan_in_loss_and_grad = False + warn_rank_0('check_for_nan_in_loss_and_grad is set to False for fake process group.') + # Disable gloo process groups for fake process group + args.enable_gloo_process_groups = False + warn_rank_0('enable_gloo_process_groups is set to False for fake process group.') # Checkpointing if args.ckpt_fully_parallel_save_deprecated and args.rank == 0: @@ -2746,6 +2754,10 @@ def _add_distributed_args(parser): "and must be consistent across all ranks.") group.add_argument('--replication-factor', default=2, type=int, help="Number of machines storing the replica of a given rank's data.") + group.add_argument('--fake-process-group', action='store_true', default=False, + help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \ + This is quite useful for profiling memory usage of distributed training with just one GPU. \ + Setting WORLD_SIZE and RANK to the specific values for target distribtued scale.') return parser diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 1e14926a2f9..96594b5194d 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -346,6 +346,12 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s 'rank': args.rank, 'timeout': timedelta(minutes=args.distributed_timeout_minutes), } + if args.fake_process_group: + assert is_torch_min_version("2.3.0"), "Fake process group is only supported with PyTorch 2.3.0 and above." + from torch.testing._internal.distributed.fake_pg import FakeStore + store = FakeStore() + init_process_group_kwargs['backend'] = 'fake' + init_process_group_kwargs['store'] = store torch.distributed.init_process_group(**init_process_group_kwargs) inprocess_restart.maybe_force_nccl_backend_init(device_id) diff --git a/megatron/training/training.py b/megatron/training/training.py index b162aa87acf..b8265c54071 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1619,7 +1619,7 @@ def training_log( mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict ) if iteration % args.log_interval == 0: - if args.record_memory_history and is_last_rank(): + if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'): snapshot = torch.cuda.memory._snapshot() from pickle import dump @@ -1700,7 +1700,9 @@ def training_log( num_microbatches = get_num_microbatches() report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) report_memory(f'(after {iteration} iterations)') - report_memory_flag = False + if iteration > 1: + # Make sure the memory after the second iteration is reported to include optimizer state memory. + report_memory_flag = False # Write timers to wandb, don't reset the counts if args.log_timers_to_tensorboard: timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False) diff --git a/megatron/training/utils.py b/megatron/training/utils.py index cef71160791..cc4560a7e3a 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -410,7 +410,7 @@ def is_last_rank(): def print_rank_last(message): """If distributed is initialized, print only on last rank.""" - if torch.distributed.is_initialized(): + if torch.distributed.is_initialized() and torch.distributed.get_backend() != 'fake': if is_last_rank(): print(message, flush=True) else: