|
12 | 12 | #include "core/session/ort_apis.h" |
13 | 13 |
|
14 | 14 | #include "core/providers/webgpu/webgpu_provider_options.h" |
| 15 | +#include "core/providers/webgpu/data_transfer.h" |
15 | 16 | using namespace onnxruntime::webgpu::options; |
16 | 17 |
|
17 | 18 | namespace onnxruntime { |
| 19 | +// Helper to get default context config, buffer cache config, backend type, and enable_pix_capture |
| 20 | +struct WebGpuContextParams { |
| 21 | + webgpu::WebGpuContextConfig context_config; |
| 22 | + webgpu::WebGpuBufferCacheConfig buffer_cache_config; |
| 23 | + int backend_type; |
| 24 | + bool enable_pix_capture; |
| 25 | +}; |
| 26 | + |
| 27 | +static WebGpuContextParams GetDefaultWebGpuContextParams(int context_id) { |
| 28 | + WebGpuContextParams params; |
| 29 | + params.context_config.context_id = context_id; |
| 30 | + params.context_config.instance = nullptr; |
| 31 | + params.context_config.device = nullptr; |
| 32 | + params.context_config.dawn_proc_table = nullptr; |
| 33 | + params.context_config.validation_mode = webgpu::ValidationMode::Basic; |
| 34 | + params.context_config.preserve_device = false; |
| 35 | + params.context_config.max_storage_buffer_binding_size = 0; |
| 36 | + params.context_config.power_preference = static_cast<int>(WGPUPowerPreference_HighPerformance); |
| 37 | + |
| 38 | + params.buffer_cache_config.storage.mode = webgpu::BufferCacheMode::Bucket; |
| 39 | + params.buffer_cache_config.uniform.mode = webgpu::BufferCacheMode::Simple; |
| 40 | + params.buffer_cache_config.query_resolve.mode = webgpu::BufferCacheMode::Disabled; |
| 41 | + params.buffer_cache_config.default_entry.mode = webgpu::BufferCacheMode::Disabled; |
| 42 | + |
| 43 | +#ifdef _WIN32 |
| 44 | +#if defined(DAWN_ENABLE_D3D12) |
| 45 | + params.backend_type = static_cast<int>(WGPUBackendType_D3D12); |
| 46 | +#elif defined(DAWN_ENABLE_VULKAN) |
| 47 | + params.backend_type = static_cast<int>(WGPUBackendType_Vulkan); |
| 48 | +#else |
| 49 | + params.backend_type = static_cast<int>(WGPUBackendType_D3D12); |
| 50 | +#endif |
| 51 | +#else |
| 52 | + params.backend_type = 0; |
| 53 | +#endif |
| 54 | + params.enable_pix_capture = false; |
| 55 | + return params; |
| 56 | +} |
18 | 57 |
|
19 | 58 | struct WebGpuProviderFactory : IExecutionProviderFactory { |
20 | 59 | WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config) |
@@ -291,4 +330,73 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create( |
291 | 330 | return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config)); |
292 | 331 | } |
293 | 332 |
|
| 333 | +// WebGPU DataTransfer implementation wrapper for the C API |
| 334 | +struct WebGpuDataTransferImpl : OrtDataTransferImpl { |
| 335 | + WebGpuDataTransferImpl(const OrtApi& ort_api_in, webgpu::BufferManager& buffer_manager) |
| 336 | + : ort_api{ort_api_in}, |
| 337 | + ep_api{*ort_api_in.GetEpApi()}, |
| 338 | + data_transfer_{buffer_manager} { |
| 339 | + ort_version_supported = ORT_API_VERSION; |
| 340 | + CanCopy = CanCopyImpl; |
| 341 | + CopyTensors = CopyTensorsImpl; |
| 342 | + Release = ReleaseImpl; |
| 343 | + } |
| 344 | + |
| 345 | + static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr, |
| 346 | + const OrtMemoryDevice* src_memory_device, |
| 347 | + const OrtMemoryDevice* dst_memory_device) noexcept { |
| 348 | + const auto& impl = *static_cast<const WebGpuDataTransferImpl*>(this_ptr); |
| 349 | + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); |
| 350 | + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); |
| 351 | + |
| 352 | + // WebGPU supports GPU<->GPU, GPU<->CPU copies |
| 353 | + return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) || |
| 354 | + (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) || |
| 355 | + (src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU); |
| 356 | + } |
| 357 | + |
| 358 | + static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr, |
| 359 | + const OrtValue** src_tensors, |
| 360 | + OrtValue** dst_tensors, |
| 361 | + OrtSyncStream** /*streams*/, |
| 362 | + size_t num_tensors) noexcept { |
| 363 | + auto& impl = *static_cast<WebGpuDataTransferImpl*>(this_ptr); |
| 364 | + for (size_t idx = 0; idx < num_tensors; ++idx) { |
| 365 | + const OrtValue* src_tensor = src_tensors[idx]; |
| 366 | + OrtValue* dst_tensor = dst_tensors[idx]; |
| 367 | + auto status = impl.data_transfer_.CopyTensor(src_tensor->Get<Tensor>(), *dst_tensor->GetMutable<Tensor>()); |
| 368 | + if (!status.IsOK()) { |
| 369 | + // Convert common::Status to OrtStatus |
| 370 | + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str()); |
| 371 | + } |
| 372 | + } |
| 373 | + return nullptr; |
| 374 | + } |
| 375 | + |
| 376 | + static void ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept { |
| 377 | + delete static_cast<WebGpuDataTransferImpl*>(this_ptr); |
| 378 | + } |
| 379 | + |
| 380 | + const OrtApi& ort_api; |
| 381 | + const OrtEpApi& ep_api; |
| 382 | + webgpu::DataTransfer data_transfer_; |
| 383 | +}; |
| 384 | + |
| 385 | +OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id) { |
| 386 | + webgpu::WebGpuContext* context_ptr = nullptr; |
| 387 | + try { |
| 388 | + context_ptr = &webgpu::WebGpuContextFactory::GetContext(context_id); |
| 389 | + } catch (...) { |
| 390 | + // Context doesn't exist, create a default one using shared helper |
| 391 | + WebGpuContextParams params = GetDefaultWebGpuContextParams(context_id); |
| 392 | + context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config); |
| 393 | + context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture); |
| 394 | + } |
| 395 | + if (context_ptr) { |
| 396 | + return new WebGpuDataTransferImpl(*OrtApis::GetApi(ORT_API_VERSION), context_ptr->BufferManager()); |
| 397 | + } |
| 398 | + |
| 399 | + return nullptr; |
| 400 | +} |
| 401 | + |
294 | 402 | } // namespace onnxruntime |
0 commit comments