Skip to content

Commit 9275dad

Browse files
committed
fix CI errors
1 parent f79a393 commit 9275dad

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,11 @@ WebGpuContext& WebGpuContextFactory::GetContext(int context_id) {
995995
return *it->second.context;
996996
}
997997

998+
bool WebGpuContextFactory::HasContext(int context_id) {
999+
std::lock_guard<std::mutex> lock(mutex_);
1000+
return contexts_.find(context_id) != contexts_.end();
1001+
}
1002+
9981003
void WebGpuContextFactory::ReleaseContext(int context_id) {
9991004
std::lock_guard<std::mutex> lock(mutex_);
10001005

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class WebGpuContextFactory {
6464

6565
static WebGpuContext& CreateContext(const WebGpuContextConfig& config);
6666
static WebGpuContext& GetContext(int context_id);
67+
static bool HasContext(int context_id);
6768

6869
static void ReleaseContext(int context_id);
6970

onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -384,19 +384,14 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl {
384384

385385
OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id) {
386386
webgpu::WebGpuContext* context_ptr = nullptr;
387-
try {
387+
if (webgpu::WebGpuContextFactory::HasContext(context_id)) {
388388
context_ptr = &webgpu::WebGpuContextFactory::GetContext(context_id);
389-
} catch (...) {
390-
// Context doesn't exist, create a default one using shared helper
389+
} else {
391390
WebGpuContextParams params = GetDefaultWebGpuContextParams(context_id);
392391
context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config);
393392
context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture);
394393
}
395-
if (context_ptr) {
396-
return new WebGpuDataTransferImpl(*OrtApis::GetApi(ORT_API_VERSION), context_ptr->BufferManager());
397-
}
398-
399-
return nullptr;
394+
return new WebGpuDataTransferImpl(*OrtApis::GetApi(ORT_API_VERSION), context_ptr->BufferManager());
400395
}
401396

402397
} // namespace onnxruntime

0 commit comments

Comments
 (0)