-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[webgpu] Fused SplitPackedQKV with FusedQKRotaryEmbedding #26447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
xiaofeihan1
merged 5 commits into
microsoft:main
from
xiaofeihan1:xiaofeihan/fused_splitQKV
Nov 11, 2025
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| #param interleaved | ||
|
|
||
| #use guardAgainstOutOfBoundsWorkgroupSizes | ||
| #use .setByIndices .getByIndices .getByOffset | ||
|
|
||
| $MAIN { | ||
| guardAgainstOutOfBoundsWorkgroupSizes(uniforms.dispatch_size); | ||
|
|
||
| // Dispatch: batch * seq * num_heads * (half_rotary_dim + need_copy_dim) | ||
| // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) | ||
| let work_per_head = uniforms.head_size - uniforms.half_rotary_dim; | ||
| let total_work = uniforms.num_heads * work_per_head; | ||
|
|
||
| let batch_idx = global_idx / (uniforms.sequence_length * total_work); | ||
| let remainder1 = global_idx % (uniforms.sequence_length * total_work); | ||
| let seq_idx = remainder1 / total_work; | ||
| let remainder2 = remainder1 % total_work; | ||
| let head_idx = remainder2 / work_per_head; | ||
| let in_head_idx = remainder2 % work_per_head; | ||
|
|
||
| // Calculate base offset in packed_qkv for this token | ||
| // Layout per token: [Q(hidden_size), K(kv_hidden_size), V(kv_hidden_size)] | ||
| let token_size = uniforms.hidden_size + 2u * uniforms.kv_hidden_size; | ||
| let base_offset = batch_idx * uniforms.sequence_length * token_size + seq_idx * token_size; | ||
|
|
||
| if (in_head_idx < uniforms.half_rotary_dim) { | ||
| // Calculate position_id (needed for rotary embedding) | ||
| let seqlen_i = seqlens.getByOffset(batch_idx); | ||
| let seqlen = u32(seqlen_i); | ||
| let total_seqlen = seqlen + 1u; | ||
| let past_seqlen = total_seqlen - uniforms.sequence_length; | ||
| let position_id = past_seqlen + seq_idx; | ||
| // Process a rotary pair (i, j) | ||
| let cos_v = cos_cache.getByIndices(vec2<u32>(position_id, in_head_idx)); | ||
| let sin_v = sin_cache.getByIndices(vec2<u32>(position_id, in_head_idx)); | ||
|
|
||
| // Calculate actual indices in the head for i and j | ||
| #if interleaved | ||
| let idx_i = in_head_idx; | ||
| let idx_j = in_head_idx + 1u; | ||
| #else | ||
| let idx_i = in_head_idx; | ||
| let idx_j = in_head_idx + uniforms.half_rotary_dim; | ||
| #endif | ||
|
|
||
| // Process Q pair | ||
| let q_base = base_offset + head_idx * uniforms.head_size; | ||
| let q_i_offset = q_base + idx_i; | ||
| let q_j_offset = q_base + idx_j; | ||
| let q_i = packed_qkv.getByOffset(q_i_offset); | ||
| let q_j = packed_qkv.getByOffset(q_j_offset); | ||
| let q_re = q_i * cos_v - q_j * sin_v; | ||
| let q_im = q_i * sin_v + q_j * cos_v; | ||
| query.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), q_re); | ||
| query.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), q_im); | ||
|
|
||
| // Process K and V pairs if within kv_num_heads | ||
| if (head_idx < uniforms.kv_num_heads) { | ||
| let k_base = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size; | ||
| let k_i_offset = k_base + idx_i; | ||
| let k_j_offset = k_base + idx_j; | ||
| let k_i = packed_qkv.getByOffset(k_i_offset); | ||
| let k_j = packed_qkv.getByOffset(k_j_offset); | ||
| let k_re = k_i * cos_v - k_j * sin_v; | ||
| let k_im = k_i * sin_v + k_j * cos_v; | ||
| key.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), k_re); | ||
| key.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), k_im); | ||
|
|
||
| // V doesn't need rotary, just copy the pair | ||
| let v_base = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size; | ||
| let v_i = packed_qkv.getByOffset(v_base + idx_i); | ||
| let v_j = packed_qkv.getByOffset(v_base + idx_j); | ||
| val.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), v_i); | ||
| val.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), v_j); | ||
| } | ||
| } else { | ||
| // Process non-rotary elements (direct copy) | ||
| let actual_idx = uniforms.half_rotary_dim + in_head_idx; | ||
|
|
||
| // Copy Q | ||
| let q_offset = base_offset + head_idx * uniforms.head_size + actual_idx; | ||
| let q_data = packed_qkv.getByOffset(q_offset); | ||
| query.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), q_data); | ||
|
|
||
| // Copy K and V if within kv_num_heads | ||
| if (head_idx < uniforms.kv_num_heads) { | ||
| let k_offset = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size + actual_idx; | ||
| let k_data = packed_qkv.getByOffset(k_offset); | ||
| key.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), k_data); | ||
|
|
||
| let v_offset = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size + actual_idx; | ||
| let v_data = packed_qkv.getByOffset(v_offset); | ||
| val.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), v_data); | ||
| } | ||
| } | ||
| } // MAIN | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.