Skip to content

Commit d1abad0

Browse files
authored
[webgpu] Propagate rotary_interleaved parameter to GQA shader (#26758)
### Description This PR fixes the last tests that were failing in #26715 (comment), where rotary_interleaved=1 in GQA kernel. The root cause was that the `rotary_interleaved` parameter was not being propagated correctly, meaning it always defaulted to 0 in `FusedQKRotaryEmbeddingProgram`. ``` Testing on Providers: ['CPUExecutionProvider', 'WebGpuExecutionProvider'] ================================================================================= Prefill_ColdStart | In:3 Past:0 Total:3 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08) Prefill_ColdStart | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Decode_Early | In:1 Past:16 Total:17 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08) Decode_Deep | In:1 Past:64 Total:65 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Speculative_Dec | In:4 Past:20 Total:24 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) Batch_Prefill | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Batch_Decode | In:1 Past:32 Total:33 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) GQA_Prefill | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) GQA_Decode | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) GQA_Batch_Dec | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) MQA_Prefill | In:32 Past:0 Total:32 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) MQA_Decode | In:1 Past:32 Total:33 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) LgBatch_MHA | In:1 Past:16 Total:17 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) LgBatch_GQA | In:1 Past:16 Total:17 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Odd_SeqLen | In:7 Past:13 Total:20 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) Odd_Heads | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) HighHeads_MHA | In:1 Past:32 Total:33 H:32 KV:32 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) HighHeads_GQA | In:1 Past:32 Total:33 H:32 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07) HighHeads_MQA | In:1 Past:32 Total:33 H:32 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) LgCtx_Prefill | In:128 Past:0 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 4.17e-07) LgCtx_Decode | In:1 Past:127 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.98e-07) TinyHead_MHA | In:4 Past:4 Total:8 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) TinyHead_GQA | In:4 Past:4 Total:8 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.49e-07) LgHead_MHA | In:2 Past:2 Total:4 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) LgHead_GQA | In:2 Past:2 Total:4 H:4 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07) Ratio_5_1 | In:1 Past:10 Total:11 H:5 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) Ratio_6_2 | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) Ratio_6_3 | In:1 Past:10 Total:11 H:6 KV:3 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08) Ratio_12_4 | In:1 Past:10 Total:11 H:12 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) Zero_Past | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00) Single_Token_Prefill | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00) Rotary_Cache_Test | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Window_Small | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Window_Large | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07) Window_Decode | In:1 Past:20 Total:21 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Softcap_Enabled | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 5.98e-05) Scale_0.5 | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07) Rotary_Interleaved | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary_Half | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary Interleaved 2 | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07) Rotary_Window | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07) 🎉 ALL SCENARIOS PASSED ACROSS ALL PROVIDERS. ``` ### Motivation and Context cc @qjia7 @guschmue
1 parent 283f1d3 commit d1abad0

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
245245
scale_,
246246
softcap_));
247247
params.use_smooth_softmax = use_smooth_softmax_;
248+
params.rotary_interleaved = rotary_interleaved_;
248249

249250
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
250251
attention_bias,

0 commit comments

Comments
 (0)