2222from torch import nn
2323
2424from ...activations import ACT2FN
25- from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask , _prepare_4d_attention_mask
25+ from ...masking_utils import create_causal_mask
2626from ...modeling_layers import GradientCheckpointingLayer
2727from ...modeling_outputs import BaseModelOutput , BaseModelOutputWithPooling , ImageClassifierOutput
2828from ...modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
@@ -310,7 +310,6 @@ def forward(
310310 self ,
311311 hidden_states : torch .Tensor ,
312312 attention_mask : Optional [torch .Tensor ] = None ,
313- causal_attention_mask : Optional [torch .Tensor ] = None ,
314313 ** kwargs : Unpack [TransformersKwargs ],
315314 ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
316315 """Input shape: Batch x Time x Channel"""
@@ -324,15 +323,6 @@ def forward(
324323 queries = queries .view (batch_size , seq_length , - 1 , self .head_dim ).transpose (1 , 2 )
325324 keys = keys .view (batch_size , seq_length , - 1 , self .head_dim ).transpose (1 , 2 )
326325 values = values .view (batch_size , seq_length , - 1 , self .head_dim ).transpose (1 , 2 )
327- # CLIP text model uses both `causal_attention_mask` and `attention_mask`
328- # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
329- if self .config ._attn_implementation == "flash_attention_2" :
330- self .is_causal = causal_attention_mask is not None
331- else :
332- if attention_mask is not None and causal_attention_mask is not None :
333- attention_mask = attention_mask + causal_attention_mask
334- elif causal_attention_mask is not None :
335- attention_mask = causal_attention_mask
336326
337327 attention_interface : Callable = eager_attention_forward
338328 if self .config ._attn_implementation != "eager" :
@@ -344,13 +334,12 @@ def forward(
344334 keys ,
345335 values ,
346336 attention_mask ,
347- is_causal = self .is_causal ,
348337 scaling = self .scale ,
349338 dropout = 0.0 if not self .training else self .dropout ,
350339 ** kwargs ,
351340 )
352341
353- attn_output = attn_output .reshape (batch_size , seq_length , embed_dim ).contiguous ()
342+ attn_output = attn_output .reshape (batch_size , seq_length , - 1 ).contiguous ()
354343 attn_output = self .out_proj (attn_output )
355344
356345 return attn_output , attn_weights
@@ -384,16 +373,14 @@ def forward(
384373 self ,
385374 hidden_states : torch .Tensor ,
386375 attention_mask : torch .Tensor ,
387- causal_attention_mask : torch .Tensor ,
388376 ** kwargs : Unpack [TransformersKwargs ],
389377 ) -> torch .FloatTensor :
390378 residual = hidden_states
391379
392380 hidden_states = self .layer_norm1 (hidden_states )
393- hidden_states , attn_weights = self .self_attn (
381+ hidden_states , _ = self .self_attn (
394382 hidden_states = hidden_states ,
395383 attention_mask = attention_mask ,
396- causal_attention_mask = causal_attention_mask ,
397384 ** kwargs ,
398385 )
399386 hidden_states = residual + hidden_states
@@ -497,7 +484,6 @@ def forward(
497484 self ,
498485 inputs_embeds ,
499486 attention_mask : Optional [torch .Tensor ] = None ,
500- causal_attention_mask : Optional [torch .Tensor ] = None ,
501487 ** kwargs : Unpack [TransformersKwargs ],
502488 ) -> BaseModelOutput :
503489 r"""
@@ -512,21 +498,13 @@ def forward(
512498 - 1 for tokens that are **not masked**,
513499 - 0 for tokens that are **masked**.
514500
515- [What are attention masks?](../glossary#attention-mask)
516- causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
517- Causal mask for the text model. Mask values selected in `[0, 1]`:
518-
519- - 1 for tokens that are **not masked**,
520- - 0 for tokens that are **masked**.
521-
522501 [What are attention masks?](../glossary#attention-mask)
523502 """
524503 hidden_states = inputs_embeds
525504 for encoder_layer in self .layers :
526505 hidden_states = encoder_layer (
527506 hidden_states ,
528507 attention_mask ,
529- causal_attention_mask ,
530508 ** kwargs ,
531509 )
532510
@@ -563,17 +541,19 @@ def forward(
563541
564542 hidden_states = self .embeddings (input_ids = input_ids , position_ids = position_ids )
565543
566- causal_attention_mask = _create_4d_causal_attention_mask (
567- input_shape , hidden_states .dtype , device = hidden_states .device
544+ attention_mask = create_causal_mask (
545+ config = self .config ,
546+ input_embeds = hidden_states ,
547+ attention_mask = attention_mask ,
548+ cache_position = torch .arange (hidden_states .shape [1 ], device = hidden_states .device ),
549+ past_key_values = None ,
568550 )
569551
570- if attention_mask is not None and self .config ._attn_implementation != "flash_attention_2" :
571- attention_mask = _prepare_4d_attention_mask (attention_mask , hidden_states .dtype )
572-
552+ kwargs .pop ("is_causal" , None )
573553 encoder_outputs : BaseModelOutput = self .encoder (
574554 inputs_embeds = hidden_states ,
575555 attention_mask = attention_mask ,
576- causal_attention_mask = causal_attention_mask ,
556+ is_causal = True ,
577557 ** kwargs ,
578558 )
579559
@@ -618,7 +598,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
618598 input_modalities = "text"
619599
620600 _no_split_modules = ["CLIPTextEmbeddings" , "CLIPEncoderLayer" ]
621- _supports_flash_attn = False # mask creation only accounts for sdpa/eager
622601
623602 def __init__ (self , config : CLIPTextConfig ):
624603 super ().__init__ (config )
@@ -632,8 +611,7 @@ def get_input_embeddings(self) -> nn.Module:
632611 def set_input_embeddings (self , value ):
633612 self .text_model .embeddings .token_embedding = value
634613
635- @check_model_inputs ()
636- @can_return_tuple
614+ @check_model_inputs (tie_last_hidden_states = False )
637615 @auto_docstring
638616 def forward (
639617 self ,
@@ -726,7 +704,6 @@ def get_input_embeddings(self) -> nn.Module:
726704 return self .vision_model .embeddings .patch_embedding
727705
728706 @check_model_inputs (tie_last_hidden_states = False )
729- @can_return_tuple
730707 @auto_docstring
731708 def forward (
732709 self ,
@@ -766,7 +743,6 @@ def forward(
766743class CLIPModel (CLIPPreTrainedModel ):
767744 config : CLIPConfig
768745 _no_split_modules = ["CLIPTextEmbeddings" , "CLIPEncoderLayer" , "CLIPVisionEmbeddings" ]
769- _supports_flash_attn = False # mask creation only accounts for sdpa/eager
770746
771747 def __init__ (self , config : CLIPConfig ):
772748 super ().__init__ (config )
@@ -966,7 +942,6 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
966942 config : CLIPTextConfig
967943 input_modalities = "text"
968944
969- _supports_flash_attn = False
970945 _no_split_modules = ["CLIPTextEmbeddings" , "CLIPEncoderLayer" ]
971946
972947 def __init__ (self , config : CLIPTextConfig ):
@@ -986,8 +961,7 @@ def get_input_embeddings(self) -> nn.Module:
986961 def set_input_embeddings (self , value ):
987962 self .text_model .embeddings .token_embedding = value
988963
989- @check_model_inputs ()
990- @can_return_tuple
964+ @check_model_inputs (tie_last_hidden_states = False )
991965 @auto_docstring
992966 def forward (
993967 self ,
@@ -1049,7 +1023,6 @@ def get_input_embeddings(self) -> nn.Module:
10491023 return self .vision_model .embeddings .patch_embedding
10501024
10511025 @check_model_inputs (tie_last_hidden_states = False )
1052- @can_return_tuple
10531026 @auto_docstring
10541027 def forward (
10551028 self ,
@@ -1117,8 +1090,7 @@ def __init__(self, config: CLIPConfig) -> None:
11171090 # Initialize weights and apply final processing
11181091 self .post_init ()
11191092
1120- @check_model_inputs ()
1121- @can_return_tuple
1093+ @check_model_inputs (tie_last_hidden_states = False )
11221094 @auto_docstring
11231095 def forward (
11241096 self ,
0 commit comments