Skip to content

CLIP issues with Flash Attention 3 pt.2 #42137

@akashpalla

Description

@akashpalla

Follow up to #41668

System Info

  • transformers version: 4.57.0
  • Platform: Linux-6.14.0-33-generic-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • Huggingface_hub version: 1.0.0.rc4

Who can help?

@ArthurZucker @vasqu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline
import torch
import os

model_path =os.getenv('MODEL_PATH')
print(model_path)
pipe = StableDiffusionXLPipeline.from_pretrained(
    model_path,
    variant="fp16",
    torch_dtype=torch.float16
).to("cuda")
pipe.fuse_qkv_projections()

pipe.text_encoder.config._attn_implementation = "flash_attention_3"
pipe.text_encoder_2.config._attn_implementation = "flash_attention_3"

prompt = ["A city at night with people walking around."] * 8
image = pipe(prompt, num_inference_steps=1).images[0]

Expected behavior

When using the default attention implementation of SDPA with CLIP although is_causal is set to true, create_causal_mask generates a non-null attention mask leading is_causal to be set to False by SDPA and not lowering into SDPA's own FA implementation.

is_causal = query.shape[2] > 1 and attention_mask is None and is_causal

When setting the attention implementation to FA3, the following error is thrown
AssertionError: expected size 16==16, stride 177408==59136 at dim=0; expected size 77==77, stride 2304==768 at dim=1 Error in op: torch.ops.flash_attn.flash_attn_func.default

Setting .contiguous() to k,q,v resolves the error

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

and is similar to what is done in flash_paged

q.transpose(1, 2).squeeze(0).contiguous(),

and is done before calling flash_attn_varlen_func

query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))

but not before calling flash_varlen_fn

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions