Skip to content

Commit 2ae1fa1

Browse files
committed
Restore blocking CONNECT functionality
1 parent 8d874e3 commit 2ae1fa1

File tree

4 files changed

+63
-30
lines changed

4 files changed

+63
-30
lines changed

megatron/core/inference/data_parallel_inference_coordinator.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,12 @@ def start(self, loop: Optional[asyncio.AbstractEventLoop] = None):
134134
self.request_id_to_client_id = {}
135135
self.request_id_to_client_request_id = {}
136136
self.next_request_id = 0
137-
138137
self._send_awaitables = asyncio.Queue()
138+
139+
# Attempt to connect, and do not allow any sends until we are connected.
140+
self.is_running = asyncio.Event()
141+
142+
self.startup_sends_task = loop.create_task(self._startup_sends_task())
139143
self.send_task = loop.create_task(self._send_task())
140144
self.recv_task = loop.create_task(self._recv_task())
141145

@@ -144,31 +148,30 @@ async def _recv_task(self):
144148
"""Main loop of the inference coordinator."""
145149

146150
print("Inference Coordinator: waiting for connections from data parallel ranks...")
147-
# First wait for all data parallel ranks to establish connections.
148-
for _ in range(self.data_parallel_size):
149-
identity, header, _ = await self._irecv()
150-
assert header == Headers.CONNECT
151-
assert identity not in self.identities_of_data_parallel_ranks
152-
self.identities_of_data_parallel_ranks.append(identity)
153-
print(f"Inference Coordinator: Data parallel rank connected: {identity}")
154-
print("All data parallel ranks connected.")
155-
logging.info("Inference Coordinator: Connected with data parallel ranks...")
156-
self.data_parallel_rank_iterator = cycle(self.identities_of_data_parallel_ranks)
157-
self.ready_event.set()
158-
print("Inference Coordinator: Ready to accept client connections.")
159-
160151
# Todo [Siddharth]: Make this more robust to handle invalid messages.
161152
while True:
162153
identity, header, data = await self._irecv()
163154

164-
if header == Headers.CONNECT:
155+
if header == Headers.ENGINE_CONNECT:
156+
assert identity not in self.identities_of_data_parallel_ranks
157+
self.identities_of_data_parallel_ranks.append(identity)
158+
print(f"Inference Coordinator: Data parallel rank connected: {identity}")
159+
if len(self.identities_of_data_parallel_ranks) == self.data_parallel_size:
160+
self.data_parallel_rank_iterator = cycle(self.identities_of_data_parallel_ranks)
161+
self.ready_event.set()
162+
self.is_running.set()
163+
print("All data parallel ranks connected.")
164+
logging.info("Inference Coordinator: Connected with data parallel ranks...")
165+
print("Inference Coordinator: Ready to accept client connections.")
166+
167+
elif header == Headers.CLIENT_CONNECT:
165168
if identity in self.known_clients:
166169
logging.info(
167170
f"Client {identity} sent a duplicate connect request. Ignoring .."
168171
)
169172
continue
170-
171173
self.known_clients.add(identity)
174+
# Due to the `startup_sends` logic, this will not be sent until we are connected.
172175
self._isend(identity, Headers.ACK)
173176

174177
elif header == Headers.SUBMIT_REQUEST:
@@ -240,6 +243,13 @@ async def _send_task(self):
240243
await (await self._send_awaitables.get())
241244
self._send_awaitables.task_done()
242245

246+
@trace_async_exceptions
247+
async def _startup_sends_task(self):
248+
"""Before a connection is established, we queue up sends for later."""
249+
await self.is_running()
250+
for (header, data) in self._startup_sends:
251+
self._isend(header, data)
252+
243253
def _isend(
244254
self, identity: bytes, header: Headers, data: Optional[List] = None
245255
) -> asyncio.Future:
@@ -251,6 +261,12 @@ def _isend(
251261
header (Headers): The signal header to send.
252262
data (Optional[List]): The data payload to send.
253263
"""
264+
# If we have not connected yet, wait on sends.
265+
if not self.is_running.is_set():
266+
self._startup_sends.append((identity, header, data))
267+
return
268+
269+
# Once we are connected, we do an atomic send and await its completion later.
254270
to_send = [identity, header.value.to_bytes()]
255271
if data is not None:
256272
to_send.append(msgpack.packb(data, use_bin_type=True))

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,8 @@ def start_listening_to_data_parallel_coordinator(
416416
self.socket_for_receiving_requests.setsockopt(zmq.IDENTITY, identity.encode('utf-8'))
417417
self.socket_for_receiving_requests.connect(dp_addr[0])
418418

419-
# send empty string. this is used to register with the coordinator.
420-
self._isend(self.socket_for_receiving_requests, Headers.CONNECT, b"")
419+
# Register with the coordinator.
420+
self._isend(self.socket_for_receiving_requests, Headers.ENGINE_CONNECT)
421421

422422
# 2. Create a publisher socket. This is used to publish or broadcast
423423
# requests within the model parallel group

megatron/core/inference/headers.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ class Headers(Enum):
88
Enum representing headers used for communication with the inference-coordinator.
99
"""
1010

11-
CONNECT = 0
12-
ACK = 1
13-
MICROBATCH_SYNC = 2
14-
SUBMIT_REQUEST = 3
15-
ENGINE_REPLY = 4
16-
PAUSE = 5
17-
UNPAUSE = 6
18-
STOP = 7
11+
ENGINE_CONNECT = 0
12+
CLIENT_CONNECT = 1
13+
ACK = 2
14+
MICROBATCH_SYNC = 3
15+
SUBMIT_REQUEST = 4
16+
ENGINE_REPLY = 5
17+
PAUSE = 6
18+
UNPAUSE = 7
19+
STOP = 8

megatron/core/inference/inference_client.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ async def _recv_task(self):
121121
try:
122122
_, header, data = await self._irecv()
123123

124-
assert header == Headers.ACK or self.initial_reply
124+
assert header == Headers.ACK or self.is_running.is_set()
125125
if header == Headers.ENGINE_REPLY:
126126
request_id, reply = data
127127
reply['latency'] = time.perf_counter() - self.request_submission_times.pop(
@@ -130,7 +130,7 @@ async def _recv_task(self):
130130
completion_future = self.completion_futures.pop(request_id)
131131
completion_future.set_result(DynamicInferenceRequest.deserialize(reply))
132132
elif header == Headers.ACK:
133-
self.initial_reply = True
133+
self.is_running.set()
134134
except asyncio.CancelledError:
135135
break
136136

@@ -149,9 +149,12 @@ def start(self, loop: Optional[asyncio.AbstractEventLoop] = None):
149149
self.next_request_id = 0
150150
self._send_awaitables = asyncio.Queue()
151151

152-
self.initial_reply = False
153-
self._isend(Headers.CONNECT)
152+
# Attempt to connect, and do not allow any sends until we are connected.
153+
self.is_running = asyncio.Event()
154+
self._startup_sends = []
155+
self._isend(Headers.CLIENT_CONNECT)
154156

157+
self.startup_sends_task = loop.create_task(self._startup_sends_task())
155158
self.send_task = loop.create_task(self._send_task())
156159
self.recv_task = loop.create_task(self._recv_task())
157160

@@ -166,6 +169,13 @@ async def _send_task(self):
166169
await (await self._send_awaitables.get())
167170
self._send_awaitables.task_done()
168171

172+
@trace_async_exceptions
173+
async def _startup_sends_task(self):
174+
"""Before a connection is established, we queue up sends for later."""
175+
await self.is_running()
176+
for (header, data) in self._startup_sends:
177+
self._isend(header, data)
178+
169179
def _isend(self, header: Headers, data: Optional[List] = None) -> asyncio.Future:
170180
"""
171181
Asynchronously send a signal to the inference coordinator.
@@ -174,6 +184,12 @@ def _isend(self, header: Headers, data: Optional[List] = None) -> asyncio.Future
174184
header (Headers): The signal header to send.
175185
data (Optional[List]): The data payload to send.
176186
"""
187+
# If we have not connected yet, wait on sends.
188+
if not self.is_running.is_set():
189+
self._startup_sends.append((header, data))
190+
return
191+
192+
# Once we are connected, we do an atomic send and await its completion later.
177193
to_send = [header.value.to_bytes()]
178194
if data is not None:
179195
to_send.append(msgpack.packb(data, use_bin_type=True))

0 commit comments

Comments
 (0)