@@ -100,6 +100,21 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
100100 << " var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << " >;\n "
101101 << " alias f32_val_t = " << (components_ == 4 ? " vec4<f32>" : (components_ == 2 ? " vec2<f32>" : " f32" )) << " ;\n " ;
102102
103+ if (has_attention_bias_) {
104+ shader.AdditionalImplementation () << " fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> f32 {\n "
105+ << " // Handle broadcasting: if dimension size is 1, use index 0\n "
106+ << " let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);\n "
107+ << " let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);\n "
108+ << " // Calculate flat offset with broadcasting applied\n "
109+ << " // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, sequence_length, total_sequence_length]\n "
110+ << " let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.M * uniforms.N +\n "
111+ << " bias_head_idx * uniforms.M * uniforms.N +\n "
112+ << " q_idx * uniforms.N +\n "
113+ << " k_idx;\n "
114+ << " return f32(attention_bias[offset]);\n "
115+ << " }\n " ;
116+ }
117+
103118 shader.MainFunctionBody () << " // x holds the N and y holds the M\n "
104119 << " let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n "
105120 << " let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n "
@@ -158,6 +173,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
158173 shader.MainFunctionBody () << " if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n "
159174 << " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n "
160175 << " let outputIdx = headOffset + (m + local_id.y) * uniforms.N + n + local_id.x;\n "
176+ << " let head_idx = batch_head_idx % uniforms.num_heads;\n "
161177 << " var sum: f32 = " << (components_ == 4 ? " value.x + value.y + value.z + value.w" : (components_ == 2 ? " value.x + value.y" : " value" )) << " ;\n " ;
162178
163179 // Add causal masking for unidirectional attention
@@ -172,7 +188,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
172188
173189 shader.MainFunctionBody () << " output[outputIdx] = output_value_t(sum * uniforms.alpha)" ;
174190 if (has_attention_bias_) {
175- shader.MainFunctionBody () << " + attention_bias[outputIdx] " ;
191+ shader.MainFunctionBody () << " + loadAttentionBias(batch_idx, head_idx, m + local_id.y, n + local_id.x) " ;
176192 }
177193 shader.MainFunctionBody () << " ;\n "
178194 << " }\n " ;
@@ -214,6 +230,16 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
214230 const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1 ) / components;
215231 const uint32_t num_total_seq_length_tile = (total_sequence_length + tile_size - 1 ) / tile_size;
216232 const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1 ) / tile_size;
233+
234+ // Get attention bias dimensions for broadcasting
235+ uint32_t attn_bias_dim0 = 1 ;
236+ uint32_t attn_bias_dim1 = 1 ;
237+ if (has_attention_bias) {
238+ const auto & bias_shape = attention_bias->Shape ();
239+ attn_bias_dim0 = static_cast <uint32_t >(bias_shape[0 ]);
240+ attn_bias_dim1 = static_cast <uint32_t >(bias_shape[1 ]);
241+ }
242+
217243 program.SetDispatchGroupSize (parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile)
218244 .SetWorkgroupSize (tile_size, tile_size)
219245 .CacheHint (std::to_string (tile_size), parameters.past_present_share_buffer_ , feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr , components, parameters.is_first_prompt_ , parameters.is_unidirectional_ )
@@ -229,7 +255,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
229255 {static_cast <uint32_t >(parameters.n_reps )},
230256 {static_cast <uint32_t >(parameters.is_first_prompt_ ? 1 : 0 )},
231257 {num_total_seq_length_tile},
232- {num_seq_length_tile}})
258+ {num_seq_length_tile},
259+ {attn_bias_dim0},
260+ {attn_bias_dim1}})
233261 .SetOverridableConstants ({{static_cast <uint32_t >(tile_size)}});
234262
235263 return context.RunProgram (program);
0 commit comments