Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,14 @@ def validate_args(args, defaults={}):
assert not args.distrib_optim_fully_reshardable_mem_efficient, \
'--distrib-optim-fully-reshardable-mem-efficient requires -enable-gloo-process-groups'

if args.fake_process_group:
assert args.moe_token_dispatcher_type != "flex", "Fake process group is not supported with flex token dispatcher."
# Disable nan check for fake process group
args.check_for_nan_in_loss_and_grad = False
warn_rank_0('check_for_nan_in_loss_and_grad is set to False for fake process group.')
# Disable gloo process groups for fake process group
args.enable_gloo_process_groups = False
warn_rank_0('enable_gloo_process_groups is set to False for fake process group.')

# Checkpointing
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
Expand Down Expand Up @@ -2746,6 +2754,10 @@ def _add_distributed_args(parser):
"and must be consistent across all ranks.")
group.add_argument('--replication-factor', default=2, type=int,
help="Number of machines storing the replica of a given rank's data.")
group.add_argument('--fake-process-group', action='store_true', default=False,
help='If set, initialize with fake distributed process group and all distributed communication operations will be skipped. \
This is quite useful for profiling memory usage of distributed training with just one GPU. \
Setting WORLD_SIZE and RANK to the specific values for target distribtued scale.')
return parser


Expand Down
6 changes: 6 additions & 0 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
'rank': args.rank,
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
}
if args.fake_process_group:
assert is_torch_min_version("2.3.0"), "Fake process group is only supported with PyTorch 2.3.0 and above."
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
init_process_group_kwargs['backend'] = 'fake'
init_process_group_kwargs['store'] = store

torch.distributed.init_process_group(**init_process_group_kwargs)
inprocess_restart.maybe_force_nccl_backend_init(device_id)
Expand Down
6 changes: 4 additions & 2 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ def training_log(
mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
)
if iteration % args.log_interval == 0:
if args.record_memory_history and is_last_rank():
if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'):
snapshot = torch.cuda.memory._snapshot()
from pickle import dump

Expand Down Expand Up @@ -1700,7 +1700,9 @@ def training_log(
num_microbatches = get_num_microbatches()
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
report_memory(f'(after {iteration} iterations)')
report_memory_flag = False
if iteration > 1:
# Make sure the memory after the second iteration is reported to include optimizer state memory.
report_memory_flag = False
# Write timers to wandb, don't reset the counts
if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False)
Expand Down
2 changes: 1 addition & 1 deletion megatron/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def is_last_rank():

def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if torch.distributed.is_initialized() and torch.distributed.get_backend() != 'fake':
if is_last_rank():
print(message, flush=True)
else:
Expand Down
Loading