-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Add SDPA and FlashAttention support to T5 #42453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to be so strict about this but T5 is not a good candidate for flash attention / sdpa. The reason is that the relative attention bias has to be modeled there and as of now, it's not possible with base flash attention (might be possible with sdpa but needs proper mask preparation). tl;dr: It will only support eager attention in the end
We can still refactor this to have the attention interface-like implementation but only for eager in the end (i.e. _supports_sdpa/flash_attn remain False). Wdyt?
Sounds reasonable to me! |
|
Heys again @vasqu , I made the changes for restricting only eager attention. Model tests are passing, only repo consistency checks fail as I mentioned above. PR is ready for merge 😊 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: t5 |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial comments. Would be nice if we could go further to include the recorder and avoid unnecessary code along output_xxx.
| return hidden_states | ||
|
|
||
|
|
||
| def eager_attention_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather have the relative position bias within here, see #38301 or more specifically
transformers/src/transformers/models/bert/modeling_bert.py
Lines 121 to 176 in 1c3188f
| def eager_attention_forward( | |
| module: nn.Module, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| scaling: Optional[float] = None, | |
| dropout: float = 0.0, | |
| head_mask: Optional[torch.Tensor] = None, | |
| use_cache: Optional[bool] = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ): | |
| if scaling is None: | |
| scaling = query.size(-1) ** -0.5 | |
| # Take the dot product between "query" and "key" to get the raw attention scores. | |
| attn_weights = torch.matmul(query, key.transpose(2, 3)) | |
| # Relative positional embeddings | |
| if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": | |
| query_length, key_length = query.shape[2], key.shape[2] | |
| if use_cache: | |
| position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) | |
| else: | |
| position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) | |
| position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) | |
| distance = position_ids_l - position_ids_r | |
| positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) | |
| positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility | |
| if module.position_embedding_type == "relative_key": | |
| relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) | |
| attn_weights = attn_weights + relative_position_scores | |
| elif module.position_embedding_type == "relative_key_query": | |
| relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) | |
| relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) | |
| attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key | |
| # Scaling is shifted in case of embeddings being relative | |
| attn_weights = attn_weights * scaling | |
| if attention_mask is not None and attention_mask.ndim == 4: | |
| attention_mask = attention_mask[:, :, :, : key.shape[-2]] | |
| attn_weights = attn_weights + attention_mask | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | |
| if head_mask is not None: | |
| attn_weights = attn_weights * head_mask | |
| attn_output = torch.matmul(attn_weights, value) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| return attn_output, attn_weights |
| "when creating this class." | ||
| ) | ||
|
|
||
| self.scaling = self.d_model**-0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For completeness, we should have the is_causal flag here, you can look into Bart for this - i.e. encoder = False, decoder = False if self attn or True if cross attn.
| if self.config._attn_implementation != "eager": | ||
| logger.warning_once( | ||
| "T5 uses relative position bias; SDPA/FlashAttention not supported, fall back to eager." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should never happen as we don't support anything other than eager. I would even raise an error here if anything
| hidden_states: torch.FloatTensor, | ||
| key_value_states: Optional[torch.FloatTensor] = None, | ||
| past_key_values: Optional[torch.FloatTensor] = None, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| position_bias: Optional[torch.FloatTensor] = None, | ||
| query_length: Optional[torch.LongTensor] = None, | ||
| output_attentions: Optional[bool] = False, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| **kwargs: Unpack[FlashAttentionKwargs], | ||
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not rename here, this would break BC. The type annotations are fine by itself.
| config: T5Config | ||
| base_model_prefix = "transformer" | ||
| supports_gradient_checkpointing = True | ||
| _supports_attention_backend = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not supported - kwargs are not used everywhere so far and enc-dec will need another look
| attn_output = attn_output.view(batch_size, -1, self.inner_dim).contiguous() | ||
| attn_output = self.o(attn_output) | ||
|
|
||
| outputs = (attn_output, position_bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice if we could refactor this along in this PR, we have an Outputrecorder which can handle collecting the weights. We no longer need to explicitly have the kwargs then. You can take a look at other model like Llama or t5gemma2 which do this. In essence, you need decorators (check_model_input, can_return_tuple) and the respective flag _can_record_outputs.
I made some changes to the T5 modeling file to support new attention interface. I made a bit of rearrangements to employ
position_biascorrectly into the attention mask.Fixes #26350
A note though, I made a
make fix-copies, however it broke several related models such aslongt5andmt5. Somehow fix script didn't copy over the imports, couldn't grab the attention code correctly hence I skipped that part. If applicable we can merge this PR + I can work on related models in another PR or I'm happy to take some hints to make the script work properly.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
@ArthurZucker @Cyrilvallez @vasqu