Skip to content

Conversation

@vasqu
Copy link
Contributor

@vasqu vasqu commented Nov 26, 2025

Fixes #42405
Closes #42404 as it has a lot of unnecessary logic and tests alongside it

Remaining TODOs:

  • Comp tests to fa2 and fa3
    • Seems like logits differ a lot, not even 1e-1 suffices --> mind that it's a full 1B model tho, local small tests pass
    • Generations stay the same tho
    • No real speed benefit on hopper, even slower than FA2...
  • Sanity check if the usage of partial destroys compile support for older fa versions
    • Works on torch 2.9.1 (maybe check more)
  • Handle FA2 detection as it clashes with FA4 atm
  • Refactor modeling utils? A lot of duplicated messages

First quick numbers (hopper)

# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
Latency:
    With FA2: 493.31220703125
    With FA3: 455.055615234375
    With FA4: 525.8052734375

NOTE:
I'm not really sure about FA4 support (see numbers above) so keeping this as draft for now! The overall logic should stay the same (even later on)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: edgetam, gemma3, modernbert, sam2, sam3, sam3_tracker

@vasqu
Copy link
Contributor Author

vasqu commented Nov 28, 2025

FA4 support cc @stas00 if you wanna play around with this PR. It's pretty much ready, just not convinced by the numbers but I also don't have quick access to a blackwell GPU (at least today :D)

Copy link
Contributor

@sfc-gh-sbekman sfc-gh-sbekman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this, Anton. Going to try it out.

To make it easier to try your PR please add to the OP how to install FA4, since it's non-trivial to find.

git clone https://github.com/Dao-AILab/flash-attention/
cd flash-attention
cd flash_attn/cute
uv build --wheel . -v --no-build-isolation --out-dir flash-attention/wheels
uv pip install flash-attention/wheels/flash_attn_cute*.whl --prerelease=allow

@sfc-gh-sbekman
Copy link
Contributor

sfc-gh-sbekman commented Dec 1, 2025

OK, gave it a test ride using your PR and the above comment's install of FA4 on B200.

I did a quick test with Llama-8b and the integration worked smoothly but the tflops performance is much worse than FA2 - 2-5x slower. Not sure if it's an issue with integration or the FA4 code or the pytorch version - most likely the upstream since the integration is just a wrapper

I tried pt-2.9.1-cu130 and pt-nightly-cu130 - same outcome

edit: sdpa in pt-nightly, which supposedly backported FA4, is about 3x faster than FA4 on its own using the same llama-8b - since they both should be using the same code, perhaps there is an issue with the integration?

@vasqu
Copy link
Contributor Author

vasqu commented Dec 1, 2025

Thanks for checking this out and all the pointers @sfc-gh-sbekman ❤️

To make it easier to try your PR please add to the OP how to install FA4, since it's non-trivial to find.

For sure, I'll add some docs for FA4 before release. Maybe also FA3 in a different PR.

I did a quick test with Llama-8b and the integration worked smoothly but the tflops performance is much worse than FA2 - 2-5x slower. Not sure if it's an issue with integration or the FA4 code or the pytorch version - most likely the upstream since the integration is just a wrapper

Shoot, so it wasn't an GPU arch issue... This is weird

sdpa in pt-nightly, which supposedly backported FA4, is about 3x faster than FA4 on its own using the same llama-8b - since they both should be using the same code, perhaps there is an issue with the integration?

Do you have a code snippet? There are so many edge cases with sdpa that it maybe is not even entering the FA backend path? Could be quickly checked by restricting the backend usage on SDPA with their context manager

with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
    pass  # do your thing here

I'm also unsure how FA4 is integrated in SDPA? Do we need to use a flag there? I remember that cudnn backend needed special treatment

@stas00
Copy link
Contributor

stas00 commented Dec 2, 2025

Shoot, so it wasn't an GPU arch issue... This is weird

Did you mean that you too have observed a similar slowdown?

Do you have a code snippet?

I was just using https://github.com/snowflakedb/ArcticTraining/ normal SFT training recipe where I tried different attention mechanisms. Just normal fwd/bwd/step - nothing special added.

I'm also unsure how FA4 is integrated in SDPA? Do we need to use a flag there? I remember that cudnn backend needed special treatment

They copied/adapted the FA4 kernels see: #42435 - you'd need pt nightly for that to work.

@vasqu
Copy link
Contributor Author

vasqu commented Dec 2, 2025

Did you mean that you too have observed a similar slowdown?

I just did some quick numbers on inference, see the test I noted down in the PR description. I used an H100 there and as you can see it's slower (not on the same magnitude as in your samples - would say it's a mixture of model size / context size)

I was just using https://github.com/snowflakedb/ArcticTraining/ normal SFT training recipe where I tried different attention mechanisms. Just normal fwd/bwd/step - nothing special added.

Gotcha, I will try to separate our implementation from the base fn of the FA library to see if our wrappers cause this or maybe some perf regression happened sometime else.

They copied/adapted the FA4 kernels see: #42435 - you'd need pt nightly for that to work.

Wrong link? My assumption / hunch was that maybe

  1. Even nightly might not use FA4 per default and sticks to FA2 per default, i.e. might need some extra flags to enable that specific backend. But that's just my feeling, need to look into it.
  2. Our implementations has some issues where attention masks are created even when it is not needed (full (causal) attention). If a mask is passed to SDPA, then the FA backend can never be entered per their restrictions. So I thought, maybe, we have this case (SDPA with xformers faster than FA4 - xformers is not so bad on short contexts <2k).

@stas00
Copy link
Contributor

stas00 commented Dec 2, 2025

My apologies, here is the correct link pytorch/pytorch#167348

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Integrate FA4 (Flash Attention for Blackwell) into HF Transformers

4 participants