Skip to content

Commit edfb6f5

Browse files
fs-eireCopilot
andauthored
[webgpu] revise implementation of buffer split support (#26429)
### Description This PR addresses a few concerns: - revert `const ProgramBase&` -> `ProgramBase&`: this itself is not doing something wrong but gives much more pressure for who reads the code to understand whether/where the program object is modified. It also can introduce further unexpected modifications to the program object (for example the indirect dispatch code) - change bool option `"ep.webgpuexecutionprovider.smallStorageBufferBindingSizeForTesting"` to `"ep.webgpuexecutionprovider.maxStorageBufferBindingSize"` so now it's possible to set any value in option. (setting to <128MB will cause an assert failure) - segments are optional in cache key (only present when not equals to 1, which is the common case) - avoid some unnecessary API calls, which is OK for native but may affect web perf. - clean up the code a little bit and add a few comments --------- Co-authored-by: Copilot <[email protected]>
1 parent 4994ccc commit edfb6f5

File tree

14 files changed

+136
-86
lines changed

14 files changed

+136
-86
lines changed

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class ComputeContext {
123123
//
124124
// Run a compute shader program.
125125
//
126-
inline Status RunProgram(ProgramBase& program) {
126+
inline Status RunProgram(const ProgramBase& program) {
127127
return webgpu_context_.Run(*this, program);
128128
}
129129

onnxruntime/core/providers/webgpu/program.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t
362362

363363
ProgramBase& ProgramBase::SetIndirectDispatchTensor(const Tensor* indirect_dispatch_tensor) {
364364
indirect_dispatch_tensor_ = indirect_dispatch_tensor;
365+
AddInput({indirect_dispatch_tensor, ProgramTensorMetadataDependency::None});
365366
return *this;
366367
}
367368

onnxruntime/core/providers/webgpu/program.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ struct ProgramInput {
226226
ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component);
227227

228228
const Tensor* tensor;
229-
uint32_t segments = 1;
230229
ProgramTensorMetadataDependency dependency;
231230
ProgramVariableDataType var_type;
232231
bool use_override_shape;
@@ -246,7 +245,6 @@ struct ProgramOutput {
246245
ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component);
247246

248247
Tensor* tensor;
249-
uint32_t segments = 1;
250248
ProgramTensorMetadataDependency dependency;
251249
ProgramVariableDataType var_type;
252250
bool is_atomic;
@@ -348,18 +346,6 @@ class ProgramBase {
348346
inline const ProgramMetadata& Metadata() const { return metadata_; }
349347
inline const std::string& CacheHint() const { return cache_hint_; }
350348
inline const std::vector<ProgramInput>& Inputs() const { return inputs_; }
351-
inline void setSegmentsForInput(size_t index, uint32_t segments) {
352-
if (index >= inputs_.size()) {
353-
throw std::out_of_range("input index out of range");
354-
}
355-
inputs_[index].segments = segments;
356-
}
357-
inline void setSegmentsForOutput(size_t index, uint32_t segments) {
358-
if (index >= outputs_.size()) {
359-
throw std::out_of_range("output index out of range");
360-
}
361-
outputs_[index].segments = segments;
362-
}
363349
inline const std::vector<ProgramOutput>& Outputs() const { return outputs_; }
364350
inline const std::vector<TensorShape>& Indices() const { return indices_; }
365351
inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; }

onnxruntime/core/providers/webgpu/program_cache_key.cc

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ namespace webgpu {
1717

1818
namespace {
1919
// append the info of an input or output to the cachekey
20-
void AppendTensorInfo(std::ostream& ss, const TensorShape& tensor_shape, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,
21-
bool& first, uint32_t segments = 1) {
20+
void AppendTensorInfo(std::ostream& ss,
21+
const TensorShape& tensor_shape,
22+
ProgramVariableDataType var_type,
23+
ProgramTensorMetadataDependency dependency,
24+
bool& first,
25+
uint32_t segments) {
2226
if (first) {
2327
first = false;
2428
} else {
@@ -34,7 +38,9 @@ void AppendTensorInfo(std::ostream& ss, const TensorShape& tensor_shape, Program
3438
ss << ';';
3539
}
3640

37-
ss D("Segs=") << segments << ';';
41+
if (segments != 1) {
42+
ss D("Segs=") << segments << ';';
43+
}
3844

3945
if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) {
4046
ss D("Dims=") << tensor_shape.ToString();
@@ -44,7 +50,10 @@ void AppendTensorInfo(std::ostream& ss, const TensorShape& tensor_shape, Program
4450
}
4551
} // namespace
4652

47-
std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) {
53+
std::string CalculateProgramCacheKey(const ProgramBase& program,
54+
std::span<uint32_t> inputs_segments,
55+
std::span<uint32_t> outputs_segments,
56+
bool is_1d_dispatch) {
4857
SS(ss, kStringInitialSizeCacheKey);
4958

5059
// final key format:
@@ -56,7 +65,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp
5665
// <UNIFORMS> = <UNIFORMS_INFO_0>|<UNIFORMS_INFO_1>|...
5766
// <UNIFORMS_INFO_i> = <UNIFORM_LENGTH>
5867
// <INPUTS_INFO> = <INPUTS_INFO_0>|<INPUTS_INFO_1>|...
59-
// <INPUTS_INFO_i> = <TENSOR_ELEMENT_TYPE_OR_EMPTY>;<TENSOR_SHAPE_OR_RANK_OR_EMPTY>
68+
// <INPUTS_INFO_i> = <TENSOR_ELEMENT_TYPE_OR_EMPTY>;<TENSOR_SEGMENTS_OR_EMPTY>;<TENSOR_SHAPE_OR_RANK_OR_EMPTY>
6069
ss << program.Name();
6170

6271
// append custom cache hint if any
@@ -98,19 +107,26 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp
98107

99108
ss << ":" D("Inputs=");
100109
first = true;
101-
for (const auto& input : program.Inputs()) {
102-
AppendTensorInfo(ss, input.use_override_shape ? input.override_shape : input.tensor->Shape(), input.var_type, input.dependency, first, input.segments);
110+
for (size_t i = 0; i < program.Inputs().size(); i++) {
111+
const auto& input = program.Inputs()[i];
112+
AppendTensorInfo(ss,
113+
input.use_override_shape ? input.override_shape : input.tensor->Shape(),
114+
input.var_type,
115+
input.dependency,
116+
first,
117+
inputs_segments[i]);
103118
}
104119

105120
ss << ":" D("Outputs=");
106121
first = true;
107-
for (const auto& output : program.Outputs()) {
122+
for (size_t i = 0; i < program.Outputs().size(); i++) {
123+
const auto& output = program.Outputs()[i];
108124
AppendTensorInfo(ss,
109125
output.use_override_shape ? output.override_shape : output.tensor->Shape(),
110126
output.var_type,
111127
output.dependency,
112128
first,
113-
output.segments);
129+
outputs_segments[i]);
114130
}
115131

116132
if (!program.Indices().empty()) {

onnxruntime/core/providers/webgpu/program_cache_key.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33

44
#pragma once
55

6+
#include <span>
67
#include <string>
78

89
#include "core/providers/webgpu/program.h"
910

1011
namespace onnxruntime {
1112
namespace webgpu {
1213

13-
std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch);
14+
std::string CalculateProgramCacheKey(const ProgramBase& program,
15+
std::span<uint32_t> inputs_segments,
16+
std::span<uint32_t> outputs_segments,
17+
bool is_1d_dispatch);
1418

1519
} // namespace webgpu
1620
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/program_manager.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,30 +39,35 @@ Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint
3939
return Status::OK();
4040
}
4141

42-
Status ProgramManager::CalculateSegmentsForInputsAndOutputs(ProgramBase& program) {
42+
Status ProgramManager::CalculateSegmentsForInputsAndOutputs(const ProgramBase& program, std::vector<uint32_t>& inputs_segments, std::vector<uint32_t>& outputs_segments) const {
43+
inputs_segments.resize(program.Inputs().size(), 1);
44+
outputs_segments.resize(program.Outputs().size(), 1);
45+
4346
const uint64_t maxStorageBufferBindingSize = webgpu_context_.DeviceLimits().maxStorageBufferBindingSize;
4447

4548
// Inputs
4649
for (size_t i = 0; i < program.Inputs().size(); ++i) {
4750
const auto& input = program.Inputs()[i];
4851
if (input.tensor && input.tensor->SizeInBytes() > maxStorageBufferBindingSize) {
4952
uint32_t segments = static_cast<uint32_t>((input.tensor->SizeInBytes() + maxStorageBufferBindingSize - 1) / maxStorageBufferBindingSize);
50-
program.setSegmentsForInput(i, segments);
53+
inputs_segments[i] = segments;
5154
}
5255
}
5356
// Outputs
5457
for (size_t i = 0; i < program.Outputs().size(); ++i) {
5558
const auto& output = program.Outputs()[i];
5659
if (output.tensor && output.tensor->SizeInBytes() > maxStorageBufferBindingSize) {
5760
uint32_t segments = static_cast<uint32_t>((output.tensor->SizeInBytes() + maxStorageBufferBindingSize - 1) / maxStorageBufferBindingSize);
58-
program.setSegmentsForOutput(i, segments);
61+
outputs_segments[i] = segments;
5962
}
6063
}
6164
return Status::OK();
6265
}
6366

6467
Status ProgramManager::Build(const ProgramBase& program,
6568
const ProgramMetadata& program_metadata,
69+
const std::span<uint32_t> inputs_segments,
70+
const std::span<uint32_t> outputs_segments,
6671
#ifndef NDEBUG // if debug build
6772
const std::string& program_key,
6873
#endif
@@ -74,6 +79,8 @@ Status ProgramManager::Build(const ProgramBase& program,
7479
auto& device = webgpu_context_.Device();
7580
ShaderHelper shader_helper{program,
7681
program_metadata,
82+
inputs_segments,
83+
outputs_segments,
7784
device,
7885
webgpu_context_.DeviceLimits(),
7986
normalized_dispatch_x,
@@ -83,8 +90,10 @@ Status ProgramManager::Build(const ProgramBase& program,
8390

8491
ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper));
8592

86-
// Finalize inputs after GenerateShaderCode() to ensure indirect buffer is added as the last input
87-
shader_helper.FinalizeInputs();
93+
// Add indirect buffer as the last shader input when using indirect dispatch.
94+
if (program.IndirectDispatchTensor() != nullptr) {
95+
shader_helper.AddInput("indirect_buffer", ShaderUsage::None);
96+
}
8897

8998
ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs());
9099
ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs());

onnxruntime/core/providers/webgpu/program_manager.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <span>
67
#include <string>
78
#include <unordered_map>
89

@@ -38,10 +39,12 @@ class ProgramManager {
3839
ProgramManager(WebGpuContext& webgpu_context) : webgpu_context_(webgpu_context) {}
3940

4041
Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const;
41-
Status CalculateSegmentsForInputsAndOutputs(ProgramBase& program);
42+
Status CalculateSegmentsForInputsAndOutputs(const ProgramBase& program, std::vector<uint32_t>& inputs_segments, std::vector<uint32_t>& outputs_segments) const;
4243

4344
Status Build(const ProgramBase& program,
4445
const ProgramMetadata& metadata,
46+
const std::span<uint32_t> inputs_segments,
47+
const std::span<uint32_t> outputs_segments,
4548
#ifndef NDEBUG // if debug build
4649
const std::string& program_key,
4750
#endif

onnxruntime/core/providers/webgpu/shader_helper.cc

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ namespace webgpu {
1818

1919
ShaderHelper::ShaderHelper(const ProgramBase& program,
2020
const ProgramMetadata& program_metadata,
21+
const std::span<uint32_t> inputs_segments,
22+
const std::span<uint32_t> outputs_segments,
2123
const wgpu::Device& device,
2224
const wgpu::Limits& limits,
2325
uint32_t dispatch_group_size_x,
2426
uint32_t dispatch_group_size_y,
2527
uint32_t dispatch_group_size_z)
2628
: device_{device},
2729
limits_{limits},
30+
inputs_segments_{inputs_segments},
31+
outputs_segments_{outputs_segments},
2832
dispatch_group_size_x_{dispatch_group_size_x},
2933
dispatch_group_size_y_{dispatch_group_size_y},
3034
dispatch_group_size_z_{dispatch_group_size_z},
@@ -95,21 +99,14 @@ Status ShaderHelper::Init() {
9599
return Status::OK();
96100
}
97101

98-
void ShaderHelper::FinalizeInputs() {
99-
// Automatically add indirect buffer as the last shader input when using indirect dispatch.
100-
if (program_.IndirectDispatchTensor() != nullptr) {
101-
AddInput("indirect_buffer", ShaderUsage::None);
102-
}
103-
}
104-
105102
const ShaderVariableHelper& ShaderHelper::AddInput(const std::string& name, ShaderUsage usage) {
106103
const size_t input_index = input_vars_.size();
107104
ORT_ENFORCE(input_index < program_.Inputs().size(),
108105
"Too many inputs in the program (", program_.Inputs().size(), ")");
109106

110107
const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape
111108
: program_.Inputs()[input_index].tensor->Shape();
112-
return AddVariableImpl(true, name, usage, dims, program_.Inputs()[input_index].segments);
109+
return AddVariableImpl(true, name, usage, dims, inputs_segments_[input_index]);
113110
}
114111

115112
const ShaderVariableHelper& ShaderHelper::AddOutput(const std::string& name, ShaderUsage usage) {
@@ -119,7 +116,7 @@ const ShaderVariableHelper& ShaderHelper::AddOutput(const std::string& name, Sha
119116

120117
const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape
121118
: program_.Outputs()[output_index].tensor->Shape();
122-
return AddVariableImpl(false, name, usage, dims, program_.Outputs()[output_index].segments);
119+
return AddVariableImpl(false, name, usage, dims, outputs_segments_[output_index]);
123120
}
124121

125122
const ShaderIndicesHelper& ShaderHelper::AddIndices(const std::string& name, ShaderUsage usage) {

onnxruntime/core/providers/webgpu/shader_helper.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <span>
67
#include <sstream>
78

89
#include "core/providers/webgpu/webgpu_external_header.h"
@@ -67,6 +68,8 @@ class ShaderHelper final {
6768
public:
6869
ShaderHelper(const ProgramBase& program,
6970
const ProgramMetadata& program_metadata,
71+
const std::span<uint32_t> inputs_segments,
72+
const std::span<uint32_t> outputs_segments,
7073
const wgpu::Device& device,
7174
const wgpu::Limits& limits,
7275
uint32_t dispatch_group_size_x,
@@ -75,11 +78,6 @@ class ShaderHelper final {
7578

7679
Status Init();
7780

78-
// Finalize inputs by automatically adding the indirect buffer if needed.
79-
// This should be called after GenerateShaderCode() to ensure the indirect buffer
80-
// is registered as the last input.
81-
void FinalizeInputs();
82-
8381
// Add an input variable to the shader.
8482
//
8583
// depending on the usage of the variable, additional code may be generated.
@@ -164,6 +162,8 @@ class ShaderHelper final {
164162

165163
const wgpu::Device& device_;
166164
const wgpu::Limits& limits_;
165+
const std::span<uint32_t> inputs_segments_;
166+
const std::span<uint32_t> outputs_segments_;
167167
uint32_t dispatch_group_size_x_;
168168
uint32_t dispatch_group_size_y_;
169169
uint32_t dispatch_group_size_z_;

0 commit comments

Comments
 (0)