@@ -66,6 +66,21 @@ AvoidCudaGraphCaptureGuard::~AvoidCudaGraphCaptureGuard() {
6666 (void )cudaThreadExchangeStreamCaptureMode (&mode_);
6767}
6868
69+ CudaDeviceGuard::CudaDeviceGuard (int deviceId) : deviceId_(deviceId), origDeviceId_(-1 ) {
70+ if (deviceId_ >= 0 ) {
71+ MSCCLPP_CUDATHROW (cudaGetDevice (&origDeviceId_));
72+ if (origDeviceId_ != deviceId_) {
73+ MSCCLPP_CUDATHROW (cudaSetDevice (deviceId_));
74+ }
75+ }
76+ }
77+
78+ CudaDeviceGuard::~CudaDeviceGuard () {
79+ if (deviceId_ >= 0 && origDeviceId_ >= 0 && origDeviceId_ != deviceId_) {
80+ (void )cudaSetDevice (origDeviceId_);
81+ }
82+ }
83+
6984CudaStreamWithFlags::CudaStreamWithFlags () : stream_(nullptr ) { MSCCLPP_CUDATHROW (cudaGetDevice (&deviceId_)); }
7085
7186CudaStreamWithFlags::CudaStreamWithFlags (unsigned int flags) {
@@ -79,11 +94,8 @@ CudaStreamWithFlags::~CudaStreamWithFlags() {
7994
8095void CudaStreamWithFlags::set (unsigned int flags) {
8196 if (!empty ()) throw Error (" CudaStreamWithFlags already set" , ErrorCode::InvalidUsage);
82- int originalDeviceId;
83- MSCCLPP_CUDATHROW (cudaGetDevice (&originalDeviceId)); // Save the current device
84- MSCCLPP_CUDATHROW (cudaSetDevice (deviceId_));
97+ CudaDeviceGuard deviceGuard (deviceId_);
8598 MSCCLPP_CUDATHROW (cudaStreamCreateWithFlags (&stream_, flags));
86- MSCCLPP_CUDATHROW (cudaSetDevice (originalDeviceId)); // Restore the original device
8799}
88100
89101bool CudaStreamWithFlags::empty () const { return stream_ == nullptr ; }
@@ -123,6 +135,18 @@ namespace detail {
123135
124136CUmemAllocationHandleType nvlsCompatibleMemHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
125137
138+ int gpuIdFromAddress (void * ptr) {
139+ int deviceId;
140+ auto res = cuPointerGetAttribute (&deviceId, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, reinterpret_cast <CUdeviceptr>(ptr));
141+ if (res == CUDA_ERROR_INVALID_VALUE) {
142+ // not a GPU address
143+ return -1 ;
144+ } else {
145+ MSCCLPP_CUTHROW (res);
146+ }
147+ return deviceId;
148+ }
149+
126150// / set memory access permission to read-write
127151// / @param base Base memory pointer.
128152// / @param size Size of the memory.
0 commit comments