@@ -24,7 +24,13 @@ namespace webgpu {
2424class WebGpuContext ;
2525class BufferManager ;
2626
27- class ComputeContext final {
27+ //
28+ // Class ComputeContextBase is designed to provide basic context information
29+ // for running a compute shader program.
30+ //
31+ // An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created.
32+ //
33+ class ComputeContextBase {
2834 public:
2935 // Nested accessor class to provide controlled access to BufferManager
3036 class BufferManagerAccessor {
@@ -34,18 +40,31 @@ class ComputeContext final {
3440 friend class WebGpuContext ;
3541
3642 private:
37- static const webgpu::BufferManager& Get (const ComputeContext & context);
43+ static const webgpu::BufferManager& Get (const ComputeContextBase & context);
3844 };
3945
40- ComputeContext (OpKernelContext& kernel_context,
41- const OpKernel& op_kernel,
42- const WebGpuExecutionProvider& ep,
43- WebGpuContext& webgpu_context);
46+ ComputeContextBase (WebGpuContext& webgpu_context,
47+ const WebGpuExecutionProvider& ep,
48+ const OpKernel& op_kernel);
4449
45- ~ComputeContext () = default ;
50+ ~ComputeContextBase () = default ;
51+
52+ //
53+ // Get the node name.
54+ //
55+ inline decltype (auto ) NodeName() const {
56+ return op_kernel_.Node ().Name ();
57+ }
4658
4759 //
48- // Get various information from the context.
60+ // Get the operator type.
61+ //
62+ inline decltype (auto ) OpType() const {
63+ return op_kernel_.Node ().OpType ();
64+ }
65+
66+ //
67+ // Get various information from the WebGPU context.
4968 //
5069
5170 inline const wgpu::AdapterInfo& AdapterInfo () const {
@@ -57,27 +76,56 @@ class ComputeContext final {
5776 inline bool HasFeature (wgpu::FeatureName feature) const {
5877 return webgpu_context_.DeviceHasFeature (feature);
5978 }
60- inline bool IsGraphCaptureEnabled () const {
61- return ep_.IsGraphCaptureEnabled ();
62- }
6379#if !defined(__wasm__)
6480 inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs () const {
6581 return webgpu_context_.SubgroupMatrixConfigs ();
6682 }
6783#endif
6884
6985 //
70- // Get the kernel context .
86+ // Get whether graph capture is enabled .
7187 //
72- inline OpKernelContext& KernelContext () {
73- return kernel_context_ ;
88+ inline bool IsGraphCaptureEnabled () const {
89+ return ep_. IsGraphCaptureEnabled () ;
7490 }
7591
7692 //
7793 // Get the logger.
7894 //
7995 inline const logging::Logger& Logger () const {
80- return kernel_context_.Logger ();
96+ return *ep_.GetLogger ();
97+ }
98+
99+ //
100+ // Run a compute shader program.
101+ //
102+ inline Status RunProgram (const ProgramBase& program) {
103+ return webgpu_context_.Run (*this , program);
104+ }
105+
106+ protected:
107+ WebGpuContext& webgpu_context_;
108+ const WebGpuExecutionProvider& ep_;
109+ const OpKernel& op_kernel_;
110+ };
111+
112+ //
113+ // Class ComputeContext provides all information a `ComputeContextBase` provides, and also
114+ // access to `OpKernelContext` for input and output tensors.
115+ class ComputeContext final : public ComputeContextBase {
116+ public:
117+ ComputeContext (WebGpuContext& webgpu_context,
118+ const WebGpuExecutionProvider& ep,
119+ const OpKernel& op_kernel,
120+ OpKernelContext& kernel_context);
121+
122+ ~ComputeContext () = default ;
123+
124+ //
125+ // Get the kernel context.
126+ //
127+ inline OpKernelContext& KernelContext () {
128+ return kernel_context_;
81129 }
82130
83131 //
@@ -145,18 +193,8 @@ class ComputeContext final {
145193 return op_kernel_.Info ().GetDataTransferManager ().CopyTensor (src, dst);
146194 }
147195
148- //
149- // Run a compute shader program.
150- //
151- inline Status RunProgram (const ProgramBase& program) {
152- return webgpu_context_.Run (*this , program);
153- }
154-
155196 private:
156- WebGpuContext& webgpu_context_;
157197 OpKernelContext& kernel_context_;
158- const OpKernel& op_kernel_;
159- const WebGpuExecutionProvider& ep_;
160198};
161199
162200} // namespace webgpu
0 commit comments