Skip to content

Commit 8d874e3

Browse files
committed
Change ZMQ communication to use async ZMQ
1 parent 5c8c8cf commit 8d874e3

File tree

6 files changed

+408
-214
lines changed

6 files changed

+408
-214
lines changed

examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async def main(
3939
# once you call engine.start_listening_to_data_parallel_coordinator,
4040
# the engine will start accepting requests from the data parallel coordinator.
4141
# and processing them in an asyncio coroutine.
42-
await engine.start_listening_to_data_parallel_coordinator(
42+
engine.start_listening_to_data_parallel_coordinator(
4343
inference_coordinator_port=port, launch_inference_coordinator=True
4444
)
4545
# if you want to use your own inference coordinator -
@@ -51,7 +51,7 @@ async def main(
5151
# 5. look at InferenceClient to see how we create requests with headers.
5252
if dist.get_rank() == 0:
5353
client = InferenceClient(port) # submits requests to the inference coordinator
54-
await client.start()
54+
client.start()
5555
base_arrival_time = time.time_ns() / 10**9
5656
for request in requests:
5757
request.time_arrival = request.time_offset + base_arrival_time
Lines changed: 140 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

3+
import asyncio
34
import faulthandler
45
import logging
56
import signal
67
from collections import deque
78
from itertools import cycle
89
from multiprocessing import Event
10+
from typing import List, Optional, Tuple
911

1012
import torch
1113

1214
from megatron.core.inference.headers import Headers
15+
from megatron.core.utils import get_asyncio_loop, trace_async_exceptions
1316

1417
try:
1518
import zmq
19+
import zmq.asyncio
1620

1721
HAVE_ZMQ = True
1822
except:
@@ -25,11 +29,6 @@
2529
except:
2630
HAVE_MSGPACK = False
2731

28-
# Register faulthandler to emit stack traces upon process kill.
29-
faulthandler.enable()
30-
faulthandler.register(signal.SIGTERM, all_threads=False, chain=True)
31-
faulthandler.register(signal.SIGINT, all_threads=False, chain=True)
32-
3332

3433
class DataParallelInferenceCoordinator:
3534
"""
@@ -65,7 +64,9 @@ class DataParallelInferenceCoordinator:
6564
next_request_id (int): A counter for generating unique server-side request IDs.
6665
"""
6766

68-
def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
67+
def __init__(
68+
self, ready_event: Event, inference_coordinator_port: int, data_parallel_size: int
69+
):
6970
"""
7071
Initializes the inference coordinator.
7172
@@ -74,6 +75,8 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
7475
ranks to connect before proceeding.
7576
7677
Args:
78+
ready_event (Event): A threading or multiprocessing event object that is set()
79+
once the coordinator is ready to accept connections.
7780
inference_coordinator_port (int): The TCP port number to bind the server to.
7881
data_parallel_size (int): The number of TP-coordinator workers that are
7982
expected to connect.
@@ -86,7 +89,10 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
8689
"please install the messagepack library to use DataParallelInferenceCoordinator\n"
8790
"pip install msgpack"
8891
)
89-
self.context = zmq.Context()
92+
self.ready_event = ready_event
93+
self.data_parallel_size = data_parallel_size
94+
95+
self.context = zmq.asyncio.Context.instance()
9096

9197
# This is the central router socket
9298
# 1. data parallel ranks connect to this socket to register themselves
@@ -96,24 +102,8 @@ def __init__(self, inference_coordinator_port: int, data_parallel_size: int):
96102
# the user that had submitted the request originally.
97103

98104
self.router_socket = self.context.socket(zmq.ROUTER)
105+
self.socket_uses_identity = True
99106
self.router_socket.bind(f"tcp://0.0.0.0:{inference_coordinator_port}")
100-
self.data_parallel_size = data_parallel_size
101-
102-
logging.info("Inference Coordinator: waiting for connections from data parallel ranks...")
103-
# First wait for all data parallel ranks to establish connections.
104-
self.identities_of_data_parallel_ranks = deque([])
105-
# time.sleep(5) # Give data parallel ranks time to spawn and connect.
106-
for _ in range(data_parallel_size):
107-
identity, _ = self.router_socket.recv_multipart()
108-
assert identity not in self.identities_of_data_parallel_ranks
109-
self.identities_of_data_parallel_ranks.append(identity)
110-
logging.info("Inference Coordinator: Connected with data parallel ranks...")
111-
self.data_parallel_rank_iterator = cycle(self.identities_of_data_parallel_ranks)
112-
113-
self.request_id_to_client_id = {}
114-
self.request_id_to_client_request_id = {}
115-
116-
self.next_request_id = 0
117107

118108
def get_next_data_parallel_rank(self):
119109
"""
@@ -124,7 +114,7 @@ def get_next_data_parallel_rank(self):
124114
"""
125115
return next(self.data_parallel_rank_iterator)
126116

127-
def start(self):
117+
def start(self, loop: Optional[asyncio.AbstractEventLoop] = None):
128118
"""
129119
Starts the main event loop for the coordinator.
130120
@@ -134,45 +124,70 @@ def start(self):
134124
handling new client connections, forwarding requests, broadcasting
135125
control signals, or processing replies from the engines.
136126
"""
127+
logging.info("Inference Coordinator: waiting for connections from data parallel ranks...")
128+
loop = get_asyncio_loop(loop)
129+
130+
self.ready_event.clear()
131+
self.identities_of_data_parallel_ranks = deque([])
132+
self.data_parallel_rank_iterator = cycle([])
133+
self.known_clients = set()
134+
self.request_id_to_client_id = {}
135+
self.request_id_to_client_request_id = {}
136+
self.next_request_id = 0
137+
138+
self._send_awaitables = asyncio.Queue()
139+
self.send_task = loop.create_task(self._send_task())
140+
self.recv_task = loop.create_task(self._recv_task())
141+
142+
@trace_async_exceptions
143+
async def _recv_task(self):
144+
"""Main loop of the inference coordinator."""
145+
146+
print("Inference Coordinator: waiting for connections from data parallel ranks...")
147+
# First wait for all data parallel ranks to establish connections.
148+
for _ in range(self.data_parallel_size):
149+
identity, header, _ = await self._irecv()
150+
assert header == Headers.CONNECT
151+
assert identity not in self.identities_of_data_parallel_ranks
152+
self.identities_of_data_parallel_ranks.append(identity)
153+
print(f"Inference Coordinator: Data parallel rank connected: {identity}")
154+
print("All data parallel ranks connected.")
155+
logging.info("Inference Coordinator: Connected with data parallel ranks...")
156+
self.data_parallel_rank_iterator = cycle(self.identities_of_data_parallel_ranks)
157+
self.ready_event.set()
158+
print("Inference Coordinator: Ready to accept client connections.")
159+
137160
# Todo [Siddharth]: Make this more robust to handle invalid messages.
138-
known_clients = set()
139161
while True:
140-
sender_identity, serialized_payload = self.router_socket.recv_multipart()
141-
deserialized_payload = msgpack.unpackb(serialized_payload, raw=False)
142-
header = Headers(deserialized_payload[0])
162+
identity, header, data = await self._irecv()
143163

144164
if header == Headers.CONNECT:
145-
if sender_identity in known_clients:
165+
if identity in self.known_clients:
146166
logging.info(
147-
f"Client {sender_identity} sent a duplicate connect request. Ignoring .."
167+
f"Client {identity} sent a duplicate connect request. Ignoring .."
148168
)
149169
continue
150170

151-
# print(f"New client connected: {sender_identity}")
152-
known_clients.add(sender_identity)
153-
self.router_socket.send_multipart(
154-
[sender_identity, msgpack.packb([Headers.ACK.value], use_bin_type=True)]
155-
)
171+
self.known_clients.add(identity)
172+
self._isend(identity, Headers.ACK)
156173

157174
elif header == Headers.SUBMIT_REQUEST:
158175
# ToDo [Siddharth]: We might want to tokenize the prompt on the
159176
# assigned data parallel rank for this process instead
160177
# of the coordinator.
161178

162179
# Message from a known client
163-
if sender_identity not in known_clients:
164-
logging.info(
165-
f"Received message from unknown client {sender_identity}. Ignoring."
166-
)
180+
if identity not in self.known_clients:
181+
logging.info(f"Received message from unknown client {identity}. Ignoring.")
167182
continue
168183
# this is a message from a client.
169184
# route it to a data parallel rank
170-
client_request_id, prompt, sampling_params = deserialized_payload[1:]
185+
client_request_id, prompt, sampling_params = data
171186
# map client request_id to server request_id
172187
# necessary because multiple clients might have the same request_id.
173188
request_id = self.next_request_id
174189
self.next_request_id += 1
175-
self.request_id_to_client_id[request_id] = sender_identity
190+
self.request_id_to_client_id[request_id] = identity
176191
self.request_id_to_client_request_id[request_id] = client_request_id
177192

178193
# Serialize prompt.
@@ -184,28 +199,22 @@ def start(self):
184199
raise Exception("specialize for <%s> prompt." % type(prompt).__name__)
185200

186201
next_data_parallel_rank_identity = self.get_next_data_parallel_rank()
187-
self.router_socket.send_multipart(
188-
[
189-
next_data_parallel_rank_identity,
190-
msgpack.packb(
191-
[Headers.SUBMIT_REQUEST.value, request_id, prompt, sampling_params],
192-
use_bin_type=True,
193-
),
194-
]
202+
self._isend(
203+
next_data_parallel_rank_identity,
204+
Headers.SUBMIT_REQUEST,
205+
[request_id, prompt, sampling_params],
195206
)
196207
elif header in [Headers.PAUSE, Headers.UNPAUSE, Headers.STOP]:
197208
# control signals for the engine
198209
# broadcast to all data parallel ranks
199-
if sender_identity not in known_clients:
210+
if identity not in self.known_clients:
200211
continue
201212
for data_parallel_rank_id in self.identities_of_data_parallel_ranks:
202-
self.router_socket.send_multipart(
203-
[data_parallel_rank_id, msgpack.packb([header.value], use_bin_type=True)]
204-
)
213+
self._isend(data_parallel_rank_id, header)
205214
elif header == Headers.ENGINE_REPLY:
206215
# This is the output of a single engine step on some data parallel rank.
207-
assert sender_identity in self.identities_of_data_parallel_ranks
208-
finished_requests = deserialized_payload[1]
216+
assert identity in self.identities_of_data_parallel_ranks
217+
finished_requests = data
209218

210219
for finished_request in finished_requests:
211220
fid = finished_request["request_id"]
@@ -214,15 +223,68 @@ def start(self):
214223
del self.request_id_to_client_id[fid]
215224
del self.request_id_to_client_request_id[fid]
216225

217-
self.router_socket.send_multipart(
218-
[
219-
client_identity,
220-
msgpack.packb(
221-
[client_request_identity, finished_request], use_bin_type=True
222-
),
223-
]
226+
self._isend(
227+
client_identity,
228+
Headers.ENGINE_REPLY,
229+
[client_request_identity, finished_request],
224230
)
225231

232+
@trace_async_exceptions
233+
async def _send_task(self):
234+
"""Pop futures of sends out of a queue and await them.
235+
236+
For explanation why this works, refer to the documentation for zmq.asyncio:
237+
'Returns a Future that resolves when sending is complete.'
238+
"""
239+
while True:
240+
await (await self._send_awaitables.get())
241+
self._send_awaitables.task_done()
242+
243+
def _isend(
244+
self, identity: bytes, header: Headers, data: Optional[List] = None
245+
) -> asyncio.Future:
246+
"""
247+
Asynchronously send a signal to the inference coordinator.
248+
249+
Args:
250+
identity (bytes): The ZMQ identity of the recipient.
251+
header (Headers): The signal header to send.
252+
data (Optional[List]): The data payload to send.
253+
"""
254+
to_send = [identity, header.value.to_bytes()]
255+
if data is not None:
256+
to_send.append(msgpack.packb(data, use_bin_type=True))
257+
send_awaitable = self.router_socket.send_multipart(to_send)
258+
self._send_awaitables.put_nowait(send_awaitable)
259+
260+
async def _irecv(
261+
self, deserialize: bool = True
262+
) -> Tuple[Optional[bytes], Headers, List | bytes | None]:
263+
"""
264+
Asynchronously receive a signal from the inference coordinator.
265+
266+
Returns:
267+
identity (Optional[bytes]): The source of the signal.
268+
header (Headers): The signal header received.
269+
data (List | bytes | None): The data payload received.
270+
"""
271+
raw = await self.router_socket.recv_multipart()
272+
if self.socket_uses_identity:
273+
identity, header, *rest = raw
274+
else:
275+
header, *rest = raw
276+
identity = None
277+
278+
header = Headers(int.from_bytes(header))
279+
data = rest[0] if rest else None
280+
281+
if deserialize:
282+
message = msgpack.unpackb(data, raw=False) if data is not None else None
283+
else:
284+
message = data
285+
286+
return identity, header, message
287+
226288
@classmethod
227289
def entrypoint(
228290
cls, ready_event: Event, inference_coordinator_port: int, data_parallel_size: int
@@ -239,17 +301,29 @@ def entrypoint(
239301
inference_coordinator_port (int): The port to bind to.
240302
data_parallel_size (int): The number of expected TP-coordinators.
241303
"""
242-
coordinator = cls(inference_coordinator_port, data_parallel_size)
243-
ready_event.set()
304+
# Register faulthandler to emit stack traces upon process kill.
305+
faulthandler.enable()
306+
faulthandler.register(signal.SIGTERM, all_threads=False, chain=True)
307+
faulthandler.register(signal.SIGINT, all_threads=False, chain=True)
308+
309+
print("Inference Coordinator: Initializing coordinator...")
310+
coordinator = cls(ready_event, inference_coordinator_port, data_parallel_size)
311+
print("Inference Coordinator: Starting coordinator...")
312+
loop = get_asyncio_loop()
313+
coordinator.start(loop=loop)
314+
print("Inference Coordinator: Coordinator started.")
244315
try:
245-
coordinator.start()
316+
loop.run_forever()
246317
except KeyboardInterrupt:
247318
logging.info("Coordinator process interrupted. Exiting...")
319+
finally:
248320
coordinator.stop()
249321

250322
def stop(self):
251323
"""
252324
Stops the inference coordinator, performing any necessary cleanup operations.
253325
"""
326+
self.send_task.cancel()
327+
self.recv_task.cancel()
254328
self.router_socket.close()
255329
self.context.term()

0 commit comments

Comments
 (0)