@@ -103,28 +103,38 @@ class DIPUCUDAAllocatorProxy : public CUDAAllocator {
103103
104104#else // # DIPU_TORCH20100 or higher
105105 void beginAllocateStreamToPool (int device, cudaStream_t stream,
106- MempoolId_t mempool_id) override {}
107- void endAllocateStreamToPool (int device, cudaStream_t stream) override {}
106+ MempoolId_t mempool_id) override {
107+ DIPU_PATCH_CUDA_ALLOCATOR ();
108+ }
109+ void endAllocateStreamToPool (int device, cudaStream_t stream) override {
110+ DIPU_PATCH_CUDA_ALLOCATOR ();
111+ }
108112
109113 void recordHistory (bool enabled, CreateContextFn context_recorder,
110114 size_t alloc_trace_max_entries,
111- RecordContext when) override {}
112- void releasePool (int device, MempoolId_t mempool_id) override {}
115+ RecordContext when) override {
116+ DIPU_PATCH_CUDA_ALLOCATOR ();
117+ }
118+ void releasePool (int device, MempoolId_t mempool_id) override {
119+ DIPU_PATCH_CUDA_ALLOCATOR ();
120+ }
113121
114- void enablePeerAccess (int dev, int dev_to_access) override {}
122+ void enablePeerAccess (int dev, int dev_to_access) override {
123+ DIPU_PATCH_CUDA_ALLOCATOR ();
124+ }
115125
116126 cudaError_t memcpyAsync (void * dst, int dstDevice, const void * src,
117127 int srcDevice, size_t count, cudaStream_t stream,
118128 bool p2p_enabled) override {
119- return cudaSuccess ;
129+ DIPU_PATCH_CUDA_ALLOCATOR () ;
120130 }
121131 std::shared_ptr<AllocatorState> getCheckpointState (int device,
122132 MempoolId_t id) override {
123- return {} ;
133+ DIPU_PATCH_CUDA_ALLOCATOR () ;
124134 }
125135 CheckpointDelta setCheckpointPoolState (
126136 int device, std::shared_ptr<AllocatorState> pps) override {
127- return {} ;
137+ DIPU_PATCH_CUDA_ALLOCATOR () ;
128138 }
129139#endif
130140
0 commit comments