Skip to content

Commit 667f22b

Browse files
committed
Merge remote-tracking branch 'origin/main' into hari/mlas_silu
2 parents a89867d + a83a158 commit 667f22b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1273
-354
lines changed

dockerfiles/Dockerfile.migraphx

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,17 @@
55
# Dockerfile to run ONNXRuntime with MIGraphX integration
66
#--------------------------------------------------------------------------
77

8-
FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0
8+
FROM rocm/pytorch:rocm7.1_ubuntu24.04_py3.12_pytorch_release_2.9.1
99

10-
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
10+
ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime
1111
ARG ONNXRUNTIME_BRANCH=main
1212

13-
ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH}
14-
15-
RUN apt-get update &&\
16-
apt-get install -y migraphx
17-
1813
WORKDIR /code
1914

2015
# Prepare onnxruntime repository & build onnxruntime
2116
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
2217
/bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\
2318
cd onnxruntime && pip install --upgrade pip &&\
24-
/bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` --config Release --parallel \
25-
--skip_tests --build_wheel --use_migraphx &&\
19+
/bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` \
20+
--config Release --parallel --skip_tests --build_wheel --use_migraphx &&\
2621
pip install /code/onnxruntime/build/Linux/Release/dist/*.whl

include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
* - `kDeviceId`: Specifies the GPU device ID to use.
88
* - `kHasUserComputeStream`: Indicates whether a user-provided compute stream is used.
99
* - `kUserComputeStream`: Specifies the user-provided compute stream.
10+
* - `kUserAuxStreamArray`: Specifies the user-provided aux stream.
1011
* - `kMaxWorkspaceSize`: Sets the maximum workspace size for GPU memory allocation.
1112
* - 'kMaxSharedMemSize': Sets the maximum amount of shared memory that TensorRT kernels are allowed to use
13+
* - `kLengthAuxStreamArray`: Specifies the length/size of the auxiliary streams array (kUserAuxStreamArray). Also sets the maximum number of auxiliary streams for TensorRT execution.
1214
* - `kDumpSubgraphs`: Enables or disables dumping of subgraphs for debugging.
1315
* - `kDetailedBuildLog`: Enables or disables detailed build logs for debugging.
1416
* - `kProfilesMinShapes`: Specifies the minimum shapes for profiling.
@@ -24,8 +26,10 @@ namespace provider_option_names {
2426
constexpr const char* kDeviceId = "device_id";
2527
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
2628
constexpr const char* kUserComputeStream = "user_compute_stream";
29+
constexpr const char* kUserAuxStreamArray = "user_aux_stream_array";
2730
constexpr const char* kMaxWorkspaceSize = "nv_max_workspace_size";
2831
constexpr const char* kMaxSharedMemSize = "nv_max_shared_mem_size";
32+
constexpr const char* kLengthAuxStreamArray = "nv_length_aux_stream_array";
2933
constexpr const char* kDumpSubgraphs = "nv_dump_subgraphs";
3034
constexpr const char* kDetailedBuildLog = "nv_detailed_build_log";
3135
constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes";

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "contrib_ops/webgpu/bert/attention.h"
55

66
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
7+
#include "contrib_ops/webgpu/bert/flash_attention.h"
78
#include "contrib_ops/webgpu/bert/multihead_attention.h"
89
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
910
#include "core/providers/webgpu/webgpu_supported_types.h"
@@ -736,6 +737,19 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
736737
// Compute Q, K, V from input, weights, and bias
737738
ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V));
738739

740+
// Check if we can use flash attention
741+
// For Attention operator, we need to create present_key and present_value tensors for flash attention
742+
// even though they are not exposed as outputs
743+
TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.num_heads_,
744+
parameters.total_sequence_length_, parameters.head_size_});
745+
Tensor present_key = context.CreateGPUTensor(input->DataType(), present_kv_shape);
746+
Tensor present_value = context.CreateGPUTensor(input->DataType(), present_kv_shape);
747+
748+
if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) {
749+
return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value,
750+
parameters, context, nullptr);
751+
}
752+
739753
// Apply the actual attention computation
740754
return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr,
741755
/* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1);

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
7676
} else {
7777
shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n";
7878
}
79-
shader.MainFunctionBody() << "let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n";
79+
shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n";
80+
if (past_present_share_buffer_) {
81+
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n";
82+
} else {
83+
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n";
84+
}
8085

8186
// Add indirect dispatch logic for thread 0
8287
if (prepare_indirect_dispatch_) {
@@ -93,8 +98,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
9398
if (has_past_) {
9499
const auto& past_key = shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
95100
shader.AddInput("past_value", ShaderUsage::UseUniform);
96-
shader.MainFunctionBody() << "let present_offset = global_idx;"
97-
<< "if (sequence_id < past_sequence_length) {\n"
101+
shader.MainFunctionBody() << "if (sequence_id < past_sequence_length) {\n"
98102
<< " let pastOffset = " << past_key.IndicesToOffset("past_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n"
99103
<< " " << present_key.SetByOffset("present_offset", "past_key[pastOffset]") << ";\n"
100104
<< " " << present_value.SetByOffset("present_offset", "past_value[pastOffset]") << ";\n"
@@ -104,8 +108,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
104108
<< " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n"
105109
<< "}";
106110
} else {
107-
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"
108-
<< " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n"
111+
shader.MainFunctionBody() << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n"
109112
<< " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n"
110113
<< " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n";
111114
}
@@ -134,10 +137,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
134137
// Determine if we need to prepare indirect dispatch
135138
bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
136139
bool use_seqlen_k = (seqlen_k != nullptr);
137-
138-
CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH,
140+
bool kv_BNSH = parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters.qkv_format_ == Q_K_V_BNSH;
141+
CopyKVCacheProgram program{"CopyKVCache", has_past, kv_BNSH, parameters.past_present_share_buffer_,
139142
prepare_indirect_dispatch, use_seqlen_k};
140-
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) {
143+
if (kv_BNSH) {
141144
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
142145
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
143146
} else {
@@ -207,6 +210,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
207210
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
208211
WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
209212
WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_),
213+
WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_),
210214
WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_),
211215
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_),
212216
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_));
@@ -256,10 +260,20 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
256260
{metadata, ProgramTensorMetadataDependency::Rank, 2}});
257261

258262
const uint32_t vectorized_head_size = parameters.head_size_ / components;
263+
264+
// Get attention bias dimensions for broadcasting
265+
uint32_t attn_bias_dim0 = 1;
266+
uint32_t attn_bias_dim1 = 1;
267+
if (has_attention_bias) {
268+
const auto& bias_shape = attention_bias->Shape();
269+
attn_bias_dim0 = static_cast<uint32_t>(bias_shape[0]);
270+
attn_bias_dim1 = static_cast<uint32_t>(bias_shape[1]);
271+
}
272+
259273
if (use_indirect_dispatch) {
260274
program.SetIndirectDispatchTensor(indirect_buffer);
261275
} else {
262-
program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile);
276+
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile);
263277
}
264278
program.SetWorkgroupSize(64)
265279
.CacheHint(tile_size, has_attention_bias, use_indirect_dispatch)
@@ -269,7 +283,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
269283
present_sequence_length,
270284
{static_cast<uint32_t>(parameters.n_reps)},
271285
{num_present_sequence_length_tile},
272-
{static_cast<uint32_t>(parameters.num_heads_)}});
286+
{static_cast<uint32_t>(parameters.num_heads_)},
287+
{static_cast<uint32_t>(parameters.batch_size_)},
288+
{attn_bias_dim0},
289+
{attn_bias_dim1}});
273290

274291
return context.RunProgram(program);
275292
}
@@ -313,11 +330,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
313330
{qk, ProgramTensorMetadataDependency::TypeAndRank},
314331
{present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
315332
program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
333+
const uint32_t batch_heads = static_cast<uint32_t>(parameters.batch_size_ * parameters.num_heads_);
316334
if (use_indirect_dispatch) {
317335
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None})
318336
.SetIndirectDispatchTensor(indirect_buffer);
319337
} else {
320-
program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile);
338+
program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile);
321339
}
322340
program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch)
323341
.SetWorkgroupSize(64)
@@ -326,7 +344,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
326344
present_sequence_length,
327345
{static_cast<uint32_t>(parameters.n_reps)},
328346
num_present_sequence_length_tile,
329-
{static_cast<uint32_t>(parameters.num_heads_)}});
347+
{batch_heads}});
330348

331349
return context.RunProgram(program);
332350
}
@@ -363,14 +381,15 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&
363381
}
364382
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}});
365383
const uint32_t num_head_size_tile = static_cast<uint32_t>((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size);
366-
program.SetDispatchGroupSize(parameters.num_heads_ * num_head_size_tile)
384+
const uint32_t batch_heads = static_cast<uint32_t>(parameters.batch_size_ * parameters.num_heads_);
385+
program.SetDispatchGroupSize(batch_heads * num_head_size_tile)
367386
.CacheHint(tile_size, seq_tile_size, use_indirect_dispatch)
368387
.SetWorkgroupSize(tile_size * tile_size)
369388
.AddUniformVariables({{static_cast<uint32_t>(parameters.v_head_size_ / components)},
370389
num_total_seq_length_tile,
371390
num_present_sequence_length_tile,
372391
{num_head_size_tile},
373-
{static_cast<uint32_t>(parameters.num_heads_)}});
392+
{batch_heads}});
374393

375394
return context.RunProgram(program);
376395
}
@@ -429,6 +448,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
429448
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
430449
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
431450
bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
451+
bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
432452
FlashAttentionProgram program{"FlashAttention",
433453
has_attention_bias,
434454
is_qualcomm,
@@ -437,6 +457,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
437457
parameters.num_heads_,
438458
parameters.is_unidirectional_,
439459
is_nvidia,
460+
q_BNSH,
440461
use_seqlen_k};
441462
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
442463
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
@@ -451,15 +472,28 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
451472
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
452473
: parameters.scale_;
453474
const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
454-
program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile)
475+
476+
// Get attention bias dimensions for broadcasting
477+
uint32_t attn_bias_dim0 = 1;
478+
uint32_t attn_bias_dim1 = 1;
479+
if (has_attention_bias) {
480+
const auto& bias_shape = attention_bias->Shape();
481+
attn_bias_dim0 = static_cast<uint32_t>(bias_shape[0]);
482+
attn_bias_dim1 = static_cast<uint32_t>(bias_shape[1]);
483+
}
484+
485+
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile)
455486
.SetWorkgroupSize(tile_size)
456-
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k)
487+
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k)
457488
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
458489
{static_cast<uint32_t>(parameters.total_sequence_length_)},
459490
{static_cast<uint32_t>(present_sequence_length)},
491+
{static_cast<uint32_t>(parameters.batch_size_)},
460492
{static_cast<uint32_t>(parameters.n_reps)},
461493
{alpha},
462-
{num_seq_tile}});
494+
{num_seq_tile},
495+
{attn_bias_dim0},
496+
{attn_bias_dim1}});
463497

464498
return context.RunProgram(program);
465499
}
@@ -500,8 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
500534

501535
bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
502536
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
503-
return parameters.batch_size_ == 1 &&
504-
!parameters.is_packed_qkv_ &&
537+
return !parameters.is_packed_qkv_ &&
505538
parameters.head_size_ == parameters.v_head_size_ &&
506539
bias == nullptr &&
507540
context.HasFeature(wgpu::FeatureName::Subgroups) &&

0 commit comments

Comments
 (0)