@@ -134,8 +134,12 @@ def start(self, loop: Optional[asyncio.AbstractEventLoop] = None):
134134 self .request_id_to_client_id = {}
135135 self .request_id_to_client_request_id = {}
136136 self .next_request_id = 0
137-
138137 self ._send_awaitables = asyncio .Queue ()
138+
139+ # Attempt to connect, and do not allow any sends until we are connected.
140+ self .is_running = asyncio .Event ()
141+
142+ self .startup_sends_task = loop .create_task (self ._startup_sends_task ())
139143 self .send_task = loop .create_task (self ._send_task ())
140144 self .recv_task = loop .create_task (self ._recv_task ())
141145
@@ -144,31 +148,30 @@ async def _recv_task(self):
144148 """Main loop of the inference coordinator."""
145149
146150 print ("Inference Coordinator: waiting for connections from data parallel ranks..." )
147- # First wait for all data parallel ranks to establish connections.
148- for _ in range (self .data_parallel_size ):
149- identity , header , _ = await self ._irecv ()
150- assert header == Headers .CONNECT
151- assert identity not in self .identities_of_data_parallel_ranks
152- self .identities_of_data_parallel_ranks .append (identity )
153- print (f"Inference Coordinator: Data parallel rank connected: { identity } " )
154- print ("All data parallel ranks connected." )
155- logging .info ("Inference Coordinator: Connected with data parallel ranks..." )
156- self .data_parallel_rank_iterator = cycle (self .identities_of_data_parallel_ranks )
157- self .ready_event .set ()
158- print ("Inference Coordinator: Ready to accept client connections." )
159-
160151 # Todo [Siddharth]: Make this more robust to handle invalid messages.
161152 while True :
162153 identity , header , data = await self ._irecv ()
163154
164- if header == Headers .CONNECT :
155+ if header == Headers .ENGINE_CONNECT :
156+ assert identity not in self .identities_of_data_parallel_ranks
157+ self .identities_of_data_parallel_ranks .append (identity )
158+ print (f"Inference Coordinator: Data parallel rank connected: { identity } " )
159+ if len (self .identities_of_data_parallel_ranks ) == self .data_parallel_size :
160+ self .data_parallel_rank_iterator = cycle (self .identities_of_data_parallel_ranks )
161+ self .ready_event .set ()
162+ self .is_running .set ()
163+ print ("All data parallel ranks connected." )
164+ logging .info ("Inference Coordinator: Connected with data parallel ranks..." )
165+ print ("Inference Coordinator: Ready to accept client connections." )
166+
167+ elif header == Headers .CLIENT_CONNECT :
165168 if identity in self .known_clients :
166169 logging .info (
167170 f"Client { identity } sent a duplicate connect request. Ignoring .."
168171 )
169172 continue
170-
171173 self .known_clients .add (identity )
174+ # Due to the `startup_sends` logic, this will not be sent until we are connected.
172175 self ._isend (identity , Headers .ACK )
173176
174177 elif header == Headers .SUBMIT_REQUEST :
@@ -240,6 +243,13 @@ async def _send_task(self):
240243 await (await self ._send_awaitables .get ())
241244 self ._send_awaitables .task_done ()
242245
246+ @trace_async_exceptions
247+ async def _startup_sends_task (self ):
248+ """Before a connection is established, we queue up sends for later."""
249+ await self .is_running ()
250+ for (header , data ) in self ._startup_sends :
251+ self ._isend (header , data )
252+
243253 def _isend (
244254 self , identity : bytes , header : Headers , data : Optional [List ] = None
245255 ) -> asyncio .Future :
@@ -251,6 +261,12 @@ def _isend(
251261 header (Headers): The signal header to send.
252262 data (Optional[List]): The data payload to send.
253263 """
264+ # If we have not connected yet, wait on sends.
265+ if not self .is_running .is_set ():
266+ self ._startup_sends .append ((identity , header , data ))
267+ return
268+
269+ # Once we are connected, we do an atomic send and await its completion later.
254270 to_send = [identity , header .value .to_bytes ()]
255271 if data is not None :
256272 to_send .append (msgpack .packb (data , use_bin_type = True ))
0 commit comments