Skip to content

Conversation

@brunomazzottiamd
Copy link
Contributor

@brunomazzottiamd brunomazzottiamd commented Dec 5, 2025

Motivation

In gpt-oss attention implementation, each attention head has a learned bias in the denominator of the softmax. This is similar to attention sink and we can enable gpt-oss by adding attention sink support to our AITER MHA kernels (both forward and backward kernels). The target model is gpt-oss20b.

Technical Details

  • gpt-oss has one sink parameter per query head, so sink tensor and its gradient must be 1D (source).
  • Attention sink feature lacks support for fp8 data types. The proposed changes were tested with bf16 and fp32 data types, but they should also work with fp16 data type.
  • Triton MHA backward has two implementations: "fused" and "one kernel". Sink support was added only to "one kernel" implementation since it's the default one and it provides the best performance most of the time.
  • Changes in forward kernel:
    • Initializes the running maximum with the respective sink value instead of -inf.
  • Changes in backward kernel:
    • Gradient of sink is a sum reduction. It should be computed and accumulated per query head.
    • Sink gradient computation was added to the outer loop that computes Q gradient. Partial accumulation is done once per query head and sequence block.
    • The sum reduction was implemented with atomic add because the kernel is parallelized across (KV heads, sequence length, batch). So multiple Triton programs would be computing partial sums of the same sink gradient element, given that we have one per query head.
    • Sink gradient type is fp32 to enable atomics ops.

Test Plan

  • Added 2 new unit tests for the sink case to op_tests/triton_tests/test_mha.py: test_mha_with_sink and test_mha_varlen_with_pe. They cover 192 new cases and test both forward and backward passes.
  • BSHD layout + Causal + Dropout case isn't supported in backward with sink because this case isn't supported in regular backward. THD layout + Dropout case isn't supported in backward with sink because this case isn't supported in regular backward. I think it's best to fix the base backward implementation before adding sink to the mix.

Test Result

  • All sink tests from op_tests/triton_tests/test_mha.py are passing on gfx942 and gfx950.
  • All other tests from op_tests/triton_tests/test_mha.py produce the same results as before the sink was added. This happens on both gfx942 and gfx950. So we can conclude that the newly added sink feature didn't break anything that was already working.

Performance Assessment

Target attention shapes:

  • Data type: bf16
  • TP1: HQ = 64, HKV = 8, D = 64, SQ = SK = 8192.
  • TP8: HQ = 8, HKV = 1, D = 64, SQ = SK = 8192.
  • Batch sizes 8-16 for thd layout and 1 for bshd layout.

Forward performance in gfx950:

TP Batch size Layout Forward time without sink (ms) Forward time with sink (ms) Speedup
1 1 bshd 0.906 0.910 1.00
1 8 thd 7.169 7.115 1.01
1 9 thd 8.060 7.955 1.01
1 10 thd 8.856 8.827 1.00
1 11 thd 9.778 9.802 1.00
1 12 thd 10.694 10.736 1.00
1 13 thd 11.546 11.487 1.01
1 14 thd 12.414 12.362 1.00
1 15 thd 13.347 13.280 1.01
1 16 thd 14.193 14.188 1.00
8 1 bshd 0.191 0.190 1.00
8 8 thd 0.965 0.963 1.00
8 9 thd 1.079 1.081 1.00
8 10 thd 1.185 1.187 1.00
8 11 thd 1.299 1.317 0.99
8 12 thd 1.409 1.422 0.99
8 13 thd 1.528 1.524 1.00
8 14 thd 1.644 1.640 1.00
8 15 thd 1.760 1.757 1.00
8 16 thd 1.864 1.874 0.99
Geomean 1.00

Backward performance in gfx950:

TP Batch size Layout "One kernel" backward time without sink (ms) "One kernel" backward time with sink (ms) Speedup
1 1 bshd 6.284 6.342 0.99
1 8 thd 52.947 52.272 1.01
1 9 thd 59.567 58.790 1.01
1 10 thd 65.465 65.748 1.00
1 11 thd 72.319 72.383 1.00
1 12 thd 79.073 78.599 1.01
1 13 thd 85.665 84.760 1.01
1 14 thd 91.813 91.301 1.01
1 15 thd 98.538 98.074 1.00
1 16 thd 105.008 104.472 1.01
8 1 bshd 4.423 4.336 1.02
8 8 thd 9.548 9.479 1.01
8 9 thd 11.594 11.661 0.99
8 10 thd 12.819 13.071 0.98
8 11 thd 13.014 13.088 0.99
8 12 thd 13.958 14.067 0.99
8 13 thd 15.211 15.439 0.99
8 14 thd 16.498 16.555 1.00
8 15 thd 17.525 17.593 1.00
8 16 thd 17.827 17.768 1.00
Geomean 1.00

Conclusion: Attention sink feature doesn't change performance on gfx950.

I did the same analysis on gfx942 and got the same conclusion. I'm not publishing the numbers in the PR for the sake of brevity.

Submission Checklist

@brunomazzottiamd brunomazzottiamd self-assigned this Dec 5, 2025
@brunomazzottiamd brunomazzottiamd added enhancement New feature or request triton labels Dec 5, 2025
@brunomazzottiamd brunomazzottiamd marked this pull request as ready for review December 5, 2025 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request triton

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants