11# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
3+ import asyncio
34import faulthandler
45import logging
56import signal
67from collections import deque
78from itertools import cycle
89from multiprocessing import Event
10+ from typing import List , Optional , Tuple
911
1012import torch
1113
1214from megatron .core .inference .headers import Headers
15+ from megatron .core .utils import get_asyncio_loop , trace_async_exceptions
1316
1417try :
1518 import zmq
19+ import zmq .asyncio
1620
1721 HAVE_ZMQ = True
1822except :
2529except :
2630 HAVE_MSGPACK = False
2731
28- # Register faulthandler to emit stack traces upon process kill.
29- faulthandler .enable ()
30- faulthandler .register (signal .SIGTERM , all_threads = False , chain = True )
31- faulthandler .register (signal .SIGINT , all_threads = False , chain = True )
32-
3332
3433class DataParallelInferenceCoordinator :
3534 """
@@ -65,7 +64,9 @@ class DataParallelInferenceCoordinator:
6564 next_request_id (int): A counter for generating unique server-side request IDs.
6665 """
6766
68- def __init__ (self , inference_coordinator_port : int , data_parallel_size : int ):
67+ def __init__ (
68+ self , ready_event : Event , inference_coordinator_port : int , data_parallel_size : int
69+ ):
6970 """
7071 Initializes the inference coordinator.
7172
@@ -74,6 +75,8 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
7475 ranks to connect before proceeding.
7576
7677 Args:
78+ ready_event (Event): A threading or multiprocessing event object that is set()
79+ once the coordinator is ready to accept connections.
7780 inference_coordinator_port (int): The TCP port number to bind the server to.
7881 data_parallel_size (int): The number of TP-coordinator workers that are
7982 expected to connect.
@@ -86,7 +89,10 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
8689 "please install the messagepack library to use DataParallelInferenceCoordinator\n "
8790 "pip install msgpack"
8891 )
89- self .context = zmq .Context ()
92+ self .ready_event = ready_event
93+ self .data_parallel_size = data_parallel_size
94+
95+ self .context = zmq .asyncio .Context .instance ()
9096
9197 # This is the central router socket
9298 # 1. data parallel ranks connect to this socket to register themselves
@@ -96,24 +102,8 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
96102 # the user that had submitted the request originally.
97103
98104 self .router_socket = self .context .socket (zmq .ROUTER )
105+ self .socket_uses_identity = True
99106 self .router_socket .bind (f"tcp://0.0.0.0:{ inference_coordinator_port } " )
100- self .data_parallel_size = data_parallel_size
101-
102- logging .info ("Inference Coordinator: waiting for connections from data parallel ranks..." )
103- # First wait for all data parallel ranks to establish connections.
104- self .identities_of_data_parallel_ranks = deque ([])
105- # time.sleep(5) # Give data parallel ranks time to spawn and connect.
106- for _ in range (data_parallel_size ):
107- identity , _ = self .router_socket .recv_multipart ()
108- assert identity not in self .identities_of_data_parallel_ranks
109- self .identities_of_data_parallel_ranks .append (identity )
110- logging .info ("Inference Coordinator: Connected with data parallel ranks..." )
111- self .data_parallel_rank_iterator = cycle (self .identities_of_data_parallel_ranks )
112-
113- self .request_id_to_client_id = {}
114- self .request_id_to_client_request_id = {}
115-
116- self .next_request_id = 0
117107
118108 def get_next_data_parallel_rank (self ):
119109 """
@@ -124,7 +114,7 @@ def get_next_data_parallel_rank(self):
124114 """
125115 return next (self .data_parallel_rank_iterator )
126116
127- def start (self ):
117+ def start (self , loop : Optional [ asyncio . AbstractEventLoop ] = None ):
128118 """
129119 Starts the main event loop for the coordinator.
130120
@@ -134,45 +124,70 @@ def start(self):
134124 handling new client connections, forwarding requests, broadcasting
135125 control signals, or processing replies from the engines.
136126 """
127+ logging .info ("Inference Coordinator: waiting for connections from data parallel ranks..." )
128+ loop = get_asyncio_loop (loop )
129+
130+ self .ready_event .clear ()
131+ self .identities_of_data_parallel_ranks = deque ([])
132+ self .data_parallel_rank_iterator = cycle ([])
133+ self .known_clients = set ()
134+ self .request_id_to_client_id = {}
135+ self .request_id_to_client_request_id = {}
136+ self .next_request_id = 0
137+
138+ self ._send_awaitables = asyncio .Queue ()
139+ self .send_task = loop .create_task (self ._send_task ())
140+ self .recv_task = loop .create_task (self ._recv_task ())
141+
142+ @trace_async_exceptions
143+ async def _recv_task (self ):
144+ """Main loop of the inference coordinator."""
145+
146+ print ("Inference Coordinator: waiting for connections from data parallel ranks..." )
147+ # First wait for all data parallel ranks to establish connections.
148+ for _ in range (self .data_parallel_size ):
149+ identity , header , _ = await self ._irecv ()
150+ assert header == Headers .CONNECT
151+ assert identity not in self .identities_of_data_parallel_ranks
152+ self .identities_of_data_parallel_ranks .append (identity )
153+ print (f"Inference Coordinator: Data parallel rank connected: { identity } " )
154+ print ("All data parallel ranks connected." )
155+ logging .info ("Inference Coordinator: Connected with data parallel ranks..." )
156+ self .data_parallel_rank_iterator = cycle (self .identities_of_data_parallel_ranks )
157+ self .ready_event .set ()
158+ print ("Inference Coordinator: Ready to accept client connections." )
159+
137160 # Todo [Siddharth]: Make this more robust to handle invalid messages.
138- known_clients = set ()
139161 while True :
140- sender_identity , serialized_payload = self .router_socket .recv_multipart ()
141- deserialized_payload = msgpack .unpackb (serialized_payload , raw = False )
142- header = Headers (deserialized_payload [0 ])
162+ identity , header , data = await self ._irecv ()
143163
144164 if header == Headers .CONNECT :
145- if sender_identity in known_clients :
165+ if identity in self . known_clients :
146166 logging .info (
147- f"Client { sender_identity } sent a duplicate connect request. Ignoring .."
167+ f"Client { identity } sent a duplicate connect request. Ignoring .."
148168 )
149169 continue
150170
151- # print(f"New client connected: {sender_identity}")
152- known_clients .add (sender_identity )
153- self .router_socket .send_multipart (
154- [sender_identity , msgpack .packb ([Headers .ACK .value ], use_bin_type = True )]
155- )
171+ self .known_clients .add (identity )
172+ self ._isend (identity , Headers .ACK )
156173
157174 elif header == Headers .SUBMIT_REQUEST :
158175 # ToDo [Siddharth]: We might want to tokenize the prompt on the
159176 # assigned data parallel rank for this process instead
160177 # of the coordinator.
161178
162179 # Message from a known client
163- if sender_identity not in known_clients :
164- logging .info (
165- f"Received message from unknown client { sender_identity } . Ignoring."
166- )
180+ if identity not in self .known_clients :
181+ logging .info (f"Received message from unknown client { identity } . Ignoring." )
167182 continue
168183 # this is a message from a client.
169184 # route it to a data parallel rank
170- client_request_id , prompt , sampling_params = deserialized_payload [ 1 :]
185+ client_request_id , prompt , sampling_params = data
171186 # map client request_id to server request_id
172187 # necessary because multiple clients might have the same request_id.
173188 request_id = self .next_request_id
174189 self .next_request_id += 1
175- self .request_id_to_client_id [request_id ] = sender_identity
190+ self .request_id_to_client_id [request_id ] = identity
176191 self .request_id_to_client_request_id [request_id ] = client_request_id
177192
178193 # Serialize prompt.
@@ -184,28 +199,22 @@ def start(self):
184199 raise Exception ("specialize for <%s> prompt." % type (prompt ).__name__ )
185200
186201 next_data_parallel_rank_identity = self .get_next_data_parallel_rank ()
187- self .router_socket .send_multipart (
188- [
189- next_data_parallel_rank_identity ,
190- msgpack .packb (
191- [Headers .SUBMIT_REQUEST .value , request_id , prompt , sampling_params ],
192- use_bin_type = True ,
193- ),
194- ]
202+ self ._isend (
203+ next_data_parallel_rank_identity ,
204+ Headers .SUBMIT_REQUEST ,
205+ [request_id , prompt , sampling_params ],
195206 )
196207 elif header in [Headers .PAUSE , Headers .UNPAUSE , Headers .STOP ]:
197208 # control signals for the engine
198209 # broadcast to all data parallel ranks
199- if sender_identity not in known_clients :
210+ if identity not in self . known_clients :
200211 continue
201212 for data_parallel_rank_id in self .identities_of_data_parallel_ranks :
202- self .router_socket .send_multipart (
203- [data_parallel_rank_id , msgpack .packb ([header .value ], use_bin_type = True )]
204- )
213+ self ._isend (data_parallel_rank_id , header )
205214 elif header == Headers .ENGINE_REPLY :
206215 # This is the output of a single engine step on some data parallel rank.
207- assert sender_identity in self .identities_of_data_parallel_ranks
208- finished_requests = deserialized_payload [ 1 ]
216+ assert identity in self .identities_of_data_parallel_ranks
217+ finished_requests = data
209218
210219 for finished_request in finished_requests :
211220 fid = finished_request ["request_id" ]
@@ -214,15 +223,68 @@ def start(self):
214223 del self .request_id_to_client_id [fid ]
215224 del self .request_id_to_client_request_id [fid ]
216225
217- self .router_socket .send_multipart (
218- [
219- client_identity ,
220- msgpack .packb (
221- [client_request_identity , finished_request ], use_bin_type = True
222- ),
223- ]
226+ self ._isend (
227+ client_identity ,
228+ Headers .ENGINE_REPLY ,
229+ [client_request_identity , finished_request ],
224230 )
225231
232+ @trace_async_exceptions
233+ async def _send_task (self ):
234+ """Pop futures of sends out of a queue and await them.
235+
236+ For explanation why this works, refer to the documentation for zmq.asyncio:
237+ 'Returns a Future that resolves when sending is complete.'
238+ """
239+ while True :
240+ await (await self ._send_awaitables .get ())
241+ self ._send_awaitables .task_done ()
242+
243+ def _isend (
244+ self , identity : bytes , header : Headers , data : Optional [List ] = None
245+ ) -> asyncio .Future :
246+ """
247+ Asynchronously send a signal to the inference coordinator.
248+
249+ Args:
250+ identity (bytes): The ZMQ identity of the recipient.
251+ header (Headers): The signal header to send.
252+ data (Optional[List]): The data payload to send.
253+ """
254+ to_send = [identity , header .value .to_bytes ()]
255+ if data is not None :
256+ to_send .append (msgpack .packb (data , use_bin_type = True ))
257+ send_awaitable = self .router_socket .send_multipart (to_send )
258+ self ._send_awaitables .put_nowait (send_awaitable )
259+
260+ async def _irecv (
261+ self , deserialize : bool = True
262+ ) -> Tuple [Optional [bytes ], Headers , List | bytes | None ]:
263+ """
264+ Asynchronously receive a signal from the inference coordinator.
265+
266+ Returns:
267+ identity (Optional[bytes]): The source of the signal.
268+ header (Headers): The signal header received.
269+ data (List | bytes | None): The data payload received.
270+ """
271+ raw = await self .router_socket .recv_multipart ()
272+ if self .socket_uses_identity :
273+ identity , header , * rest = raw
274+ else :
275+ header , * rest = raw
276+ identity = None
277+
278+ header = Headers (int .from_bytes (header ))
279+ data = rest [0 ] if rest else None
280+
281+ if deserialize :
282+ message = msgpack .unpackb (data , raw = False ) if data is not None else None
283+ else :
284+ message = data
285+
286+ return identity , header , message
287+
226288 @classmethod
227289 def entrypoint (
228290 cls , ready_event : Event , inference_coordinator_port : int , data_parallel_size : int
@@ -239,17 +301,29 @@ def entrypoint(
239301 inference_coordinator_port (int): The port to bind to.
240302 data_parallel_size (int): The number of expected TP-coordinators.
241303 """
242- coordinator = cls (inference_coordinator_port , data_parallel_size )
243- ready_event .set ()
304+ # Register faulthandler to emit stack traces upon process kill.
305+ faulthandler .enable ()
306+ faulthandler .register (signal .SIGTERM , all_threads = False , chain = True )
307+ faulthandler .register (signal .SIGINT , all_threads = False , chain = True )
308+
309+ print ("Inference Coordinator: Initializing coordinator..." )
310+ coordinator = cls (ready_event , inference_coordinator_port , data_parallel_size )
311+ print ("Inference Coordinator: Starting coordinator..." )
312+ loop = get_asyncio_loop ()
313+ coordinator .start (loop = loop )
314+ print ("Inference Coordinator: Coordinator started." )
244315 try :
245- coordinator . start ()
316+ loop . run_forever ()
246317 except KeyboardInterrupt :
247318 logging .info ("Coordinator process interrupted. Exiting..." )
319+ finally :
248320 coordinator .stop ()
249321
250322 def stop (self ):
251323 """
252324 Stops the inference coordinator, performing any necessary cleanup operations.
253325 """
326+ self .send_task .cancel ()
327+ self .recv_task .cancel ()
254328 self .router_socket .close ()
255329 self .context .term ()
0 commit comments