[TRITON] Add attention sink support to Triton MHA kernels #1576
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
In gpt-oss attention implementation, each attention head has a learned bias in the denominator of the softmax. This is similar to attention sink and we can enable gpt-oss by adding attention sink support to our AITER MHA kernels (both forward and backward kernels). The target model is gpt-oss20b.
Technical Details
fp8data types. The proposed changes were tested withbf16andfp32data types, but they should also work withfp16data type.-inf.fp32to enable atomics ops.Test Plan
op_tests/triton_tests/test_mha.py:test_mha_with_sinkandtest_mha_varlen_with_pe. They cover 192 new cases and test both forward and backward passes.Test Result
op_tests/triton_tests/test_mha.pyare passing ongfx942andgfx950.op_tests/triton_tests/test_mha.pyproduce the same results as before the sink was added. This happens on bothgfx942andgfx950. So we can conclude that the newly added sink feature didn't break anything that was already working.Performance Assessment
Target attention shapes:
bf16thdlayout and 1 forbshdlayout.Forward performance in
gfx950:Backward performance in
gfx950:Conclusion: Attention sink feature doesn't change performance on
gfx950.I did the same analysis on
gfx942and got the same conclusion. I'm not publishing the numbers in the PR for the sake of brevity.Submission Checklist