From 12bcd66916eb03e0ff5ac0800e9f4016afed91e9 Mon Sep 17 00:00:00 2001 From: "Dennis(Zhenhuan) Liu" Date: Tue, 18 Nov 2025 23:02:40 +0800 Subject: [PATCH 1/3] [DEV] Add support of fake distributed process group (#2254) --- megatron/training/arguments.py | 9 +++++++++ megatron/training/initialize.py | 5 +++++ megatron/training/training.py | 6 ++++-- megatron/training/utils.py | 2 +- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index b5f777a30c6..a5daf1dc2f2 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1103,6 +1103,11 @@ def validate_args(args, defaults={}): assert not args.distrib_optim_fully_reshardable_mem_efficient, \ '--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups' + if args.fake_process_group: + # Disable nan check for fake process group + args.check_for_nan_in_loss_and_grad = False + # Disable gloo process groups for fake process group + args.enable_gloo_process_groups = False # Checkpointing if args.ckpt_fully_parallel_save_deprecated and args.rank == 0: @@ -2746,6 +2751,10 @@ def _add_distributed_args(parser): "and must be consistent across all ranks.") group.add_argument('--replication-factor', default=2, type=int, help="Number of machines storing the replica of a given rank's data.") + group.add_argument('--fake-process-group', action='store_true', default=False, + help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \ + This is quite useful for profiling memory usage of distributed training with just one GPU. \ + Setting WORLD_SIZE and RANK to the specific values for target distribtued scale.') return parser diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 1e14926a2f9..8b585fdd87b 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -346,6 +346,11 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s 'rank': args.rank, 'timeout': timedelta(minutes=args.distributed_timeout_minutes), } + if args.fake_process_group: + from torch.testing._internal.distributed.fake_pg import FakeStore + store = FakeStore() + init_process_group_kwargs['backend'] = 'fake' + init_process_group_kwargs['store'] = store torch.distributed.init_process_group(**init_process_group_kwargs) inprocess_restart.maybe_force_nccl_backend_init(device_id) diff --git a/megatron/training/training.py b/megatron/training/training.py index b162aa87acf..b8265c54071 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1619,7 +1619,7 @@ def training_log( mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict ) if iteration % args.log_interval == 0: - if args.record_memory_history and is_last_rank(): + if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'): snapshot = torch.cuda.memory._snapshot() from pickle import dump @@ -1700,7 +1700,9 @@ def training_log( num_microbatches = get_num_microbatches() report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) report_memory(f'(after {iteration} iterations)') - report_memory_flag = False + if iteration > 1: + # Make sure the memory after the second iteration is reported to include optimizer state memory. + report_memory_flag = False # Write timers to wandb, don't reset the counts if args.log_timers_to_tensorboard: timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False) diff --git a/megatron/training/utils.py b/megatron/training/utils.py index cef71160791..cc4560a7e3a 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -410,7 +410,7 @@ def is_last_rank(): def print_rank_last(message): """If distributed is initialized, print only on last rank.""" - if torch.distributed.is_initialized(): + if torch.distributed.is_initialized() and torch.distributed.get_backend() != 'fake': if is_last_rank(): print(message, flush=True) else: From 521bfcfed805fe9f7c07c3b07c40758121a12576 Mon Sep 17 00:00:00 2001 From: Dennis Liu Date: Tue, 18 Nov 2025 19:58:25 -0800 Subject: [PATCH 2/3] Add version checks and warnings. --- megatron/training/arguments.py | 3 +++ megatron/training/initialize.py | 1 + 2 files changed, 4 insertions(+) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a5daf1dc2f2..1c8ec85deb6 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1104,10 +1104,13 @@ def validate_args(args, defaults={}): '--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups' if args.fake_process_group: + assert args.moe_token_dispatcher_type != "flex", "Fake process group is not supported with flex token dispatcher." # Disable nan check for fake process group args.check_for_nan_in_loss_and_grad = False + print('Warning: check_for_nan_in_loss_and_grad is set to False for fake process group.') # Disable gloo process groups for fake process group args.enable_gloo_process_groups = False + print('Warning: enable_gloo_process_groups is set to False for fake process group.') # Checkpointing if args.ckpt_fully_parallel_save_deprecated and args.rank == 0: diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 8b585fdd87b..96594b5194d 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -347,6 +347,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s 'timeout': timedelta(minutes=args.distributed_timeout_minutes), } if args.fake_process_group: + assert is_torch_min_version("2.3.0"), "Fake process group is only supported with PyTorch 2.3.0 and above." from torch.testing._internal.distributed.fake_pg import FakeStore store = FakeStore() init_process_group_kwargs['backend'] = 'fake' From ef9c9e79861c6b6d15913177ca0d3bcf70357672 Mon Sep 17 00:00:00 2001 From: Dennis Liu Date: Tue, 18 Nov 2025 23:57:03 -0800 Subject: [PATCH 3/3] print to warn_rank_0 --- megatron/training/arguments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 1c8ec85deb6..3aeade771c6 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1107,10 +1107,10 @@ def validate_args(args, defaults={}): assert args.moe_token_dispatcher_type != "flex", "Fake process group is not supported with flex token dispatcher." # Disable nan check for fake process group args.check_for_nan_in_loss_and_grad = False - print('Warning: check_for_nan_in_loss_and_grad is set to False for fake process group.') + warn_rank_0('check_for_nan_in_loss_and_grad is set to False for fake process group.') # Disable gloo process groups for fake process group args.enable_gloo_process_groups = False - print('Warning: enable_gloo_process_groups is set to False for fake process group.') + warn_rank_0('enable_gloo_process_groups is set to False for fake process group.') # Checkpointing if args.ckpt_fully_parallel_save_deprecated and args.rank == 0: