@@ -379,53 +379,54 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
379379 Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
380380 const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k,
381381 const Tensor* cos_cache, const Tensor* sin_cache) {
382+ constexpr uint32_t tile_size = 64 ;
383+
382384 // Extract present_sequence_length directly from present_key tensor shape:
383385 // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
384386 const uint32_t present_sequence_length = static_cast <uint32_t >(present_key->Shape ()[2 ]);
385387
386388 const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled ();
387389
388- // Determine if we should use fused split packed QKV with rotary embedding based on cos_cache and sin_cache
389- const bool use_fused_split_rotary_copykv = (cos_cache != nullptr && sin_cache != nullptr );
390-
391390 // Declare query_output at function scope to ensure it persists throughout the function
392391 Tensor query_output;
393392
394393 // Create indirect dispatch buffer if using indirect dispatch
395394 Tensor* indirect_buffer_ptr = nullptr ;
396395 Tensor indirect_buffer;
397- // Handle fused split packed QKV with rotary embedding and copy KV if requested
398- if (use_fused_split_rotary_copykv) {
396+
397+ // Prepare indirect dispatch buffer for decode path with static KV cache
398+ const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
399+ parameters.past_present_share_buffer_ &&
400+ seqlen_k != nullptr &&
401+ context.IsGraphCaptureEnabled ();
402+ if (use_indirect_dispatch) {
403+ const TensorShape indirect_buffer_shape{3 }; // 3 uint32 values for dispatch dimensions
404+ indirect_buffer = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), indirect_buffer_shape);
405+ indirect_buffer_ptr = &indirect_buffer;
406+ }
407+
408+ const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr );
409+
410+ if (do_rotary) {
411+ ORT_ENFORCE (parameters.is_packed_qkv_ , " Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input." );
412+ ORT_ENFORCE (parameters.past_present_share_buffer_ , " Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache." );
413+
399414 // Q points to the packed QKV tensor in this case, create query output tensor
400415 query_output = context.CreateGPUTensor (Q->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.hidden_size_ }));
401- // For decode path (sequence_length == 1), may prepare indirect dispatch if needed
402- // Prepare indirect dispatch buffer for decode path with static KV cache
403- const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
404- parameters.past_present_share_buffer_ &&
405- seqlen_k != nullptr &&
406- context.IsGraphCaptureEnabled ();
407- if (use_indirect_dispatch) {
408- const TensorShape indirect_buffer_shape{3 }; // 3 uint32 values for dispatch dimensions
409- indirect_buffer = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), indirect_buffer_shape);
410- indirect_buffer_ptr = &indirect_buffer;
411- }
412416
413417 ORT_RETURN_IF_ERROR (RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV (context, parameters,
414418 Q, seqlen_k,
415419 cos_cache, sin_cache,
416420 &query_output, present_key, present_value,
417- indirect_buffer_ptr));
421+ indirect_buffer_ptr, tile_size ));
418422 Q = &query_output;
419423 K = present_key;
420424 V = present_value;
425+ } else {
426+ ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr , use_indirect_dispatch ? indirect_buffer_ptr : nullptr ));
421427 }
422428
423429 if (parameters.sequence_length_ > 1 ) {
424- const uint32_t tile_size = 64 ;
425- // For encode path, copy KV if not using fused operation
426- if (!use_fused_split_rotary_copykv) {
427- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr , nullptr ));
428- }
429430 bool has_attention_bias = attention_bias != nullptr ;
430431 bool is_qualcomm = context.AdapterInfo ().vendor == std::string_view{" qualcomm" };
431432 bool is_nvidia = context.AdapterInfo ().vendor == std::string_view{" nvidia" };
@@ -470,28 +471,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
470471 parameters.sequence_length_ , present_sequence_length});
471472 const TensorShape qk_shape (qk_dims);
472473 Tensor qk = context.CreateGPUTensor (Q->DataType (), qk_shape);
473- constexpr uint32_t tile_size = 64 ;
474474 const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1 ) / tile_size;
475475 const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1 ) / tile_size;
476476
477- // Determine if we should use indirect dispatch
478- const bool use_indirect_dispatch = parameters.past_present_share_buffer_ &&
479- seqlen_k != nullptr &&
480- context.IsGraphCaptureEnabled ();
481-
482- if (!use_fused_split_rotary_copykv) {
483- if (use_indirect_dispatch) {
484- const TensorShape indirect_buffer_shape{3 }; // 3 uint32 values for dispatch dimensions
485- indirect_buffer = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), indirect_buffer_shape);
486- indirect_buffer_ptr = &indirect_buffer;
487- // Use the fused CopyKVCache that also prepares the indirect dispatch buffer
488- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
489- } else {
490- // Use the original CopyKVCache without indirect dispatch preparation
491- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, nullptr , nullptr ));
492- }
493- }
494-
495477 // The metadata is used to store the max and sum of each tile.
496478 const TensorShapeVector metadata_dims ({parameters.batch_size_ , parameters.num_heads_ ,
497479 num_present_sequence_length_tile, 2 });
@@ -539,7 +521,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
539521 Tensor* query,
540522 Tensor* present_key,
541523 Tensor* present_value,
542- Tensor* indirect_buffer) {
524+ Tensor* indirect_buffer,
525+ uint32_t tile_size) {
543526 const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t >(cos_cache->Shape ()[1 ]);
544527 const auto head_size = params.head_size_ ;
545528
@@ -567,8 +550,6 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
567550
568551 const bool prepare_indirect_dispatch = (indirect_buffer != nullptr );
569552
570- constexpr uint32_t tile_size = 64 ;
571-
572553 SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program (params.rotary_interleaved_ , prepare_indirect_dispatch);
573554 program
574555 .CacheHint (params.rotary_interleaved_ , prepare_indirect_dispatch)
0 commit comments