Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/bias_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/math/unary_elementwise_ops.h"
#include "contrib_ops/webgpu/bert/bias_gelu.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
BiasGelu,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
BiasGelu);

Status BiasGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform);

shader.AdditionalImplementation() << onnxruntime::webgpu::ErfImpl;
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")
<< " var a = " << x.GetByOffset("global_idx") << ";\n";

// Add bias to input
if (bias_components_ == 1) {
shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n"
" a += x_value_t("
<< bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n";
} else {
shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n";
}

// Apply GELU activation: 0.5 * a * (1.0 + erf(a * 0.7071067811865475))
shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::GeluExpr);

return Status::OK();
}

Status BiasGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const auto* input = context.Input(0);
const auto* bias = context.Input(1);
auto* output = context.Output(0, input->Shape());

uint32_t data_size = onnxruntime::narrow<uint32_t>(output->Shape().Size());
if (data_size == 0) {
return Status::OK();
}

const auto& input_shape = input->Shape();
const auto& bias_shape = bias->Shape();

// Validate inputs
if (input_shape.NumDimensions() < 1 || bias_shape.NumDimensions() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"BiasGelu: input must have at least 1 dimension and bias must be 1-dimensional.");
}

if (input_shape.GetDims().back() != bias_shape.GetDims().back()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"BiasGelu: bias must match the last dimension of input.");
}

const auto vec_size = (data_size + 3) / 4;
uint32_t bias_size = onnxruntime::narrow<uint32_t>(bias->Shape().Size());
int bias_components = 1;

if (bias_size % 4 == 0) {
bias_components = 4;
bias_size = bias_size / 4;
}

BiasGeluProgram program{bias_components};
program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components})
.AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4})
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariable({vec_size});

return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
38 changes: 38 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/bias_gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/bert/bias_gelu.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/bias_gelu.h:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

class BiasGeluProgram final : public Program<BiasGeluProgram> {
public:
BiasGeluProgram(int bias_components) : Program{"BiasGelu"}, bias_components_{bias_components} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
int bias_components_;
};

class BiasGelu final : public WebGpuKernel {
public:
BiasGelu(const OpKernelInfo& info) : WebGpuKernel(info) {}

Status ComputeInternal(ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace webgpu {

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv);
Expand Down Expand Up @@ -42,6 +43,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized)>,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/contrib_ops/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ TEST(BiasGeluTest, Float) {
RunBiasGeluTestFloat({2, 2333}, {2333});
}

#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU)
static void RunBiasGeluTestHalf(const std::vector<int64_t>& input_dims, const std::vector<int64_t>& bias_dims) {
RandomValueGenerator random{2333};
std::vector<float> input_data = random.Uniform<float>(input_dims, -1.0f, 1.0f);
Expand Down
Loading