Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
is_librosa_available,
is_mistral_common_available,
is_mlx_available,
is_numba_available,
is_pretty_midi_available,
)

Expand Down
65 changes: 34 additions & 31 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,51 +969,51 @@ def get_input_embeddings(self) -> nn.Module:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""

# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
# for most NLP models), and if so, return it.

name = getattr(self, "_input_embed_layer", "embed_tokens")

# 1) Direct attribute (most NLP models).
if (default_embedding := getattr(self, name, None)) is not None:
return default_embedding
# 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
return getattr(self.embeddings, name)
# 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
if hasattr(self, "model") and hasattr(self.model, name):
return getattr(self.model, name)

if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
return self.model.embed_tokens
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)
Comment on lines +984 to +986
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.base_model property has the same functionality

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true!

if base_model is not None and base_model is not self:
return base_model.get_input_embeddings()

# 3) vanilla decoder‑only architectures
elif hasattr(self, "embed_tokens"):
return self.embed_tokens
else:
base_model = getattr(self, "base_model_prefix", None)
if base_model is not None:
base_model = getattr(self, base_model, None)
if base_model is not None and base_model is not self:
return base_model.get_input_embeddings()
raise NotImplementedError(
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
"please override in the subclass."
)
raise NotImplementedError(
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
)

def set_input_embeddings(self, value: nn.Module):
"""Fallback setter that handles **~70%** of models in the code-base.

Order of attempts:
1. `self.model.embed_tokens`
2. `self.embed_tokens`
3. delegate to the *base model* if one exists
4. otherwise raise `NotImplementedError` so subclasses still can (and
1. `self.<_input_embed_layer>` (direct attribute)
2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
3. `self.model.<_input_embed_layer>` (encoder/decoder models)
4. delegate to the *base model* if one exists
5. otherwise raise `NotImplementedError` so subclasses still can (and
should) override for exotic layouts.
"""

# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
name = getattr(self, "_input_embed_layer", "embed_tokens")
if hasattr(self, "model") and hasattr(self.model, name):
setattr(self.model, name, value)
# 2) as well as vanilla decoder‑only architectures
elif hasattr(self, name):
# 1) Direct attribute (most NLP models)
if hasattr(self, name):
setattr(self, name, value)
# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
setattr(self.embeddings, name, value)
# 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
elif hasattr(self, "model") and hasattr(self.model, name):
setattr(self.model, name, value)
# 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
elif getattr(self, self.base_model_prefix, self) is not self:
base_model = getattr(self, self.base_model_prefix, self)
base_model.set_input_embeddings(value)
Expand Down Expand Up @@ -1983,9 +1983,12 @@ def make_inputs_require_grads(module, input, output):
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
continue

input_embeddings = module.get_input_embeddings()
try:
input_embeddings = module.get_input_embeddings()
except NotImplementedError:
continue
Comment on lines +1987 to +1990
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no simple way around this unfortunately


if input_embeddings is None:
if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
continue

embedding_id = id(input_embeddings)
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ class AlignVisionModel(AlignPreTrainedModel):
main_input_name = "pixel_values"
input_modalities = ("image",)
supports_gradient_checkpointing = False
_input_embed_layer = "convolution"

def __init__(self, config: AlignVisionConfig):
super().__init__(config)
Expand All @@ -994,9 +995,6 @@ def __init__(self, config: AlignVisionConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.convolution

@can_return_tuple
@auto_docstring
def forward(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/fast_vlm/modeling_fast_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, image_features):
class FastVlmPreTrainedModel(PreTrainedModel):
config: FastVlmConfig
base_model_prefix = "model"
input_modalities = ["image", "text"]
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"

Expand Down
18 changes: 18 additions & 0 deletions src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -982,6 +988,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1101,6 +1113,12 @@ def __init__(self, config):

self.post_init()

def get_input_embeddings(self):
return self.layoutlmv3.get_input_embeddings()

def set_input_embeddings(self, value):
self.layoutlmv3.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/poolformer/modeling_poolformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def __init__(self, config):
self.post_init()

def get_input_embeddings(self):
return self.embeddings.patch_embeddings
# Input embeddings correspond to the very first patch-embedding stage.
return self.encoder.patch_embeddings[0]

def set_input_embeddings(self, value):
self.encoder.patch_embeddings[0] = value

@auto_docstring
def forward(
Expand Down Expand Up @@ -332,6 +336,12 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.poolformer.get_input_embeddings()

def set_input_embeddings(self, value):
self.poolformer.set_input_embeddings(value)

@auto_docstring
def forward(
self,
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/models/siglip/modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,11 @@ def forward(
return BaseModelOutput(last_hidden_state=hidden_states)


class SiglipTextTransformer(nn.Module):
class SiglipTextTransformer(SiglipPreTrainedModel):
_input_embed_layer = "token_embedding"

def __init__(self, config: SiglipTextConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipTextEmbeddings(config)
Expand Down Expand Up @@ -614,6 +616,7 @@ def forward(


class SiglipVisionTransformer(SiglipPreTrainedModel):
_input_embed_layer = "patch_embedding"
_can_record_outputs = {
"hidden_states": SiglipEncoderLayer,
"attentions": SiglipAttention,
Expand Down Expand Up @@ -774,6 +777,12 @@ def __init__(self, config: SiglipConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding

def set_input_embeddings(self, value: nn.Module):
self.text_model.embeddings.token_embedding = value

@filter_out_non_signature_kwargs()
@auto_docstring
def get_text_features(
Expand Down Expand Up @@ -969,6 +978,12 @@ def __init__(self, config: SiglipConfig) -> None:
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

def set_input_embeddings(self, value: nn.Module):
self.vision_model.embeddings.patch_embedding = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/models/siglip2/modeling_siglip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def forward(


class Siglip2VisionTransformer(Siglip2PreTrainedModel):
_input_embed_layer = "patch_embedding"
_can_record_outputs = {
"hidden_states": Siglip2EncoderLayer,
"attentions": Siglip2Attention,
Expand Down Expand Up @@ -588,9 +589,11 @@ def forward(
return embeddings


class Siglip2TextTransformer(nn.Module):
class Siglip2TextTransformer(Siglip2PreTrainedModel):
_input_embed_layer = "token_embedding"

def __init__(self, config: Siglip2TextConfig):
super().__init__()
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2TextEmbeddings(config)
Expand Down Expand Up @@ -831,6 +834,12 @@ def __init__(self, config: Siglip2Config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding

def set_input_embeddings(self, value: nn.Module):
self.text_model.embeddings.token_embedding = value

@filter_out_non_signature_kwargs()
@auto_docstring
def get_text_features(
Expand Down Expand Up @@ -1048,6 +1057,12 @@ def __init__(self, config: Siglip2Config) -> None:
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

def set_input_embeddings(self, value: nn.Module):
self.vision_model.embeddings.patch_embedding = value

@check_model_inputs
@auto_docstring
def forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
_input_embed_layer = "shared"

def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
Expand All @@ -932,9 +933,6 @@ def __init__(self, config: SwitchTransformersConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
_input_embed_layer = "shared"

def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
Expand All @@ -688,9 +689,6 @@ def __init__(self, config: SwitchTransformersConfig):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def _timm_model_supports_gradient_checkpointing(self):
def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs):
self.timm_model.set_grad_checkpointing(enable)

def get_input_embeddings(self):
# TIMM backbones operate directly on images and do not expose token embeddings.
return None

def set_input_embeddings(self, value):
raise NotImplementedError("TimmWrapper models do not own token embeddings and cannot set them.")


class TimmWrapperModel(TimmWrapperPreTrainedModel):
"""
Expand All @@ -150,13 +157,6 @@ def __init__(self, config: TimmWrapperConfig):
self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs)
self.post_init()

def get_input_embeddings(self):
# Vision backbones from timm do not expose token embeddings, so there is nothing to return.
return None

def set_input_embeddings(self, value):
raise NotImplementedError("TimmWrapperModel does not own token embeddings and cannot set them.")

@auto_docstring
def forward(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
is_mistral_common_available,
is_natten_available,
is_nltk_available,
is_numba_available,
is_onnx_available,
is_openai_available,
is_optimum_available,
Expand Down Expand Up @@ -1386,6 +1387,13 @@ def require_pyctcdecode(test_case):
return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)


def require_numba(test_case):
"""
Decorator marking a test that requires numba
"""
return unittest.skipUnless(is_numba_available(), "test requires numba")(test_case)


def require_librosa(test_case):
"""
Decorator marking a test that requires librosa
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
is_ninja_available,
is_nltk_available,
is_num2words_available,
is_numba_available,
is_onnx_available,
is_openai_available,
is_optimum_available,
Expand Down
Loading