Skip to content

Commit d946ff6

Browse files
fs-eireankitm3k
authored andcommitted
[webgpu] Add implementation of BiasGelu (microsoft#26560)
### Description Add implementation of BiasGelu
1 parent 781e3b7 commit d946ff6

File tree

4 files changed

+136
-1
lines changed

4 files changed

+136
-1
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/webgpu_supported_types.h"
6+
#include "core/providers/webgpu/math/unary_elementwise_ops.h"
7+
#include "contrib_ops/webgpu/bert/bias_gelu.h"
8+
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
9+
10+
namespace onnxruntime {
11+
namespace contrib {
12+
namespace webgpu {
13+
14+
ONNX_OPERATOR_KERNEL_EX(
15+
BiasGelu,
16+
kMSDomain,
17+
1,
18+
kWebGpuExecutionProvider,
19+
(*KernelDefBuilder::Create())
20+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
21+
BiasGelu);
22+
23+
Status BiasGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
24+
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
25+
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
26+
const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform);
27+
28+
shader.AdditionalImplementation() << onnxruntime::webgpu::ErfImpl;
29+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")
30+
<< " var a = " << x.GetByOffset("global_idx") << ";\n";
31+
32+
// Add bias to input
33+
if (bias_components_ == 1) {
34+
shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n"
35+
" a += x_value_t("
36+
<< bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", "
37+
<< bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", "
38+
<< bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", "
39+
<< bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n";
40+
} else {
41+
shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n";
42+
}
43+
44+
// Apply GELU activation: 0.5 * a * (1.0 + erf(a * 0.7071067811865475))
45+
shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::GeluExpr);
46+
47+
return Status::OK();
48+
}
49+
50+
Status BiasGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
51+
const auto* input = context.Input(0);
52+
const auto* bias = context.Input(1);
53+
auto* output = context.Output(0, input->Shape());
54+
55+
uint32_t data_size = onnxruntime::narrow<uint32_t>(output->Shape().Size());
56+
if (data_size == 0) {
57+
return Status::OK();
58+
}
59+
60+
const auto& input_shape = input->Shape();
61+
const auto& bias_shape = bias->Shape();
62+
63+
// Validate inputs
64+
if (input_shape.NumDimensions() < 1 || bias_shape.NumDimensions() != 1) {
65+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
66+
"BiasGelu: input must have at least 1 dimension and bias must be 1-dimensional.");
67+
}
68+
69+
if (input_shape.GetDims().back() != bias_shape.GetDims().back()) {
70+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
71+
"BiasGelu: bias must match the last dimension of input.");
72+
}
73+
74+
const auto vec_size = (data_size + 3) / 4;
75+
uint32_t bias_size = onnxruntime::narrow<uint32_t>(bias->Shape().Size());
76+
int bias_components = 1;
77+
78+
if (bias_size % 4 == 0) {
79+
bias_components = 4;
80+
bias_size = bias_size / 4;
81+
}
82+
83+
BiasGeluProgram program{bias_components};
84+
program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
85+
.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components})
86+
.AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4})
87+
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
88+
.AddUniformVariable({vec_size});
89+
90+
return context.RunProgram(program);
91+
}
92+
93+
} // namespace webgpu
94+
} // namespace contrib
95+
} // namespace onnxruntime
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/program.h"
7+
#include "core/providers/webgpu/webgpu_kernel.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace webgpu {
12+
13+
using namespace onnxruntime::webgpu;
14+
using onnxruntime::webgpu::ComputeContext;
15+
16+
class BiasGeluProgram final : public Program<BiasGeluProgram> {
17+
public:
18+
BiasGeluProgram(int bias_components) : Program{"BiasGelu"}, bias_components_{bias_components} {
19+
}
20+
21+
Status GenerateShaderCode(ShaderHelper& sh) const override;
22+
23+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});
24+
25+
private:
26+
int bias_components_;
27+
};
28+
29+
class BiasGelu final : public WebGpuKernel {
30+
public:
31+
BiasGelu(const OpKernelInfo& info) : WebGpuKernel(info) {}
32+
33+
Status ComputeInternal(ComputeContext& context) const override;
34+
};
35+
36+
} // namespace webgpu
37+
} // namespace contrib
38+
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace webgpu {
1212

1313
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention);
1414
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd);
15+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasGelu);
1516
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu);
1617
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu);
1718
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv);
@@ -42,6 +43,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable
4243
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
4344
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
4445
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
46+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasGelu)>,
4547
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
4648
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
4749
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized)>,

onnxruntime/test/contrib_ops/element_wise_ops_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ TEST(BiasGeluTest, Float) {
109109
RunBiasGeluTestFloat({2, 2333}, {2333});
110110
}
111111

112-
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
112+
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU)
113113
static void RunBiasGeluTestHalf(const std::vector<int64_t>& input_dims, const std::vector<int64_t>& bias_dims) {
114114
RandomValueGenerator random{2333};
115115
std::vector<float> input_data = random.Uniform<float>(input_dims, -1.0f, 1.0f);

0 commit comments

Comments
 (0)