Skip to content

Conversation

@parthmannan
Copy link
Contributor

@parthmannan parthmannan commented Nov 17, 2025

What does this PR do ?

Dev MR for details - #2054

Design document discussed in MCore sync meeting - https://docs.google.com/document/d/1MnIPQ_VbpDNp-adtvcEv-SYx6A8rtt3-fDdxbcdrmk0/edit?usp=sharing

The first issue this MR is trying to solve is the imbalance between DP ranks when using packed sequences (for example in SFT). While packing sequences can help reduce variability in total sequence length, it does not guarantee equal workload. Attention compute is quadratic to sequence length and a single long sequence of 1k has 2x more compute than a packed sequence made of 2x512 length. This problem gets much worse when we have very large sequences and/or a large variation between sequence lengths.
This MR schedules a variable number of microbatches per rank in DPxCP group to ensure balanced workload.

The second issue this MR is trying to solve is redundant CP communication. Our context parallel size is based on the full packed sequence length (usually the max seq length of all samples). For example, if a sequence of 1k requires CP2, we apply CP2 to a packed sequence of 2x512 as well. But in reality, we can easily partition the packed sequence of 2x512 into 2 GPUs by separating the 2 samples without any CP. This MR introduces dynamic context parallelism where each sample is individually scheduled with a dynamic CP group.

To achieve the above, we introduce a balanced scheduler and a dataloader wrapper.
The dataloader wrapper is responsible for collecting the metadata which informs the scheduler of the sequence length of each sample across the entire global batch. This dataloader breaks up the packed sequences into individual samples as they are individually scheduled. Once we have the metadata, we can perform the scheduling using the balanced scheduler which assigns samples to ranks (across DPxCP group) and a dynamic CP group size. To avoid any deadlocks, we divide the schedule into groups (this replaces the notion of microbatches). Within each group, each rank is part of a fixed CP group. However, each rank may run different number of samples in order for all ranks to have a balanced compute.

Screenshot 2025-10-08 at 3 21 39 PM

We have run performance and correctness evaluations using the feature. Using the SFT packed dataset with max seq len of 128k and testing with LLaMa3 8B dummy model, we see 3x performance improvement with this feature. While there is room for improving the baseline itself, the speedup should remain in the 2-3x range.

This is how 128k seq len with CP16 looks like (without this feature). The GPU is bound by CP communications.
Screenshot 2025-10-08 at 3 28 38 PM

This is how 128k seq len with CP16 looks like (with this feature). The GPU is bound by attention compute since all redundant comms have been removed.
Screenshot 2025-10-08 at 3 30 26 PM

Feature correctness (@xiaoyao0115)
hybrid_cp_loss_curve

This is the first milestone of this feature and there's many improvements that we want to make in the future releases.

  1. The feature does not support pipeline parallelism or FSDP yet. We hope to add PP support next.
  2. The feature is limited to creating dynamic groups of CP of power 2. We hope to add complete dynamic support using changes in TransformerEngine DPA.
  3. The feature does not support CUDA graphs.
  4. The feature works best with FlashAttention instead of cuDNN FusedAttention. This is because the changing lengths and CP size make cuDNN recompile the graph and all performance gains are lost. We'll advocate for dynamic support to cuDNN FusedAttention.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either [email protected] or [email protected].

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

parthmannan and others added 30 commits July 14, 2025 19:08
…ia.com:12051/ADLR/megatron-lm into pmannan/hetero_cp_test_sft
@parthmannan parthmannan requested review from a team as code owners November 17, 2025 23:37
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 17, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@duncanriach duncanriach self-requested a review November 18, 2025 00:47
@yanring yanring added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Nov 18, 2025
@yanring yanring added this to the Core 0.16 milestone Nov 18, 2025
Copy link
Contributor

@asolergi-nv asolergi-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for your contribution!

one_logger = get_one_logger()

if args.hybrid_context_parallel:
train_data_iterator = iter(HybridCPDataLoaderWrapper(train_data_iterator, config))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove the __iter__ method, WDYT?

Suggested change
train_data_iterator = iter(HybridCPDataLoaderWrapper(train_data_iterator, config))
train_data_iterator = HybridCPDataLoaderWrapper(train_data_iterator, config)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants