3636from tqdm import tqdm
3737
3838import transformers
39- from transformers import BitsAndBytesConfig , GenerationConfig
39+ from transformers import AutoTokenizer , BitsAndBytesConfig , GenerationConfig , PreTrainedTokenizerBase
4040from transformers .utils .import_utils import (
4141 is_fastapi_available ,
4242 is_librosa_available ,
@@ -823,9 +823,9 @@ def continuous_batching_chat_completion(self, req: dict, request_id: str) -> Str
823823 self .running_continuous_batching_manager .start ()
824824
825825 # TODO (Joao, Lysandre): this should also work with tool support
826- inputs = processor .apply_chat_template (req [ "messages" ], return_tensors = "pt" , add_generation_prompt = True ). to (
827- model . device
828- )["input_ids" ][0 ]
826+ inputs = processor .apply_chat_template (
827+ req [ "messages" ], return_tensors = "pt" , add_generation_prompt = True , return_dict = True
828+ ). to ( model . device ) ["input_ids" ][0 ]
829829
830830 def stream_chat_completion (request_id , decode_stream ):
831831 from ..generation .continuous_batching import RequestStatus
@@ -841,8 +841,13 @@ def stream_chat_completion(request_id, decode_stream):
841841
842842 if result .status == RequestStatus .FINISHED :
843843 generated_all_tokens = n_tokens_generated >= generation_config .max_new_tokens
844- final_token_is_eos = result == tokenizer .eos_token
845- reason = "length" if (generated_all_tokens and not final_token_is_eos ) else "stop"
844+
845+ # If the tokenizer has an eos_token, we can have a more robust check.
846+ if hasattr (tokenizer , "eos_token" ):
847+ final_token_is_eos = result == tokenizer .eos_token
848+ generated_all_tokens = generated_all_tokens and not final_token_is_eos
849+
850+ reason = "length" if generated_all_tokens else "stop"
846851
847852 yield self .build_chat_completion_chunk (
848853 request_id ,
@@ -921,7 +926,11 @@ def cancellation_wrapper_buffer(_request_id):
921926 return JSONResponse (json_chunk , media_type = "application/json" )
922927
923928 @staticmethod
924- def get_model_modality (model : "PreTrainedModel" ) -> Modality :
929+ def get_model_modality (model : "PreTrainedModel" , processor = None ) -> Modality :
930+ if processor is not None :
931+ if isinstance (processor , PreTrainedTokenizerBase ):
932+ return Modality .LLM
933+
925934 from transformers .models .auto .modeling_auto import (
926935 MODEL_FOR_CAUSAL_LM_MAPPING_NAMES ,
927936 MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES ,
@@ -1011,7 +1020,7 @@ def generate_chat_completion(self, req: dict) -> StreamingResponse | JSONRespons
10111020 self .last_model = model_id_and_revision
10121021 model , processor = self .load_model_and_processor (model_id_and_revision )
10131022
1014- modality = self .get_model_modality (model )
1023+ modality = self .get_model_modality (model , processor = processor )
10151024 processor_inputs = self .get_processor_inputs_from_inbound_messages (messages , modality )
10161025
10171026 # ====== TOOL PREPROCESSING LOGIC ======
@@ -1184,8 +1193,14 @@ def generate_with_cache(**kwargs):
11841193 )
11851194
11861195 generated_all_tokens = n_tokens_generated >= generation_config .max_new_tokens
1187- final_token_is_eos = result == streamer .tokenizer .eos_token
1188- reason = "length" if (generated_all_tokens and not final_token_is_eos ) else "stop"
1196+
1197+ # If the tokenizer has an eos_token, we can have a more robust check.
1198+ if hasattr (streamer .tokenizer , "eos_token" ):
1199+ final_token_is_eos = result == streamer .tokenizer .eos_token
1200+ generated_all_tokens = generated_all_tokens and not final_token_is_eos
1201+
1202+ reason = "length" if generated_all_tokens else "stop"
1203+
11891204 yield self .build_chat_completion_chunk (_request_id , finish_reason = reason , model = model_id_and_revision )
11901205
11911206 thread .join ()
@@ -1272,7 +1287,9 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
12721287 else :
12731288 raise TypeError ("inputs should be a list, dict, or str" )
12741289
1275- inputs = processor .apply_chat_template (inputs , add_generation_prompt = True , return_tensors = "pt" )["input_ids" ]
1290+ inputs = processor .apply_chat_template (
1291+ inputs , add_generation_prompt = True , return_tensors = "pt" , return_dict = True
1292+ )["input_ids" ]
12761293 inputs = inputs .to (model .device )
12771294 request_id = req .get ("previous_response_id" , "req_0" )
12781295
@@ -1576,7 +1593,9 @@ def generate_response_non_streaming(self, req: dict) -> dict:
15761593 else :
15771594 raise ValueError ("inputs should be a list, dict, or str" )
15781595
1579- inputs = processor .apply_chat_template (inputs , add_generation_prompt = True , return_tensors = "pt" )["input_ids" ]
1596+ inputs = processor .apply_chat_template (
1597+ inputs , add_generation_prompt = True , return_tensors = "pt" , return_dict = True
1598+ )["input_ids" ]
15801599 inputs = inputs .to (model .device )
15811600 request_id = req .get ("previous_response_id" , "req_0" )
15821601
@@ -1775,11 +1794,22 @@ def _load_model_and_data_processor(self, model_id_and_revision: str):
17751794 else :
17761795 model_id , revision = model_id_and_revision , "main"
17771796
1778- data_processor = AutoProcessor .from_pretrained (
1779- model_id ,
1780- revision = revision ,
1781- trust_remote_code = self .trust_remote_code ,
1782- )
1797+ try :
1798+ data_processor = AutoProcessor .from_pretrained (
1799+ model_id ,
1800+ revision = revision ,
1801+ trust_remote_code = self .trust_remote_code ,
1802+ )
1803+ except OSError :
1804+ try :
1805+ data_processor = AutoTokenizer .from_pretrained (
1806+ model_id ,
1807+ revision = revision ,
1808+ trust_remote_code = self .trust_remote_code ,
1809+ )
1810+ except OSError :
1811+ raise OSError ("Failed to load processor with `AutoProcessor` and `AutoTokenizer`." )
1812+
17831813 dtype = self .dtype if self .dtype in ["auto" , None ] else getattr (torch , self .dtype )
17841814 quantization_config = self .get_quantization_config ()
17851815
0 commit comments