Skip to content

Commit ac5897d

Browse files
committed
[webgpu] update ComputeContext to make it work with PrePack
1 parent 91a9d02 commit ac5897d

File tree

10 files changed

+176
-42
lines changed

10 files changed

+176
-42
lines changed

onnxruntime/core/providers/webgpu/compute_context.cc

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,26 @@
66

77
namespace onnxruntime {
88
namespace webgpu {
9-
ComputeContext::ComputeContext(OpKernelContext& kernel_context,
10-
const OpKernel& op_kernel,
11-
const WebGpuExecutionProvider& ep,
12-
WebGpuContext& webgpu_context)
9+
10+
ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context,
11+
const WebGpuExecutionProvider& ep,
12+
const OpKernel& op_kernel)
1313
: webgpu_context_{webgpu_context},
14-
kernel_context_{kernel_context},
15-
op_kernel_{op_kernel},
16-
ep_{ep} {
14+
ep_{ep},
15+
op_kernel_{op_kernel} {
1716
}
1817

19-
const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) {
18+
const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) {
2019
return context.ep_.BufferManager();
2120
}
2221

22+
ComputeContext::ComputeContext(WebGpuContext& webgpu_context,
23+
const WebGpuExecutionProvider& ep,
24+
const OpKernel& op_kernel,
25+
OpKernelContext& kernel_context)
26+
: ComputeContextBase(webgpu_context, ep, op_kernel),
27+
kernel_context_{kernel_context} {
28+
}
29+
2330
} // namespace webgpu
2431
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ namespace webgpu {
2424
class WebGpuContext;
2525
class BufferManager;
2626

27-
class ComputeContext final {
27+
//
28+
// Class ComputeContextBase is designed to provide basic context information
29+
// for running a compute shader program.
30+
//
31+
// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created.
32+
//
33+
class ComputeContextBase {
2834
public:
2935
// Nested accessor class to provide controlled access to BufferManager
3036
class BufferManagerAccessor {
@@ -34,18 +40,31 @@ class ComputeContext final {
3440
friend class WebGpuContext;
3541

3642
private:
37-
static const webgpu::BufferManager& Get(const ComputeContext& context);
43+
static const webgpu::BufferManager& Get(const ComputeContextBase& context);
3844
};
3945

40-
ComputeContext(OpKernelContext& kernel_context,
41-
const OpKernel& op_kernel,
42-
const WebGpuExecutionProvider& ep,
43-
WebGpuContext& webgpu_context);
46+
ComputeContextBase(WebGpuContext& webgpu_context,
47+
const WebGpuExecutionProvider& ep,
48+
const OpKernel& op_kernel);
4449

45-
~ComputeContext() = default;
50+
~ComputeContextBase() = default;
51+
52+
//
53+
// Get the node name.
54+
//
55+
inline decltype(auto) NodeName() const {
56+
return op_kernel_.Node().Name();
57+
}
4658

4759
//
48-
// Get various information from the context.
60+
// Get the operator type.
61+
//
62+
inline decltype(auto) OpType() const {
63+
return op_kernel_.Node().OpType();
64+
}
65+
66+
//
67+
// Get various information from the WebGPU context.
4968
//
5069

5170
inline const wgpu::AdapterInfo& AdapterInfo() const {
@@ -57,27 +76,56 @@ class ComputeContext final {
5776
inline bool HasFeature(wgpu::FeatureName feature) const {
5877
return webgpu_context_.DeviceHasFeature(feature);
5978
}
60-
inline bool IsGraphCaptureEnabled() const {
61-
return ep_.IsGraphCaptureEnabled();
62-
}
6379
#if !defined(__wasm__)
6480
inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const {
6581
return webgpu_context_.SubgroupMatrixConfigs();
6682
}
6783
#endif
6884

6985
//
70-
// Get the kernel context.
86+
// Get whether graph capture is enabled.
7187
//
72-
inline OpKernelContext& KernelContext() {
73-
return kernel_context_;
88+
inline bool IsGraphCaptureEnabled() const {
89+
return ep_.IsGraphCaptureEnabled();
7490
}
7591

7692
//
7793
// Get the logger.
7894
//
7995
inline const logging::Logger& Logger() const {
80-
return kernel_context_.Logger();
96+
return *ep_.GetLogger();
97+
}
98+
99+
//
100+
// Run a compute shader program.
101+
//
102+
inline Status RunProgram(const ProgramBase& program) {
103+
return webgpu_context_.Run(*this, program);
104+
}
105+
106+
protected:
107+
WebGpuContext& webgpu_context_;
108+
const WebGpuExecutionProvider& ep_;
109+
const OpKernel& op_kernel_;
110+
};
111+
112+
//
113+
// Class ComputeContext provides all information a `ComputeContextBase` provides, and also
114+
// access to `OpKernelContext` for input and output tensors.
115+
class ComputeContext final : public ComputeContextBase {
116+
public:
117+
ComputeContext(WebGpuContext& webgpu_context,
118+
const WebGpuExecutionProvider& ep,
119+
const OpKernel& op_kernel,
120+
OpKernelContext& kernel_context);
121+
122+
~ComputeContext() = default;
123+
124+
//
125+
// Get the kernel context.
126+
//
127+
inline OpKernelContext& KernelContext() {
128+
return kernel_context_;
81129
}
82130

83131
//
@@ -145,18 +193,8 @@ class ComputeContext final {
145193
return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst);
146194
}
147195

148-
//
149-
// Run a compute shader program.
150-
//
151-
inline Status RunProgram(const ProgramBase& program) {
152-
return webgpu_context_.Run(*this, program);
153-
}
154-
155196
private:
156-
WebGpuContext& webgpu_context_;
157197
OpKernelContext& kernel_context_;
158-
const OpKernel& op_kernel_;
159-
const WebGpuExecutionProvider& ep_;
160198
};
161199

162200
} // namespace webgpu

onnxruntime/core/providers/webgpu/nn/conv.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,58 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
217217
return context.RunProgram(conv2d_mm_program);
218218
}
219219

220+
template <bool is_channels_last, bool is_fused>
221+
Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& context,
222+
const Tensor& tensor,
223+
int input_idx,
224+
AllocatorPtr alloc,
225+
/*out*/ bool& is_packed,
226+
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
227+
is_packed = false;
228+
229+
printf("=== PrePack called for Conv kernel\n");
230+
printf(" input_idx: %d\n", input_idx);
231+
printf(" tensor shape: ");
232+
for (size_t i = 0; i < tensor.Shape().NumDimensions(); ++i) {
233+
printf("%lld ", tensor.Shape()[i]);
234+
}
235+
printf("\n");
236+
printf(" tensor location: %s\n", tensor.Location().ToString().c_str());
237+
printf(" allocator info: %s\n", alloc->Info().ToString().c_str());
238+
239+
if constexpr (is_channels_last) {
240+
if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) {
241+
// only deal with 4D NHWC weights
242+
243+
// Tensors and allocator should both be on GPU
244+
ORT_ENFORCE(tensor.Location().device.Type() == OrtDevice::GPU &&
245+
tensor.Location().mem_type == OrtMemType::OrtMemTypeDefault &&
246+
tensor.Location().name == WEBGPU_BUFFER,
247+
"Tensor must be a WebGPU buffer.");
248+
ORT_ENFORCE(alloc->Info().device.Type() == OrtDevice::GPU &&
249+
alloc->Info().name == WEBGPU_BUFFER,
250+
"Allocator must be for WebGPU.");
251+
252+
// Step.1 - calculate transposed weight shape
253+
TensorShape transposed_kernel_shape{1, 2, 3, 4}; // placeholder
254+
255+
// Step.2 - create transposed weight tensor
256+
transposed_kernel_ = std::make_unique<Tensor>(tensor.DataType(), transposed_kernel_shape, alloc);
257+
258+
// Step.3 - do transpose
259+
size_t perm[] = {2, 3, 1, 0};
260+
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context,
261+
perm,
262+
tensor,
263+
*transposed_kernel_));
264+
265+
is_packed = true; // set this flag to true so that ORT will release the initializer tensor
266+
}
267+
}
268+
269+
return Status::OK();
270+
}
271+
220272
// Explicit template instantiation for FusedConv
221273
template class Conv<false, false>;
222274
template class Conv<false, true>;

onnxruntime/core/providers/webgpu/nn/conv.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@ class Conv : public WebGpuKernel {
2323
}
2424
Status ComputeInternal(ComputeContext& context) const override;
2525

26+
Status PrePackInternal(ComputeContextBase& context, const Tensor& tensor, int input_idx, AllocatorPtr alloc,
27+
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override;
28+
2629
protected:
2730
ConvAttributes conv_attrs_;
2831
Activation activation_;
32+
std::unique_ptr<Tensor> transposed_kernel_; // should only has value when `is_initializer` AND `is_4D` AND `is_NHWC`
2933
};
3034

3135
Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector<size_t>& perm);

onnxruntime/core/providers/webgpu/tensor/transpose.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
108108
return Status::OK();
109109
}
110110

111-
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context,
111+
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context,
112112
gsl::span<const size_t> permutations,
113113
const Tensor& input, Tensor& output) {
114114
const auto& input_shape = input.Shape();

onnxruntime/core/providers/webgpu/tensor/transpose.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase {
1616
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
1717
}
1818
Status ComputeInternal(ComputeContext& context) const override;
19-
static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span<const size_t> permutations, const Tensor& input, Tensor& output);
19+
static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span<const size_t> permutations, const Tensor& input, Tensor& output);
2020

2121
constexpr static uint32_t TILE_SIZE = 16;
2222
};

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ Status WebGpuContext::Wait(wgpu::Future f) {
178178
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status));
179179
}
180180

181-
Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
181+
Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) {
182182
const auto& inputs = program.Inputs();
183183
const auto& outputs = program.Outputs();
184184

@@ -288,8 +288,8 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
288288
auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch);
289289

290290
if (is_profiling_) {
291-
PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(),
292-
context.KernelContext().GetOpType(),
291+
PendingKernelInfo pending_kernel_info(context.NodeName(),
292+
context.OpType(),
293293
program.Name(),
294294
key,
295295
inputs,
@@ -442,7 +442,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
442442
const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field;
443443

444444
WGPUBuffer uniform_buffer = nullptr;
445-
const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context);
445+
const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context);
446446
if (uniform_buffer_total_size > 0) {
447447
std::vector<uint8_t> uniform_data_buffer(uniform_buffer_total_size);
448448

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Tensor;
2121

2222
namespace webgpu {
2323
class WebGpuContext;
24-
class ComputeContext;
24+
class ComputeContextBase;
2525
class ProgramBase;
2626

2727
// Definition for CapturedCommandInfo in the webgpu namespace
@@ -168,7 +168,7 @@ class WebGpuContext final {
168168
//
169169
Status PopErrorScope();
170170

171-
Status Run(ComputeContext& context, const ProgramBase& program);
171+
Status Run(ComputeContextBase& context, const ProgramBase& program);
172172
void OnRunEnd();
173173

174174
private:

onnxruntime/core/providers/webgpu/webgpu_kernel.cc

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ WebGpuKernel::WebGpuKernel(const OpKernelInfo& info)
1616

1717
Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const {
1818
WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId());
19-
ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context};
19+
ComputeContext context{webgpu_context,
20+
ep_,
21+
*this,
22+
*p_op_kernel_context};
2023

2124
if (webgpu_context.ValidationMode() >= ValidationMode::Full) {
2225
webgpu_context.PushErrorScope();
@@ -31,5 +34,25 @@ Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const {
3134
return s;
3235
}
3336

37+
Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
38+
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) {
39+
WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId());
40+
ComputeContextBase context{webgpu_context,
41+
ep_,
42+
*this};
43+
44+
return PrePackInternal(context, tensor, input_idx, alloc, is_packed, prepacked_weights);
45+
}
46+
47+
Status WebGpuKernel::PrePackInternal(ComputeContextBase& /*context*/,
48+
const Tensor& /*tensor*/,
49+
int /*input_idx*/,
50+
AllocatorPtr /*alloc*/,
51+
/*out*/ bool& is_packed,
52+
/*out*/ PrePackedWeights* /*prepacked_weights*/) {
53+
is_packed = false;
54+
return Status::OK();
55+
}
56+
3457
} // namespace webgpu
3558
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_kernel.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ class WebGpuKernel : public OpKernel {
2323

2424
virtual Status ComputeInternal(ComputeContext& context) const = 0;
2525

26+
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
27+
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override;
28+
29+
virtual Status PrePackInternal(ComputeContextBase& context,
30+
const Tensor& tensor,
31+
int input_idx,
32+
AllocatorPtr alloc,
33+
/*out*/ bool& is_packed,
34+
/*out*/ PrePackedWeights* prepacked_weights);
35+
2636
private:
2737
const WebGpuExecutionProvider& ep_;
2838
};

0 commit comments

Comments
 (0)