Skip to content

Conversation

@llsj14
Copy link
Contributor

@llsj14 llsj14 commented Oct 30, 2025

Purpose

  • In most cases tool calling parsers support incremental stream responses
  • However, if the tool call parser doesn't support incremental tool call stream response, it can cause a bug that omits the function name in the OpenAI chat completion API
  • Therefore, the _should_check_for_unstreamed_tool_arg_tokens function should verify that the current delta doesn't contain a name before entering this branch, since this logic is specifically designed to handle cases where only arguments are being streamed incrementally, not complete tool calls with both name and arguments
  • If tool call parser doesn't support incremental stream responses, due to the following part, it will omit function name.
    delta_message = DeltaMessage(
    tool_calls=[
    DeltaToolCall(
    index=index,
    function=DeltaFunctionCall(
    arguments=remaining_call
    ).model_dump(exclude_none=True),
    )
    ]
    )

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a bug in tool call streaming. The fix ensures that the logic for handling unstreamed tool arguments is only triggered when the tool call name is not present in the streamed delta. This is achieved by adding a check for function.name in _should_check_for_unstreamed_tool_arg_tokens. The accompanying docstring update clearly explains the new behavior. The change is well-targeted and improves the robustness of tool call parsing. I find no issues with the implementation.

@llsj14 llsj14 force-pushed the fix/stream-tool-response branch from c95da58 to 74af711 Compare October 30, 2025 10:16
@llsj14
Copy link
Contributor Author

llsj14 commented Oct 30, 2025

@K-Mistele @bbrowning
Could you also review this PR? Although not all tool call parsers require this fix, it resolves function name omission issues in custom parsers that lack incremental streaming support.

@bbrowning
Copy link
Contributor

@llsj14 Do you have an example or a test so I can see the before/after difference here? The code change is fairly small and seems reasonable on the surface. How can I compare the behavior that exhibited this bug before that is now fixed?

@llsj14
Copy link
Contributor Author

llsj14 commented Oct 31, 2025

I used our in-house custom reasoning/tool calling parsers, so sharing an open-source example was difficult. An open source parser likely to show the same issue have other bugs, making a clean demonstration hard.
However, when a tool calling parser doesn’t support incremental stream responses and sends a complete tool call at once, the function name can be omitted as shown below.

as-is)
"tool_calls":[{"index":0,"function":{"arguments":"{\"location\": \"Seoul, South Korea\", \"format\": \"celsius\"}"}}]

to-be)
"tool_calls":[{"index":0,"function":{"name":"get_current_weather","arguments":"{\"location\": \"Seoul, South Korea\", \"format\": \"celsius\"}"}}]

Added and not delta_message.tool_calls[0].function.name to _should_check_for_unstreamed_tool_arg_tokens so this branch only runs when the name isn’t already in the delta. This prevents name omission when parsers send complete tool calls at once.

@bbrowning
Copy link
Contributor

Thanks for that additional detail. Because I don't have your tool call parser to test locally, I tested this with the following unit test addition:

diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py
index d1367b4ee..b56d53b2f 100644
--- a/tests/entrypoints/openai/test_serving_chat.py
+++ b/tests/entrypoints/openai/test_serving_chat.py
@@ -1,6 +1,7 @@
 # SPDX-License-Identifier: Apache-2.0
 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 import asyncio
+import json
 from contextlib import suppress
 from dataclasses import dataclass, field
 from typing import Any
@@ -10,10 +11,12 @@ import pytest
 import pytest_asyncio
 from openai import OpenAI
 
+from vllm import CompletionOutput, RequestOutput
 from vllm.config.multimodal import MultiModalConfig
-from vllm.entrypoints.openai.protocol import ChatCompletionRequest
+from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
 from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
 from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
+from vllm.entrypoints.openai.tool_parsers import ToolParser
 from vllm.transformers_utils.tokenizer import get_tokenizer
 from vllm.v1.engine.async_llm import AsyncLLM
 
@@ -366,7 +369,7 @@ class MockModelConfig:
         return self.diff_sampling_param or {}
 
 
-def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
+def _build_serving_chat(engine: AsyncLLM, **kwargs) -> OpenAIServingChat:
     models = OpenAIServingModels(
         engine_client=engine,
         base_model_paths=BASE_MODEL_PATHS,
@@ -378,6 +381,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
         chat_template=CHAT_TEMPLATE,
         chat_template_content_format="auto",
         request_logger=None,
+        **kwargs,
     )
 
     async def _fake_process_inputs(
@@ -651,3 +655,91 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
         await serving_chat.create_chat_completion(req)
     engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
     assert engine_prompt.get("cache_salt") == "test_salt"
+
+
+@pytest.mark.asyncio
+async def test_serving_chat_streaming_full_tool_calls():
+    """Test that we can stream back full tool calls at once from tool call parsers."""
+    mock_model_config = MockModelConfig()
+
+    mock_engine = MagicMock(spec=AsyncLLM)
+    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
+    mock_engine.errored = False
+    mock_engine.model_config = mock_model_config
+    mock_engine.processor = MagicMock()
+    mock_engine.io_processor = MagicMock()
+
+    async def mock_generator():
+        yield RequestOutput(
+            request_id="req-123",
+            prompt="",
+            prompt_token_ids=[],
+            prompt_logprobs=None,
+            outputs=[
+                CompletionOutput(
+                    index=0,
+                    text="foo",
+                    token_ids=[1, 2, 3],
+                    cumulative_logprob=None,
+                    logprobs=None,
+                    finish_reason="tool_calls",
+                ),
+            ],
+            finished=True,
+        )
+
+    mock_engine.generate.return_value = mock_generator()
+
+    tool_call = {
+        "id": "tool_123",
+        "index": 0,
+        "type": "function",
+        "function": {
+            "name": "foo",
+            "arguments": "{}",
+        },
+    }
+
+    # Mock tool parser that returns the full tool call at once when streaming
+    class MockToolParser(ToolParser):
+        def extract_tool_calls_streaming(self, *args, **kwargs):
+            tool_call_arr = [tool_call]
+            self.prev_tool_call_arr = tool_call_arr
+            self.streamed_args_for_tool = ["{}"]
+            return DeltaMessage(
+                role="assistant",
+                content=None,
+                tool_calls=tool_call_arr,
+            )
+
+    tool_parser = lambda _: MockToolParser(None)
+
+    serving_chat = _build_serving_chat(
+        mock_engine, enable_auto_tools=True, tool_parser="llama3_json"
+    )
+    # Override the real tool parser with our mock one
+    serving_chat.tool_parser = tool_parser
+
+    req = ChatCompletionRequest(
+        model=MODEL_NAME,
+        messages=[{"role": "user", "content": "what is 1+1?"}],
+        tools=[
+            {
+                "type": "function",
+                "function": {"name": "foo", "description": "foo", "parameters": {}},
+            }
+        ],
+        tool_choice="auto",
+        stream=True,
+    )
+
+    response = await serving_chat.create_chat_completion(req)
+    chunks = [chunk async for chunk in response]
+
+    assert len(chunks) > 1
+    tool_call_chunk = json.loads(chunks[1].removeprefix("data: "))
+    tool_call_choices = tool_call_chunk["choices"]
+    assert tool_call_choices
+    delta_tool_calls = tool_call_choices[0]["delta"]["tool_calls"]
+    assert delta_tool_calls
+    assert delta_tool_calls[0] == tool_call

That added test fails before your change and passes after your change. Does my MockToolParser there look roughly equivalent to what you're describing? A streaming tool call parser that only returns entire tool calls at once as opposed to streaming the arguments out incrementally? If so, you're welcome to use that code or similar to add a test for this in the PR, as without some kind of test we can't prevent accidentally regressing on this behavior in the future.

I haven't run full integration tests locally with your change, and don't have the power to trigger CI to test with your change to ensure it doesn't have any other unintended consequences. Nor can I approve it. But, it gets a thumbs up from me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants