-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Make gradient-checkpoint enabling tolerant of models without get_input_embeddings #42558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
added a test as well - but can't find a clean way around the models for which it is not relevant to have a getter method and not causing as many side-effects. WDYT @zucchini-nlp ? kind of stumped (try/excepting at higher level would always work but hides a lot) |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
models for which it is not relevant to have a getter method and not causing as many side-effects
Would this also mean that we can't support correctly PEFT and GC with these models, or do they have a custom way to set grad on the inputs? We could raise an error with a better message saying that models doesn't support unless it has a way to get its input embeddings, wdyt?
| base_model = getattr(self, "base_model_prefix", None) | ||
| if base_model is not None: | ||
| base_model = getattr(self, base_model, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: self.base_model property has the same functionality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true!
src/transformers/modeling_utils.py
Outdated
| _input_embed_layer = "embed_tokens" # default layer that holds input embeddings. | ||
|
|
||
| def get_input_embeddings(self) -> nn.Module: | ||
| def _get_input_embeddings_no_raise(self) -> Optional[nn.Module]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh interesting, I was assuming the base get_input_embedding already returns None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well I ended up in some many little edge cases lol
|
Yes it's a good idea to raise/inform for downstream users. I reverted a couple things and will update the test so it actually checks that enabling GC works (probably add another test) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly to fix a broken env situation that can be caused around timm_wrapper (or timm_backbone?) so it protects a few imports
|
I reverted a few models to inner positional embeddings calls as mentioned in #38913 . Modified a few others models as the test I added ( Hopefully that helps VLMs + GC and does not break adapters |
| try: | ||
| input_embeddings = module.get_input_embeddings() | ||
| except NotImplementedError: | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no simple way around this unfortunately
| if not found_embeddings: | ||
| logger.warning_once( | ||
| f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token " | ||
| "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully " | ||
| "support those features." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at least we can warn users!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: align, altclip, clvp, falcon_mamba, fast_vlm, internvl, layoutlm, layoutlmv3, lilt, mamba, mlcd, poolformer, siglip, siglip2, splinter, switch_transformers |
What does this PR do?
Ad title indicates, #42542 and likely a few other models are broken by merged #41993 . This adds an embedding getter and attempts to test the feature with more coverage.
Basically what it does
Should help GC for PEFT adapters for many VLMs hopefully (and normal models too)