@@ -290,7 +290,8 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
290290 connections = {rank : group .communicator .connect (endpoint , rank ) for rank in remote_nghrs }
291291 connections = {rank : conn .get () for rank , conn in connections .items ()}
292292
293- semaphores = group .make_semaphore (connections , Host2HostSemaphore )
293+ semaphores = group .make_semaphores (connections )
294+ semaphores = {rank : Host2HostSemaphore (sema ) for rank , sema in semaphores .items ()}
294295 for rank in connections :
295296 semaphores [rank ].signal ()
296297
@@ -309,7 +310,8 @@ def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
309310 connections = {rank : group .communicator .connect (endpoint , rank ) for rank in remote_nghrs }
310311 connections = {rank : conn .get () for rank , conn in connections .items ()}
311312
312- semaphores = group .make_semaphore (connections , Host2HostSemaphore )
313+ semaphores = group .make_semaphores (connections )
314+ semaphores = {rank : Host2HostSemaphore (sema ) for rank , sema in semaphores .items ()}
313315
314316 def target_wait (sems , conns ):
315317 for rank in conns :
@@ -457,7 +459,8 @@ def signal(semaphores):
457459
458460 group , connections = create_group_and_connection (mpi_group , connection_type )
459461
460- semaphores = group .make_semaphore (connections , Host2DeviceSemaphore )
462+ semaphores = group .make_semaphores (connections )
463+ semaphores = {rank : Host2DeviceSemaphore (sema ) for rank , sema in semaphores .items ()}
461464 kernel = MscclppKernel ("h2d_semaphore" , group .my_rank , group .nranks , semaphores )
462465 kernel ()
463466
@@ -473,7 +476,8 @@ def signal(semaphores):
473476def test_d2d_semaphores (mpi_group : MpiGroup ):
474477 group , connections = create_group_and_connection (mpi_group , "NVLink" )
475478
476- semaphores = group .make_semaphore (connections , MemoryDevice2DeviceSemaphore )
479+ semaphores = group .make_semaphores (connections )
480+ semaphores = {rank : MemoryDevice2DeviceSemaphore (sema ) for rank , sema in semaphores .items ()}
477481 group .barrier ()
478482 kernel = MscclppKernel ("d2d_semaphore" , group .my_rank , group .nranks , semaphores )
479483 kernel ()
@@ -545,29 +549,29 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, connection_type: str):
545549 group .barrier ()
546550 all_reg_memories = group .register_tensor_with_connections (memory , connections )
547551
548- semaphores = group .make_semaphore (connections , Host2DeviceSemaphore )
552+ semaphores = group .make_semaphores (connections )
549553
550- list_conn = []
551554 list_sem = []
552555 list_reg_mem = []
553- first_conn = next (iter (connections .values ()))
554556 first_sem = next (iter (semaphores .values ()))
555557 for rank in range (group .nranks ):
556558 if rank in connections :
557- list_conn .append (connections [rank ])
558559 list_sem .append (semaphores [rank ])
559560 else :
560- list_conn .append (first_conn ) # just for simplicity of indexing
561561 list_sem .append (first_sem )
562562
563563 list_reg_mem .append (all_reg_memories [rank ])
564564
565- proxy = _ext .MyProxyService (group .my_rank , group .nranks , nelem * memory .itemsize , list_conn , list_reg_mem , list_sem )
565+ proxy = _ext .MyProxyService (group .my_rank , group .nranks , nelem * memory .itemsize , list_reg_mem , list_sem )
566566
567567 fifo_device_handle = proxy .fifo_device_handle ()
568568
569569 kernel = MscclppKernel (
570- "proxy" , my_rank = group .my_rank , nranks = group .nranks , semaphore_or_channels = semaphores , fifo = fifo_device_handle
570+ "proxy" ,
571+ my_rank = group .my_rank ,
572+ nranks = group .nranks ,
573+ semaphore_or_channels = {rank : Host2DeviceSemaphore (sema ) for rank , sema in semaphores .items ()},
574+ fifo = fifo_device_handle ,
571575 )
572576 proxy .start ()
573577 group .barrier ()
@@ -632,7 +636,8 @@ def test_nvls(mpi_group: MpiGroup):
632636 mem_handle = nvls_connection .bind_allocated_memory (memory .data .ptr , memory .data .mem .size )
633637
634638 nvlinks_connections = create_connection (group , "NVLink" )
635- semaphores = group .make_semaphore (nvlinks_connections , MemoryDevice2DeviceSemaphore )
639+ semaphores = group .make_semaphores (nvlinks_connections )
640+ semaphores = {rank : MemoryDevice2DeviceSemaphore (sema ) for rank , sema in semaphores .items ()}
636641
637642 kernel = MscclppKernel (
638643 "nvls" ,
0 commit comments