Skip to content

Conversation

@molbap
Copy link
Contributor

@molbap molbap commented Dec 2, 2025

What does this PR do?

Ad title indicates, #42542 and likely a few other models are broken by merged #41993 . This adds an embedding getter and attempts to test the feature with more coverage.

Basically what it does

  • Stop hard-failing gradient_checkpointing_enable when a model lacks a get_input_embeddings. We now just call enable_input_require_grads, let it attach hooks where it can, and issue a single warning if no embedding module is found.
  • Simplify enable_input_require_grads (and the InternVL/MLCD and a couple more model overrides/adjustments) by making them responsible for the warning.
  • Adds a big test to make sure all of that works (please take a look)

Should help GC for PEFT adapters for many VLMs hopefully (and normal models too)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@molbap
Copy link
Contributor Author

molbap commented Dec 2, 2025

added a test as well - but can't find a clean way around the models for which it is not relevant to have a getter method and not causing as many side-effects. WDYT @zucchini-nlp ? kind of stumped (try/excepting at higher level would always work but hides a lot)

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

models for which it is not relevant to have a getter method and not causing as many side-effects

Would this also mean that we can't support correctly PEFT and GC with these models, or do they have a custom way to set grad on the inputs? We could raise an error with a better message saying that models doesn't support unless it has a way to get its input embeddings, wdyt?

Comment on lines +985 to +987
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.base_model property has the same functionality

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true!

_input_embed_layer = "embed_tokens" # default layer that holds input embeddings.

def get_input_embeddings(self) -> nn.Module:
def _get_input_embeddings_no_raise(self) -> Optional[nn.Module]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh interesting, I was assuming the base get_input_embedding already returns None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well I ended up in some many little edge cases lol

@molbap
Copy link
Contributor Author

molbap commented Dec 3, 2025

Yes it's a good idea to raise/inform for downstream users. I reverted a couple things and will update the test so it actually checks that enabling GC works (probably add another test)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to fix a broken env situation that can be caused around timm_wrapper (or timm_backbone?) so it protects a few imports

@molbap
Copy link
Contributor Author

molbap commented Dec 4, 2025

I reverted a few models to inner positional embeddings calls as mentioned in #38913 .

Modified a few others models as the test I added (test_enable_input_require_grads_with_gradient_checkpointing ) was a bit naive and I was just continue-ing, now it's a proper skip if the loss is undefined.

Hopefully that helps VLMs + GC and does not break adapters

Comment on lines +1987 to +1990
try:
input_embeddings = module.get_input_embeddings()
except NotImplementedError:
continue
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no simple way around this unfortunately

Comment on lines +2007 to +2011
if not found_embeddings:
logger.warning_once(
f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
"support those features."
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at least we can warn users!

@github-actions
Copy link
Contributor

github-actions bot commented Dec 4, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: align, altclip, clvp, falcon_mamba, fast_vlm, internvl, layoutlm, layoutlmv3, lilt, mamba, mlcd, poolformer, siglip, siglip2, splinter, switch_transformers

@molbap molbap changed the title Add embedding getter + test Make gradient-checkpoint enabling tolerant of models without get_input_embeddings Dec 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants