Commit d1abad0
authored
[webgpu] Propagate rotary_interleaved parameter to GQA shader (#26758)
### Description
This PR fixes the last tests that were failing in
#26715 (comment),
where rotary_interleaved=1 in GQA kernel. The root cause was that the
`rotary_interleaved` parameter was not being propagated correctly,
meaning it always defaulted to 0 in `FusedQKRotaryEmbeddingProgram`.
```
Testing on Providers: ['CPUExecutionProvider', 'WebGpuExecutionProvider']
=================================================================================
Prefill_ColdStart | In:3 Past:0 Total:3 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Prefill_ColdStart | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Decode_Early | In:1 Past:16 Total:17 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Decode_Deep | In:1 Past:64 Total:65 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Speculative_Dec | In:4 Past:20 Total:24 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
Batch_Prefill | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Batch_Decode | In:1 Past:32 Total:33 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
GQA_Prefill | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
GQA_Decode | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
GQA_Batch_Dec | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
MQA_Prefill | In:32 Past:0 Total:32 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
MQA_Decode | In:1 Past:32 Total:33 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgBatch_MHA | In:1 Past:16 Total:17 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgBatch_GQA | In:1 Past:16 Total:17 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Odd_SeqLen | In:7 Past:13 Total:20 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
Odd_Heads | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
HighHeads_MHA | In:1 Past:32 Total:33 H:32 KV:32 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
HighHeads_GQA | In:1 Past:32 Total:33 H:32 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
HighHeads_MQA | In:1 Past:32 Total:33 H:32 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgCtx_Prefill | In:128 Past:0 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 4.17e-07)
LgCtx_Decode | In:1 Past:127 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.98e-07)
TinyHead_MHA | In:4 Past:4 Total:8 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
TinyHead_GQA | In:4 Past:4 Total:8 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.49e-07)
LgHead_MHA | In:2 Past:2 Total:4 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
LgHead_GQA | In:2 Past:2 Total:4 H:4 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Ratio_5_1 | In:1 Past:10 Total:11 H:5 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Ratio_6_2 | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Ratio_6_3 | In:1 Past:10 Total:11 H:6 KV:3 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Ratio_12_4 | In:1 Past:10 Total:11 H:12 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Zero_Past | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00)
Single_Token_Prefill | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00)
Rotary_Cache_Test | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Window_Small | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Window_Large | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07)
Window_Decode | In:1 Past:20 Total:21 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Softcap_Enabled | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 5.98e-05)
Scale_0.5 | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Rotary_Interleaved | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary_Half | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary Interleaved 2 | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary_Window | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07)
🎉 ALL SCENARIOS PASSED ACROSS ALL PROVIDERS.
```
### Motivation and Context
cc @qjia7 @guschmue1 parent 283f1d3 commit d1abad0
1 file changed
+1
-0
lines changedLines changed: 1 addition & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
245 | 245 | | |
246 | 246 | | |
247 | 247 | | |
| 248 | + | |
248 | 249 | | |
249 | 250 | | |
250 | 251 | | |
| |||
0 commit comments