Skip to content

Commit 7a833d1

Browse files
authored
🚨 [Clip] Fix masking and enable flash attention on all model types (#41750)
* fix * make kwargs fully passed and adjust with outputs xxx * propogate metaclip 2 * propogate mlcd and fix test * style * fix repo consistency, need to add ignore rules as those are building blocks * style * oops * fix mlcd
1 parent 8bde822 commit 7a833d1

File tree

6 files changed

+180
-403
lines changed

6 files changed

+180
-403
lines changed

src/transformers/models/clip/modeling_clip.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch import nn
2323

2424
from ...activations import ACT2FN
25-
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
25+
from ...masking_utils import create_causal_mask
2626
from ...modeling_layers import GradientCheckpointingLayer
2727
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
2828
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -310,7 +310,6 @@ def forward(
310310
self,
311311
hidden_states: torch.Tensor,
312312
attention_mask: Optional[torch.Tensor] = None,
313-
causal_attention_mask: Optional[torch.Tensor] = None,
314313
**kwargs: Unpack[TransformersKwargs],
315314
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
316315
"""Input shape: Batch x Time x Channel"""
@@ -324,15 +323,6 @@ def forward(
324323
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
325324
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
326325
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
327-
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
328-
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
329-
if self.config._attn_implementation == "flash_attention_2":
330-
self.is_causal = causal_attention_mask is not None
331-
else:
332-
if attention_mask is not None and causal_attention_mask is not None:
333-
attention_mask = attention_mask + causal_attention_mask
334-
elif causal_attention_mask is not None:
335-
attention_mask = causal_attention_mask
336326

337327
attention_interface: Callable = eager_attention_forward
338328
if self.config._attn_implementation != "eager":
@@ -344,13 +334,12 @@ def forward(
344334
keys,
345335
values,
346336
attention_mask,
347-
is_causal=self.is_causal,
348337
scaling=self.scale,
349338
dropout=0.0 if not self.training else self.dropout,
350339
**kwargs,
351340
)
352341

353-
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
342+
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
354343
attn_output = self.out_proj(attn_output)
355344

356345
return attn_output, attn_weights
@@ -384,16 +373,14 @@ def forward(
384373
self,
385374
hidden_states: torch.Tensor,
386375
attention_mask: torch.Tensor,
387-
causal_attention_mask: torch.Tensor,
388376
**kwargs: Unpack[TransformersKwargs],
389377
) -> torch.FloatTensor:
390378
residual = hidden_states
391379

392380
hidden_states = self.layer_norm1(hidden_states)
393-
hidden_states, attn_weights = self.self_attn(
381+
hidden_states, _ = self.self_attn(
394382
hidden_states=hidden_states,
395383
attention_mask=attention_mask,
396-
causal_attention_mask=causal_attention_mask,
397384
**kwargs,
398385
)
399386
hidden_states = residual + hidden_states
@@ -497,7 +484,6 @@ def forward(
497484
self,
498485
inputs_embeds,
499486
attention_mask: Optional[torch.Tensor] = None,
500-
causal_attention_mask: Optional[torch.Tensor] = None,
501487
**kwargs: Unpack[TransformersKwargs],
502488
) -> BaseModelOutput:
503489
r"""
@@ -512,21 +498,13 @@ def forward(
512498
- 1 for tokens that are **not masked**,
513499
- 0 for tokens that are **masked**.
514500
515-
[What are attention masks?](../glossary#attention-mask)
516-
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
517-
Causal mask for the text model. Mask values selected in `[0, 1]`:
518-
519-
- 1 for tokens that are **not masked**,
520-
- 0 for tokens that are **masked**.
521-
522501
[What are attention masks?](../glossary#attention-mask)
523502
"""
524503
hidden_states = inputs_embeds
525504
for encoder_layer in self.layers:
526505
hidden_states = encoder_layer(
527506
hidden_states,
528507
attention_mask,
529-
causal_attention_mask,
530508
**kwargs,
531509
)
532510

@@ -563,17 +541,19 @@ def forward(
563541

564542
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
565543

566-
causal_attention_mask = _create_4d_causal_attention_mask(
567-
input_shape, hidden_states.dtype, device=hidden_states.device
544+
attention_mask = create_causal_mask(
545+
config=self.config,
546+
input_embeds=hidden_states,
547+
attention_mask=attention_mask,
548+
cache_position=torch.arange(hidden_states.shape[1], device=hidden_states.device),
549+
past_key_values=None,
568550
)
569551

570-
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
571-
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
572-
552+
kwargs.pop("is_causal", None)
573553
encoder_outputs: BaseModelOutput = self.encoder(
574554
inputs_embeds=hidden_states,
575555
attention_mask=attention_mask,
576-
causal_attention_mask=causal_attention_mask,
556+
is_causal=True,
577557
**kwargs,
578558
)
579559

@@ -618,7 +598,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
618598
input_modalities = "text"
619599

620600
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
621-
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
622601

623602
def __init__(self, config: CLIPTextConfig):
624603
super().__init__(config)
@@ -632,8 +611,7 @@ def get_input_embeddings(self) -> nn.Module:
632611
def set_input_embeddings(self, value):
633612
self.text_model.embeddings.token_embedding = value
634613

635-
@check_model_inputs()
636-
@can_return_tuple
614+
@check_model_inputs(tie_last_hidden_states=False)
637615
@auto_docstring
638616
def forward(
639617
self,
@@ -726,7 +704,6 @@ def get_input_embeddings(self) -> nn.Module:
726704
return self.vision_model.embeddings.patch_embedding
727705

728706
@check_model_inputs(tie_last_hidden_states=False)
729-
@can_return_tuple
730707
@auto_docstring
731708
def forward(
732709
self,
@@ -766,7 +743,6 @@ def forward(
766743
class CLIPModel(CLIPPreTrainedModel):
767744
config: CLIPConfig
768745
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
769-
_supports_flash_attn = False # mask creation only accounts for sdpa/eager
770746

771747
def __init__(self, config: CLIPConfig):
772748
super().__init__(config)
@@ -966,7 +942,6 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
966942
config: CLIPTextConfig
967943
input_modalities = "text"
968944

969-
_supports_flash_attn = False
970945
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
971946

972947
def __init__(self, config: CLIPTextConfig):
@@ -986,8 +961,7 @@ def get_input_embeddings(self) -> nn.Module:
986961
def set_input_embeddings(self, value):
987962
self.text_model.embeddings.token_embedding = value
988963

989-
@check_model_inputs()
990-
@can_return_tuple
964+
@check_model_inputs(tie_last_hidden_states=False)
991965
@auto_docstring
992966
def forward(
993967
self,
@@ -1049,7 +1023,6 @@ def get_input_embeddings(self) -> nn.Module:
10491023
return self.vision_model.embeddings.patch_embedding
10501024

10511025
@check_model_inputs(tie_last_hidden_states=False)
1052-
@can_return_tuple
10531026
@auto_docstring
10541027
def forward(
10551028
self,
@@ -1117,8 +1090,7 @@ def __init__(self, config: CLIPConfig) -> None:
11171090
# Initialize weights and apply final processing
11181091
self.post_init()
11191092

1120-
@check_model_inputs()
1121-
@can_return_tuple
1093+
@check_model_inputs(tie_last_hidden_states=False)
11221094
@auto_docstring
11231095
def forward(
11241096
self,

0 commit comments

Comments
 (0)