Skip to content

Conversation

@dxqb
Copy link

@dxqb dxqb commented Dec 2, 2025

addresses #12776

What does this PR do?

This PR keeps the tuples, but moves the splitting from tensors into tuples of tensors to the transformer blocks, to avoid issues with checkpointing. By passing a tensor directly, torch.utils.checkpoint() identifies the tensor and saves it accordingly without running a backward through it multiple times.

This is a draft. If you agree with this change I can make it nicer. Among other things:

  • type hints are incorrect
  • splitting might not be necessary anymore, because they are used immediately after

Who can review?

@yiyixuxu and @asomoza

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant