Skip to content

Commit ba2b5ff

Browse files
aarnphmeicherseiji
authored andcommitted
[Fix] [gpt-oss] fix non-tool calling path for chat completion (vllm-project#24324)
1 parent 2b6e1d2 commit ba2b5ff

File tree

2 files changed

+83
-38
lines changed

2 files changed

+83
-38
lines changed

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,41 @@ def monkeypatch_module():
3636
mpatch.undo()
3737

3838

39+
@pytest.fixture(scope="module",
40+
params=[True, False],
41+
ids=["with_tool_parser", "without_tool_parser"])
42+
def with_tool_parser(request) -> bool:
43+
return request.param
44+
45+
3946
@pytest.fixture(scope="module")
40-
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
41-
with monkeypatch_module.context() as m:
42-
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
43-
args = [
44-
"--enforce-eager",
45-
"--max-model-len",
46-
"8192",
47+
def default_server_args(with_tool_parser: bool):
48+
args = [
49+
# use half precision for speed and memory savings in CI environment
50+
"--enforce-eager",
51+
"--max-model-len",
52+
"4096",
53+
"--reasoning-parser",
54+
"openai_gptoss",
55+
"--gpu-memory-utilization",
56+
"0.8",
57+
]
58+
if with_tool_parser:
59+
args.extend([
4760
"--tool-call-parser",
4861
"openai",
49-
"--reasoning-parser",
50-
"openai_gptoss",
5162
"--enable-auto-tool-choice",
52-
]
53-
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
63+
])
64+
return args
65+
66+
67+
@pytest.fixture(scope="module")
68+
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
69+
default_server_args: list[str]):
70+
with monkeypatch_module.context() as m:
71+
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
72+
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
73+
default_server_args) as remote_server:
5474
yield remote_server
5575

5676

@@ -61,7 +81,8 @@ async def gptoss_client(gptoss_server):
6181

6282

6383
@pytest.mark.asyncio
64-
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
84+
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
85+
with_tool_parser: bool):
6586
tools = [{
6687
"type": "function",
6788
"function": {
@@ -94,10 +115,14 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
94115
]
95116

96117
stream = await gptoss_client.chat.completions.create(
97-
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
118+
model=GPT_OSS_MODEL_NAME,
119+
messages=messages,
120+
tools=tools if with_tool_parser else None,
121+
stream=True)
98122

99123
name = None
100124
args_buf = ""
125+
content_buf = ""
101126
async for chunk in stream:
102127
delta = chunk.choices[0].delta
103128
if delta.tool_calls:
@@ -106,13 +131,22 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
106131
name = tc.function.name
107132
if tc.function and tc.function.arguments:
108133
args_buf += tc.function.arguments
109-
110-
assert name is not None
111-
assert len(args_buf) > 0
134+
if getattr(delta, "content", None):
135+
content_buf += delta.content
136+
if with_tool_parser:
137+
assert name is not None
138+
assert len(args_buf) > 0
139+
else:
140+
assert name is None
141+
assert len(args_buf) == 0
142+
assert len(content_buf) > 0
112143

113144

114145
@pytest.mark.asyncio
115-
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
146+
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
147+
with_tool_parser: bool):
148+
if not with_tool_parser:
149+
pytest.skip("skip non-tool for multi-turn tests")
116150
tools = [{
117151
"type": "function",
118152
"function": {
@@ -175,7 +209,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
175209
)
176210
second_msg = second.choices[0].message
177211
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
178-
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
212+
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0)
179213

180214

181215
MODEL_NAME = "openai-community/gpt2"

vllm/entrypoints/openai/serving_chat.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77
from collections.abc import AsyncGenerator, AsyncIterator
88
from collections.abc import Sequence as GenericSequence
9-
from typing import TYPE_CHECKING, Callable, Final, Optional, Union
9+
from typing import Callable, Final, Optional, Union
1010

1111
import jinja2
1212
import partial_json_parser
@@ -1174,6 +1174,7 @@ async def chat_completion_full_generator(
11741174
for output in final_res.outputs:
11751175
token_ids = output.token_ids
11761176
out_logprobs = output.logprobs
1177+
tool_call_info = None
11771178

11781179
if request.logprobs and request.top_logprobs is not None:
11791180
assert out_logprobs is not None, "Did not output logprobs"
@@ -1188,32 +1189,42 @@ async def chat_completion_full_generator(
11881189
logprobs = None
11891190

11901191
if self.use_harmony:
1191-
if TYPE_CHECKING:
1192-
assert self.tool_parser is not None
1193-
tool_parser = self.tool_parser(tokenizer)
1194-
# NOTE: We use token_ids for openai tool parser
1195-
tool_call_info = tool_parser.extract_tool_calls(
1196-
"",
1197-
request=request,
1198-
token_ids=token_ids, # type: ignore
1199-
)
1200-
reasoning_content, content = None, tool_call_info.content
1201-
if request.include_reasoning:
1192+
if self.tool_parser is not None:
1193+
tool_parser = self.tool_parser(tokenizer)
1194+
# NOTE: We use token_ids for openai tool parser
1195+
tool_call_info = tool_parser.extract_tool_calls(
1196+
"",
1197+
request=request,
1198+
token_ids=token_ids, # type: ignore
1199+
)
1200+
reasoning_content, content = None, tool_call_info.content
1201+
if request.include_reasoning:
1202+
reasoning_content, content, _ = parse_chat_output(
1203+
token_ids)
1204+
message = ChatMessage(
1205+
role=role,
1206+
reasoning_content=reasoning_content,
1207+
content=content,
1208+
tool_calls=tool_call_info.tool_calls,
1209+
)
1210+
else:
12021211
reasoning_content, content, _ = parse_chat_output(
12031212
token_ids)
1204-
message = ChatMessage(
1205-
role=role,
1206-
reasoning_content=reasoning_content,
1207-
content=content,
1208-
tool_calls=tool_call_info.tool_calls,
1209-
)
1213+
if not request.include_reasoning:
1214+
reasoning_content = None
1215+
message = ChatMessage(
1216+
role=role,
1217+
reasoning_content=reasoning_content,
1218+
content=content,
1219+
)
12101220

12111221
choice_data = ChatCompletionResponseChoice(
12121222
index=output.index,
12131223
message=message,
12141224
logprobs=logprobs,
1215-
finish_reason="tool_calls"
1216-
if tool_call_info.tools_called else
1225+
finish_reason="tool_calls" if
1226+
(tool_call_info is not None
1227+
and tool_call_info.tools_called) else
12171228
output.finish_reason if output.finish_reason else "stop",
12181229
stop_reason=output.stop_reason,
12191230
)

0 commit comments

Comments
 (0)