Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 14 additions & 42 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import nn

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -310,7 +310,6 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
Expand All @@ -324,15 +323,6 @@ def forward(
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
Expand All @@ -344,13 +334,12 @@ def forward(
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we use -1 for the batchsize as text can be ragged but not an issue

attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down Expand Up @@ -384,16 +373,14 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -497,7 +484,6 @@ def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
Expand All @@ -512,21 +498,13 @@ def forward(
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
"""
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
**kwargs,
)

Expand Down Expand Up @@ -563,17 +541,19 @@ def forward(

hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)

causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=torch.arange(hidden_states.shape[1], device=hidden_states.device),
past_key_values=None,
)

if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

kwargs.pop("is_causal", None)
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
is_causal=True,
**kwargs,
)

Expand Down Expand Up @@ -618,7 +598,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
input_modalities = "text"

_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager

def __init__(self, config: CLIPTextConfig):
super().__init__(config)
Expand All @@ -632,8 +611,7 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value

@check_model_inputs()
@can_return_tuple
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -726,7 +704,6 @@ def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -766,7 +743,6 @@ def forward(
class CLIPModel(CLIPPreTrainedModel):
config: CLIPConfig
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager

def __init__(self, config: CLIPConfig):
super().__init__(config)
Expand Down Expand Up @@ -966,7 +942,6 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
config: CLIPTextConfig
input_modalities = "text"

_supports_flash_attn = False
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]

def __init__(self, config: CLIPTextConfig):
Expand All @@ -986,8 +961,7 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value

@check_model_inputs()
@can_return_tuple
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1049,7 +1023,6 @@ def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

@check_model_inputs(tie_last_hidden_states=False)
@can_return_tuple
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1117,8 +1090,7 @@ def __init__(self, config: CLIPConfig) -> None:
# Initialize weights and apply final processing
self.post_init()

@check_model_inputs()
@can_return_tuple
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
Expand Down
Loading