Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/mscclpp/port_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#ifndef MSCCLPP_PORT_CHANNEL_HPP_
#define MSCCLPP_PORT_CHANNEL_HPP_

#include <set>

#include "core.hpp"
#include "port_channel_device.hpp"
#include "proxy.hpp"
Expand Down Expand Up @@ -45,6 +47,10 @@ class ProxyService : public BaseProxyService {
/// @return The ID of the memory region.
MemoryId addMemory(RegisteredMemory memory);

/// Unregister a memory region from the proxy service.
/// @param memoryId The ID of the memory region to unregister.
void removeMemory(MemoryId memoryId);

/// Get a semaphore by ID.
/// @param id The ID of the semaphore.
/// @return The semaphore.
Expand Down Expand Up @@ -72,6 +78,7 @@ class ProxyService : public BaseProxyService {
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> memories_;
std::shared_ptr<Proxy> proxy_;
std::set<MemoryId> reusableMemoryIds_;
int deviceNumaNode;
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests;

Expand Down
5 changes: 4 additions & 1 deletion src/ib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
MSCCLPP_CUTHROW(cuCtxGetDevice(&dev));
MSCCLPP_CUTHROW(cuDeviceGetAttribute(&dmaBufSupported, CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED, dev));
#endif // !defined(__HIP_PLATFORM_AMD__)
if (cuMemAlloc && dmaBufSupported) {
if (cuMemAlloc) {
#if !defined(__HIP_PLATFORM_AMD__)
if (!dmaBufSupported) {
throw mscclpp::Error("Please make sure dma buffer is supported by the device", ErrorCode::InvalidUsage);
}
int fd;
MSCCLPP_CUTHROW(cuMemGetHandleForAddressRange(&fd, addr, pages * pageSize, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0));

Expand Down
3 changes: 2 additions & 1 deletion src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ NvlsConnection::Impl::Impl(const std::vector<char>& data) {
}

NvlsConnection::Impl::~Impl() {
// we don't need to free multicast handle object according to NCCL.
// Please ensure that all memory mappings are unmapped from the handle before calling the connection destructor.
cuMemRelease(mcHandle_);
if (rootPid_ == getpid()) {
close(mcFileDesc_);
}
Expand Down
16 changes: 16 additions & 0 deletions src/port_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,26 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Host2Devi
}

MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) {
if (!reusableMemoryIds_.empty()) {
auto it = reusableMemoryIds_.begin();
MemoryId memoryId = *it;
reusableMemoryIds_.erase(it);
memories_[memoryId] = memory;
return memoryId;
}
memories_.push_back(memory);
return memories_.size() - 1;
}

MSCCLPP_API_CPP void ProxyService::removeMemory(MemoryId memoryId) {
if (reusableMemoryIds_.find(memoryId) != reusableMemoryIds_.end() || memoryId >= memories_.size()) {
WARN("Attempted to remove a memory that is not registered or already removed: %u", memoryId);
return;
}
memories_[memoryId] = RegisteredMemory();
Copy link
Contributor

@chhwang chhwang Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a resource lock so that the removal doesn't happen while there exists unflushed triggers on the flight in the proxy. Also, we need a mechanism that prevents proxies on remote ranks from processing requests on this RegisteredMemory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think users are responsible for safely releasing the buffer. Before removing it, ensure that no peers or device-side operations are still accessing this memory. Adding spinlock here to prevent race conditions inside the proxy_channel.

reusableMemoryIds_.insert(memoryId);
}

MSCCLPP_API_CPP std::shared_ptr<Host2DeviceSemaphore> ProxyService::semaphore(SemaphoreId id) const {
return semaphores_[id];
}
Expand Down
13 changes: 9 additions & 4 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,15 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {

RegisteredMemory::Impl::~Impl() {
// Close the CUDA IPC handle if it was opened during deserialization
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) {
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash) {
if (getPidHash() == this->pidHash) {
// For local registered memory
if (fileDesc >= 0) {
close(fileDesc);
fileDesc = -1;
}
return;
}
void* base = static_cast<char*>(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase;
if (this->isCuMemMapAlloc) {
CUmemGenericAllocationHandle handle;
Expand All @@ -288,9 +296,6 @@ RegisteredMemory::Impl::~Impl() {
MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size));
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size));
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR && fileDesc >= 0) {
close(fileDesc);
}
} else {
cudaError_t err = cudaIpcCloseMemHandle(base);
if (err != cudaSuccess) {
Expand Down
12 changes: 12 additions & 0 deletions test/unit/core_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
#include <gtest/gtest.h>

#include <mscclpp/core.hpp>
#include <mscclpp/port_channel.hpp>

class LocalCommunicatorTest : public ::testing::Test {
protected:
void SetUp() override {
bootstrap = std::make_shared<mscclpp::TcpBootstrap>(0, 1);
bootstrap->initialize(bootstrap->createUniqueId());
comm = std::make_shared<mscclpp::Communicator>(bootstrap);
proxyService = std::make_shared<mscclpp::ProxyService>();
}

std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
std::shared_ptr<mscclpp::Communicator> comm;
std::shared_ptr<mscclpp::ProxyService> proxyService;
};

TEST_F(LocalCommunicatorTest, RegisterMemory) {
Expand All @@ -36,3 +39,12 @@ TEST_F(LocalCommunicatorTest, SendMemoryToSelf) {
EXPECT_EQ(sameMemory.size(), memory.size());
EXPECT_EQ(sameMemory.transports(), memory.transports());
}

TEST_F(LocalCommunicatorTest, ProxyServiceAddRemoveMemory) {
auto memory = mscclpp::RegisteredMemory();
auto memoryId = proxyService->addMemory(memory);
EXPECT_EQ(memoryId, 0);
proxyService->removeMemory(memoryId);
memoryId = proxyService->addMemory(memory);
EXPECT_EQ(memoryId, 0);
}
Loading