Skip to content
56 changes: 17 additions & 39 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import nn

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -303,8 +303,8 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

Expand All @@ -317,15 +317,6 @@ def forward(
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
Expand All @@ -337,10 +328,9 @@ def forward(
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
Expand Down Expand Up @@ -379,8 +369,8 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.FloatTensor]:
"""
Args:
Expand All @@ -398,8 +388,8 @@ def forward(
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + hidden_states

Expand All @@ -425,7 +415,7 @@ class CLIPPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
_supports_attention_backend = True
_supports_attention_backend = False # kwargs are not supported throughout all modules

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -503,9 +493,9 @@ def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> BaseModelOutput:
r"""
Args:
Expand All @@ -519,13 +509,6 @@ def forward(
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Expand All @@ -545,14 +528,15 @@ def forward(
all_attentions = () if output_attentions else None

hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -604,23 +588,20 @@ def forward(

hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)

# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=torch.arange(hidden_states.shape[1], device=hidden_states.device),
past_key_values=None,
)

# expand attention_mask
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
is_causal=True,
)

last_hidden_state = encoder_outputs.last_hidden_state
Expand Down Expand Up @@ -666,7 +647,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
input_modalities = "text"

_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager

def __init__(self, config: CLIPTextConfig):
super().__init__(config)
Expand Down Expand Up @@ -825,7 +805,6 @@ def forward(
class CLIPModel(CLIPPreTrainedModel):
config: CLIPConfig
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
_supports_flash_attn = False # mask creation only accounts for sdpa/eager

def __init__(self, config: CLIPConfig):
super().__init__(config)
Expand Down Expand Up @@ -1034,7 +1013,6 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
config: CLIPTextConfig
input_modalities = "text"

_supports_flash_attn = False
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]

def __init__(self, config: CLIPTextConfig):
Expand Down