Skip to content

Commit 8b8593b

Browse files
authored
Fix Python bindings and tests (#690)
Minimal fix to make things work. We need a more careful look at preventing silent fallback of nanobind when it fails to (properly) construct a C++ STL object with mscclpp instances.
1 parent 060c35f commit 8b8593b

File tree

8 files changed

+73
-63
lines changed

8 files changed

+73
-63
lines changed

python/csrc/core_py.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ void register_core(nb::module_& m) {
216216

217217
def_shared_future<RegisteredMemory>(m, "RegisteredMemory");
218218
def_shared_future<Connection>(m, "Connection");
219+
def_shared_future<Semaphore>(m, "Semaphore");
219220

220221
nb::class_<Communicator>(m, "Communicator")
221222
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
@@ -242,7 +243,7 @@ void register_core(nb::module_& m) {
242243
nb::arg("remote_rank"), nb::arg("tag"), nb::arg("local_config"))
243244
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag"))
244245
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag"))
245-
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("local_flag"), nb::arg("remote_rank"),
246+
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("connection"), nb::arg("remote_rank"),
246247
nb::arg("tag") = 0)
247248
.def("remote_rank_of", &Communicator::remoteRankOf)
248249
.def("tag_of", &Communicator::tagOf)

python/csrc/memory_channel_py.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,20 @@ void register_memory_channel(nb::module_& m) {
2626

2727
nb::class_<MemoryChannel>(m, "MemoryChannel")
2828
.def(nb::init<>())
29-
.def("__init__",
30-
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
31-
RegisteredMemory dst, RegisteredMemory src) { new (memoryChannel) MemoryChannel(semaphore, dst, src); })
32-
.def("__init__",
33-
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
34-
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
35-
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
36-
})
29+
.def(
30+
"__init__",
31+
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
32+
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
33+
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
34+
},
35+
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
36+
.def(
37+
"__init__",
38+
[](MemoryChannel* memoryChannel, const Semaphore& semaphore, RegisteredMemory dst, RegisteredMemory src,
39+
uintptr_t packet_buffer = 0) {
40+
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
41+
},
42+
nb::arg("semaphore"), nb::arg("dst"), nb::arg("src"), nb::arg("packet_buffer") = 0)
3743
.def("device_handle", &MemoryChannel::deviceHandle);
3844

3945
nb::class_<MemoryChannel::DeviceHandle>(m, "MemoryChannelDeviceHandle")

python/mscclpp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
connect_nvls_collective,
4848
EndpointConfig,
4949
Fifo,
50+
Semaphore,
5051
Host2DeviceSemaphore,
5152
Host2HostSemaphore,
5253
numa,
@@ -79,6 +80,7 @@
7980
"connect_nvls_collective",
8081
"EndpointConfig",
8182
"Fifo",
83+
"Semaphore",
8284
"Host2DeviceSemaphore",
8385
"Host2HostSemaphore",
8486
"numa",

python/mscclpp/comm.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Connection,
1111
connect_nvls_collective,
1212
EndpointConfig,
13+
Semaphore,
1314
Host2DeviceSemaphore,
1415
Host2HostSemaphore,
1516
ProxyService,
@@ -133,18 +134,14 @@ def _register_memory_with_connections(
133134
all_registered_memories[rank] = future_memories[rank].get()
134135
return all_registered_memories
135136

136-
def make_semaphore(
137-
self,
138-
connections: dict[int, Connection],
139-
semaphore_type: Type[Host2HostSemaphore] | Type[Host2DeviceSemaphore] | Type[MemoryDevice2DeviceSemaphore],
140-
) -> dict[int, Host2HostSemaphore]:
141-
semaphores = {}
137+
def make_semaphores(self, connections: dict[int, Connection]) -> dict[int, Semaphore]:
138+
future_semaphores = {}
142139
for rank in connections:
143-
semaphores[rank] = semaphore_type(self.communicator, connections[rank])
144-
return semaphores
140+
future_semaphores[rank] = self.communicator.build_semaphore(connections[rank], rank)
141+
return {rank: future.get() for rank, future in future_semaphores.items()}
145142

146143
def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, MemoryChannel]:
147-
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
144+
semaphores = self.make_semaphores(connections)
148145
registered_memories = self.register_tensor_with_connections(tensor, connections)
149146
channels = {}
150147
for rank in connections:
@@ -159,7 +156,7 @@ def make_memory_channels_with_scratch(
159156
registeredScratchBuffer: RegisteredMemory,
160157
connections: dict[int, Connection],
161158
) -> dict[int, MemoryChannel]:
162-
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
159+
semaphores = self.make_semaphores(connections)
163160
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
164161
channels = {}
165162
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
@@ -177,7 +174,7 @@ def make_memory_channels_with_scratch(
177174
def make_port_channels(
178175
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
179176
) -> dict[int, PortChannel]:
180-
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
177+
semaphores = self.make_semaphores(connections)
181178
registered_memories = self.register_tensor_with_connections(tensor, connections)
182179
memory_ids = {}
183180
semaphore_ids = {}
@@ -210,7 +207,7 @@ def make_port_channels_with_scratch(
210207
)
211208
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
212209

213-
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
210+
semaphores = self.make_semaphores(connections)
214211
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
215212
memory_ids = {}
216213
semaphore_ids = {}
@@ -229,7 +226,7 @@ def make_port_channels_with_scratch(
229226
def register_semaphore_with_proxy(
230227
self, proxy_service: ProxyService, connections: dict[int, Connection]
231228
) -> dict[int, PortChannel]:
232-
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
229+
semaphores = self.make_semaphores(connections)
233230
semaphore_ids = {}
234231
for rank in semaphores:
235232
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])

python/mscclpp_benchmark/mscclpp_op.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,10 @@ def __init__(
453453
)
454454

455455
# create a memory_channel for each remote neighbor
456-
self.semaphores = group.make_semaphore(self.nvlink_connections, MemoryDevice2DeviceSemaphore)
456+
self.semaphores = {
457+
rank: MemoryDevice2DeviceSemaphore(sema)
458+
for rank, sema in group.make_semaphores(self.nvlink_connections).items()
459+
}
457460
file_dir = os.path.dirname(os.path.abspath(__file__))
458461
self.kernel = KernelBuilder(
459462
file="allreduce.cu",

python/test/_cpp/proxy_test.cpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <mscclpp/core.hpp>
1111
#include <mscclpp/fifo.hpp>
1212
#include <mscclpp/gpu_utils.hpp>
13-
#include <mscclpp/numa.hpp>
1413
#include <mscclpp/proxy.hpp>
1514
#include <mscclpp/semaphore.hpp>
1615
#include <vector>
@@ -19,37 +18,39 @@ namespace nb = nanobind;
1918

2019
class MyProxyService {
2120
private:
22-
int deviceNumaNode_;
2321
int my_rank_, nranks_, dataSize_;
24-
std::vector<mscclpp::Connection> connections_;
25-
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem_;
26-
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
22+
std::vector<mscclpp::RegisteredMemory> allRegMem_;
23+
std::vector<mscclpp::Host2DeviceSemaphore> semaphores_;
2724
mscclpp::Proxy proxy_;
2825

2926
public:
30-
MyProxyService(int my_rank, int nranks, int dataSize, std::vector<mscclpp::Connection> conns,
31-
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>> allRegMem,
32-
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores)
27+
MyProxyService(int my_rank, int nranks, int dataSize, nb::list allRegMemList, nb::list semaphoreList)
3328
: my_rank_(my_rank),
3429
nranks_(nranks),
3530
dataSize_(dataSize),
36-
connections_(conns),
37-
allRegMem_(allRegMem),
38-
semaphores_(semaphores),
3931
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {
40-
int cudaDevice;
41-
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
42-
deviceNumaNode_ = mscclpp::getDeviceNumaNode(cudaDevice);
32+
allRegMem_.reserve(allRegMemList.size());
33+
for (size_t i = 0; i < allRegMemList.size(); ++i) {
34+
auto& regMem = nb::cast<const mscclpp::RegisteredMemory&>(allRegMemList[i]);
35+
allRegMem_.push_back(regMem);
36+
}
37+
semaphores_.reserve(semaphoreList.size());
38+
for (size_t i = 0; i < semaphoreList.size(); ++i) {
39+
auto& sema = nb::cast<const mscclpp::Semaphore&>(semaphoreList[i]);
40+
semaphores_.emplace_back(sema);
41+
}
4342
}
4443

4544
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger) {
4645
int dataSizePerRank = dataSize_ / nranks_;
4746
for (int r = 1; r < nranks_; ++r) {
4847
int nghr = (my_rank_ + r) % nranks_;
49-
connections_[nghr].write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_],
50-
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
51-
semaphores_[nghr]->signal();
52-
connections_[nghr].flush();
48+
auto& sema = semaphores_[nghr];
49+
auto& conn = sema.connection();
50+
conn.write(allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, allRegMem_[my_rank_],
51+
my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank);
52+
sema.signal();
53+
conn.flush();
5354
}
5455
return mscclpp::ProxyHandlerResult::Continue;
5556
}
@@ -61,16 +62,11 @@ class MyProxyService {
6162
mscclpp::FifoDeviceHandle fifoDeviceHandle() { return proxy_.fifo()->deviceHandle(); }
6263
};
6364

64-
void init_mscclpp_proxy_test_module(nb::module_ &m) {
65+
NB_MODULE(_ext, m) {
6566
nb::class_<MyProxyService>(m, "MyProxyService")
66-
.def(nb::init<int, int, int, std::vector<mscclpp::Connection>,
67-
std::vector<std::shared_ptr<mscclpp::RegisteredMemory>>,
68-
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>>>(),
69-
nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"),
70-
nb::arg("h2d_sem_vec"))
67+
.def(nb::init<int, int, int, nb::list, nb::list>(), nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"),
68+
nb::arg("reg_mem_list"), nb::arg("sem_list"))
7169
.def("fifo_device_handle", &MyProxyService::fifoDeviceHandle)
7270
.def("start", &MyProxyService::start)
7371
.def("stop", &MyProxyService::stop);
7472
}
75-
76-
NB_MODULE(_ext, m) { init_mscclpp_proxy_test_module(m); }

python/test/test_mscclpp.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
290290
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
291291
connections = {rank: conn.get() for rank, conn in connections.items()}
292292

293-
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
293+
semaphores = group.make_semaphores(connections)
294+
semaphores = {rank: Host2HostSemaphore(sema) for rank, sema in semaphores.items()}
294295
for rank in connections:
295296
semaphores[rank].signal()
296297

@@ -309,7 +310,8 @@ def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
309310
connections = {rank: group.communicator.connect(endpoint, rank) for rank in remote_nghrs}
310311
connections = {rank: conn.get() for rank, conn in connections.items()}
311312

312-
semaphores = group.make_semaphore(connections, Host2HostSemaphore)
313+
semaphores = group.make_semaphores(connections)
314+
semaphores = {rank: Host2HostSemaphore(sema) for rank, sema in semaphores.items()}
313315

314316
def target_wait(sems, conns):
315317
for rank in conns:
@@ -457,7 +459,8 @@ def signal(semaphores):
457459

458460
group, connections = create_group_and_connection(mpi_group, connection_type)
459461

460-
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
462+
semaphores = group.make_semaphores(connections)
463+
semaphores = {rank: Host2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
461464
kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores)
462465
kernel()
463466

@@ -473,7 +476,8 @@ def signal(semaphores):
473476
def test_d2d_semaphores(mpi_group: MpiGroup):
474477
group, connections = create_group_and_connection(mpi_group, "NVLink")
475478

476-
semaphores = group.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
479+
semaphores = group.make_semaphores(connections)
480+
semaphores = {rank: MemoryDevice2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
477481
group.barrier()
478482
kernel = MscclppKernel("d2d_semaphore", group.my_rank, group.nranks, semaphores)
479483
kernel()
@@ -545,29 +549,29 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, connection_type: str):
545549
group.barrier()
546550
all_reg_memories = group.register_tensor_with_connections(memory, connections)
547551

548-
semaphores = group.make_semaphore(connections, Host2DeviceSemaphore)
552+
semaphores = group.make_semaphores(connections)
549553

550-
list_conn = []
551554
list_sem = []
552555
list_reg_mem = []
553-
first_conn = next(iter(connections.values()))
554556
first_sem = next(iter(semaphores.values()))
555557
for rank in range(group.nranks):
556558
if rank in connections:
557-
list_conn.append(connections[rank])
558559
list_sem.append(semaphores[rank])
559560
else:
560-
list_conn.append(first_conn) # just for simplicity of indexing
561561
list_sem.append(first_sem)
562562

563563
list_reg_mem.append(all_reg_memories[rank])
564564

565-
proxy = _ext.MyProxyService(group.my_rank, group.nranks, nelem * memory.itemsize, list_conn, list_reg_mem, list_sem)
565+
proxy = _ext.MyProxyService(group.my_rank, group.nranks, nelem * memory.itemsize, list_reg_mem, list_sem)
566566

567567
fifo_device_handle = proxy.fifo_device_handle()
568568

569569
kernel = MscclppKernel(
570-
"proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=semaphores, fifo=fifo_device_handle
570+
"proxy",
571+
my_rank=group.my_rank,
572+
nranks=group.nranks,
573+
semaphore_or_channels={rank: Host2DeviceSemaphore(sema) for rank, sema in semaphores.items()},
574+
fifo=fifo_device_handle,
571575
)
572576
proxy.start()
573577
group.barrier()
@@ -632,7 +636,8 @@ def test_nvls(mpi_group: MpiGroup):
632636
mem_handle = nvls_connection.bind_allocated_memory(memory.data.ptr, memory.data.mem.size)
633637

634638
nvlinks_connections = create_connection(group, "NVLink")
635-
semaphores = group.make_semaphore(nvlinks_connections, MemoryDevice2DeviceSemaphore)
639+
semaphores = group.make_semaphores(nvlinks_connections)
640+
semaphores = {rank: MemoryDevice2DeviceSemaphore(sema) for rank, sema in semaphores.items()}
636641

637642
kernel = MscclppKernel(
638643
"nvls",

src/semaphore.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ struct Semaphore::Impl {
9797
Semaphore::Semaphore(const SemaphoreStub& localStub, const SemaphoreStub& remoteStub) {
9898
auto remoteMemImpl = remoteStub.memory().pimpl_;
9999
if (remoteMemImpl->hostHash == getHostHash() && remoteMemImpl->pidHash == getPidHash()) {
100-
pimpl_ = std::make_unique<Impl>(localStub, RegisteredMemory::deserialize(remoteStub.memory().serialize()));
100+
pimpl_ = std::make_shared<Impl>(localStub, RegisteredMemory::deserialize(remoteStub.memory().serialize()));
101101
} else {
102-
pimpl_ = std::make_unique<Impl>(localStub, remoteStub.memory());
102+
pimpl_ = std::make_shared<Impl>(localStub, remoteStub.memory());
103103
}
104104
}
105105

0 commit comments

Comments
 (0)