diff --git a/cpp/benchmarks/core/HashMap.cpp b/cpp/benchmarks/core/HashMap.cpp index a14ebcb9cc0..0361afcb9f5 100644 --- a/cpp/benchmarks/core/HashMap.cpp +++ b/cpp/benchmarks/core/HashMap.cpp @@ -253,8 +253,8 @@ void HashReserveInt(benchmark::State& state, class Int3 { public: - Int3() : x_(0), y_(0), z_(0) {}; - Int3(int k) : x_(k), y_(k * 2), z_(k * 4) {}; + Int3() : x_(0), y_(0), z_(0){}; + Int3(int k) : x_(k), y_(k * 2), z_(k * 4){}; bool operator==(const Int3& other) const { return x_ == other.x_ && y_ == other.y_ && z_ == other.z_; } diff --git a/cpp/open3d/core/CUDAUtils.cpp b/cpp/open3d/core/CUDAUtils.cpp index ab805bda049..168ed375423 100644 --- a/cpp/open3d/core/CUDAUtils.cpp +++ b/cpp/open3d/core/CUDAUtils.cpp @@ -11,6 +11,8 @@ #include "open3d/utility/Logging.h" #ifdef BUILD_CUDA_MODULE +#include + #include "open3d/core/MemoryManager.h" #endif @@ -141,41 +143,60 @@ static void SetDevice(int device_id) { OPEN3D_CUDA_CHECK(cudaSetDevice(device_id)); } -class CUDAStream { -public: - static CUDAStream& GetInstance() { - // The global stream state is given per thread like CUDA's internal - // device state. - static thread_local CUDAStream instance; - return instance; - } +void Synchronize(const CUDAStream& stream) { + OPEN3D_CUDA_CHECK(cudaStreamSynchronize(stream.Get())); +} - cudaStream_t Get() { return stream_; } - void Set(cudaStream_t stream) { stream_ = stream; } +#endif - static cudaStream_t Default() { return static_cast(0); } +} // namespace cuda -private: - CUDAStream() = default; - CUDAStream(const CUDAStream&) = delete; - CUDAStream& operator=(const CUDAStream&) = delete; +#ifdef BUILD_CUDA_MODULE - cudaStream_t stream_ = Default(); -}; +CUDAStream& CUDAStream::GetInstance() { + // The global stream state is given per thread like CUDA's internal + // device state. + thread_local CUDAStream instance = CUDAStream::Default(); + return instance; +} -cudaStream_t GetStream() { return CUDAStream::GetInstance().Get(); } +CUDAStream CUDAStream::CreateNew() { + CUDAStream stream; + OPEN3D_CUDA_CHECK(cudaStreamCreate(&stream.stream_)); + return stream; +} -static void SetStream(cudaStream_t stream) { - CUDAStream::GetInstance().Set(stream); +void CUDAStream::SetHostToDeviceMemcpyPolicy(CUDAMemoryCopyPolicy policy) { + OPEN3D_ASSERT(!IsDefaultStream()); + memcpy_from_host_to_device_ = policy; } -cudaStream_t GetDefaultStream() { return CUDAStream::Default(); } +CUDAMemoryCopyPolicy CUDAStream::GetHostToDeviceMemcpyPolicy() const { + return memcpy_from_host_to_device_; +} -#endif +CUDAMemoryCopyPolicy CUDAStream::GetDeviceToHostMemcpyPolicy() const { + return memcpy_from_device_to_host_; +} -} // namespace cuda +void CUDAStream::SetDeviceToHostMemcpyPolicy(CUDAMemoryCopyPolicy policy) { + OPEN3D_ASSERT(!IsDefaultStream()); + memcpy_from_device_to_host_ = policy; +} -#ifdef BUILD_CUDA_MODULE +bool CUDAStream::IsDefaultStream() const { + return stream_ == static_cast(nullptr); +} + +cudaStream_t CUDAStream::Get() const { return stream_; } + +void CUDAStream::Set(cudaStream_t stream) { stream_ = stream; } + +void CUDAStream::Destroy() { + OPEN3D_ASSERT(!IsDefaultStream()); + OPEN3D_CUDA_CHECK(cudaStreamDestroy(stream_)); + *this = CUDAStream::Default(); +} CUDAScopedDevice::CUDAScopedDevice(int device_id) : prev_device_id_(cuda::GetDevice()) { @@ -189,27 +210,22 @@ CUDAScopedDevice::CUDAScopedDevice(const Device& device) CUDAScopedDevice::~CUDAScopedDevice() { cuda::SetDevice(prev_device_id_); } -constexpr CUDAScopedStream::CreateNewStreamTag - CUDAScopedStream::CreateNewStream; - -CUDAScopedStream::CUDAScopedStream(const CreateNewStreamTag&) - : prev_stream_(cuda::GetStream()), owns_new_stream_(true) { - OPEN3D_CUDA_CHECK(cudaStreamCreate(&new_stream_)); - cuda::SetStream(new_stream_); -} - -CUDAScopedStream::CUDAScopedStream(cudaStream_t stream) - : prev_stream_(cuda::GetStream()), +CUDAScopedStream::CUDAScopedStream(CUDAStream stream, bool destroy_on_exit) + : prev_stream_(CUDAStream::GetInstance()), new_stream_(stream), - owns_new_stream_(false) { - cuda::SetStream(stream); + owns_new_stream_(destroy_on_exit) { + CUDAStream::GetInstance() = new_stream_; } CUDAScopedStream::~CUDAScopedStream() { if (owns_new_stream_) { - OPEN3D_CUDA_CHECK(cudaStreamDestroy(new_stream_)); + OPEN3D_ASSERT((prev_stream_.Get() != new_stream_.Get()) && + "CUDAScopedStream destroy_on_exit would destroy the same " + "stream which was in place before the scoped stream was " + "created."); + new_stream_.Destroy(); } - cuda::SetStream(prev_stream_); + CUDAStream::GetInstance() = prev_stream_; } CUDAState& CUDAState::GetInstance() { @@ -304,10 +320,35 @@ size_t GetCUDACurrentTotalMemSize() { namespace open3d { namespace core { +const std::unordered_set kProcessEndingErrors = { + cudaErrorAssert, + cudaErrorLaunchTimeout, + cudaErrorHardwareStackError, + cudaErrorIllegalInstruction, + cudaErrorMisalignedAddress, + cudaErrorInvalidAddressSpace, + cudaErrorInvalidPc, + cudaErrorTensorMemoryLeak, + cudaErrorMpsClientTerminated, + cudaErrorExternalDevice, + cudaErrorContained, + cudaErrorIllegalAddress, + cudaErrorLaunchFailure, + cudaErrorECCUncorrectable, + cudaErrorUnknown}; + void __OPEN3D_CUDA_CHECK(cudaError_t err, const char* file, const int line) { if (err != cudaSuccess) { - utility::LogError("{}:{} CUDA runtime error: {}", file, line, - cudaGetErrorString(err)); + if (kProcessEndingErrors.count(err)) { + utility::LogError( + "{}:{} CUDA runtime error: {}. This is a process-ending " + "error. All further operations will fail and the process " + "needs to be relaunched to be able to use CUDA.", + file, line, cudaGetErrorString(err)); + } else { + utility::LogError("{}:{} CUDA runtime error: {}", file, line, + cudaGetErrorString(err)); + } } } diff --git a/cpp/open3d/core/CUDAUtils.h b/cpp/open3d/core/CUDAUtils.h index 16a3841e2d1..20a0b5693c8 100644 --- a/cpp/open3d/core/CUDAUtils.h +++ b/cpp/open3d/core/CUDAUtils.h @@ -57,6 +57,124 @@ namespace core { #ifdef BUILD_CUDA_MODULE +/// \enum CUDAMemoryCopyPolicy +/// +/// Specifier for different behavior of memory copies between the host and +/// device. +/// +enum class CUDAMemoryCopyPolicy { + // Default. + // Ensure all memory copy operations are finished by synchronizing the CUDA + // stream on which the copy occurred. + Sync = 0, + // Asynchronous memory copies. Unmanaged. + // No memory safety at all - you are responsible for your own actions. + // There are no guaranteed about the lifetime of memory copied between the + // host and the device. If memory is freed before the copy finishes, you + // *will* have serious memory issues. + Async = 2 +}; + +/// \class CUDAStream +/// +/// An Open3D representation of a CUDA stream. +/// +class CUDAStream { +public: + static CUDAStream& GetInstance(); + + /// Creates a new CUDA stream. + /// The caller is responsible for eventually destroying the stream by + /// calling Destroy(). + static CUDAStream CreateNew(); + + /// Explicitly constructs a default stream. The default constructor could be + /// used, but this is clearer and closer to the old API. + static CUDAStream Default() { return {}; } + + /// Default constructor. Refers to the default CUDA stream. + CUDAStream() = default; + + /// Sets the behavior of memory copy operations device->host. + /// Sync by default. The default CUDA stream is implicitly synchronized with + /// every other stream. As such, it is invalid to call this function on the + /// default stream. + /// \param policy The desired behavior. + /// + /// Having non-synchronous memory + /// copy from device to host can result in memory corruption and various + /// other problems if you do not know what you are doing. Example: + /// ```cpp + /// void pokingTheBear() { + /// CUDAScopedStream scoped_stream(CUDAStream::CreateNew(), true); + /// CUDAStream::GetInstance().SetDeviceToHostMemcpyPolicy(CUDAMemoryCopyPolicy::AsyncUnmanaged); + /// Tensor foo = Tensor::Init({0.f}, "CUDA:0"); + /// Tensor foo_cpu = foo.To("CPU:0"); // launches an async copy from + /// device to cpu memory owned by foo_cpu. Until the async copy + /// completes, the memory will be uninitialized (random garbage). + /// // Any operations on foo_cpu will be undefined here, as you cannot + /// be sure the async memcpy has finished or not + /// cuda::Synchronize(CUDAStream::GetInstance()); // force a manual sync + /// // It is now safe to perform operations on foo_cpu + /// } + /// ``` + void SetDeviceToHostMemcpyPolicy(CUDAMemoryCopyPolicy policy); + + /// Returns the current value of the memory synchronization flag for + /// device->host memory copies. The default stream will always return Sync, + /// because it is implicitly synchronized. + CUDAMemoryCopyPolicy GetDeviceToHostMemcpyPolicy() const; + + /// Sets the behavior of memory copy operations host->device. + /// Sync by default. The default CUDA stream is implicitly synchronized with + /// every other stream. As such, it is invalid to call this function on the + /// default stream. + /// \param policy The desired behavior. + /// Having non-synchronous memory copy from host to device can result in + /// memory corruption and various other problems if you do not know what you + /// are doing. Example: + /// ```cpp + /// void pokingTheBear() { + /// CUDAScopedStream scoped_stream(CUDAStream::CreateNew(), true); + /// CUDAStream::GetInstance().SetHostToDeviceMemcpyPolicy(CUDAMemoryCopyPolicy::AsyncUnmanaged); + /// Tensor foo; + /// { + /// Tensor foo_cpu = Tensor::Init({-1.f}); + /// foo = foo_cpu.To("CUDA:0"); // launches async copy from foo_cpu + /// to foo + /// } + /// // fo_cpu goes out of scope, no guarantee that the data will be + /// // copied to the device memory pointed to by 'foo' before free is + /// // called on the host. CUDA may throw an illegal memory access + /// error. + /// } + /// ``` + void SetHostToDeviceMemcpyPolicy(CUDAMemoryCopyPolicy policy); + + /// Returns the current value of the memory synchronization flag for + /// host->device memory copies. The default stream will always return Sync, + /// because it is implicitly synchronized. + CUDAMemoryCopyPolicy GetHostToDeviceMemcpyPolicy() const; + + /// Returns true if this refers to the default CUDA stream. + bool IsDefaultStream() const; + + cudaStream_t Get() const; + void Set(cudaStream_t stream); + + /// Destroys the underlying CUDA stream. It is invalid to call this on the + /// default stream. After this call, this object refers to the default + /// stream. + void Destroy(); + +private: + cudaStream_t stream_ = static_cast(nullptr); + CUDAMemoryCopyPolicy memcpy_from_device_to_host_ = + CUDAMemoryCopyPolicy::Sync; + CUDAMemoryCopyPolicy memcpy_from_host_to_device_ = + CUDAMemoryCopyPolicy::Sync; +}; + /// \class CUDAScopedDevice /// /// Switch CUDA device id in the current scope. The device id will be reset @@ -135,20 +253,8 @@ class CUDAScopedDevice { /// } /// ``` class CUDAScopedStream { -private: - struct CreateNewStreamTag { - CreateNewStreamTag(const CreateNewStreamTag&) = delete; - CreateNewStreamTag& operator=(const CreateNewStreamTag&) = delete; - CreateNewStreamTag(CreateNewStreamTag&&) = delete; - CreateNewStreamTag& operator=(CreateNewStreamTag&&) = delete; - }; - public: - constexpr static CreateNewStreamTag CreateNewStream = {}; - - explicit CUDAScopedStream(const CreateNewStreamTag&); - - explicit CUDAScopedStream(cudaStream_t stream); + explicit CUDAScopedStream(CUDAStream stream, bool destroy_on_exit = false); ~CUDAScopedStream(); @@ -156,8 +262,8 @@ class CUDAScopedStream { CUDAScopedStream& operator=(const CUDAScopedStream&) = delete; private: - cudaStream_t prev_stream_; - cudaStream_t new_stream_; + CUDAStream prev_stream_; + CUDAStream new_stream_; bool owns_new_stream_ = false; }; @@ -265,8 +371,10 @@ bool SupportsMemoryPools(const Device& device); #ifdef BUILD_CUDA_MODULE int GetDevice(); -cudaStream_t GetStream(); -cudaStream_t GetDefaultStream(); + +/// Calls cudaStreamSynchronize() for the specified CUDA stream. +/// \param stream The stream to be synchronized. +void Synchronize(const CUDAStream& stream); #endif diff --git a/cpp/open3d/core/Indexer.h b/cpp/open3d/core/Indexer.h index 76f99acba5d..298997117b5 100644 --- a/cpp/open3d/core/Indexer.h +++ b/cpp/open3d/core/Indexer.h @@ -638,7 +638,7 @@ class Indexer { class IndexerIterator { public: struct Iterator { - Iterator() {}; + Iterator(){}; Iterator(const Indexer& indexer); Iterator(Iterator&& other) = default; diff --git a/cpp/open3d/core/MemoryManagerCUDA.cpp b/cpp/open3d/core/MemoryManagerCUDA.cpp index 8d15a8b8d45..44bce8535a7 100644 --- a/cpp/open3d/core/MemoryManagerCUDA.cpp +++ b/cpp/open3d/core/MemoryManagerCUDA.cpp @@ -22,7 +22,8 @@ void* MemoryManagerCUDA::Malloc(size_t byte_size, const Device& device) { #if CUDART_VERSION >= 11020 if (cuda::SupportsMemoryPools(device)) { OPEN3D_CUDA_CHECK(cudaMallocAsync(static_cast(&ptr), - byte_size, cuda::GetStream())); + byte_size, + CUDAStream::GetInstance().Get())); } else { OPEN3D_CUDA_CHECK(cudaMalloc(static_cast(&ptr), byte_size)); } @@ -43,7 +44,8 @@ void MemoryManagerCUDA::Free(void* ptr, const Device& device) { if (ptr && IsCUDAPointer(ptr, device)) { #if CUDART_VERSION >= 11020 if (cuda::SupportsMemoryPools(device)) { - OPEN3D_CUDA_CHECK(cudaFreeAsync(ptr, cuda::GetStream())); + OPEN3D_CUDA_CHECK( + cudaFreeAsync(ptr, CUDAStream::GetInstance().Get())); } else { OPEN3D_CUDA_CHECK(cudaFree(ptr)); } @@ -62,6 +64,7 @@ void MemoryManagerCUDA::Memcpy(void* dst_ptr, const void* src_ptr, const Device& src_device, size_t num_bytes) { + const CUDAStream& current_stream = CUDAStream::GetInstance(); if (dst_device.IsCUDA() && src_device.IsCPU()) { if (!IsCUDAPointer(dst_ptr, dst_device)) { utility::LogError("dst_ptr is not a CUDA pointer."); @@ -69,7 +72,12 @@ void MemoryManagerCUDA::Memcpy(void* dst_ptr, CUDAScopedDevice scoped_device(dst_device); OPEN3D_CUDA_CHECK(cudaMemcpyAsync(dst_ptr, src_ptr, num_bytes, cudaMemcpyHostToDevice, - cuda::GetStream())); + current_stream.Get())); + if (!current_stream.IsDefaultStream() && + current_stream.GetHostToDeviceMemcpyPolicy() == + CUDAMemoryCopyPolicy::Sync) { + OPEN3D_CUDA_CHECK(cudaStreamSynchronize(current_stream.Get())); + } } else if (dst_device.IsCPU() && src_device.IsCUDA()) { if (!IsCUDAPointer(src_ptr, src_device)) { utility::LogError("src_ptr is not a CUDA pointer."); @@ -77,7 +85,12 @@ void MemoryManagerCUDA::Memcpy(void* dst_ptr, CUDAScopedDevice scoped_device(src_device); OPEN3D_CUDA_CHECK(cudaMemcpyAsync(dst_ptr, src_ptr, num_bytes, cudaMemcpyDeviceToHost, - cuda::GetStream())); + current_stream.Get())); + if (!current_stream.IsDefaultStream() && + current_stream.GetDeviceToHostMemcpyPolicy() == + CUDAMemoryCopyPolicy::Sync) { + OPEN3D_CUDA_CHECK(cudaStreamSynchronize(current_stream.Get())); + } } else if (dst_device.IsCUDA() && src_device.IsCUDA()) { if (!IsCUDAPointer(dst_ptr, dst_device)) { utility::LogError("dst_ptr is not a CUDA pointer."); @@ -90,27 +103,18 @@ void MemoryManagerCUDA::Memcpy(void* dst_ptr, CUDAScopedDevice scoped_device(src_device); OPEN3D_CUDA_CHECK(cudaMemcpyAsync(dst_ptr, src_ptr, num_bytes, cudaMemcpyDeviceToDevice, - cuda::GetStream())); + current_stream.Get())); } else if (CUDAState::GetInstance().IsP2PEnabled(src_device.GetID(), dst_device.GetID())) { OPEN3D_CUDA_CHECK(cudaMemcpyPeerAsync( dst_ptr, dst_device.GetID(), src_ptr, src_device.GetID(), - num_bytes, cuda::GetStream())); + num_bytes, current_stream.Get())); } else { - void* cpu_buf = MemoryManager::Malloc(num_bytes, Device("CPU:0")); - { - CUDAScopedDevice scoped_device(src_device); - OPEN3D_CUDA_CHECK(cudaMemcpyAsync(cpu_buf, src_ptr, num_bytes, - cudaMemcpyDeviceToHost, - cuda::GetStream())); - } - { - CUDAScopedDevice scoped_device(dst_device); - OPEN3D_CUDA_CHECK(cudaMemcpyAsync(dst_ptr, cpu_buf, num_bytes, - cudaMemcpyHostToDevice, - cuda::GetStream())); - } - MemoryManager::Free(cpu_buf, Device("CPU:0")); + const Device cpu_device("CPU:0"); + void* cpu_buf = MemoryManager::Malloc(num_bytes, cpu_device); + Memcpy(cpu_buf, cpu_device, src_ptr, src_device, num_bytes); + Memcpy(dst_ptr, dst_device, cpu_buf, cpu_device, num_bytes); + MemoryManager::Free(cpu_buf, cpu_device); } } else { utility::LogError("Wrong cudaMemcpyKind."); diff --git a/cpp/open3d/core/ParallelFor.h b/cpp/open3d/core/ParallelFor.h index 9e917789947..8e1971e9038 100644 --- a/cpp/open3d/core/ParallelFor.h +++ b/cpp/open3d/core/ParallelFor.h @@ -61,8 +61,8 @@ void ParallelForCUDA_(const Device& device, int64_t n, const func_t& func) { int64_t grid_size = (n + items_per_block - 1) / items_per_block; ElementWiseKernel_ - <<>>( - n, func); + <<>>(n, func); OPEN3D_GET_LAST_CUDA_ERROR("ParallelFor failed."); } diff --git a/cpp/open3d/core/hashmap/CUDA/CUDAHashBackendBuffer.cu b/cpp/open3d/core/hashmap/CUDA/CUDAHashBackendBuffer.cu index 627095858ec..1acc74024bc 100644 --- a/cpp/open3d/core/hashmap/CUDA/CUDAHashBackendBuffer.cu +++ b/cpp/open3d/core/hashmap/CUDA/CUDAHashBackendBuffer.cu @@ -15,8 +15,10 @@ namespace open3d { namespace core { void CUDAResetHeap(Tensor &heap) { uint32_t *heap_ptr = heap.GetDataPtr(); - thrust::sequence(thrust::device, heap_ptr, heap_ptr + heap.GetLength(), 0); + thrust::sequence(thrust::cuda::par.on(CUDAStream::GetInstance().Get()), + heap_ptr, heap_ptr + heap.GetLength(), 0); OPEN3D_CUDA_CHECK(cudaGetLastError()); + cuda::Synchronize(CUDAStream::GetInstance()); } } // namespace core } // namespace open3d diff --git a/cpp/open3d/core/hashmap/CUDA/CUDAHashBackendBufferAccessor.h b/cpp/open3d/core/hashmap/CUDA/CUDAHashBackendBufferAccessor.h index eeb74bc2610..ec8f2d283ad 100644 --- a/cpp/open3d/core/hashmap/CUDA/CUDAHashBackendBufferAccessor.h +++ b/cpp/open3d/core/hashmap/CUDA/CUDAHashBackendBufferAccessor.h @@ -58,7 +58,9 @@ class CUDAHashBackendBufferAccessor { std::vector value_ptrs(n_values_); for (size_t i = 0; i < n_values_; ++i) { value_ptrs[i] = value_buffers[i].GetDataPtr(); - cudaMemset(value_ptrs[i], 0, capacity_ * value_dsizes_host[i]); + OPEN3D_CUDA_CHECK(cudaMemsetAsync( + value_ptrs[i], 0, capacity_ * value_dsizes_host[i], + core::CUDAStream::GetInstance().Get())); } values_ = static_cast( MemoryManager::Malloc(n_values_ * sizeof(uint8_t *), device)); @@ -66,7 +68,7 @@ class CUDAHashBackendBufferAccessor { n_values_ * sizeof(uint8_t *)); heap_top_ = hashmap_buffer.GetHeapTop().cuda.GetDataPtr(); - cuda::Synchronize(); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); } diff --git a/cpp/open3d/core/hashmap/CUDA/SlabHashBackend.h b/cpp/open3d/core/hashmap/CUDA/SlabHashBackend.h index dbf665c795f..e0ae01618ad 100644 --- a/cpp/open3d/core/hashmap/CUDA/SlabHashBackend.h +++ b/cpp/open3d/core/hashmap/CUDA/SlabHashBackend.h @@ -97,15 +97,17 @@ void SlabHashBackend::Find(const void* input_keys, CUDAScopedDevice scoped_device(this->device_); if (count == 0) return; - OPEN3D_CUDA_CHECK(cudaMemset(output_masks, 0, sizeof(bool) * count)); - cuda::Synchronize(); + OPEN3D_CUDA_CHECK(cudaMemsetAsync(output_masks, 0, sizeof(bool) * count, + CUDAStream::GetInstance().Get())); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); const int64_t num_blocks = (count + kThreadsPerBlock - 1) / kThreadsPerBlock; - FindKernel<<>>( + FindKernel<<>>( impl_, input_keys, output_buf_indices, output_masks, count); - cuda::Synchronize(); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); } @@ -116,8 +118,9 @@ void SlabHashBackend::Erase(const void* input_keys, CUDAScopedDevice scoped_device(this->device_); if (count == 0) return; - OPEN3D_CUDA_CHECK(cudaMemset(output_masks, 0, sizeof(bool) * count)); - cuda::Synchronize(); + OPEN3D_CUDA_CHECK(cudaMemsetAsync(output_masks, 0, sizeof(bool) * count, + CUDAStream::GetInstance().Get())); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); auto buf_indices = static_cast( MemoryManager::Malloc(sizeof(buf_index_t) * count, this->device_)); @@ -125,12 +128,12 @@ void SlabHashBackend::Erase(const void* input_keys, const int64_t num_blocks = (count + kThreadsPerBlock - 1) / kThreadsPerBlock; EraseKernelPass0<<>>( + core::CUDAStream::GetInstance().Get()>>>( impl_, input_keys, buf_indices, output_masks, count); EraseKernelPass1<<>>(impl_, buf_indices, - output_masks, count); - cuda::Synchronize(); + core::CUDAStream::GetInstance().Get()>>>( + impl_, buf_indices, output_masks, count); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); MemoryManager::Free(buf_indices, this->device_); @@ -142,18 +145,19 @@ int64_t SlabHashBackend::GetActiveIndices( CUDAScopedDevice scoped_device(this->device_); uint32_t* count = static_cast( MemoryManager::Malloc(sizeof(uint32_t), this->device_)); - OPEN3D_CUDA_CHECK(cudaMemset(count, 0, sizeof(uint32_t))); + OPEN3D_CUDA_CHECK(cudaMemsetAsync(count, 0, sizeof(uint32_t), + CUDAStream::GetInstance().Get())); - cuda::Synchronize(); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); const int64_t num_blocks = (impl_.bucket_count_ * kWarpSize + kThreadsPerBlock - 1) / kThreadsPerBlock; GetActiveIndicesKernel<<>>( + core::CUDAStream::GetInstance().Get()>>>( impl_, output_buf_indices, count); - cuda::Synchronize(); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); uint32_t ret; @@ -170,9 +174,10 @@ void SlabHashBackend::Clear() { this->buffer_->ResetHeap(); // Clear the linked list heads - OPEN3D_CUDA_CHECK(cudaMemset(impl_.bucket_list_head_, 0xFF, - sizeof(Slab) * this->bucket_count_)); - cuda::Synchronize(); + OPEN3D_CUDA_CHECK(cudaMemsetAsync(impl_.bucket_list_head_, 0xFF, + sizeof(Slab) * this->bucket_count_, + CUDAStream::GetInstance().Get())); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); // Clear the linked list nodes @@ -195,20 +200,23 @@ template std::vector SlabHashBackend::BucketSizes() const { CUDAScopedDevice scoped_device(this->device_); thrust::device_vector elems_per_bucket(impl_.bucket_count_); - thrust::fill(elems_per_bucket.begin(), elems_per_bucket.end(), 0); + thrust::fill(thrust::cuda::par.on(CUDAStream::GetInstance().Get()), + elems_per_bucket.begin(), elems_per_bucket.end(), 0); const int64_t num_blocks = (impl_.buffer_accessor_.capacity_ + kThreadsPerBlock - 1) / kThreadsPerBlock; CountElemsPerBucketKernel<<>>( + CUDAStream::GetInstance().Get()>>>( impl_, thrust::raw_pointer_cast(elems_per_bucket.data())); - cuda::Synchronize(); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); std::vector result(impl_.bucket_count_); - thrust::copy(elems_per_bucket.begin(), elems_per_bucket.end(), + thrust::copy(thrust::cuda::par.on(CUDAStream::GetInstance().Get()), + elems_per_bucket.begin(), elems_per_bucket.end(), result.begin()); + cuda::Synchronize(CUDAStream::GetInstance()); return result; } @@ -231,20 +239,26 @@ void SlabHashBackend::Insert( /// Increase heap_top to pre-allocate potential memory increment and /// avoid atomicAdd in kernel. int prev_heap_top = this->buffer_->GetHeapTopIndex(); - *thrust::device_ptr(impl_.buffer_accessor_.heap_top_) = - prev_heap_top + count; + int new_value = prev_heap_top + count; + thrust::fill_n( + thrust::cuda::par.on(CUDAStream::GetInstance().Get()), + thrust::device_pointer_cast(impl_.buffer_accessor_.heap_top_), 1, + new_value); const int64_t num_blocks = (count + kThreadsPerBlock - 1) / kThreadsPerBlock; InsertKernelPass0<<>>( + core::CUDAStream::GetInstance().Get()>>>( impl_, input_keys, output_buf_indices, prev_heap_top, count); InsertKernelPass1<<>>( + core::CUDAStream::GetInstance().Get()>>>( impl_, input_keys, output_buf_indices, output_masks, count); thrust::device_vector input_values_soa_device( - input_values_soa.begin(), input_values_soa.end()); + input_values_soa.size()); + thrust::copy(thrust::cuda::par.on(CUDAStream::GetInstance().Get()), + input_values_soa.begin(), input_values_soa.end(), + input_values_soa_device.begin()); int64_t n_values = input_values_soa.size(); const void* const* ptr_input_values_soa = @@ -253,11 +267,11 @@ void SlabHashBackend::Insert( impl_.buffer_accessor_.common_block_size_, [&]() { InsertKernelPass2 <<>>( + core::CUDAStream::GetInstance().Get()>>>( impl_, ptr_input_values_soa, output_buf_indices, output_masks, count, n_values); }); - cuda::Synchronize(); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); } @@ -279,9 +293,10 @@ void SlabHashBackend::Allocate(int64_t capacity) { // Allocate linked list heads. impl_.bucket_list_head_ = static_cast(MemoryManager::Malloc( sizeof(Slab) * this->bucket_count_, this->device_)); - OPEN3D_CUDA_CHECK(cudaMemset(impl_.bucket_list_head_, 0xFF, - sizeof(Slab) * this->bucket_count_)); - cuda::Synchronize(); + OPEN3D_CUDA_CHECK(cudaMemsetAsync(impl_.bucket_list_head_, 0xFF, + sizeof(Slab) * this->bucket_count_, + CUDAStream::GetInstance().Get())); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); impl_.Setup(this->bucket_count_, node_mgr_->impl_, buffer_accessor_); diff --git a/cpp/open3d/core/hashmap/CUDA/SlabNodeManager.h b/cpp/open3d/core/hashmap/CUDA/SlabNodeManager.h index 9668c31005a..47e1755a424 100644 --- a/cpp/open3d/core/hashmap/CUDA/SlabNodeManager.h +++ b/cpp/open3d/core/hashmap/CUDA/SlabNodeManager.h @@ -233,17 +233,19 @@ class SlabNodeManager { ~SlabNodeManager() { MemoryManager::Free(impl_.super_blocks_, device_); } void Reset() { - OPEN3D_CUDA_CHECK(cudaMemset( + OPEN3D_CUDA_CHECK(cudaMemsetAsync( impl_.super_blocks_, 0xFF, - kUIntsPerSuperBlock * kSuperBlocks * sizeof(uint32_t))); + kUIntsPerSuperBlock * kSuperBlocks * sizeof(uint32_t), + CUDAStream::GetInstance().Get())); for (uint32_t i = 0; i < kSuperBlocks; i++) { // setting bitmaps into zeros: - OPEN3D_CUDA_CHECK(cudaMemset( + OPEN3D_CUDA_CHECK(cudaMemsetAsync( impl_.super_blocks_ + i * kUIntsPerSuperBlock, 0x00, - kBlocksPerSuperBlock * kSlabsPerBlock * sizeof(uint32_t))); + kBlocksPerSuperBlock * kSlabsPerBlock * sizeof(uint32_t), + CUDAStream::GetInstance().Get())); } - cuda::Synchronize(); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); } @@ -251,23 +253,29 @@ class SlabNodeManager { const uint32_t num_super_blocks = kSuperBlocks; thrust::device_vector slabs_per_superblock(kSuperBlocks); - thrust::fill(slabs_per_superblock.begin(), slabs_per_superblock.end(), + thrust::fill(thrust::cuda::par.on(CUDAStream::GetInstance().Get()), + slabs_per_superblock.begin(), slabs_per_superblock.end(), 0); // Counting total number of allocated memory units. int num_mem_units = kBlocksPerSuperBlock * 32; int num_cuda_blocks = (num_mem_units + kThreadsPerBlock - 1) / kThreadsPerBlock; - CountSlabsPerSuperblockKernel<<>>( + CountSlabsPerSuperblockKernel<<< + num_cuda_blocks, kThreadsPerBlock, 0, + core::CUDAStream::GetInstance().Get()>>>( impl_, thrust::raw_pointer_cast(slabs_per_superblock.data())); - cuda::Synchronize(); OPEN3D_CUDA_CHECK(cudaGetLastError()); std::vector result(num_super_blocks); - thrust::copy(slabs_per_superblock.begin(), slabs_per_superblock.end(), - result.begin()); - + OPEN3D_CUDA_CHECK(cudaMemcpyAsync( + result.data(), + thrust::raw_pointer_cast(slabs_per_superblock.data()), + num_super_blocks * sizeof(int), cudaMemcpyDeviceToHost, + CUDAStream::GetInstance().Get())); + if (!CUDAStream::GetInstance().IsDefaultStream()) { + cuda::Synchronize(CUDAStream::GetInstance()); + } return result; } diff --git a/cpp/open3d/core/hashmap/CUDA/StdGPUHashBackend.h b/cpp/open3d/core/hashmap/CUDA/StdGPUHashBackend.h index d6707a6ce17..dd9021aa8d2 100644 --- a/cpp/open3d/core/hashmap/CUDA/StdGPUHashBackend.h +++ b/cpp/open3d/core/hashmap/CUDA/StdGPUHashBackend.h @@ -200,10 +200,11 @@ void StdGPUHashBackend::Find(const void* input_keys, uint32_t threads = 128; uint32_t blocks = (count + threads - 1) / threads; - STDGPUFindKernel<<>>( + STDGPUFindKernel<<>>( impl_, buffer_accessor_, static_cast(input_keys), output_buf_indices, output_masks, count); - cuda::Synchronize(this->device_); + cuda::Synchronize(CUDAStream::GetInstance()); } // Need an explicit kernel for non-const access to map @@ -244,10 +245,11 @@ void StdGPUHashBackend::Erase(const void* input_keys, buf_index_t* output_buf_indices = static_cast(toutput_buf_indices.GetDataPtr()); - STDGPUEraseKernel<<>>( + STDGPUEraseKernel<<>>( impl_, buffer_accessor_, static_cast(input_keys), output_buf_indices, output_masks, count); - cuda::Synchronize(this->device_); + cuda::Synchronize(CUDAStream::GetInstance()); } template @@ -364,24 +366,27 @@ void StdGPUHashBackend::Insert( CUDAScopedDevice scoped_device(this->device_); uint32_t threads = 128; uint32_t blocks = (count + threads - 1) / threads; + int64_t n_values = input_values_soa.size(); - thrust::device_vector input_values_soa_device( - input_values_soa.begin(), input_values_soa.end()); + thrust::device_vector input_values_soa_device(n_values); + thrust::copy(thrust::cuda::par.on(CUDAStream::GetInstance().Get()), + input_values_soa.begin(), input_values_soa.end(), + input_values_soa_device.begin()); - int64_t n_values = input_values_soa.size(); const void* const* ptr_input_values_soa = thrust::raw_pointer_cast(input_values_soa_device.data()); DISPATCH_DIVISOR_SIZE_TO_BLOCK_T( buffer_accessor_.common_block_size_, [&]() { STDGPUInsertKernel - <<>>( + <<>>( impl_, buffer_accessor_, static_cast(input_keys), ptr_input_values_soa, output_buf_indices, output_masks, count, n_values); }); - cuda::Synchronize(this->device_); + cuda::Synchronize(CUDAStream::GetInstance()); } template @@ -398,12 +403,12 @@ void StdGPUHashBackend::Allocate(int64_t capacity) { // stdgpu initializes on the default stream. Set the current stream to // ensure correct behavior. { - CUDAScopedStream scoped_stream(cuda::GetDefaultStream()); + CUDAScopedStream scoped_stream(CUDAStream::Default()); impl_ = InternalStdGPUHashBackend::createDeviceObject( this->capacity_, InternalStdGPUHashBackendAllocator(this->device_.GetID())); - cuda::Synchronize(this->device_); + cuda::Synchronize(CUDAStream::GetInstance()); } } @@ -416,7 +421,7 @@ void StdGPUHashBackend::Free() { // stdgpu initializes on the default stream. Set the current stream to // ensure correct behavior. { - CUDAScopedStream scoped_stream(cuda::GetDefaultStream()); + CUDAScopedStream scoped_stream(CUDAStream::Default()); InternalStdGPUHashBackend::destroyDeviceObject(impl_); } diff --git a/cpp/open3d/core/kernel/NonZeroCUDA.cu b/cpp/open3d/core/kernel/NonZeroCUDA.cu index c68b927581e..c4400e20be9 100644 --- a/cpp/open3d/core/kernel/NonZeroCUDA.cu +++ b/cpp/open3d/core/kernel/NonZeroCUDA.cu @@ -60,8 +60,6 @@ Tensor NonZeroCUDA(const Tensor& src) { CUDAScopedDevice scoped_device(src.GetDevice()); Tensor src_contiguous = src.Contiguous(); const int64_t num_elements = src_contiguous.NumElements(); - const int64_t num_bytes = - num_elements * src_contiguous.GetDtype().ByteSize(); thrust::counting_iterator index_first(0); thrust::counting_iterator index_last = index_first + num_elements; @@ -72,9 +70,11 @@ Tensor NonZeroCUDA(const Tensor& src) { thrust::device_ptr src_ptr( static_cast(src_contiguous.GetDataPtr())); - auto it = thrust::copy_if(index_first, index_last, src_ptr, - non_zero_indices.begin(), - NonZeroFunctor()); + auto it = thrust::copy_if( + thrust::cuda::par.on(CUDAStream::GetInstance().Get()), + index_first, index_last, src_ptr, non_zero_indices.begin(), + NonZeroFunctor()); + cuda::Synchronize(CUDAStream::GetInstance()); non_zero_indices.resize(thrust::distance(non_zero_indices.begin(), it)); }); @@ -88,13 +88,14 @@ Tensor NonZeroCUDA(const Tensor& src) { TensorIterator result_iter(result); index_last = index_first + num_non_zeros; - thrust::for_each(thrust::device, + thrust::for_each(thrust::cuda::par.on(CUDAStream::GetInstance().Get()), thrust::make_zip_iterator(thrust::make_tuple( index_first, non_zero_indices.begin())), thrust::make_zip_iterator(thrust::make_tuple( index_last, non_zero_indices.end())), FlatIndexTransformFunctor(result_iter, num_non_zeros, num_dims, shape)); + cuda::Synchronize(CUDAStream::GetInstance()); return result; } diff --git a/cpp/open3d/core/kernel/ReductionCUDA.cu b/cpp/open3d/core/kernel/ReductionCUDA.cu index e6504940c20..9f9ec5a6be4 100644 --- a/cpp/open3d/core/kernel/ReductionCUDA.cu +++ b/cpp/open3d/core/kernel/ReductionCUDA.cu @@ -959,8 +959,9 @@ private: std::make_unique(config.SemaphoreSize(), device); buffer = buffer_blob->GetDataPtr(); semaphores = semaphores_blob->GetDataPtr(); - OPEN3D_CUDA_CHECK( - cudaMemset(semaphores, 0, config.SemaphoreSize())); + OPEN3D_CUDA_CHECK(cudaMemsetAsync(semaphores, 0, + config.SemaphoreSize(), + CUDAStream::GetInstance().Get())); } OPEN3D_ASSERT(can_use_32bit_indexing); @@ -979,8 +980,8 @@ private: int shared_memory = config.SharedMemorySize(); ReduceKernel <<>>(reduce_op); - cuda::Synchronize(); + core::CUDAStream::GetInstance().Get()>>>(reduce_op); + cuda::Synchronize(CUDAStream::GetInstance()); OPEN3D_CUDA_CHECK(cudaGetLastError()); } diff --git a/cpp/open3d/core/linalg/AddMMCUDA.cpp b/cpp/open3d/core/linalg/AddMMCUDA.cpp index 122f8f178b1..7ed491044dd 100644 --- a/cpp/open3d/core/linalg/AddMMCUDA.cpp +++ b/cpp/open3d/core/linalg/AddMMCUDA.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: MIT // ---------------------------------------------------------------------------- +#include "open3d/core/CUDAUtils.h" #include "open3d/core/linalg/AddMM.h" #include "open3d/core/linalg/BlasWrapper.h" #include "open3d/core/linalg/LinalgUtils.h" @@ -29,6 +30,7 @@ void AddMMCUDA(void* A_data, Dtype dtype, const Device& device) { cublasHandle_t handle = CuBLASContext::GetInstance().GetHandle(device); + cublasSetStream(handle, CUDAStream::GetInstance().Get()); DISPATCH_LINALG_DTYPE_TO_TEMPLATE(dtype, [&]() { scalar_t alpha_ = scalar_t(alpha); scalar_t beta_ = scalar_t(beta); diff --git a/cpp/open3d/core/linalg/InverseCUDA.cpp b/cpp/open3d/core/linalg/InverseCUDA.cpp index dca871c9bc3..85f54d00057 100644 --- a/cpp/open3d/core/linalg/InverseCUDA.cpp +++ b/cpp/open3d/core/linalg/InverseCUDA.cpp @@ -6,6 +6,7 @@ // ---------------------------------------------------------------------------- #include "open3d/core/Blob.h" +#include "open3d/core/CUDAUtils.h" #include "open3d/core/linalg/Inverse.h" #include "open3d/core/linalg/LapackWrapper.h" #include "open3d/core/linalg/LinalgUtils.h" @@ -21,6 +22,7 @@ void InverseCUDA(void* A_data, const Device& device) { cusolverDnHandle_t handle = CuSolverContext::GetInstance().GetHandle(device); + cusolverDnSetStream(handle, CUDAStream::GetInstance().Get()); DISPATCH_LINALG_DTYPE_TO_TEMPLATE(dtype, [&]() { int len; diff --git a/cpp/open3d/core/linalg/LUCUDA.cpp b/cpp/open3d/core/linalg/LUCUDA.cpp index 6a5f9f89544..1ada0a20559 100644 --- a/cpp/open3d/core/linalg/LUCUDA.cpp +++ b/cpp/open3d/core/linalg/LUCUDA.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: MIT // ---------------------------------------------------------------------------- +#include "open3d/core/CUDAUtils.h" #include "open3d/core/linalg/LUImpl.h" #include "open3d/core/linalg/LapackWrapper.h" #include "open3d/core/linalg/LinalgUtils.h" @@ -20,6 +21,7 @@ void LUCUDA(void* A_data, const Device& device) { cusolverDnHandle_t handle = CuSolverContext::GetInstance().GetHandle(device); + cusolverDnSetStream(handle, CUDAStream::GetInstance().Get()); DISPATCH_LINALG_DTYPE_TO_TEMPLATE(dtype, [&]() { int len; OPEN3D_CUSOLVER_CHECK( diff --git a/cpp/open3d/core/linalg/LeastSquaresCUDA.cpp b/cpp/open3d/core/linalg/LeastSquaresCUDA.cpp index 82f59ec2c34..a3caee46eba 100644 --- a/cpp/open3d/core/linalg/LeastSquaresCUDA.cpp +++ b/cpp/open3d/core/linalg/LeastSquaresCUDA.cpp @@ -30,8 +30,10 @@ void LeastSquaresCUDA(void* A_data, const Device& device) { cusolverDnHandle_t cusolver_handle = CuSolverContext::GetInstance().GetHandle(device); + cusolverDnSetStream(cusolver_handle, CUDAStream::GetInstance().Get()); cublasHandle_t cublas_handle = CuBLASContext::GetInstance().GetHandle(device); + cublasSetStream(cublas_handle, CUDAStream::GetInstance().Get()); DISPATCH_LINALG_DTYPE_TO_TEMPLATE(dtype, [&]() { int len_geqrf, len_ormqr, len; diff --git a/cpp/open3d/core/linalg/MatmulCUDA.cpp b/cpp/open3d/core/linalg/MatmulCUDA.cpp index f5a517ef4bd..2e755ac8791 100644 --- a/cpp/open3d/core/linalg/MatmulCUDA.cpp +++ b/cpp/open3d/core/linalg/MatmulCUDA.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: MIT // ---------------------------------------------------------------------------- +#include "open3d/core/CUDAUtils.h" #include "open3d/core/linalg/BlasWrapper.h" #include "open3d/core/linalg/LinalgUtils.h" #include "open3d/core/linalg/Matmul.h" @@ -22,6 +23,7 @@ void MatmulCUDA(void* A_data, Dtype dtype, const Device& device) { cublasHandle_t handle = CuBLASContext::GetInstance().GetHandle(device); + cublasSetStream(handle, CUDAStream::GetInstance().Get()); DISPATCH_LINALG_DTYPE_TO_TEMPLATE(dtype, [&]() { scalar_t alpha = 1, beta = 0; OPEN3D_CUBLAS_CHECK( diff --git a/cpp/open3d/core/linalg/SVDCUDA.cpp b/cpp/open3d/core/linalg/SVDCUDA.cpp index a7acbf07b99..60882d5fb9e 100644 --- a/cpp/open3d/core/linalg/SVDCUDA.cpp +++ b/cpp/open3d/core/linalg/SVDCUDA.cpp @@ -6,6 +6,7 @@ // ---------------------------------------------------------------------------- #include "open3d/core/Blob.h" +#include "open3d/core/CUDAUtils.h" #include "open3d/core/linalg/LapackWrapper.h" #include "open3d/core/linalg/LinalgUtils.h" #include "open3d/core/linalg/SVD.h" @@ -24,6 +25,7 @@ void SVDCUDA(const void* A_data, const Device& device) { cusolverDnHandle_t handle = CuSolverContext::GetInstance().GetHandle(device); + cusolverDnSetStream(handle, CUDAStream::GetInstance().Get()); DISPATCH_LINALG_DTYPE_TO_TEMPLATE(dtype, [&]() { int len; diff --git a/cpp/open3d/core/linalg/SolveCUDA.cpp b/cpp/open3d/core/linalg/SolveCUDA.cpp index 1e2c656103b..09625efbc52 100644 --- a/cpp/open3d/core/linalg/SolveCUDA.cpp +++ b/cpp/open3d/core/linalg/SolveCUDA.cpp @@ -26,6 +26,7 @@ void SolveCUDA(void* A_data, const Device& device) { cusolverDnHandle_t handle = CuSolverContext::GetInstance().GetHandle(device); + cusolverDnSetStream(handle, CUDAStream::GetInstance().Get()); DISPATCH_LINALG_DTYPE_TO_TEMPLATE(dtype, [&]() { int len; diff --git a/cpp/open3d/core/nns/FixedRadiusIndex.cpp b/cpp/open3d/core/nns/FixedRadiusIndex.cpp index 38742fa3609..fb50820bb70 100644 --- a/cpp/open3d/core/nns/FixedRadiusIndex.cpp +++ b/cpp/open3d/core/nns/FixedRadiusIndex.cpp @@ -15,7 +15,7 @@ namespace open3d { namespace core { namespace nns { -FixedRadiusIndex::FixedRadiusIndex() {}; +FixedRadiusIndex::FixedRadiusIndex(){}; FixedRadiusIndex::FixedRadiusIndex(const Tensor &dataset_points, double radius) { @@ -31,7 +31,7 @@ FixedRadiusIndex::FixedRadiusIndex(const Tensor &dataset_points, SetTensorData(dataset_points, radius, index_dtype); }; -FixedRadiusIndex::~FixedRadiusIndex() {}; +FixedRadiusIndex::~FixedRadiusIndex(){}; bool FixedRadiusIndex::SetTensorData(const Tensor &dataset_points, double radius, diff --git a/cpp/open3d/core/nns/FixedRadiusSearchImpl.cuh b/cpp/open3d/core/nns/FixedRadiusSearchImpl.cuh index c4eb31c8532..c76d47d35c9 100644 --- a/cpp/open3d/core/nns/FixedRadiusSearchImpl.cuh +++ b/cpp/open3d/core/nns/FixedRadiusSearchImpl.cuh @@ -787,6 +787,12 @@ void BuildSpatialHashTableCUDA(const cudaStream_t& stream, count_tmp.first, hash_table_cell_splits, count_tmp.second, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } + inclusive_scan_temp = mem_temp.Alloc(inclusive_scan_temp.second); if (!get_temp_size) { @@ -825,7 +831,8 @@ void BuildSpatialHashTableCUDA(const cudaStream_t& stream, } template -void SortPairs(void* temp, +void SortPairs(const cudaStream_t& stream, + void* temp, size_t& temp_size, int texture_alignment, int64_t num_indices, @@ -850,7 +857,12 @@ void SortPairs(void* temp, sort_temp.first, sort_temp.second, distances_unsorted, distances_sorted, indices_unsorted, indices_sorted, num_indices, num_segments, query_neighbors_row_splits, - query_neighbors_row_splits + 1); + query_neighbors_row_splits + 1, 0, sizeof(T) * 8, stream); + // MUST synchronize non-default streams because InclusiveSum writes to the + // value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } sort_temp = mem_temp.Alloc(sort_temp.second); if (!get_temp_size) { @@ -858,7 +870,7 @@ void SortPairs(void* temp, sort_temp.first, sort_temp.second, distances_unsorted, distances_sorted, indices_unsorted, indices_sorted, num_indices, num_segments, query_neighbors_row_splits, - query_neighbors_row_splits + 1); + query_neighbors_row_splits + 1, 0, sizeof(T) * 8, stream); } mem_temp.Free(sort_temp); @@ -963,9 +975,7 @@ void FixedRadiusSearchCUDA(const cudaStream_t& stream, cudaMemcpyAsync(&last_prefix_sum_entry, query_neighbors_row_splits + num_queries, sizeof(int64_t), cudaMemcpyDeviceToHost, stream); - // wait for the async copies - while (cudaErrorNotReady == cudaStreamQuery(stream)) { /*empty*/ - } + cudaStreamSynchronize(stream); } mem_temp.Free(inclusive_scan_temp); } diff --git a/cpp/open3d/core/nns/FixedRadiusSearchOps.cu b/cpp/open3d/core/nns/FixedRadiusSearchOps.cu index 6f2a7e1df9e..693f5bb6501 100644 --- a/cpp/open3d/core/nns/FixedRadiusSearchOps.cu +++ b/cpp/open3d/core/nns/FixedRadiusSearchOps.cu @@ -25,7 +25,7 @@ void BuildSpatialHashTableCUDA(const Tensor& points, Tensor& hash_table_index, Tensor& hash_table_cell_splits) { CUDAScopedDevice scoped_device(points.GetDevice()); - const cudaStream_t stream = 0; + cudaStream_t stream = CUDAStream::GetInstance().Get(); int texture_alignment = 512; void* temp_ptr = nullptr; @@ -74,7 +74,7 @@ void FixedRadiusSearchCUDA(const Tensor& points, Tensor& neighbors_row_splits, Tensor& neighbors_distance) { CUDAScopedDevice scoped_device(points.GetDevice()); - const cudaStream_t stream = 0; + cudaStream_t stream = CUDAStream::GetInstance().Get(); int texture_alignment = 512; Device device = points.GetDevice(); @@ -135,8 +135,8 @@ void FixedRadiusSearchCUDA(const Tensor& points, Tensor distances_sorted = Tensor::Empty({num_indices}, dtype, device); // Determine temp_size for sorting - impl::SortPairs(temp_ptr, temp_size, texture_alignment, num_indices, - num_segments, + impl::SortPairs(stream, temp_ptr, temp_size, texture_alignment, + num_indices, num_segments, neighbors_row_splits.GetDataPtr(), indices_unsorted.GetDataPtr(), distances_unsorted.GetDataPtr(), @@ -147,8 +147,8 @@ void FixedRadiusSearchCUDA(const Tensor& points, temp_ptr = temp_tensor.GetDataPtr(); // Actually run the sorting. - impl::SortPairs(temp_ptr, temp_size, texture_alignment, num_indices, - num_segments, + impl::SortPairs(stream, temp_ptr, temp_size, texture_alignment, + num_indices, num_segments, neighbors_row_splits.GetDataPtr(), indices_unsorted.GetDataPtr(), distances_unsorted.GetDataPtr(), @@ -174,7 +174,7 @@ void HybridSearchCUDA(const Tensor& points, Tensor& neighbors_count, Tensor& neighbors_distance) { CUDAScopedDevice scoped_device(points.GetDevice()); - const cudaStream_t stream = 0; + cudaStream_t stream = CUDAStream::GetInstance().Get(); Device device = points.GetDevice(); diff --git a/cpp/open3d/core/nns/KnnSearchOps.cu b/cpp/open3d/core/nns/KnnSearchOps.cu index 8fbaf466b5c..371a47203bf 100644 --- a/cpp/open3d/core/nns/KnnSearchOps.cu +++ b/cpp/open3d/core/nns/KnnSearchOps.cu @@ -34,7 +34,7 @@ void KnnSearchCUDABruteForce(const Tensor& points, OUTPUT_ALLOCATOR& output_allocator, Tensor& query_neighbors_row_splits) { CUDAScopedDevice scoped_device(points.GetDevice()); - const cudaStream_t stream = cuda::GetStream(); + const cudaStream_t stream = CUDAStream::GetInstance().Get(); int num_points = points.GetShape(0); int num_queries = queries.GetShape(0); @@ -100,6 +100,7 @@ void KnnSearchCUDAOptimized(const Tensor& points, OUTPUT_ALLOCATOR& output_allocator, Tensor& query_neighbors_row_splits) { CUDAScopedDevice scoped_device(points.GetDevice()); + cudaStream_t stream = CUDAStream::GetInstance().Get(); int num_points = points.GetShape(0); int num_queries = queries.GetShape(0); int dim = points.GetShape(1); @@ -157,8 +158,6 @@ void KnnSearchCUDAOptimized(const Tensor& points, buf_distances.Slice(0, 0, num_queries_i); Tensor buf_indices_row_view = buf_indices.Slice(0, 0, num_queries_i); { - CUDAScopedStream scoped_stream(CUDAScopedStream::CreateNewStream); - cudaStream_t cur_stream = cuda::GetStream(); for (int j = 0; j < num_points; j += tile_cols) { int num_points_j = std::min(tile_cols, num_points - j); int col_j = j / tile_cols; @@ -185,7 +184,7 @@ void KnnSearchCUDAOptimized(const Tensor& points, output_allocator.NeighborsDistance_() .View({num_queries, knn}) .Slice(0, i, i + num_queries_i); - runL2SelectMin(cur_stream, temp_distances_view, + runL2SelectMin(stream, temp_distances_view, point_norms_j, out_distances_view, out_indices_view, knn, num_cols, tile_cols); @@ -193,7 +192,7 @@ void KnnSearchCUDAOptimized(const Tensor& points, query_norms_i.View({num_queries_i, 1})); } else { runL2SelectMin( - cur_stream, temp_distances_view, point_norms_j, + stream, temp_distances_view, point_norms_j, buf_distances_col_view, buf_indices_col_view, knn, num_cols, tile_cols); buf_distances_col_view.Add_( @@ -202,10 +201,10 @@ void KnnSearchCUDAOptimized(const Tensor& points, } // Write results to output tensor. if (tile_cols != num_points) { - runIncrementIndex(cur_stream, buf_indices_row_view, knn, + runIncrementIndex(stream, buf_indices_row_view, knn, tile_cols); runBlockSelectPair( - cur_stream, buf_distances_row_view.GetDataPtr(), + stream, buf_distances_row_view.GetDataPtr(), buf_indices_row_view.GetDataPtr(), distances_ptr + knn * i, indices_ptr + knn * i, false, knn, buf_distances_row_view.GetShape(1), diff --git a/cpp/open3d/core/nns/NanoFlannIndex.cpp b/cpp/open3d/core/nns/NanoFlannIndex.cpp index 353ec89a0a9..b4d39299a7f 100644 --- a/cpp/open3d/core/nns/NanoFlannIndex.cpp +++ b/cpp/open3d/core/nns/NanoFlannIndex.cpp @@ -19,7 +19,7 @@ namespace open3d { namespace core { namespace nns { -NanoFlannIndex::NanoFlannIndex() {}; +NanoFlannIndex::NanoFlannIndex(){}; NanoFlannIndex::NanoFlannIndex(const Tensor &dataset_points) { SetTensorData(dataset_points); @@ -30,7 +30,7 @@ NanoFlannIndex::NanoFlannIndex(const Tensor &dataset_points, SetTensorData(dataset_points, index_dtype); }; -NanoFlannIndex::~NanoFlannIndex() {}; +NanoFlannIndex::~NanoFlannIndex(){}; bool NanoFlannIndex::SetTensorData(const Tensor &dataset_points, const Dtype &index_dtype) { diff --git a/cpp/open3d/core/nns/NearestNeighborSearch.cpp b/cpp/open3d/core/nns/NearestNeighborSearch.cpp index bc4ac751538..ebf38ec05b3 100644 --- a/cpp/open3d/core/nns/NearestNeighborSearch.cpp +++ b/cpp/open3d/core/nns/NearestNeighborSearch.cpp @@ -13,7 +13,7 @@ namespace open3d { namespace core { namespace nns { -NearestNeighborSearch::~NearestNeighborSearch() {}; +NearestNeighborSearch::~NearestNeighborSearch(){}; bool NearestNeighborSearch::SetIndex() { nanoflann_index_.reset(new NanoFlannIndex()); diff --git a/cpp/open3d/io/PointCloudIO.h b/cpp/open3d/io/PointCloudIO.h index 0747efdd81f..a2b27971f22 100644 --- a/cpp/open3d/io/PointCloudIO.h +++ b/cpp/open3d/io/PointCloudIO.h @@ -44,7 +44,7 @@ struct ReadPointCloudOption { remove_nan_points(remove_nan_points), remove_infinite_points(remove_infinite_points), print_progress(print_progress), - update_progress(update_progress) {}; + update_progress(update_progress){}; ReadPointCloudOption(std::function up) : ReadPointCloudOption() { update_progress = up; @@ -101,7 +101,7 @@ struct WritePointCloudOption { write_ascii(write_ascii), compressed(compressed), print_progress(print_progress), - update_progress(update_progress) {}; + update_progress(update_progress){}; // for compatibility WritePointCloudOption(bool write_ascii, bool compressed = false, @@ -110,7 +110,7 @@ struct WritePointCloudOption { : write_ascii(IsAscii(write_ascii)), compressed(Compressed(compressed)), print_progress(print_progress), - update_progress(update_progress) {}; + update_progress(update_progress){}; // for compatibility WritePointCloudOption(std::string format, bool write_ascii, @@ -121,7 +121,7 @@ struct WritePointCloudOption { write_ascii(IsAscii(write_ascii)), compressed(Compressed(compressed)), print_progress(print_progress), - update_progress(update_progress) {}; + update_progress(update_progress){}; WritePointCloudOption(std::function up) : WritePointCloudOption() { update_progress = up; diff --git a/cpp/open3d/io/rpc/ConnectionBase.h b/cpp/open3d/io/rpc/ConnectionBase.h index f041d10f5e4..94d0d67b0b0 100644 --- a/cpp/open3d/io/rpc/ConnectionBase.h +++ b/cpp/open3d/io/rpc/ConnectionBase.h @@ -21,8 +21,8 @@ namespace rpc { /// Base class for all connections class ConnectionBase { public: - ConnectionBase() {}; - virtual ~ConnectionBase() {}; + ConnectionBase(){}; + virtual ~ConnectionBase(){}; /// Function for sending data wrapped in a zmq message object. virtual std::shared_ptr Send(zmq::message_t& send_msg) = 0; diff --git a/cpp/open3d/io/sensor/RGBDSensor.h b/cpp/open3d/io/sensor/RGBDSensor.h index 001d5773712..0119916828b 100644 --- a/cpp/open3d/io/sensor/RGBDSensor.h +++ b/cpp/open3d/io/sensor/RGBDSensor.h @@ -23,7 +23,7 @@ class RGBDSensor { public: RGBDSensor() {} virtual bool Connect(size_t sensor_index) = 0; - virtual ~RGBDSensor() {}; + virtual ~RGBDSensor(){}; /// Capture one frame, return an RGBDImage. /// If \p enable_align_depth_to_color is true, the depth image will be diff --git a/cpp/open3d/ml/contrib/RoiPoolKernel.cu b/cpp/open3d/ml/contrib/RoiPoolKernel.cu index 208772851c9..c07d980b032 100644 --- a/cpp/open3d/ml/contrib/RoiPoolKernel.cu +++ b/cpp/open3d/ml/contrib/RoiPoolKernel.cu @@ -325,8 +325,9 @@ void roipool3dLauncher(int batch_size, cudaFree(pts_assign); cudaFree(pts_idx); -#ifdef DEBUG - core::cuda::Synchronize(); // for using printf in kernel function +#if defined(DEBUG) && BUILD_CUDA_MODULE + core::cuda::Synchronize( + CUDAStream::GetInstance()); // for using printf in kernel function #endif } diff --git a/cpp/open3d/ml/impl/misc/InvertNeighborsList.cuh b/cpp/open3d/ml/impl/misc/InvertNeighborsList.cuh index 20123b4375f..86cb6002962 100644 --- a/cpp/open3d/ml/impl/misc/InvertNeighborsList.cuh +++ b/cpp/open3d/ml/impl/misc/InvertNeighborsList.cuh @@ -274,6 +274,12 @@ void InvertNeighborsListCUDA(const cudaStream_t& stream, tmp_neighbors_count.first, out_neighbors_row_splits + 1, out_num_queries, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } + inclusive_scan_temp = mem_temp.Alloc(inclusive_scan_temp.second); if (!get_temp_size) { diff --git a/cpp/open3d/ml/impl/misc/Voxelize.cuh b/cpp/open3d/ml/impl/misc/Voxelize.cuh index 85158f9afc5..cf9ffe1cfe1 100644 --- a/cpp/open3d/ml/impl/misc/Voxelize.cuh +++ b/cpp/open3d/ml/impl/misc/Voxelize.cuh @@ -656,6 +656,11 @@ void VoxelizeCUDA(const cudaStream_t& stream, cub::DeviceRadixSort::SortPairs( sort_pairs_temp.first, sort_pairs_temp.second, hashes_dbuf, point_indices_dbuf, num_points, 0, sizeof(int64_t) * 8, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } sort_pairs_temp = mem_temp.Alloc(sort_pairs_temp.second); if (!get_temp_size) { cub::DeviceRadixSort::SortPairs(sort_pairs_temp.first, @@ -684,6 +689,12 @@ void VoxelizeCUDA(const cudaStream_t& stream, unique_hashes.first, unique_hashes_count.first, num_voxels_mem.first, num_points, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } + encode_temp = mem_temp.Alloc(encode_temp.second); if (!get_temp_size) { cub::DeviceRunLengthEncode::Encode( @@ -700,8 +711,7 @@ void VoxelizeCUDA(const cudaStream_t& stream, hashes_dbuf.Current() + hashes.second - 1, sizeof(int64_t), cudaMemcpyDeviceToHost, stream); // wait for the async copies - while (cudaErrorNotReady == cudaStreamQuery(stream)) { /*empty*/ - } + cudaStreamSynchronize(stream); } mem_temp.Free(encode_temp); } @@ -724,6 +734,12 @@ void VoxelizeCUDA(const cudaStream_t& stream, unique_hashes_count.first, unique_hashes_count_prefix_sum.first, unique_hashes_count.second, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } + inclusive_scan_temp = mem_temp.Alloc(inclusive_scan_temp.second); if (!get_temp_size) { // We only need the prefix sum for the first num_voxels. @@ -766,6 +782,11 @@ void VoxelizeCUDA(const cudaStream_t& stream, encode_temp.first, encode_temp.second, unique_hashes_batch_id, unique_batches.first, unique_batches_count.first, num_batches_mem.first, num_voxels, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } encode_temp = mem_temp.Alloc(encode_temp.second); if (!get_temp_size) { cub::DeviceRunLengthEncode::Encode( @@ -788,7 +809,8 @@ void VoxelizeCUDA(const cudaStream_t& stream, std::pair num_voxels_per_batch = mem_temp.Alloc(batch_size); if (!get_temp_size) { - cudaMemset(num_voxels_per_batch.first, 0, batch_size * sizeof(int64_t)); + cudaMemsetAsync(num_voxels_per_batch.first, 0, + batch_size * sizeof(int64_t), stream); ComputeVoxelPerBatch(stream, num_voxels_per_batch.first, unique_batches_count.first, unique_batches.first, num_batches); @@ -808,6 +830,12 @@ void VoxelizeCUDA(const cudaStream_t& stream, num_voxels_per_batch.first, num_voxels_prefix_sum.first, num_voxels_per_batch.second, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } + inclusive_scan_temp = mem_temp.Alloc(inclusive_scan_temp.second); if (!get_temp_size) { if (num_voxels > max_voxels) { @@ -844,6 +872,12 @@ void VoxelizeCUDA(const cudaStream_t& stream, num_voxels_per_batch.first, out_batch_splits + 1, num_voxels_per_batch.second, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } + inclusive_scan_temp = mem_temp.Alloc(inclusive_scan_temp.second); if (!get_temp_size) { @@ -912,6 +946,12 @@ void VoxelizeCUDA(const cudaStream_t& stream, points_count.first, out_voxel_row_splits + 1, num_valid_voxels, stream); + // MUST synchronize non-default streams because InclusiveSum writes to + // the value we will be using + if (stream != nullptr) { + cudaStreamSynchronize(stream); + } + inclusive_scan_temp = mem_temp.Alloc(inclusive_scan_temp.second); if (!get_temp_size) { cub::DeviceScan::InclusiveSum( diff --git a/cpp/open3d/pipelines/registration/ColoredICP.h b/cpp/open3d/pipelines/registration/ColoredICP.h index e8abdfbe154..b401cdd02d9 100644 --- a/cpp/open3d/pipelines/registration/ColoredICP.h +++ b/cpp/open3d/pipelines/registration/ColoredICP.h @@ -27,7 +27,7 @@ class RegistrationResult; class TransformationEstimationForColoredICP : public TransformationEstimation { public: - ~TransformationEstimationForColoredICP() override {}; + ~TransformationEstimationForColoredICP() override{}; TransformationEstimationType GetTransformationEstimationType() const override { diff --git a/cpp/open3d/t/geometry/RGBDImage.h b/cpp/open3d/t/geometry/RGBDImage.h index ec4f9d9790d..d0db11824ee 100644 --- a/cpp/open3d/t/geometry/RGBDImage.h +++ b/cpp/open3d/t/geometry/RGBDImage.h @@ -53,7 +53,7 @@ class RGBDImage : public Geometry { return color_device; } - ~RGBDImage() override {}; + ~RGBDImage() override{}; /// Clear stored data. RGBDImage &Clear() override; diff --git a/cpp/open3d/t/geometry/kernel/NPPImage.cpp b/cpp/open3d/t/geometry/kernel/NPPImage.cpp index eb47ec9ba30..c766aa4ebe6 100644 --- a/cpp/open3d/t/geometry/kernel/NPPImage.cpp +++ b/cpp/open3d/t/geometry/kernel/NPPImage.cpp @@ -23,7 +23,7 @@ namespace npp { static NppStreamContext MakeNPPContext() { NppStreamContext context; - context.hStream = core::cuda::GetStream(); + context.hStream = core::CUDAStream::GetInstance().Get(); context.nCudaDeviceId = core::cuda::GetDevice(); cudaDeviceProp device_prop; @@ -52,8 +52,7 @@ static NppStreamContext MakeNPPContext() { // to expose this member variable. #if NPP_VERSION >= 11100 unsigned int stream_flags; - OPEN3D_CUDA_CHECK( - cudaStreamGetFlags(core::cuda::GetStream(), &stream_flags)); + OPEN3D_CUDA_CHECK(cudaStreamGetFlags(context.hStream, &stream_flags)); context.nStreamFlags = stream_flags; #endif diff --git a/cpp/open3d/t/geometry/kernel/PointCloudImpl.h b/cpp/open3d/t/geometry/kernel/PointCloudImpl.h index a6efbc0f9a5..6f33ec080d2 100644 --- a/cpp/open3d/t/geometry/kernel/PointCloudImpl.h +++ b/cpp/open3d/t/geometry/kernel/PointCloudImpl.h @@ -125,8 +125,8 @@ void UnprojectCPU int total_pts_count = (*count_ptr).load(); #endif -#ifdef __CUDACC__ - core::cuda::Synchronize(); +#if BUILD_CUDA_MODULE + core::cuda::Synchronize(core::CUDAStream::GetInstance()); #endif points = points.Slice(0, 0, total_pts_count); if (have_colors) { @@ -601,7 +601,9 @@ void EstimateCovariancesUsingHybridSearchCPU }); }); - core::cuda::Synchronize(points.GetDevice()); +#if BUILD_CUDA_MODULE + core::cuda::Synchronize(core::CUDAStream::GetInstance()); +#endif } #if defined(__CUDACC__) @@ -650,7 +652,9 @@ void EstimateCovariancesUsingRadiusSearchCPU }); }); - core::cuda::Synchronize(points.GetDevice()); +#if BUILD_CUDA_MODULE + core::cuda::Synchronize(core::CUDAStream::GetInstance()); +#endif } #if defined(__CUDACC__) @@ -703,7 +707,9 @@ void EstimateCovariancesUsingKNNSearchCPU }); }); - core::cuda::Synchronize(points.GetDevice()); +#if defined(BUILD_CUDA_MODULE) + core::cuda::Synchronize(core::CUDAStream::GetInstance()); +#endif } template @@ -1022,7 +1028,9 @@ void EstimateNormalsFromCovariancesCPU }); }); - core::cuda::Synchronize(covariances.GetDevice()); +#if defined(BUILD_CUDA_MODULE) + core::cuda::Synchronize(core::CUDAStream::GetInstance()); +#endif } template @@ -1174,7 +1182,9 @@ void EstimateColorGradientsUsingHybridSearchCPU }); }); - core::cuda::Synchronize(points.GetDevice()); +#if defined(BUILD_CUDA_MODULE) + core::cuda::Synchronize(core::CUDAStream::GetInstance()); +#endif } #if defined(__CUDACC__) @@ -1229,7 +1239,9 @@ void EstimateColorGradientsUsingKNNSearchCPU }); }); - core::cuda::Synchronize(points.GetDevice()); +#if defined(BUILD_CUDA_MODULE) + core::cuda::Synchronize(core::CUDAStream::GetInstance()); +#endif } #if defined(__CUDACC__) @@ -1284,7 +1296,9 @@ void EstimateColorGradientsUsingRadiusSearchCPU }); }); - core::cuda::Synchronize(points.GetDevice()); +#if defined(BUILD_CUDA_MODULE) + core::cuda::Synchronize(core::CUDAStream::GetInstance()); +#endif } } // namespace pointcloud diff --git a/cpp/open3d/t/geometry/kernel/VoxelBlockGridImpl.h b/cpp/open3d/t/geometry/kernel/VoxelBlockGridImpl.h index 01bd9060b9e..fbe0b5c4f0d 100644 --- a/cpp/open3d/t/geometry/kernel/VoxelBlockGridImpl.h +++ b/cpp/open3d/t/geometry/kernel/VoxelBlockGridImpl.h @@ -288,8 +288,8 @@ void IntegrateCPU *weight_ptr = weight + 1; }); -#if defined(__CUDACC__) - core::cuda::Synchronize(); +#if defined(BUILD_CUDA_MODULE) + core::cuda::Synchronize(core::CUDAStream::GetInstance()); #endif } @@ -498,8 +498,8 @@ void EstimateRangeCPU #endif }); -#if defined(__CUDACC__) - core::cuda::Synchronize(); +#if defined(BUILD_CUDA_MODULE) + core::cuda::Synchronize(core::CUDAStream::GetInstance()); #endif if (needed_frag_count != frag_count) { @@ -1026,8 +1026,8 @@ void RayCastCPU } // surface-found }); -#if defined(__CUDACC__) - core::cuda::Synchronize(); +#if defined(BUILD_CUDA_MODULE) + core::cuda::Synchronize(core::CUDAStream::GetInstance()); #endif } @@ -1284,7 +1284,7 @@ void ExtractPointCloudCPU valid_size = total_count; #if defined(BUILD_CUDA_MODULE) && defined(__CUDACC__) - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); #endif } diff --git a/cpp/open3d/t/pipelines/kernel/RGBDOdometryCUDA.cu b/cpp/open3d/t/pipelines/kernel/RGBDOdometryCUDA.cu index 8709196ee4a..86d81722977 100644 --- a/cpp/open3d/t/pipelines/kernel/RGBDOdometryCUDA.cu +++ b/cpp/open3d/t/pipelines/kernel/RGBDOdometryCUDA.cu @@ -116,12 +116,12 @@ void ComputeOdometryResultPointToPlaneCUDA( const dim3 blocks((rows * cols + kBlockSize - 1) / kBlockSize); const dim3 threads(kBlockSize); - ComputeOdometryResultPointToPlaneCUDAKernel<<>>( + ComputeOdometryResultPointToPlaneCUDAKernel<<< + blocks, threads, 0, core::CUDAStream::GetInstance().Get()>>>( source_vertex_indexer, target_vertex_indexer, target_normal_indexer, ti, global_sum_ptr, rows, cols, depth_outlier_trunc, depth_huber_delta); - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); DecodeAndSolve6x6(global_sum, delta, inlier_residual, inlier_count); } @@ -223,14 +223,14 @@ void ComputeOdometryResultIntensityCUDA( const dim3 blocks((cols * rows + kBlockSize - 1) / kBlockSize); const dim3 threads(kBlockSize); - ComputeOdometryResultIntensityCUDAKernel<<>>( + ComputeOdometryResultIntensityCUDAKernel<<< + blocks, threads, 0, core::CUDAStream::GetInstance().Get()>>>( source_depth_indexer, target_depth_indexer, source_intensity_indexer, target_intensity_indexer, target_intensity_dx_indexer, target_intensity_dy_indexer, source_vertex_indexer, ti, global_sum_ptr, rows, cols, depth_outlier_trunc, intensity_huber_delta); - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); DecodeAndSolve6x6(global_sum, delta, inlier_residual, inlier_count); } @@ -344,15 +344,15 @@ void ComputeOdometryResultHybridCUDA(const core::Tensor& source_depth, const dim3 blocks((cols * rows + kBlockSize - 1) / kBlockSize); const dim3 threads(kBlockSize); - ComputeOdometryResultHybridCUDAKernel<<>>( + ComputeOdometryResultHybridCUDAKernel<<< + blocks, threads, 0, core::CUDAStream::GetInstance().Get()>>>( source_depth_indexer, target_depth_indexer, source_intensity_indexer, target_intensity_indexer, target_depth_dx_indexer, target_depth_dy_indexer, target_intensity_dx_indexer, target_intensity_dy_indexer, source_vertex_indexer, ti, global_sum_ptr, rows, cols, depth_outlier_trunc, depth_huber_delta, intensity_huber_delta); - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); DecodeAndSolve6x6(global_sum, delta, inlier_residual, inlier_count); } @@ -422,11 +422,11 @@ void ComputeOdometryInformationMatrixCUDA(const core::Tensor& source_vertex_map, const dim3 blocks((cols * rows + kBlockSize - 1) / kBlockSize); const dim3 threads(kBlockSize); - ComputeOdometryInformationMatrixCUDAKernel<<>>( + ComputeOdometryInformationMatrixCUDAKernel<<< + blocks, threads, 0, core::CUDAStream::GetInstance().Get()>>>( source_vertex_indexer, target_vertex_indexer, ti, global_sum_ptr, rows, cols, square_dist_thr); - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); // 21 => 6x6 const core::Device host(core::Device("CPU:0")); diff --git a/cpp/open3d/t/pipelines/kernel/RegistrationCUDA.cu b/cpp/open3d/t/pipelines/kernel/RegistrationCUDA.cu index 1fe76815a2f..d9522bc9b87 100644 --- a/cpp/open3d/t/pipelines/kernel/RegistrationCUDA.cu +++ b/cpp/open3d/t/pipelines/kernel/RegistrationCUDA.cu @@ -102,7 +102,8 @@ void ComputePosePointToPlaneCUDA(const core::Tensor &source_points, kernel.type_, scalar_t, kernel.scaling_parameter_, kernel.shape_parameter_, [&]() { ComputePosePointToPlaneKernelCUDA<<< - blocks, threads, 0, core::cuda::GetStream()>>>( + blocks, threads, 0, + core::CUDAStream::GetInstance().Get()>>>( source_points.GetDataPtr(), target_points.GetDataPtr(), target_normals.GetDataPtr(), @@ -111,7 +112,7 @@ void ComputePosePointToPlaneCUDA(const core::Tensor &source_points, }); }); - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); DecodeAndSolve6x6(global_sum, pose, residual, inlier_count); } @@ -210,7 +211,8 @@ void ComputePoseColoredICPCUDA(const core::Tensor &source_points, kernel.type_, scalar_t, kernel.scaling_parameter_, kernel.shape_parameter_, [&]() { ComputePoseColoredICPKernelCUDA<<< - blocks, threads, 0, core::cuda::GetStream()>>>( + blocks, threads, 0, + core::CUDAStream::GetInstance().Get()>>>( source_points.GetDataPtr(), source_colors.GetDataPtr(), target_points.GetDataPtr(), @@ -224,7 +226,7 @@ void ComputePoseColoredICPCUDA(const core::Tensor &source_points, }); }); - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); DecodeAndSolve6x6(global_sum, pose, residual, inlier_count); } @@ -349,7 +351,7 @@ void ComputePoseDopplerICPCUDA( sqrt_lambda_doppler / static_cast(period); PreComputeForDopplerICPKernelCUDA - <<<1, 1, 0, core::cuda::GetStream()>>>( + <<<1, 1, 0, core::CUDAStream::GetInstance().Get()>>>( R_S_to_V.GetDataPtr(), r_v_to_s_in_V.GetDataPtr(), w_v_in_V.GetDataPtr(), @@ -361,7 +363,8 @@ void ComputePoseDopplerICPCUDA( kernel_geometric.scaling_parameter_, kernel_doppler.type_, kernel_doppler.scaling_parameter_, [&]() { ComputePoseDopplerICPKernelCUDA<<< - blocks, threads, 0, core::cuda::GetStream()>>>( + blocks, threads, 0, + core::CUDAStream::GetInstance().Get()>>>( source_points.GetDataPtr(), source_dopplers.GetDataPtr(), source_directions.GetDataPtr(), @@ -382,7 +385,7 @@ void ComputePoseDopplerICPCUDA( }); }); - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); DecodeAndSolve6x6(global_sum, output_pose, residual, inlier_count); } @@ -446,13 +449,13 @@ void ComputeInformationMatrixCUDA(const core::Tensor &target_points, DISPATCH_FLOAT_DTYPE_TO_TEMPLATE(dtype, [&]() { scalar_t *global_sum_ptr = global_sum.GetDataPtr(); - ComputeInformationMatrixKernelCUDA<<>>( + ComputeInformationMatrixKernelCUDA<<< + blocks, threads, 0, core::CUDAStream::GetInstance().Get()>>>( target_points.GetDataPtr(), correspondence_indices.GetDataPtr(), n, global_sum_ptr); - core::cuda::Synchronize(); + core::cuda::Synchronize(core::CUDAStream::GetInstance()); core::Tensor global_sum_cpu = global_sum.To(core::Device("CPU:0"), core::Float64); diff --git a/cpp/open3d/t/pipelines/kernel/TransformationConverter.cu b/cpp/open3d/t/pipelines/kernel/TransformationConverter.cu index 55fe8bcb14c..1c582f1dd7a 100644 --- a/cpp/open3d/t/pipelines/kernel/TransformationConverter.cu +++ b/cpp/open3d/t/pipelines/kernel/TransformationConverter.cu @@ -31,14 +31,16 @@ template <> void PoseToTransformationCUDA(float *transformation_ptr, const float *X_ptr) { PoseToTransformationKernel - <<<1, 1, 0, core::cuda::GetStream()>>>(transformation_ptr, X_ptr); + <<<1, 1, 0, core::CUDAStream::GetInstance().Get()>>>( + transformation_ptr, X_ptr); } template <> void PoseToTransformationCUDA(double *transformation_ptr, const double *X_ptr) { PoseToTransformationKernel - <<<1, 1, 0, core::cuda::GetStream()>>>(transformation_ptr, X_ptr); + <<<1, 1, 0, core::CUDAStream::GetInstance().Get()>>>( + transformation_ptr, X_ptr); } template @@ -57,14 +59,16 @@ template <> void TransformationToPoseCUDA(float *X_ptr, const float *transformation_ptr) { TransformationToPoseKernel - <<<1, 1, 0, core::cuda::GetStream()>>>(X_ptr, transformation_ptr); + <<<1, 1, 0, core::CUDAStream::GetInstance().Get()>>>( + X_ptr, transformation_ptr); } template <> void TransformationToPoseCUDA(double *X_ptr, const double *transformation_ptr) { TransformationToPoseKernel - <<<1, 1, 0, core::cuda::GetStream()>>>(X_ptr, transformation_ptr); + <<<1, 1, 0, core::CUDAStream::GetInstance().Get()>>>( + X_ptr, transformation_ptr); } } // namespace kernel diff --git a/cpp/open3d/t/pipelines/registration/TransformationEstimation.h b/cpp/open3d/t/pipelines/registration/TransformationEstimation.h index 73227014b22..ab72cc33f10 100644 --- a/cpp/open3d/t/pipelines/registration/TransformationEstimation.h +++ b/cpp/open3d/t/pipelines/registration/TransformationEstimation.h @@ -225,7 +225,7 @@ class TransformationEstimationPointToPlane : public TransformationEstimation { /// Float64, on CPU device for colored-icp method. class TransformationEstimationForColoredICP : public TransformationEstimation { public: - ~TransformationEstimationForColoredICP() override {}; + ~TransformationEstimationForColoredICP() override{}; /// \brief Constructor. /// @@ -307,7 +307,7 @@ class TransformationEstimationForColoredICP : public TransformationEstimation { /// Float64, on CPU device for DopplerICP method. class TransformationEstimationForDopplerICP : public TransformationEstimation { public: - ~TransformationEstimationForDopplerICP() override {}; + ~TransformationEstimationForDopplerICP() override{}; /// \brief Constructor. /// diff --git a/cpp/open3d/utility/Optional.h b/cpp/open3d/utility/Optional.h index 1ac46eb764b..26a6a1e4fb2 100644 --- a/cpp/open3d/utility/Optional.h +++ b/cpp/open3d/utility/Optional.h @@ -165,7 +165,7 @@ union storage_t { unsigned char dummy_; T value_; - constexpr storage_t(trivial_init_t) noexcept : dummy_() {}; + constexpr storage_t(trivial_init_t) noexcept : dummy_(){}; template constexpr storage_t(Args&&... args) @@ -179,7 +179,7 @@ union constexpr_storage_t { unsigned char dummy_; T value_; - constexpr constexpr_storage_t(trivial_init_t) noexcept : dummy_() {}; + constexpr constexpr_storage_t(trivial_init_t) noexcept : dummy_(){}; template constexpr constexpr_storage_t(Args&&... args) @@ -193,8 +193,7 @@ struct optional_base { bool init_; storage_t storage_; - constexpr optional_base() noexcept - : init_(false), storage_(trivial_init) {}; + constexpr optional_base() noexcept : init_(false), storage_(trivial_init){}; explicit constexpr optional_base(const T& v) : init_(true), storage_(v) {} @@ -225,7 +224,7 @@ struct constexpr_optional_base { constexpr_storage_t storage_; constexpr constexpr_optional_base() noexcept - : init_(false), storage_(trivial_init) {}; + : init_(false), storage_(trivial_init){}; explicit constexpr constexpr_optional_base(const T& v) : init_(true), storage_(v) {} @@ -315,8 +314,8 @@ class optional : private OptionalBase { typedef T value_type; // 20.5.5.1, constructors - constexpr optional() noexcept : OptionalBase() {}; - constexpr optional(nullopt_t) noexcept : OptionalBase() {}; + constexpr optional() noexcept : OptionalBase(){}; + constexpr optional(nullopt_t) noexcept : OptionalBase(){}; optional(const optional& rhs) : OptionalBase() { if (rhs.initialized()) { diff --git a/cpp/open3d/visualization/gui/Gui.h b/cpp/open3d/visualization/gui/Gui.h index 5eb53a15cbf..3ab40897a88 100644 --- a/cpp/open3d/visualization/gui/Gui.h +++ b/cpp/open3d/visualization/gui/Gui.h @@ -76,7 +76,7 @@ enum class FontStyle { class FontContext { public: - virtual ~FontContext() {}; + virtual ~FontContext(){}; virtual void* GetFont(FontId font_id) = 0; }; diff --git a/cpp/open3d/visualization/gui/ImguiFilamentBridge.cpp b/cpp/open3d/visualization/gui/ImguiFilamentBridge.cpp index 73d773b2cb7..21a81a77b4a 100644 --- a/cpp/open3d/visualization/gui/ImguiFilamentBridge.cpp +++ b/cpp/open3d/visualization/gui/ImguiFilamentBridge.cpp @@ -87,7 +87,7 @@ static Material* LoadMaterialTemplate(const std::string& path, Engine& engine) { class MaterialPool { public: - MaterialPool() {}; + MaterialPool(){}; MaterialPool(filament::Engine* engine, filament::Material* material_template) { diff --git a/cpp/open3d/visualization/gui/PickPointsInteractor.cpp b/cpp/open3d/visualization/gui/PickPointsInteractor.cpp index 0633e28ac67..0871ed11a5c 100644 --- a/cpp/open3d/visualization/gui/PickPointsInteractor.cpp +++ b/cpp/open3d/visualization/gui/PickPointsInteractor.cpp @@ -74,8 +74,7 @@ class SelectionIndexLookup { std::string name; size_t start_index; - Obj(const std::string &n, size_t start) - : name(n), start_index(start) {}; + Obj(const std::string &n, size_t start) : name(n), start_index(start){}; bool IsValid() const { return !name.empty(); } }; diff --git a/cpp/open3d/visualization/gui/WindowSystem.h b/cpp/open3d/visualization/gui/WindowSystem.h index 8dd5a849086..f77ce15886f 100644 --- a/cpp/open3d/visualization/gui/WindowSystem.h +++ b/cpp/open3d/visualization/gui/WindowSystem.h @@ -28,7 +28,7 @@ class WindowSystem { public: using OSWindow = void*; - virtual ~WindowSystem() {}; + virtual ~WindowSystem(){}; virtual void Initialize() = 0; virtual void Uninitialize() = 0; diff --git a/cpp/open3d/visualization/webrtc_server/PeerConnectionManager.h b/cpp/open3d/visualization/webrtc_server/PeerConnectionManager.h index 0120e20b4fe..e57f7891ce1 100644 --- a/cpp/open3d/visualization/webrtc_server/PeerConnectionManager.h +++ b/cpp/open3d/visualization/webrtc_server/PeerConnectionManager.h @@ -120,7 +120,7 @@ class PeerConnectionManager { webrtc::PeerConnectionInterface* pc, std::promise& promise) - : pc_(pc), promise_(promise) {}; + : pc_(pc), promise_(promise){}; private: webrtc::PeerConnectionInterface* pc_; @@ -153,7 +153,7 @@ class PeerConnectionManager { webrtc::PeerConnectionInterface* pc, std::promise& promise) - : pc_(pc), promise_(promise) {}; + : pc_(pc), promise_(promise){}; private: webrtc::PeerConnectionInterface* pc_; diff --git a/cpp/pybind/core/cuda_utils.cpp b/cpp/pybind/core/cuda_utils.cpp index cb82bb030c4..4f0aba9dbc2 100644 --- a/cpp/pybind/core/cuda_utils.cpp +++ b/cpp/pybind/core/cuda_utils.cpp @@ -35,11 +35,13 @@ void pybind_cuda_utils_definitions(py::module& m) { cuda::Synchronize(); } }, - "Synchronizes CUDA devices. If no device is specified, all CUDA " - "devices will be synchronized. No effect if the specified device " - "is not a CUDA device. No effect if Open3D is not compiled with " - "CUDA support.", - "device"_a = py::none()); + "Synchronizes a CUDA stream."); +#if BUILD_CUDA_MODULE + m_cuda.def( + "synchronize_stream", + [](const CUDAStream& stream) { cuda::Synchronize(stream); }, + "Synchronizes a CUDA stream."); +#endif } } // namespace core diff --git a/cpp/tests/core/CUDAUtils.cpp b/cpp/tests/core/CUDAUtils.cpp index 3a0634fe3d7..3532df648ab 100644 --- a/cpp/tests/core/CUDAUtils.cpp +++ b/cpp/tests/core/CUDAUtils.cpp @@ -44,41 +44,44 @@ TEST(CUDAUtils, InitState) { void CheckScopedStreamManually() { int current_device = core::cuda::GetDevice(); - ASSERT_EQ(core::cuda::GetStream(), core::cuda::GetDefaultStream()); + ASSERT_EQ(core::CUDAStream::GetInstance().Get(), + core::CUDAStream::Default().Get()); ASSERT_EQ(core::cuda::GetDevice(), current_device); - cudaStream_t stream; - OPEN3D_CUDA_CHECK(cudaStreamCreate(&stream)); + core::CUDAStream stream = core::CUDAStream::CreateNew(); { core::CUDAScopedStream scoped_stream(stream); - ASSERT_EQ(core::cuda::GetStream(), stream); - ASSERT_NE(core::cuda::GetStream(), core::cuda::GetDefaultStream()); + ASSERT_EQ(core::CUDAStream::GetInstance().Get(), stream.Get()); + ASSERT_FALSE(core::CUDAStream::GetInstance().IsDefaultStream()); ASSERT_EQ(core::cuda::GetDevice(), current_device); } - OPEN3D_CUDA_CHECK(cudaStreamDestroy(stream)); + stream.Destroy(); - ASSERT_EQ(core::cuda::GetStream(), core::cuda::GetDefaultStream()); + ASSERT_TRUE(core::CUDAStream::GetInstance().IsDefaultStream()); ASSERT_EQ(core::cuda::GetDevice(), current_device); } void CheckScopedStreamAutomatically() { int current_device = core::cuda::GetDevice(); - ASSERT_EQ(core::cuda::GetStream(), core::cuda::GetDefaultStream()); + ASSERT_EQ(core::CUDAStream::GetInstance().Get(), + core::CUDAStream::Default().Get()); ASSERT_EQ(core::cuda::GetDevice(), current_device); { - core::CUDAScopedStream scoped_stream( - core::CUDAScopedStream::CreateNewStream); + core::CUDAScopedStream scoped_stream(core::CUDAStream::CreateNew(), + true); - ASSERT_NE(core::cuda::GetStream(), core::cuda::GetDefaultStream()); + ASSERT_NE(core::CUDAStream::GetInstance().Get(), + core::CUDAStream::Default().Get()); ASSERT_EQ(core::cuda::GetDevice(), current_device); } - ASSERT_EQ(core::cuda::GetStream(), core::cuda::GetDefaultStream()); + ASSERT_EQ(core::CUDAStream::GetInstance().Get(), + core::CUDAStream::Default().Get()); ASSERT_EQ(core::cuda::GetDevice(), current_device); } diff --git a/cpp/tests/core/HashMap.cpp b/cpp/tests/core/HashMap.cpp index 72b2b2b1586..c3077ebef16 100644 --- a/cpp/tests/core/HashMap.cpp +++ b/cpp/tests/core/HashMap.cpp @@ -367,8 +367,8 @@ TEST_P(HashMapPermuteDevices, Clear) { class int3 { public: - int3() : x_(0), y_(0), z_(0) {}; - int3(int k) : x_(k), y_(k * 2), z_(k * 4) {}; + int3() : x_(0), y_(0), z_(0){}; + int3(int k) : x_(k), y_(k * 2), z_(k * 4){}; bool operator==(const int3 &other) const { return x_ == other.x_ && y_ == other.y_ && z_ == other.z_; } diff --git a/cpp/tests/core/ParallelFor.cu b/cpp/tests/core/ParallelFor.cu index b83517b8f45..d084e7f2d1d 100644 --- a/cpp/tests/core/ParallelFor.cu +++ b/cpp/tests/core/ParallelFor.cu @@ -45,5 +45,21 @@ TEST(ParallelFor, LambdaCUDA) { } } +TEST(ParallelFor, LambdaCUDA_NonDefaultStream) { + const core::Device device("CUDA:0"); + const size_t N = 10000000; + + core::CUDAScopedStream s(core::CUDAStream::CreateNew(), true); + + core::Tensor tensor({N, 1}, core::Int64, device); + + RunParallelForOn(tensor); + + core::Tensor tensor_cpu = tensor.To(core::Device("CPU:0")); + for (int64_t i = 0; i < tensor.NumElements(); ++i) { + ASSERT_EQ(tensor_cpu.GetDataPtr()[i], i); + } +} + } // namespace tests } // namespace open3d