-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[FA4] Initial support
#42435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[FA4] Initial support
#42435
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: edgetam, gemma3, modernbert, sam2, sam3, sam3_tracker |
|
FA4 support cc @stas00 if you wanna play around with this PR. It's pretty much ready, just not convinced by the numbers but I also don't have quick access to a blackwell GPU (at least today :D) |
sfc-gh-sbekman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this, Anton. Going to try it out.
To make it easier to try your PR please add to the OP how to install FA4, since it's non-trivial to find.
git clone https://github.com/Dao-AILab/flash-attention/
cd flash-attention
cd flash_attn/cute
uv build --wheel . -v --no-build-isolation --out-dir flash-attention/wheels
uv pip install flash-attention/wheels/flash_attn_cute*.whl --prerelease=allow
|
OK, gave it a test ride using your PR and the above comment's install of FA4 on B200. I did a quick test with Llama-8b and the integration worked smoothly but the tflops performance is much worse than FA2 - 2-5x slower. Not sure if it's an issue with integration or the FA4 code or the pytorch version - most likely the upstream since the integration is just a wrapper I tried pt-2.9.1-cu130 and pt-nightly-cu130 - same outcome edit: |
|
Thanks for checking this out and all the pointers @sfc-gh-sbekman ❤️
For sure, I'll add some docs for FA4 before release. Maybe also FA3 in a different PR.
Shoot, so it wasn't an GPU arch issue... This is weird
Do you have a code snippet? There are so many edge cases with sdpa that it maybe is not even entering the FA backend path? Could be quickly checked by restricting the backend usage on SDPA with their context manager with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
pass # do your thing hereI'm also unsure how FA4 is integrated in SDPA? Do we need to use a flag there? I remember that cudnn backend needed special treatment |
Did you mean that you too have observed a similar slowdown?
I was just using https://github.com/snowflakedb/ArcticTraining/ normal SFT training recipe where I tried different attention mechanisms. Just normal fwd/bwd/step - nothing special added.
They copied/adapted the FA4 kernels see: #42435 - you'd need pt nightly for that to work. |
I just did some quick numbers on inference, see the test I noted down in the PR description. I used an H100 there and as you can see it's slower (not on the same magnitude as in your samples - would say it's a mixture of model size / context size)
Gotcha, I will try to separate our implementation from the base fn of the FA library to see if our wrappers cause this or maybe some perf regression happened sometime else.
Wrong link? My assumption / hunch was that maybe
|
|
My apologies, here is the correct link pytorch/pytorch#167348 |
Fixes #42405
Closes #42404 as it has a lot of unnecessary logic and tests alongside it
Remaining TODOs:
partialdestroys compile support for older fa versionsFirst quick numbers (hopper)
NOTE:
I'm not really sure about FA4 support (see numbers above) so keeping this as draft for now! The overall logic should stay the same (even later on)