Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 78 additions & 16 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "contrib_ops/webgpu/bert/flash_attention.h"

#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/shader_helper.h"

using namespace onnxruntime::webgpu;
using namespace ::onnxruntime::common;
Expand Down Expand Up @@ -67,6 +68,51 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu
return context.RunProgram(program);
}

// Split packed QKV with Q/K rotary embedding fusion
Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
const WebgpuAttentionParameters& params,
const Tensor* packedQKV,
const Tensor* seqlen_k,
const Tensor* cos_cache,
const Tensor* sin_cache,
Tensor* query,
Tensor* key,
Tensor* val) {
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
const auto head_size = params.head_size_;

// Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim)
// work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
// = head_size - half_rotary_dim
const auto work_per_head = head_size - half_rotary_embedding_dim;
auto dispatch_size = static_cast<uint32_t>(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head);

SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_);
program
.CacheHint(params.rotary_interleaved_)
.AddInput({packedQKV, ProgramTensorMetadataDependency::Rank})
.AddInputs({
{seqlen_k, ProgramTensorMetadataDependency::Rank},
{cos_cache, ProgramTensorMetadataDependency::Rank},
{sin_cache, ProgramTensorMetadataDependency::Rank},
})
.AddOutputs({{query, ProgramTensorMetadataDependency::Rank},
{key, ProgramTensorMetadataDependency::Rank},
{val, ProgramTensorMetadataDependency::Rank}})
.AddUniformVariables({
{static_cast<uint32_t>(params.sequence_length_)},
{static_cast<uint32_t>(params.hidden_size_)},
{static_cast<uint32_t>(params.kv_hidden_size_)},
{static_cast<uint32_t>(params.num_heads_)},
{static_cast<uint32_t>(params.kv_num_heads_)},
{static_cast<uint32_t>(head_size)},
{half_rotary_embedding_dim},
{static_cast<uint32_t>(dispatch_size)},
})
.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

// Fused Q/K rotary embedding
Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
const WebgpuAttentionParameters& params,
Expand Down Expand Up @@ -207,30 +253,46 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor qSplit;
Tensor kSplit;
Tensor vSplit;
if (parameters.is_packed_qkv_) {

Tensor qRotary;
Tensor kRotary;
if (parameters.is_packed_qkv_ && do_rotary_) {
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
ORT_RETURN_IF_ERROR(SplitPackedQKV(context, parameters, query, &qSplit, &kSplit, &vSplit));
ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbedding(context, parameters,
query, seqlen_k,
cos_cache, sin_cache,
&qSplit, &kSplit, &vSplit));
parameters.is_packed_qkv_ = false;
parameters.qkv_format_ = Q_K_V_BSNH;
query = &qSplit;
key = &kSplit;
value = &vSplit;
}

Tensor qRotary;
Tensor kRotary;
if (do_rotary_) {
qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
kRotary = context.CreateGPUTensor(key->DataType(), key->Shape());
ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters,
query, key,
seqlen_k,
cos_cache, sin_cache,
&qRotary, &kRotary));
query = &qRotary;
key = &kRotary;
} else {
// Original separate path
if (parameters.is_packed_qkv_) {
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
ORT_RETURN_IF_ERROR(SplitPackedQKV(context, parameters, query, &qSplit, &kSplit, &vSplit));
parameters.is_packed_qkv_ = false;
parameters.qkv_format_ = Q_K_V_BSNH;
query = &qSplit;
key = &kSplit;
value = &vSplit;
}
if (do_rotary_) {
qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
kRotary = context.CreateGPUTensor(key->DataType(), key->Shape());
ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters,
query, key,
seqlen_k,
cos_cache, sin_cache,
&qRotary, &kRotary));
query = &qRotary;
key = &kRotary;
}
}

// Use a sliding window if the total sequence exceeds the window's length.
Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,47 @@ class SplitPackedQKVProgram final : public Program<SplitPackedQKVProgram> {
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32});
};

class SplitPackedQKVWithRotaryEmbeddingProgram final : public Program<SplitPackedQKVWithRotaryEmbeddingProgram> {
public:
SplitPackedQKVWithRotaryEmbeddingProgram(bool interleaved) : Program{"SplitPackedQKVWithRotaryEmbedding"}, interleaved_{interleaved} {}

Status GenerateShaderCode(ShaderHelper& sh) const override {
// Inputs
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform);
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform);
const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform);

// Outputs
const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform);
const auto& key = sh.AddOutput("key", ShaderUsage::UseUniform);
const auto& val = sh.AddOutput("val", ShaderUsage::UseUniform);

return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding.wgsl.template",
WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_),
WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache),
WGSL_TEMPLATE_VARIABLE(key, key),
WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv),
WGSL_TEMPLATE_VARIABLE(query, query),
WGSL_TEMPLATE_VARIABLE(seqlens, seqlens),
WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache),
WGSL_TEMPLATE_VARIABLE(val, val));
}

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"kv_num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"half_rotary_dim", ProgramUniformVariableDataType::Uint32},
{"dispatch_size", ProgramUniformVariableDataType::Uint32});

private:
const bool interleaved_;
};

class GroupQueryAttention final : public WebGpuKernel {
public:
GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#param interleaved

#use guardAgainstOutOfBoundsWorkgroupSizes
#use .setByIndices .getByIndices .getByOffset

$MAIN {
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.dispatch_size);

// Dispatch: batch * seq * num_heads * (half_rotary_dim + need_copy_dim)
// work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
let work_per_head = uniforms.head_size - uniforms.half_rotary_dim;
let total_work = uniforms.num_heads * work_per_head;

let batch_idx = global_idx / (uniforms.sequence_length * total_work);
let remainder1 = global_idx % (uniforms.sequence_length * total_work);
let seq_idx = remainder1 / total_work;
let remainder2 = remainder1 % total_work;
let head_idx = remainder2 / work_per_head;
let in_head_idx = remainder2 % work_per_head;

// Calculate base offset in packed_qkv for this token
// Layout per token: [Q(hidden_size), K(kv_hidden_size), V(kv_hidden_size)]
let token_size = uniforms.hidden_size + 2u * uniforms.kv_hidden_size;
let base_offset = batch_idx * uniforms.sequence_length * token_size + seq_idx * token_size;

if (in_head_idx < uniforms.half_rotary_dim) {
// Calculate position_id (needed for rotary embedding)
let seqlen_i = seqlens.getByOffset(batch_idx);
let seqlen = u32(seqlen_i);
let total_seqlen = seqlen + 1u;
let past_seqlen = total_seqlen - uniforms.sequence_length;
let position_id = past_seqlen + seq_idx;
// Process a rotary pair (i, j)
let cos_v = cos_cache.getByIndices(vec2<u32>(position_id, in_head_idx));
let sin_v = sin_cache.getByIndices(vec2<u32>(position_id, in_head_idx));

// Calculate actual indices in the head for i and j
#if interleaved
let idx_i = in_head_idx;
let idx_j = in_head_idx + 1u;
#else
let idx_i = in_head_idx;
let idx_j = in_head_idx + uniforms.half_rotary_dim;
#endif

// Process Q pair
let q_base = base_offset + head_idx * uniforms.head_size;
let q_i_offset = q_base + idx_i;
let q_j_offset = q_base + idx_j;
let q_i = packed_qkv.getByOffset(q_i_offset);
let q_j = packed_qkv.getByOffset(q_j_offset);
let q_re = q_i * cos_v - q_j * sin_v;
let q_im = q_i * sin_v + q_j * cos_v;
query.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), q_re);
query.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), q_im);

// Process K and V pairs if within kv_num_heads
if (head_idx < uniforms.kv_num_heads) {
let k_base = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size;
let k_i_offset = k_base + idx_i;
let k_j_offset = k_base + idx_j;
let k_i = packed_qkv.getByOffset(k_i_offset);
let k_j = packed_qkv.getByOffset(k_j_offset);
let k_re = k_i * cos_v - k_j * sin_v;
let k_im = k_i * sin_v + k_j * cos_v;
key.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), k_re);
key.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), k_im);

// V doesn't need rotary, just copy the pair
let v_base = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size;
let v_i = packed_qkv.getByOffset(v_base + idx_i);
let v_j = packed_qkv.getByOffset(v_base + idx_j);
val.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), v_i);
val.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), v_j);
}
} else {
// Process non-rotary elements (direct copy)
let actual_idx = uniforms.half_rotary_dim + in_head_idx;

// Copy Q
let q_offset = base_offset + head_idx * uniforms.head_size + actual_idx;
let q_data = packed_qkv.getByOffset(q_offset);
query.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), q_data);

// Copy K and V if within kv_num_heads
if (head_idx < uniforms.kv_num_heads) {
let k_offset = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size + actual_idx;
let k_data = packed_qkv.getByOffset(k_offset);
key.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), k_data);

let v_offset = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size + actual_idx;
let v_data = packed_qkv.getByOffset(v_offset);
val.setByIndices(vec3<u32>(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), v_data);
}
}
} // MAIN
21 changes: 21 additions & 0 deletions onnxruntime/core/providers/webgpu/wgsl_templates/wgsl_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,25 @@ duk_ret_t ShaderVariable_GetByOffset(duk_context* ctx) {
return 1;
}

/** @brief JavaScript binding for ShaderVariableHelper::GetByIndices */
duk_ret_t ShaderVariable_GetByIndices(duk_context* ctx) {
const char* indices_expr = duk_require_string(ctx, 0);
const ShaderVariableHelper* helper = GetHelperFromFunction<const ShaderVariableHelper>(ctx);
std::string result = helper->GetByIndices(indices_expr);
duk_push_string(ctx, result.c_str());
return 1;
}

/** @brief JavaScript binding for ShaderVariableHelper::SetByIndices */
duk_ret_t ShaderVariable_SetByIndices(duk_context* ctx) {
const char* indices_expr = duk_require_string(ctx, 0);
const char* value_expr = duk_require_string(ctx, 1);
const ShaderVariableHelper* helper = GetHelperFromFunction<const ShaderVariableHelper>(ctx);
std::string result = helper->SetByIndices(indices_expr, value_expr);
duk_push_string(ctx, result.c_str());
return 1;
}

/** @brief JavaScript binding for ShaderVariableHelper::Rank */
duk_ret_t ShaderVariable_Rank(duk_context* ctx) {
const ShaderVariableHelper* helper = GetHelperFromFunction<const ShaderVariableHelper>(ctx);
Expand Down Expand Up @@ -373,6 +392,8 @@ Status ApplyTemplateDynamic(ShaderHelper& shader_helper,
CreateShaderVariableMethod(ctx, "OffsetToIndices", ShaderVariable_OffsetToIndices, 1, var_helper);
CreateShaderVariableMethod(ctx, "SetByOffset", ShaderVariable_SetByOffset, 2, var_helper);
CreateShaderVariableMethod(ctx, "GetByOffset", ShaderVariable_GetByOffset, 1, var_helper);
CreateShaderVariableMethod(ctx, "GetByIndices", ShaderVariable_GetByIndices, 1, var_helper);
CreateShaderVariableMethod(ctx, "SetByIndices", ShaderVariable_SetByIndices, 2, var_helper);
CreateShaderVariableMethod(ctx, "Rank", ShaderVariable_Rank, 0, var_helper);
duk_put_prop_string(ctx, -2, arg.name.c_str());
}
Expand Down
Loading