Skip to content

Commit ab442e1

Browse files
committed
Address code review comments
1 parent 27b7626 commit ab442e1

File tree

10 files changed

+112
-58
lines changed

10 files changed

+112
-58
lines changed

onnxruntime/core/framework/allocator_utils.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) {
5252
if (info.use_stream_aware_arena) {
5353
#ifdef ORT_ENABLE_STREAM
5454
return AllocatorPtr(
55-
std::make_unique<StreamAwareArena>(std::move(device_allocator),
56-
max_mem,
57-
arena_extend_str,
58-
initial_chunk_size_bytes,
59-
max_dead_bytes_per_chunk,
60-
initial_growth_chunk_size_bytes));
55+
std::make_unique<StreamAwareBFCArena>(std::move(device_allocator),
56+
max_mem,
57+
arena_extend_str,
58+
initial_chunk_size_bytes,
59+
max_dead_bytes_per_chunk,
60+
initial_growth_chunk_size_bytes));
6161
#else
62-
ORT_THROW("StreamAwareArena should be transparent to minimal build.");
62+
ORT_THROW("StreamAwareBFCArena should be transparent to minimal build.");
6363
#endif
6464
} else {
6565
return AllocatorPtr(

onnxruntime/core/framework/bfc_arena.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -826,13 +826,13 @@ void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_fla
826826
}
827827
}
828828

829-
StreamAwareArena::StreamAwareArena(std::unique_ptr<IAllocator> resource_allocator,
830-
size_t total_memory,
831-
ArenaExtendStrategy arena_extend_strategy,
832-
int initial_chunk_size_bytes,
833-
int max_dead_bytes_per_chunk,
834-
int initial_growth_chunk_size_bytes,
835-
int64_t max_power_of_two_extend_bytes)
829+
StreamAwareBFCArena::StreamAwareBFCArena(std::unique_ptr<IAllocator> resource_allocator,
830+
size_t total_memory,
831+
ArenaExtendStrategy arena_extend_strategy,
832+
int initial_chunk_size_bytes,
833+
int max_dead_bytes_per_chunk,
834+
int initial_growth_chunk_size_bytes,
835+
int64_t max_power_of_two_extend_bytes)
836836
: BFCArena(std::move(resource_allocator),
837837
total_memory,
838838
arena_extend_strategy,
@@ -842,11 +842,11 @@ StreamAwareArena::StreamAwareArena(std::unique_ptr<IAllocator> resource_allocato
842842
max_power_of_two_extend_bytes) {
843843
}
844844

845-
void* StreamAwareArena::AllocOnStream(size_t size, Stream* current_stream) {
845+
void* StreamAwareBFCArena::AllocOnStream(size_t size, Stream* current_stream) {
846846
return AllocateRawInternal(size, false, current_stream);
847847
}
848848

849-
void StreamAwareArena::ReleaseStreamBuffers(Stream* stream) {
849+
void StreamAwareBFCArena::ReleaseStreamBuffers(Stream* stream) {
850850
// since chunks on target stream will be reset to nullptr, trigger coalesce to see whether we can get bigger chunk.
851851
ResetChunkOnTargetStream(stream, true);
852852
}

onnxruntime/core/framework/bfc_arena.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace onnxruntime {
4343
#endif
4444
#endif
4545

46-
class StreamAwareArena;
46+
class StreamAwareBFCArena;
4747
// A memory allocator that implements a 'best-fit with coalescing'
4848
// algorithm. This is essentially a very simple version of Doug Lea's
4949
// malloc (dlmalloc).
@@ -502,15 +502,15 @@ class BFCArena : public IArena {
502502
};
503503

504504
#ifdef ORT_ENABLE_STREAM
505-
class StreamAwareArena : public BFCArena {
505+
class StreamAwareBFCArena : public BFCArena {
506506
public:
507-
StreamAwareArena(std::unique_ptr<IAllocator> resource_allocator,
508-
size_t total_memory,
509-
ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY,
510-
int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES,
511-
int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK,
512-
int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES,
513-
int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES);
507+
StreamAwareBFCArena(std::unique_ptr<IAllocator> resource_allocator,
508+
size_t total_memory,
509+
ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY,
510+
int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES,
511+
int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK,
512+
int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES,
513+
int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES);
514514

515515
bool IsStreamAware() const override { return true; }
516516

onnxruntime/core/framework/device_stream_collection.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DeviceStreamCollectionImpl {
3737
if (it.second->Info().device == stream->GetDevice() &&
3838
it.second->Info().alloc_type == OrtArenaAllocator) {
3939
if (it.second->IsStreamAware()) {
40-
// Previously we only had one StreamAwareArena. We need to guard
40+
// Previously we only had one StreamAwareBFCArena. We need to guard
4141
// against multiple allocators now.
4242
auto* arena_alloc = IArena::SafeArenaCast(it.second.get());
4343
if (arena_alloc) {

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,15 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi
154154

155155
return CreateAllocator(default_memory_info);
156156
} else {
157-
const bool use_cuda_mempool =
157+
const bool cuda_mempool_requested =
158158
default_memory_arena_cfg != nullptr && default_memory_arena_cfg->use_cuda_mempool == 1;
159159

160+
const bool use_cuda_mempool = cuda_mempool_requested && cuda::CudaMempoolArena::IsCudaVersionSupported();
161+
162+
if (cuda_mempool_requested && !use_cuda_mempool) {
163+
LOGS_DEFAULT(WARNING) << "CUDA memory pool requested but not supported on this device/driver. Falling back to default BFCArena with CUDA allocator.";
164+
}
165+
160166
if (use_cuda_mempool) {
161167
auto device = OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device_id);
162168
auto mem_info = OrtMemoryInfo("CUDAMemPoolArena", OrtAllocatorType::OrtArenaAllocator, device, OrtMemTypeDefault);

onnxruntime/core/providers/cuda/cuda_mempool_arena.cc

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,23 @@ CudaMempoolArena::~CudaMempoolArena() {
5757
for (auto& kv : alloc_map_) {
5858
void* p = kv.first;
5959
const cudaStream_t s = kv.second.stream;
60-
(void)cudaFreeAsync(p, s); // ignore errors in destructor
60+
ORT_IGNORE_RETURN_VALUE(cudaFreeAsync(p, s)); // ignore errors in destructor
6161
}
6262

63+
// 2) Synchronize all streams we know about (those that ever held allocations).
64+
SyncAllKnownStreams_NoThrow();
65+
6366
// Now it is safe to drop our bookkeeping.
6467
alloc_map_.clear();
6568
stream_map_.clear();
6669

67-
// 2) Synchronize all streams we know about (those that ever held allocations).
68-
SyncAllKnownStreams_NoThrow();
69-
7070
// 3) Safety barrier: ensure any frees enqueued on destroyed/unknown streams are completed.
71-
(void)cudaDeviceSynchronize(); // ignore errors in destructor
71+
ORT_IGNORE_RETURN_VALUE(cudaDeviceSynchronize()); // ignore errors in destructor
7272

7373
// 4) Trim to zero and destroy the pool.
7474
if (pool_) {
75-
(void)cudaMemPoolTrimTo(pool_, 0); // best-effort
76-
(void)cudaMemPoolDestroy(pool_);
75+
ORT_IGNORE_RETURN_VALUE(cudaMemPoolTrimTo(pool_, 0)); // best-effort
76+
ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool_));
7777
pool_ = nullptr;
7878
}
7979
}
@@ -93,7 +93,7 @@ void* CudaMempoolArena::Alloc(size_t size) {
9393
<< size << " bytes at " << p << " on default stream.";
9494

9595
// In case the default stream is busy.
96-
::cudaStreamSynchronize(kDefaultStream);
96+
ORT_IGNORE_RETURN_VALUE(cudaStreamSynchronize(kDefaultStream));
9797

9898
{
9999
std::lock_guard<std::mutex> lock(mutex_);
@@ -231,8 +231,41 @@ void CudaMempoolArena::MaybeRehashLocked() {
231231
void CudaMempoolArena::SyncAllKnownStreams_NoThrow() {
232232
for (const auto& kv : stream_map_) {
233233
const cudaStream_t s = kv.first;
234-
(void)cudaStreamSynchronize(s); // ignore errors; device-wide sync follows
234+
ORT_IGNORE_RETURN_VALUE(cudaStreamSynchronize(s)); // ignore errors; device-wide sync follows
235+
}
236+
}
237+
238+
bool CudaMempoolArena::IsCudaVersionSupported() noexcept {
239+
int ort_cuda_rt_version = 0;
240+
cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version);
241+
if (cuda_status != cudaSuccess) {
242+
return false;
243+
}
244+
245+
if (ort_cuda_rt_version < 11020) {
246+
return false;
247+
}
248+
249+
int ort_cuda_driver_version = 0;
250+
cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version);
251+
if (cuda_status != cudaSuccess) {
252+
return false;
253+
}
254+
255+
if (ort_cuda_driver_version < 11020) {
256+
return false;
235257
}
258+
259+
// Check if the driver version supports the runtime version
260+
if (ort_cuda_rt_version >= 12000 && ort_cuda_driver_version < 12000) {
261+
return false;
262+
}
263+
264+
if (ort_cuda_rt_version >= 13000 && ort_cuda_driver_version < 13000) {
265+
return false;
266+
}
267+
268+
return true;
236269
}
237270

238271
} // namespace cuda

onnxruntime/core/providers/cuda/cuda_mempool_arena.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,18 @@ namespace cuda {
2929
* - Creates a **process-local** CUDA mempool for a specific device (from `OrtMemoryInfo`).
3030
* - All allocations use **`cudaMallocFromPoolAsync()`** on either the legacy default stream (0) or a
3131
* caller-provided stream. The allocation stream is recorded for ordered free.
32-
* - `Free()` and `ReleaseStreamBuffers()` enqueue **`cudaFreeAsync()`** on the recorded stream to
32+
* - `Free()` enqueue **`cudaFreeAsync()`** on the recorded stream to
3333
* respect CUDA's stream-ordered semantics.
3434
* - `Shrink()` trims the pool with **`cudaMemPoolTrimTo(bytes_to_keep)`** and right-sizes the book-keeping maps
3535
* under lock.
3636
*
3737
* ### Tuning
38-
* - `pool_release_threshold`: if non-zero, sets `cudaMemPoolAttrReleaseThreshold`. **Recommended: 1 MB.**
39-
* - `initial_pool_size_bytes`: if > 0, pre‑reserve pool capacity by setting
38+
* - `pool_release_threshold`: if non-zero, sets `cudaMemPoolAttrReleaseThreshold`. **Recommended: 1 MB.**, but
39+
* must be experimentally determined based on workload for optimal memory consumption vs performance.
4040
* `cudaMemPoolAttrReservedMemCurrent`. **Recommended: 10 MB.**
41+
* - `bytes_to_keep_on_shrink`: target size for `cudaMemPoolTrimTo()` on `Shrink()`. This is only relevant
42+
* if Shrink() is enabled. It usually costs performance, and strictly speaking is not necessary for cuda mempools
43+
* since they release memory on at synchronous points according to `pool_release_threshold`.
4144
*
4245
* ### Thread-safety
4346
* - All updates to internal maps and statistics are guarded by an internal `std::mutex`.
@@ -122,11 +125,16 @@ class CudaMempoolArena final : public IArena {
122125
// void ReleaseStreamBuffers(Stream* stream) override;
123126

124127
/**
125-
* @brief Trim the pool to `bytes_to_keep` (configured at construction) using `cudaMemPoolTrimTo()`.
128+
* @brief Trim the pool to `bytes_to_keep_on_shrink_` (configured at construction) using `cudaMemPoolTrimTo()`.
129+
* Memory still allocated is not affected. Shrink() may affect your performance and contrary to BFCArena
130+
* This allocator does not need Shrink. Cuda mempool is capable of releasing memory automatically
131+
* according to pool_release_threshold_ set at construction.
126132
* Also rehashes internal maps under lock to keep them reasonably sized.
127133
*/
128134
Status Shrink() override;
129135

136+
static bool IsCudaVersionSupported() noexcept;
137+
130138
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CudaMempoolArena);
131139

132140
private:

onnxruntime/test/framework/bfc_arena_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ struct StreamMock : public Stream {
339339

340340
#ifdef ORT_ENABLE_STREAM
341341
TEST(StreamAwareArenaTest, TwoStreamAllocation) {
342-
StreamAwareArena a(std::unique_ptr<IAllocator>(new CPUAllocator()), 1 << 30);
342+
StreamAwareBFCArena a(std::unique_ptr<IAllocator>(new CPUAllocator()), 1 << 30);
343343
CheckStats(&a, 0, 0, 0, 0);
344344

345345
OrtDevice tmp;

onnxruntime/test/framework/session_state_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
405405
// One reserve call should have been made (for allocating memory for the sole initializer in the model)
406406
ASSERT_EQ(1, alloc_stats.num_reserves);
407407

408-
// This counter comes from Reserve(). The actual call for arena based allocator went to StreamAwareArena instance
408+
// This counter comes from Reserve(). The actual call for arena based allocator went to StreamAwareBFCArena instance
409409
ASSERT_EQ(1, alloc_stats.num_allocs);
410410
}
411411
}

onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,33 @@ namespace test {
2525
static bool IsCudaMemPoolSupported() {
2626
int ort_cuda_rt_version = 0;
2727
cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version);
28-
bool version_supported = (cuda_status == cudaSuccess && ort_cuda_rt_version >= 11020);
29-
if (!version_supported) {
28+
if (cuda_status != cudaSuccess) {
3029
return false;
3130
}
32-
// Creating a cuda mempool in some pipelines fails with
33-
// CUDA failure 801: operation not supported ; GPU=0 ; hostname=af14bbb1c000000 ;
34-
// Even though CUDA version may be 12.8 possibly due to the driver.
35-
cudaMemPoolProps props{};
36-
// Pinned is not the same as pinned allocator, cudaMemLocationTypeDevice actually does not exist
37-
// even though is present in some internet docs.
38-
props.allocType = cudaMemAllocationTypePinned;
39-
props.handleTypes = cudaMemHandleTypeNone; // local to process
40-
props.location.type = cudaMemLocationTypeDevice; // Device memory
41-
props.location.id = 0; // test device 0
42-
cudaMemPool_t pool;
43-
auto cuda_error = cudaMemPoolCreate(&pool, &props);
44-
if (cuda_error != cudaSuccess) {
31+
32+
if (ort_cuda_rt_version < 11020) {
33+
return false;
34+
}
35+
36+
int ort_cuda_driver_version = 0;
37+
cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version);
38+
if (cuda_status != cudaSuccess) {
39+
return false;
40+
}
41+
42+
if (ort_cuda_driver_version < 11020) {
4543
return false;
4644
}
47-
cuda_error = cudaMemPoolDestroy(pool);
45+
46+
// Check if the driver version supports the runtime version
47+
if (ort_cuda_rt_version >= 12000 && ort_cuda_driver_version < 12000) {
48+
return false;
49+
}
50+
51+
if (ort_cuda_rt_version >= 13000 && ort_cuda_driver_version < 13000) {
52+
return false;
53+
}
54+
4855
return true;
4956
}
5057

0 commit comments

Comments
 (0)