-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Description
Follow up to #41668
System Info
transformersversion: 4.57.0- Platform: Linux-6.14.0-33-generic-x86_64-with-glibc2.39
- Python version: 3.12.3
- Huggingface_hub version: 1.0.0.rc4
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline
import torch
import os
model_path =os.getenv('MODEL_PATH')
print(model_path)
pipe = StableDiffusionXLPipeline.from_pretrained(
model_path,
variant="fp16",
torch_dtype=torch.float16
).to("cuda")
pipe.fuse_qkv_projections()
pipe.text_encoder.config._attn_implementation = "flash_attention_3"
pipe.text_encoder_2.config._attn_implementation = "flash_attention_3"
prompt = ["A city at night with people walking around."] * 8
image = pipe(prompt, num_inference_steps=1).images[0]
Expected behavior
When using the default attention implementation of SDPA with CLIP although is_causal is set to true, create_causal_mask generates a non-null attention mask leading is_causal to be set to False by SDPA and not lowering into SDPA's own FA implementation.
| is_causal = query.shape[2] > 1 and attention_mask is None and is_causal |
When setting the attention implementation to FA3, the following error is thrown
AssertionError: expected size 16==16, stride 177408==59136 at dim=0; expected size 77==77, stride 2304==768 at dim=1 Error in op: torch.ops.flash_attn.flash_attn_func.default
Setting .contiguous() to k,q,v resolves the error
transformers/src/transformers/integrations/flash_attention.py
Lines 56 to 58 in 37d48bb
| query = query.transpose(1, 2) | |
| key = key.transpose(1, 2) | |
| value = value.transpose(1, 2) |
and is similar to what is done in flash_paged
| q.transpose(1, 2).squeeze(0).contiguous(), |
and is done before calling flash_attn_varlen_func
transformers/src/transformers/modeling_flash_attention_utils.py
Lines 389 to 391 in 37d48bb
| query = query.contiguous().view(-1, query.size(-2), query.size(-1)) | |
| key = key.contiguous().view(-1, key.size(-2), key.size(-1)) | |
| value = value.contiguous().view(-1, value.size(-2), value.size(-1)) |
but not before calling flash_varlen_fn