@@ -76,7 +76,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
7676 } else {
7777 shader.MainFunctionBody () << " let total_seq_length = uniforms.total_sequence_length;\n " ;
7878 }
79- shader.MainFunctionBody () << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n " ;
79+ shader.MainFunctionBody () << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n " ;
80+ if (past_present_share_buffer_) {
81+ shader.MainFunctionBody () << " let present_offset = " << present_key.IndicesToOffset (" present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)" ) << " ;\n " ;
82+ } else {
83+ shader.MainFunctionBody () << " let present_offset = " << present_key.IndicesToOffset (" present_key_indices_t(batch, num_head_id, sequence_id, head_size_id)" ) << " ;\n " ;
84+ }
8085
8186 // Add indirect dispatch logic for thread 0
8287 if (prepare_indirect_dispatch_) {
@@ -93,8 +98,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
9398 if (has_past_) {
9499 const auto & past_key = shader.AddInput (" past_key" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
95100 shader.AddInput (" past_value" , ShaderUsage::UseUniform);
96- shader.MainFunctionBody () << " let present_offset = global_idx;"
97- << " if (sequence_id < past_sequence_length) {\n "
101+ shader.MainFunctionBody () << " if (sequence_id < past_sequence_length) {\n "
98102 << " let pastOffset = " << past_key.IndicesToOffset (" past_key_indices_t(batch, num_head_id, sequence_id, head_size_id)" ) << " ;\n "
99103 << " " << present_key.SetByOffset (" present_offset" , " past_key[pastOffset]" ) << " ;\n "
100104 << " " << present_value.SetByOffset (" present_offset" , " past_value[pastOffset]" ) << " ;\n "
@@ -104,8 +108,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
104108 << " " << present_value.SetByOffset (" present_offset" , " value[offset]" ) << " ;\n "
105109 << " }" ;
106110 } else {
107- shader.MainFunctionBody () << " let present_offset = " << present_key.IndicesToOffset (" present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)" ) << " ;\n "
108- << " let offset = " << key.IndicesToOffset (kv_BNSH_ ? " key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : " key_indices_t(batch, sequence_id, num_head_id, head_size_id)" ) << " ;\n "
111+ shader.MainFunctionBody () << " let offset = " << key.IndicesToOffset (kv_BNSH_ ? " key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : " key_indices_t(batch, sequence_id, num_head_id, head_size_id)" ) << " ;\n "
109112 << " " << present_key.SetByOffset (" present_offset" , " key[offset]" ) << " ;\n "
110113 << " " << present_value.SetByOffset (" present_offset" , " value[offset]" ) << " ;\n " ;
111114 }
@@ -134,10 +137,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
134137 // Determine if we need to prepare indirect dispatch
135138 bool prepare_indirect_dispatch = (indirect_buffer != nullptr );
136139 bool use_seqlen_k = (seqlen_k != nullptr );
137-
138- CopyKVCacheProgram program{" CopyKVCache" , has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH ,
140+ bool kv_BNSH = parameters. qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters. qkv_format_ == Q_K_V_BNSH;
141+ CopyKVCacheProgram program{" CopyKVCache" , has_past, kv_BNSH, parameters.past_present_share_buffer_ ,
139142 prepare_indirect_dispatch, use_seqlen_k};
140- if (parameters. qkv_format_ == Q_K_V_BSNH_BNSH_BNSH ) {
143+ if (kv_BNSH ) {
141144 program.AddInputs ({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
142145 {V, ProgramTensorMetadataDependency::TypeAndRank, components}});
143146 } else {
@@ -207,6 +210,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
207210 WGSL_TEMPLATE_PARAMETER (is_qualcomm, is_qualcomm_),
208211 WGSL_TEMPLATE_PARAMETER (is_unidirectional, is_unidirectional_),
209212 WGSL_TEMPLATE_PARAMETER (prefer_subgroupshuffle, !is_nvidia_),
213+ WGSL_TEMPLATE_PARAMETER (q_BNSH, q_BNSH_),
210214 WGSL_TEMPLATE_PARAMETER (qkv_head_size, qkv_head_size_),
211215 WGSL_TEMPLATE_PARAMETER (qkv_num_heads, qkv_num_heads_),
212216 WGSL_TEMPLATE_PARAMETER (use_seqlen_k, use_seqlen_k_));
@@ -256,10 +260,20 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
256260 {metadata, ProgramTensorMetadataDependency::Rank, 2 }});
257261
258262 const uint32_t vectorized_head_size = parameters.head_size_ / components;
263+
264+ // Get attention bias dimensions for broadcasting
265+ uint32_t attn_bias_dim0 = 1 ;
266+ uint32_t attn_bias_dim1 = 1 ;
267+ if (has_attention_bias) {
268+ const auto & bias_shape = attention_bias->Shape ();
269+ attn_bias_dim0 = static_cast <uint32_t >(bias_shape[0 ]);
270+ attn_bias_dim1 = static_cast <uint32_t >(bias_shape[1 ]);
271+ }
272+
259273 if (use_indirect_dispatch) {
260274 program.SetIndirectDispatchTensor (indirect_buffer);
261275 } else {
262- program.SetDispatchGroupSize (parameters.num_heads_ * num_total_seq_length_tile);
276+ program.SetDispatchGroupSize (parameters.batch_size_ * parameters. num_heads_ * num_total_seq_length_tile);
263277 }
264278 program.SetWorkgroupSize (64 )
265279 .CacheHint (tile_size, has_attention_bias, use_indirect_dispatch)
@@ -269,7 +283,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
269283 present_sequence_length,
270284 {static_cast <uint32_t >(parameters.n_reps )},
271285 {num_present_sequence_length_tile},
272- {static_cast <uint32_t >(parameters.num_heads_ )}});
286+ {static_cast <uint32_t >(parameters.num_heads_ )},
287+ {static_cast <uint32_t >(parameters.batch_size_ )},
288+ {attn_bias_dim0},
289+ {attn_bias_dim1}});
273290
274291 return context.RunProgram (program);
275292}
@@ -313,11 +330,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
313330 {qk, ProgramTensorMetadataDependency::TypeAndRank},
314331 {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
315332 program.AddOutputs ({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
333+ const uint32_t batch_heads = static_cast <uint32_t >(parameters.batch_size_ * parameters.num_heads_ );
316334 if (use_indirect_dispatch) {
317335 program.AddInput ({seqlen_k, ProgramTensorMetadataDependency::None})
318336 .SetIndirectDispatchTensor (indirect_buffer);
319337 } else {
320- program.SetDispatchGroupSize (parameters. num_heads_ * num_total_seq_length_tile);
338+ program.SetDispatchGroupSize (batch_heads * num_total_seq_length_tile);
321339 }
322340 program.CacheHint (tile_size, head_size_vec, use_indirect_dispatch)
323341 .SetWorkgroupSize (64 )
@@ -326,7 +344,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
326344 present_sequence_length,
327345 {static_cast <uint32_t >(parameters.n_reps )},
328346 num_present_sequence_length_tile,
329- {static_cast < uint32_t >(parameters. num_heads_ ) }});
347+ {batch_heads }});
330348
331349 return context.RunProgram (program);
332350}
@@ -363,14 +381,15 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&
363381 }
364382 program.AddOutputs ({{output, ProgramTensorMetadataDependency::TypeAndRank, components}});
365383 const uint32_t num_head_size_tile = static_cast <uint32_t >((parameters.v_head_size_ + tile_head_size - 1 ) / tile_head_size);
366- program.SetDispatchGroupSize (parameters.num_heads_ * num_head_size_tile)
384+ const uint32_t batch_heads = static_cast <uint32_t >(parameters.batch_size_ * parameters.num_heads_ );
385+ program.SetDispatchGroupSize (batch_heads * num_head_size_tile)
367386 .CacheHint (tile_size, seq_tile_size, use_indirect_dispatch)
368387 .SetWorkgroupSize (tile_size * tile_size)
369388 .AddUniformVariables ({{static_cast <uint32_t >(parameters.v_head_size_ / components)},
370389 num_total_seq_length_tile,
371390 num_present_sequence_length_tile,
372391 {num_head_size_tile},
373- {static_cast < uint32_t >(parameters. num_heads_ ) }});
392+ {batch_heads }});
374393
375394 return context.RunProgram (program);
376395}
@@ -429,6 +448,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
429448 bool is_qualcomm = context.AdapterInfo ().vendor == std::string_view{" qualcomm" };
430449 bool is_nvidia = context.AdapterInfo ().vendor == std::string_view{" nvidia" };
431450 bool is_fp16 = (Q->GetElementType () == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
451+ bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
432452 FlashAttentionProgram program{" FlashAttention" ,
433453 has_attention_bias,
434454 is_qualcomm,
@@ -437,6 +457,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
437457 parameters.num_heads_ ,
438458 parameters.is_unidirectional_ ,
439459 is_nvidia,
460+ q_BNSH,
440461 use_seqlen_k};
441462 program.AddInputs ({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4 },
442463 {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4 },
@@ -451,15 +472,28 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
451472 const float alpha = parameters.scale_ == 0 .0f ? 1 .f / sqrt (static_cast <float >(parameters.head_size_ ))
452473 : parameters.scale_ ;
453474 const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1 ) / tile_size;
454- program.SetDispatchGroupSize (parameters.num_heads_ * num_seq_tile)
475+
476+ // Get attention bias dimensions for broadcasting
477+ uint32_t attn_bias_dim0 = 1 ;
478+ uint32_t attn_bias_dim1 = 1 ;
479+ if (has_attention_bias) {
480+ const auto & bias_shape = attention_bias->Shape ();
481+ attn_bias_dim0 = static_cast <uint32_t >(bias_shape[0 ]);
482+ attn_bias_dim1 = static_cast <uint32_t >(bias_shape[1 ]);
483+ }
484+
485+ program.SetDispatchGroupSize (parameters.batch_size_ * parameters.num_heads_ * num_seq_tile)
455486 .SetWorkgroupSize (tile_size)
456- .CacheHint (has_attention_bias, parameters.head_size_ , parameters.num_heads_ , parameters.is_unidirectional_ , is_qualcomm, is_nvidia, use_seqlen_k)
487+ .CacheHint (has_attention_bias, parameters.head_size_ , parameters.num_heads_ , parameters.is_unidirectional_ , is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k)
457488 .AddUniformVariables ({{static_cast <uint32_t >(parameters.sequence_length_ )},
458489 {static_cast <uint32_t >(parameters.total_sequence_length_ )},
459490 {static_cast <uint32_t >(present_sequence_length)},
491+ {static_cast <uint32_t >(parameters.batch_size_ )},
460492 {static_cast <uint32_t >(parameters.n_reps )},
461493 {alpha},
462- {num_seq_tile}});
494+ {num_seq_tile},
495+ {attn_bias_dim0},
496+ {attn_bias_dim1}});
463497
464498 return context.RunProgram (program);
465499 }
@@ -500,8 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
500534
501535bool CanApplyFlashAttention (const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
502536 const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
503- return parameters.batch_size_ == 1 &&
504- !parameters.is_packed_qkv_ &&
537+ return !parameters.is_packed_qkv_ &&
505538 parameters.head_size_ == parameters.v_head_size_ &&
506539 bias == nullptr &&
507540 context.HasFeature (wgpu::FeatureName::Subgroups) &&
0 commit comments