Skip to content

Commit 42d5e64

Browse files
committed
address comments
1 parent 925b620 commit 42d5e64

File tree

6 files changed

+122
-33
lines changed

6 files changed

+122
-33
lines changed

include/onnxruntime/core/session/environment.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,6 @@ class Environment {
154154
const DataTransferManager& GetDataTransferManager() const {
155155
return data_transfer_mgr_;
156156
}
157-
158-
// Register a data transfer for an execution provider with the environment's data transfer manager
159-
// This is needed for EPs like WebGPU where CopyTensors C API needs access to the data transfer
160-
Status RegisterDataTransferForEP(std::unique_ptr<IDataTransfer> data_transfer) {
161-
return data_transfer_mgr_.RegisterDataTransfer(std::move(data_transfer));
162-
}
163157
#endif // !defined(ORT_MINIMAL_BUILD)
164158

165159
// return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator

onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,48 @@
1212
#include "core/session/ort_apis.h"
1313

1414
#include "core/providers/webgpu/webgpu_provider_options.h"
15+
#include "core/providers/webgpu/data_transfer.h"
1516
using namespace onnxruntime::webgpu::options;
1617

1718
namespace onnxruntime {
19+
// Helper to get default context config, buffer cache config, backend type, and enable_pix_capture
20+
struct WebGpuContextParams {
21+
webgpu::WebGpuContextConfig context_config;
22+
webgpu::WebGpuBufferCacheConfig buffer_cache_config;
23+
int backend_type;
24+
bool enable_pix_capture;
25+
};
26+
27+
static WebGpuContextParams GetDefaultWebGpuContextParams(int context_id) {
28+
WebGpuContextParams params;
29+
params.context_config.context_id = context_id;
30+
params.context_config.instance = nullptr;
31+
params.context_config.device = nullptr;
32+
params.context_config.dawn_proc_table = nullptr;
33+
params.context_config.validation_mode = webgpu::ValidationMode::Basic;
34+
params.context_config.preserve_device = false;
35+
params.context_config.max_storage_buffer_binding_size = 0;
36+
params.context_config.power_preference = static_cast<int>(WGPUPowerPreference_HighPerformance);
37+
38+
params.buffer_cache_config.storage.mode = webgpu::BufferCacheMode::Bucket;
39+
params.buffer_cache_config.uniform.mode = webgpu::BufferCacheMode::Simple;
40+
params.buffer_cache_config.query_resolve.mode = webgpu::BufferCacheMode::Disabled;
41+
params.buffer_cache_config.default_entry.mode = webgpu::BufferCacheMode::Disabled;
42+
43+
#ifdef _WIN32
44+
#if defined(DAWN_ENABLE_D3D12)
45+
params.backend_type = static_cast<int>(WGPUBackendType_D3D12);
46+
#elif defined(DAWN_ENABLE_VULKAN)
47+
params.backend_type = static_cast<int>(WGPUBackendType_Vulkan);
48+
#else
49+
params.backend_type = static_cast<int>(WGPUBackendType_D3D12);
50+
#endif
51+
#else
52+
params.backend_type = 0;
53+
#endif
54+
params.enable_pix_capture = false;
55+
return params;
56+
}
1857

1958
struct WebGpuProviderFactory : IExecutionProviderFactory {
2059
WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config)
@@ -291,4 +330,73 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
291330
return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config));
292331
}
293332

333+
// WebGPU DataTransfer implementation wrapper for the C API
334+
struct WebGpuDataTransferImpl : OrtDataTransferImpl {
335+
WebGpuDataTransferImpl(const OrtApi& ort_api_in, webgpu::BufferManager& buffer_manager)
336+
: ort_api{ort_api_in},
337+
ep_api{*ort_api_in.GetEpApi()},
338+
data_transfer_{buffer_manager} {
339+
ort_version_supported = ORT_API_VERSION;
340+
CanCopy = CanCopyImpl;
341+
CopyTensors = CopyTensorsImpl;
342+
Release = ReleaseImpl;
343+
}
344+
345+
static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr,
346+
const OrtMemoryDevice* src_memory_device,
347+
const OrtMemoryDevice* dst_memory_device) noexcept {
348+
const auto& impl = *static_cast<const WebGpuDataTransferImpl*>(this_ptr);
349+
OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device);
350+
OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device);
351+
352+
// WebGPU supports GPU<->GPU, GPU<->CPU copies
353+
return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) ||
354+
(src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) ||
355+
(src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU);
356+
}
357+
358+
static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr,
359+
const OrtValue** src_tensors,
360+
OrtValue** dst_tensors,
361+
OrtSyncStream** /*streams*/,
362+
size_t num_tensors) noexcept {
363+
auto& impl = *static_cast<WebGpuDataTransferImpl*>(this_ptr);
364+
for (size_t idx = 0; idx < num_tensors; ++idx) {
365+
const OrtValue* src_tensor = src_tensors[idx];
366+
OrtValue* dst_tensor = dst_tensors[idx];
367+
auto status = impl.data_transfer_.CopyTensor(src_tensor->Get<Tensor>(), *dst_tensor->GetMutable<Tensor>());
368+
if (!status.IsOK()) {
369+
// Convert common::Status to OrtStatus
370+
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str());
371+
}
372+
}
373+
return nullptr;
374+
}
375+
376+
static void ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept {
377+
delete static_cast<WebGpuDataTransferImpl*>(this_ptr);
378+
}
379+
380+
const OrtApi& ort_api;
381+
const OrtEpApi& ep_api;
382+
webgpu::DataTransfer data_transfer_;
383+
};
384+
385+
OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id) {
386+
webgpu::WebGpuContext* context_ptr = nullptr;
387+
try {
388+
context_ptr = &webgpu::WebGpuContextFactory::GetContext(context_id);
389+
} catch (...) {
390+
// Context doesn't exist, create a default one using shared helper
391+
WebGpuContextParams params = GetDefaultWebGpuContextParams(context_id);
392+
context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config);
393+
context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture);
394+
}
395+
if (context_ptr) {
396+
return new WebGpuDataTransferImpl(*OrtApis::GetApi(ORT_API_VERSION), context_ptr->BufferManager());
397+
}
398+
399+
return nullptr;
400+
}
401+
294402
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@
1010

1111
#include "core/providers/webgpu/webgpu_provider_options.h"
1212

13+
struct OrtDataTransferImpl;
14+
1315
namespace onnxruntime {
1416
struct ConfigOptions;
1517

1618
struct WebGpuProviderFactoryCreator {
1719
static std::shared_ptr<IExecutionProviderFactory> Create(const ConfigOptions& config_options);
1820
};
1921

22+
// C API to create data transfer for WebGPU EP
23+
// Returns nullptr if WebGPU context (context_id=0) doesn't exist yet
24+
// Caller takes ownership of the returned OrtDataTransferImpl*
25+
OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id);
26+
2027
} // namespace onnxruntime

onnxruntime/core/session/inference_session.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -854,25 +854,10 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
854854
VLOGS(*session_logger_, 1) << "Adding execution provider of type: " << provider_type;
855855
auto p_data_xfr = p_exec_provider->GetDataTransfer();
856856
if (p_data_xfr) {
857-
// Register with session's data transfer manager
858857
auto st = data_transfer_mgr_.RegisterDataTransfer(std::move(p_data_xfr));
859858
if (!st.IsOK()) {
860859
return st;
861860
}
862-
863-
#if !defined(ORT_MINIMAL_BUILD)
864-
// For WebGPU EP, also register with environment's data transfer manager
865-
// so that CopyTensors C API can work (it only checks environment's DTM)
866-
if (provider_type == kWebGpuExecutionProvider) {
867-
auto p_data_xfr_env = p_exec_provider->GetDataTransfer();
868-
if (p_data_xfr_env) {
869-
auto st_env = const_cast<Environment&>(environment_).RegisterDataTransferForEP(std::move(p_data_xfr_env));
870-
if (!st_env.IsOK()) {
871-
LOGS(*session_logger_, WARNING) << "Failed to register WebGPU data transfer with environment: " << st_env.ErrorMessage();
872-
}
873-
}
874-
}
875-
#endif
876861
}
877862

878863
auto p_external_data_loader = p_exec_provider->GetExternalDataLoader();

onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,13 @@ OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* co
5757
return nullptr;
5858
}
5959

60-
/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of
61-
an InferenceSession.
62-
OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info,
63-
const OrtKeyValuePairs* allocator_options,
64-
OrtAllocator** allocator) noexcept override {
65-
*allocator = device_allocators[memory_info->device.Id()].get();
66-
}
67-
68-
OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override {
69-
// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors.
70-
*data_transfer = nullptr;
60+
OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept {
61+
// Call the WebGPU provider's C API to create the data transfer
62+
// This is implemented in the WebGPU provider backend which has access to WebGPU headers
63+
*data_transfer = OrtWebGpuCreateDataTransfer(0); // Use default context (context_id=0)
7164
return nullptr;
7265
}
73-
*/
66+
7467
} // namespace onnxruntime
7568

7669
#endif // USE_WEBGPU

onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class WebGpuEpFactory : public EpFactoryInternalImpl {
2929
const OrtSessionOptions* session_options,
3030
const OrtLogger* session_logger,
3131
std::unique_ptr<IExecutionProvider>* ep) noexcept override;
32+
33+
OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override;
3234
};
3335
} // namespace onnxruntime
3436

0 commit comments

Comments
 (0)