@@ -39,7 +39,7 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
3939def _launch_distributed_inference (
4040 model_name : str , builder_args : BuilderArgs , tokenizer_args : TokenizerArgs
4141) -> tuple [List ]:
42- # create programmatic elastic launch
42+ # launch distributed inference worker, each worker gets a pipe to communicate with the main process
4343 logger .info ("Launching distributed inference ..." )
4444
4545 num_processes_per_node = builder_args .pp * builder_args .tp
@@ -50,17 +50,25 @@ def _launch_distributed_inference(
5050
5151 pipes = []
5252 procs = []
53- for rank in range (num_processes_per_node ):
54- server_pipe , client_pipe = mp .Pipe (duplex = True )
55- pipes .append (server_pipe )
56- proc = mp .Process (
57- target = partial (_setup_env , num_processes_per_node , rank , main ),
58- args = (model_name , builder_args , tokenizer_args , client_pipe ),
59- )
60- proc .start ()
53+ try :
54+ for rank in range (num_processes_per_node ):
55+ server_pipe , client_pipe = mp .Pipe (duplex = True )
56+ pipes .append (server_pipe )
57+ procs .append (
58+ mp .Process (
59+ target = partial (_setup_env , num_processes_per_node , rank , main ),
60+ args = (model_name , builder_args , tokenizer_args , client_pipe ),
61+ )
62+ )
63+ procs [- 1 ].start ()
6164
62- for pipe in pipes :
63- response = pipe .recv ()
65+ for pipe in pipes :
66+ assert pipe .recv () == "ready" , "Starting the worker failed"
67+ except Exception as e :
68+ logger .error (f"Error during distributed inference: { str (e )} " )
69+ for p in procs :
70+ p .kill ()
71+ raise e
6472
6573 logger .info (
6674 f"Done launching distributed inference on { num_processes_per_node } GPUs."
@@ -105,11 +113,13 @@ def __init__(
105113 self .loop = loop
106114
107115 def schedule_request (self , req : Request ):
116+ # add request to queue and create deque and async event for response
108117 self .req_to_states [req .request_id ] = asyncio .Event ()
109118 self .req_to_results [req .request_id ] = deque ()
110119 self .request_queue .put (req )
111120
112121 def process_requests_loop (self ):
122+ # Continuously process requests (one at a time for now), results are routed into the requests deque
113123 while True :
114124 req = self .request_queue .get ()
115125 if req == "stop" :
@@ -127,6 +137,7 @@ def process_requests_loop(self):
127137 running &= not outputs [0 ].is_finished
128138
129139 async def wait_for_request (self , req : Request ) -> Output :
140+ # Wait for request to deliver result, uses event to trigger and reads from left side of deque
130141 is_finished = False
131142 while not is_finished :
132143 await self .req_to_states [req .request_id ].wait ()
@@ -138,6 +149,7 @@ async def wait_for_request(self, req: Request) -> Output:
138149 del self .req_to_results [req .request_id ]
139150
140151 def step (self ) -> List [Output ]:
152+ # Make a prefill or decoding step and receive results
141153 responses = []
142154 # TODO: Implement a scheduler to handle the requests
143155 if len (self .in_flight_requests ) > 0 :
@@ -166,6 +178,7 @@ def step(self) -> List[Output]:
166178 text , token_ids = v
167179 outputs .append (
168180 Output (
181+ # TODO: Look for tokenizer.eos_id as well
169182 is_finished = self .current_step >= self .generator_args .max_new_tokens ,
170183 text = text ,
171184 token = token_ids ,
@@ -218,6 +231,7 @@ def __init__(
218231 atexit .register (self .shutdown )
219232
220233 def shutdown (self ):
234+ # Stop all processes and threads
221235 self .scheduler .request_queue .put ("stop" )
222236 self .scheduler_thread .join ()
223237
@@ -227,6 +241,7 @@ def shutdown(self):
227241 p .kill ()
228242
229243 def generate (self , text ):
244+ # Function to generate text from prompt
230245 req = Request .new_request (text )
231246 self .scheduler .schedule_request (req )
232247
0 commit comments