Skip to content

Commit aaecfb8

Browse files
committed
[webgpu] Support broadcast attention_bias
Fixed #26766
1 parent a83a158 commit aaecfb8

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
100100
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
101101
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
102102

103+
if (has_attention_bias_) {
104+
shader.AdditionalImplementation() << "fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> f32 {\n"
105+
<< " // Handle broadcasting: if dimension size is 1, use index 0\n"
106+
<< " let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);\n"
107+
<< " let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);\n"
108+
<< " // Calculate flat offset with broadcasting applied\n"
109+
<< " // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, sequence_length, total_sequence_length]\n"
110+
<< " let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.M * uniforms.N +\n"
111+
<< " bias_head_idx * uniforms.M * uniforms.N +\n"
112+
<< " q_idx * uniforms.N +\n"
113+
<< " k_idx;\n"
114+
<< " return f32(attention_bias[offset]);\n"
115+
<< "}\n";
116+
}
117+
103118
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
104119
<< "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n"
105120
<< "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n"
@@ -158,6 +173,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
158173
shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n"
159174
<< " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n"
160175
<< " let outputIdx = headOffset + (m + local_id.y) * uniforms.N + n + local_id.x;\n"
176+
<< " let head_idx = batch_head_idx % uniforms.num_heads;\n"
161177
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
162178

163179
// Add causal masking for unidirectional attention
@@ -172,7 +188,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
172188

173189
shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";
174190
if (has_attention_bias_) {
175-
shader.MainFunctionBody() << " + attention_bias[outputIdx]";
191+
shader.MainFunctionBody() << " + loadAttentionBias(batch_idx, head_idx, m + local_id.y, n + local_id.x)";
176192
}
177193
shader.MainFunctionBody() << ";\n"
178194
<< "}\n";
@@ -214,6 +230,16 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
214230
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
215231
const uint32_t num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
216232
const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
233+
234+
// Get attention bias dimensions for broadcasting
235+
uint32_t attn_bias_dim0 = 1;
236+
uint32_t attn_bias_dim1 = 1;
237+
if (has_attention_bias) {
238+
const auto& bias_shape = attention_bias->Shape();
239+
attn_bias_dim0 = static_cast<uint32_t>(bias_shape[0]);
240+
attn_bias_dim1 = static_cast<uint32_t>(bias_shape[1]);
241+
}
242+
217243
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile)
218244
.SetWorkgroupSize(tile_size, tile_size)
219245
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_, parameters.is_unidirectional_)
@@ -229,7 +255,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
229255
{static_cast<uint32_t>(parameters.n_reps)},
230256
{static_cast<uint32_t>(parameters.is_first_prompt_ ? 1 : 0)},
231257
{num_total_seq_length_tile},
232-
{num_seq_length_tile}})
258+
{num_seq_length_tile},
259+
{attn_bias_dim0},
260+
{attn_bias_dim1}})
233261
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
234262

235263
return context.RunProgram(program);

onnxruntime/contrib_ops/webgpu/bert/attention.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
5353
{"n_reps", ProgramUniformVariableDataType::Uint32},
5454
{"is_first_prompt", ProgramUniformVariableDataType::Uint32},
5555
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
56-
{"num_seq_length_tile", ProgramUniformVariableDataType::Uint32});
56+
{"num_seq_length_tile", ProgramUniformVariableDataType::Uint32},
57+
{"attn_bias_dim0", ProgramUniformVariableDataType::Uint32},
58+
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32});
5759

5860
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
5961

0 commit comments

Comments
 (0)