-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[BugFix] fix: skip check unstreamed tool arg tokens when tool call name is present #27806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a bug in tool call streaming. The fix ensures that the logic for handling unstreamed tool arguments is only triggered when the tool call name is not present in the streamed delta. This is achieved by adding a check for function.name in _should_check_for_unstreamed_tool_arg_tokens. The accompanying docstring update clearly explains the new behavior. The change is well-targeted and improves the robustness of tool call parsing. I find no issues with the implementation.
…sent Signed-off-by: Sungjae Lee <[email protected]>
c95da58 to
74af711
Compare
|
@K-Mistele @bbrowning |
|
@llsj14 Do you have an example or a test so I can see the before/after difference here? The code change is fairly small and seems reasonable on the surface. How can I compare the behavior that exhibited this bug before that is now fixed? |
|
I used our in-house custom reasoning/tool calling parsers, so sharing an open-source example was difficult. An open source parser likely to show the same issue have other bugs, making a clean demonstration hard. as-is) to-be) Added |
|
Thanks for that additional detail. Because I don't have your tool call parser to test locally, I tested this with the following unit test addition: diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py
index d1367b4ee..b56d53b2f 100644
--- a/tests/entrypoints/openai/test_serving_chat.py
+++ b/tests/entrypoints/openai/test_serving_chat.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
+import json
from contextlib import suppress
from dataclasses import dataclass, field
from typing import Any
@@ -10,10 +11,12 @@ import pytest
import pytest_asyncio
from openai import OpenAI
+from vllm import CompletionOutput, RequestOutput
from vllm.config.multimodal import MultiModalConfig
-from vllm.entrypoints.openai.protocol import ChatCompletionRequest
+from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
+from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
@@ -366,7 +369,7 @@ class MockModelConfig:
return self.diff_sampling_param or {}
-def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
+def _build_serving_chat(engine: AsyncLLM, **kwargs) -> OpenAIServingChat:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
@@ -378,6 +381,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None,
+ **kwargs,
)
async def _fake_process_inputs(
@@ -651,3 +655,91 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt"
+
+
+@pytest.mark.asyncio
+async def test_serving_chat_streaming_full_tool_calls():
+ """Test that we can stream back full tool calls at once from tool call parsers."""
+ mock_model_config = MockModelConfig()
+
+ mock_engine = MagicMock(spec=AsyncLLM)
+ mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
+ mock_engine.errored = False
+ mock_engine.model_config = mock_model_config
+ mock_engine.processor = MagicMock()
+ mock_engine.io_processor = MagicMock()
+
+ async def mock_generator():
+ yield RequestOutput(
+ request_id="req-123",
+ prompt="",
+ prompt_token_ids=[],
+ prompt_logprobs=None,
+ outputs=[
+ CompletionOutput(
+ index=0,
+ text="foo",
+ token_ids=[1, 2, 3],
+ cumulative_logprob=None,
+ logprobs=None,
+ finish_reason="tool_calls",
+ ),
+ ],
+ finished=True,
+ )
+
+ mock_engine.generate.return_value = mock_generator()
+
+ tool_call = {
+ "id": "tool_123",
+ "index": 0,
+ "type": "function",
+ "function": {
+ "name": "foo",
+ "arguments": "{}",
+ },
+ }
+
+ # Mock tool parser that returns the full tool call at once when streaming
+ class MockToolParser(ToolParser):
+ def extract_tool_calls_streaming(self, *args, **kwargs):
+ tool_call_arr = [tool_call]
+ self.prev_tool_call_arr = tool_call_arr
+ self.streamed_args_for_tool = ["{}"]
+ return DeltaMessage(
+ role="assistant",
+ content=None,
+ tool_calls=tool_call_arr,
+ )
+
+ tool_parser = lambda _: MockToolParser(None)
+
+ serving_chat = _build_serving_chat(
+ mock_engine, enable_auto_tools=True, tool_parser="llama3_json"
+ )
+ # Override the real tool parser with our mock one
+ serving_chat.tool_parser = tool_parser
+
+ req = ChatCompletionRequest(
+ model=MODEL_NAME,
+ messages=[{"role": "user", "content": "what is 1+1?"}],
+ tools=[
+ {
+ "type": "function",
+ "function": {"name": "foo", "description": "foo", "parameters": {}},
+ }
+ ],
+ tool_choice="auto",
+ stream=True,
+ )
+
+ response = await serving_chat.create_chat_completion(req)
+ chunks = [chunk async for chunk in response]
+
+ assert len(chunks) > 1
+ tool_call_chunk = json.loads(chunks[1].removeprefix("data: "))
+ tool_call_choices = tool_call_chunk["choices"]
+ assert tool_call_choices
+ delta_tool_calls = tool_call_choices[0]["delta"]["tool_calls"]
+ assert delta_tool_calls
+ assert delta_tool_calls[0] == tool_callThat added test fails before your change and passes after your change. Does my I haven't run full integration tests locally with your change, and don't have the power to trigger CI to test with your change to ensure it doesn't have any other unintended consequences. Nor can I approve it. But, it gets a thumbs up from me! |
Purpose
_should_check_for_unstreamed_tool_arg_tokensfunction should verify that the current delta doesn't contain a name before entering this branch, since this logic is specifically designed to handle cases where only arguments are being streamed incrementally, not complete tool calls with both name and argumentsvllm/vllm/entrypoints/openai/serving_chat.py
Lines 1157 to 1166 in af826e0
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.