Skip to content

Commit 6b27fe2

Browse files
committed
put copykv cache in advance
1 parent c32777f commit 6b27fe2

File tree

3 files changed

+28
-46
lines changed

3 files changed

+28
-46
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -379,53 +379,54 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
379379
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
380380
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k,
381381
const Tensor* cos_cache, const Tensor* sin_cache) {
382+
constexpr uint32_t tile_size = 64;
383+
382384
// Extract present_sequence_length directly from present_key tensor shape:
383385
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
384386
const uint32_t present_sequence_length = static_cast<uint32_t>(present_key->Shape()[2]);
385387

386388
const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled();
387389

388-
// Determine if we should use fused split packed QKV with rotary embedding based on cos_cache and sin_cache
389-
const bool use_fused_split_rotary_copykv = (cos_cache != nullptr && sin_cache != nullptr);
390-
391390
// Declare query_output at function scope to ensure it persists throughout the function
392391
Tensor query_output;
393392

394393
// Create indirect dispatch buffer if using indirect dispatch
395394
Tensor* indirect_buffer_ptr = nullptr;
396395
Tensor indirect_buffer;
397-
// Handle fused split packed QKV with rotary embedding and copy KV if requested
398-
if (use_fused_split_rotary_copykv) {
396+
397+
// Prepare indirect dispatch buffer for decode path with static KV cache
398+
const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
399+
parameters.past_present_share_buffer_ &&
400+
seqlen_k != nullptr &&
401+
context.IsGraphCaptureEnabled();
402+
if (use_indirect_dispatch) {
403+
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
404+
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
405+
indirect_buffer_ptr = &indirect_buffer;
406+
}
407+
408+
const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr);
409+
410+
if (do_rotary) {
411+
ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input.");
412+
ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache.");
413+
399414
// Q points to the packed QKV tensor in this case, create query output tensor
400415
query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
401-
// For decode path (sequence_length == 1), may prepare indirect dispatch if needed
402-
// Prepare indirect dispatch buffer for decode path with static KV cache
403-
const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
404-
parameters.past_present_share_buffer_ &&
405-
seqlen_k != nullptr &&
406-
context.IsGraphCaptureEnabled();
407-
if (use_indirect_dispatch) {
408-
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
409-
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
410-
indirect_buffer_ptr = &indirect_buffer;
411-
}
412416

413417
ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters,
414418
Q, seqlen_k,
415419
cos_cache, sin_cache,
416420
&query_output, present_key, present_value,
417-
indirect_buffer_ptr));
421+
indirect_buffer_ptr, tile_size));
418422
Q = &query_output;
419423
K = present_key;
420424
V = present_value;
425+
} else {
426+
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, use_indirect_dispatch ? indirect_buffer_ptr : nullptr));
421427
}
422428

423429
if (parameters.sequence_length_ > 1) {
424-
const uint32_t tile_size = 64;
425-
// For encode path, copy KV if not using fused operation
426-
if (!use_fused_split_rotary_copykv) {
427-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr));
428-
}
429430
bool has_attention_bias = attention_bias != nullptr;
430431
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
431432
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
@@ -470,28 +471,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
470471
parameters.sequence_length_, present_sequence_length});
471472
const TensorShape qk_shape(qk_dims);
472473
Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape);
473-
constexpr uint32_t tile_size = 64;
474474
const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size;
475475
const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size;
476476

477-
// Determine if we should use indirect dispatch
478-
const bool use_indirect_dispatch = parameters.past_present_share_buffer_ &&
479-
seqlen_k != nullptr &&
480-
context.IsGraphCaptureEnabled();
481-
482-
if (!use_fused_split_rotary_copykv) {
483-
if (use_indirect_dispatch) {
484-
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
485-
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
486-
indirect_buffer_ptr = &indirect_buffer;
487-
// Use the fused CopyKVCache that also prepares the indirect dispatch buffer
488-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
489-
} else {
490-
// Use the original CopyKVCache without indirect dispatch preparation
491-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, nullptr, nullptr));
492-
}
493-
}
494-
495477
// The metadata is used to store the max and sum of each tile.
496478
const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_,
497479
num_present_sequence_length_tile, 2});
@@ -539,7 +521,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
539521
Tensor* query,
540522
Tensor* present_key,
541523
Tensor* present_value,
542-
Tensor* indirect_buffer) {
524+
Tensor* indirect_buffer,
525+
uint32_t tile_size) {
543526
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
544527
const auto head_size = params.head_size_;
545528

@@ -567,8 +550,6 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
567550

568551
const bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
569552

570-
constexpr uint32_t tile_size = 64;
571-
572553
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch);
573554
program
574555
.CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch)

onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
187187
Tensor* query,
188188
Tensor* present_key,
189189
Tensor* present_value,
190-
Tensor* indirect_buffer);
190+
Tensor* indirect_buffer,
191+
uint32_t tile_size);
191192
} // namespace webgpu
192193
} // namespace contrib
193194
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
292292

293293
if (parameters.is_packed_qkv_ && do_rotary_) {
294294
// Use the ultimate fused operation when FlashAttention and static KV cache is enabled.
295-
if (will_use_flash_attention && parameters.past_present_share_buffer_) {
295+
if (will_use_flash_attention && !parameters.past_present_share_buffer_) {
296296
// Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled
297297
// query points to packed QKV, K and V are nullptr since they're not needed
298298
return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value,

0 commit comments

Comments
 (0)