@@ -420,6 +420,7 @@ async def add_request_async(
420420 lora_request : Optional [LoRARequest ] = None ,
421421 trace_headers : Optional [Mapping [str , str ]] = None ,
422422 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
423+ priority : int = 0 ,
423424 ) -> None :
424425 ...
425426
@@ -433,6 +434,7 @@ async def add_request_async(
433434 lora_request : Optional [LoRARequest ] = None ,
434435 trace_headers : Optional [Mapping [str , str ]] = None ,
435436 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
437+ priority : int = 0 ,
436438 ) -> None :
437439 ...
438440
@@ -449,6 +451,7 @@ async def add_request_async(
449451 lora_request : Optional [LoRARequest ] = None ,
450452 trace_headers : Optional [Mapping [str , str ]] = None ,
451453 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
454+ priority : int = 0 ,
452455 * ,
453456 inputs : Optional [PromptType ] = None , # DEPRECATED
454457 ) -> None :
@@ -460,6 +463,9 @@ async def add_request_async(
460463 if lora_request is not None and not self .lora_config :
461464 raise ValueError (f"Got lora_request { lora_request } but LoRA is "
462465 "not enabled!" )
466+ if priority != 0 and not self .scheduler_config .policy == "priority" :
467+ raise ValueError (f"Got priority { priority } but "
468+ "Priority scheduling is not enabled." )
463469 if arrival_time is None :
464470 arrival_time = time .time ()
465471
@@ -479,6 +485,7 @@ async def add_request_async(
479485 lora_request = lora_request ,
480486 prompt_adapter_request = prompt_adapter_request ,
481487 trace_headers = trace_headers ,
488+ priority = priority ,
482489 )
483490
484491 async def check_health_async (self ) -> None :
@@ -829,6 +836,7 @@ def add_request(
829836 lora_request : Optional [LoRARequest ] = None ,
830837 trace_headers : Optional [Mapping [str , str ]] = None ,
831838 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
839+ priority : int = 0 ,
832840 ) -> Coroutine [None , None , AsyncGenerator [Union [
833841 RequestOutput , EmbeddingRequestOutput ], None ]]:
834842 ...
@@ -843,6 +851,7 @@ def add_request(
843851 lora_request : Optional [LoRARequest ] = None ,
844852 trace_headers : Optional [Mapping [str , str ]] = None ,
845853 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
854+ priority : int = 0 ,
846855 ) -> Coroutine [None , None , AsyncGenerator [Union [
847856 RequestOutput , EmbeddingRequestOutput ], None ]]:
848857 ...
@@ -860,6 +869,7 @@ async def add_request(
860869 lora_request : Optional [LoRARequest ] = None ,
861870 trace_headers : Optional [Mapping [str , str ]] = None ,
862871 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
872+ priority : int = 0 ,
863873 * ,
864874 inputs : Optional [PromptType ] = None , # DEPRECATED
865875 ) -> AsyncGenerator [Union [RequestOutput , EmbeddingRequestOutput ], None ]:
@@ -877,6 +887,11 @@ async def add_request(
877887 "error that caused the background loop to stop "
878888 "(AsyncEngineDeadError)." )
879889
890+ if (priority != 0
891+ and not self .engine .scheduler_config .policy == "priority" ):
892+ raise ValueError (f"Got priority { priority } but "
893+ "Priority scheduling is not enabled." )
894+
880895 stream = self ._request_tracker .add_request (
881896 request_id ,
882897 verbose = self .log_requests ,
@@ -885,7 +900,9 @@ async def add_request(
885900 arrival_time = arrival_time or time .time (),
886901 lora_request = lora_request ,
887902 trace_headers = trace_headers ,
888- prompt_adapter_request = prompt_adapter_request )
903+ prompt_adapter_request = prompt_adapter_request ,
904+ priority = priority ,
905+ )
889906
890907 return stream .generator ()
891908
@@ -896,7 +913,8 @@ async def generate(
896913 request_id : str ,
897914 lora_request : Optional [LoRARequest ] = None ,
898915 trace_headers : Optional [Mapping [str , str ]] = None ,
899- prompt_adapter_request : Optional [PromptAdapterRequest ] = None
916+ prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
917+ priority : int = 0 ,
900918 ) -> AsyncGenerator [RequestOutput , None ]:
901919 """Generate outputs for a request.
902920
@@ -913,6 +931,8 @@ async def generate(
913931 trace_headers: OpenTelemetry trace headers.
914932 prompt_adapter_request: Prompt Adapter request to use
915933 for generation, if any.
934+ priority: The priority of the request.
935+ Only applicable with priority scheduling.
916936
917937 Yields:
918938 The output `RequestOutput` objects from the LLMEngine
@@ -968,6 +988,7 @@ async def generate(
968988 lora_request = lora_request ,
969989 trace_headers = trace_headers ,
970990 prompt_adapter_request = prompt_adapter_request ,
991+ priority = priority ,
971992 ):
972993 yield LLMEngine .validate_output (output , RequestOutput )
973994
0 commit comments