Skip to content

Commit 52b988d

Browse files
authored
Transformers serve fix (#42570)
* Fix: lacking EOS token + failing AutoProcessor * Tests * Tests
1 parent 3f17410 commit 52b988d

File tree

2 files changed

+59
-20
lines changed

2 files changed

+59
-20
lines changed

src/transformers/cli/serve.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from tqdm import tqdm
3737

3838
import transformers
39-
from transformers import BitsAndBytesConfig, GenerationConfig
39+
from transformers import AutoTokenizer, BitsAndBytesConfig, GenerationConfig, PreTrainedTokenizerBase
4040
from transformers.utils.import_utils import (
4141
is_fastapi_available,
4242
is_librosa_available,
@@ -823,9 +823,9 @@ def continuous_batching_chat_completion(self, req: dict, request_id: str) -> Str
823823
self.running_continuous_batching_manager.start()
824824

825825
# TODO (Joao, Lysandre): this should also work with tool support
826-
inputs = processor.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to(
827-
model.device
828-
)["input_ids"][0]
826+
inputs = processor.apply_chat_template(
827+
req["messages"], return_tensors="pt", add_generation_prompt=True, return_dict=True
828+
).to(model.device)["input_ids"][0]
829829

830830
def stream_chat_completion(request_id, decode_stream):
831831
from ..generation.continuous_batching import RequestStatus
@@ -841,8 +841,13 @@ def stream_chat_completion(request_id, decode_stream):
841841

842842
if result.status == RequestStatus.FINISHED:
843843
generated_all_tokens = n_tokens_generated >= generation_config.max_new_tokens
844-
final_token_is_eos = result == tokenizer.eos_token
845-
reason = "length" if (generated_all_tokens and not final_token_is_eos) else "stop"
844+
845+
# If the tokenizer has an eos_token, we can have a more robust check.
846+
if hasattr(tokenizer, "eos_token"):
847+
final_token_is_eos = result == tokenizer.eos_token
848+
generated_all_tokens = generated_all_tokens and not final_token_is_eos
849+
850+
reason = "length" if generated_all_tokens else "stop"
846851

847852
yield self.build_chat_completion_chunk(
848853
request_id,
@@ -921,7 +926,11 @@ def cancellation_wrapper_buffer(_request_id):
921926
return JSONResponse(json_chunk, media_type="application/json")
922927

923928
@staticmethod
924-
def get_model_modality(model: "PreTrainedModel") -> Modality:
929+
def get_model_modality(model: "PreTrainedModel", processor=None) -> Modality:
930+
if processor is not None:
931+
if isinstance(processor, PreTrainedTokenizerBase):
932+
return Modality.LLM
933+
925934
from transformers.models.auto.modeling_auto import (
926935
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
927936
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
@@ -1011,7 +1020,7 @@ def generate_chat_completion(self, req: dict) -> StreamingResponse | JSONRespons
10111020
self.last_model = model_id_and_revision
10121021
model, processor = self.load_model_and_processor(model_id_and_revision)
10131022

1014-
modality = self.get_model_modality(model)
1023+
modality = self.get_model_modality(model, processor=processor)
10151024
processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality)
10161025

10171026
# ====== TOOL PREPROCESSING LOGIC ======
@@ -1184,8 +1193,14 @@ def generate_with_cache(**kwargs):
11841193
)
11851194

11861195
generated_all_tokens = n_tokens_generated >= generation_config.max_new_tokens
1187-
final_token_is_eos = result == streamer.tokenizer.eos_token
1188-
reason = "length" if (generated_all_tokens and not final_token_is_eos) else "stop"
1196+
1197+
# If the tokenizer has an eos_token, we can have a more robust check.
1198+
if hasattr(streamer.tokenizer, "eos_token"):
1199+
final_token_is_eos = result == streamer.tokenizer.eos_token
1200+
generated_all_tokens = generated_all_tokens and not final_token_is_eos
1201+
1202+
reason = "length" if generated_all_tokens else "stop"
1203+
11891204
yield self.build_chat_completion_chunk(_request_id, finish_reason=reason, model=model_id_and_revision)
11901205

11911206
thread.join()
@@ -1272,7 +1287,9 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
12721287
else:
12731288
raise TypeError("inputs should be a list, dict, or str")
12741289

1275-
inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt")["input_ids"]
1290+
inputs = processor.apply_chat_template(
1291+
inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True
1292+
)["input_ids"]
12761293
inputs = inputs.to(model.device)
12771294
request_id = req.get("previous_response_id", "req_0")
12781295

@@ -1576,7 +1593,9 @@ def generate_response_non_streaming(self, req: dict) -> dict:
15761593
else:
15771594
raise ValueError("inputs should be a list, dict, or str")
15781595

1579-
inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt")["input_ids"]
1596+
inputs = processor.apply_chat_template(
1597+
inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True
1598+
)["input_ids"]
15801599
inputs = inputs.to(model.device)
15811600
request_id = req.get("previous_response_id", "req_0")
15821601

@@ -1775,11 +1794,22 @@ def _load_model_and_data_processor(self, model_id_and_revision: str):
17751794
else:
17761795
model_id, revision = model_id_and_revision, "main"
17771796

1778-
data_processor = AutoProcessor.from_pretrained(
1779-
model_id,
1780-
revision=revision,
1781-
trust_remote_code=self.trust_remote_code,
1782-
)
1797+
try:
1798+
data_processor = AutoProcessor.from_pretrained(
1799+
model_id,
1800+
revision=revision,
1801+
trust_remote_code=self.trust_remote_code,
1802+
)
1803+
except OSError:
1804+
try:
1805+
data_processor = AutoTokenizer.from_pretrained(
1806+
model_id,
1807+
revision=revision,
1808+
trust_remote_code=self.trust_remote_code,
1809+
)
1810+
except OSError:
1811+
raise OSError("Failed to load processor with `AutoProcessor` and `AutoTokenizer`.")
1812+
17831813
dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype)
17841814
quantization_config = self.get_quantization_config()
17851815

tests/cli/test_serve.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def test_generation_config_in_request(self):
329329

330330
def test_early_return_due_to_length(self):
331331
request = {
332-
"model": "Qwen/Qwen3-0.6B",
332+
"model": "Qwen/Qwen2.5-0.5B-Instruct",
333333
"messages": [{"role": "user", "content": "Hello, how are you?"}],
334334
"stream": True,
335335
"max_tokens": 3,
@@ -339,8 +339,17 @@ def test_early_return_due_to_length(self):
339339
last_payload = all_payloads[-1]
340340
self.assertTrue(last_payload.choices[0]["finish_reason"] == "length")
341341

342-
# TODO: one test for each request flag, to confirm it is working as expected
343-
# TODO: speed-based test to confirm that KV cache is working across requests
342+
def test_continues_until_stop(self):
343+
request = {
344+
"model": "Qwen/Qwen2.5-0.5B-Instruct",
345+
"messages": [{"role": "user", "content": 'Please only answer with "Hi."'}],
346+
"stream": True,
347+
"max_tokens": 30,
348+
}
349+
350+
all_payloads = self.run_server(request)
351+
last_payload = all_payloads[-1]
352+
self.assertTrue(last_payload.choices[0]["finish_reason"] == "stop")
344353

345354

346355
class ServeCompletionsGenerateMockTests(unittest.TestCase):

0 commit comments

Comments
 (0)