diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py new file mode 100644 index 000000000000..5e6a4c85ff79 --- /dev/null +++ b/tests/entrypoints/test_context.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock, patch + +import pytest +from openai_harmony import StreamState + +from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext +from vllm.outputs import CompletionOutput, RequestOutput + + +# Helper function for Python < 3.10 compatibility +async def async_next(async_iterator): + """Compatibility function equivalent to Python 3.10's anext().""" + return await async_iterator.__anext__() + + +def create_mock_request_output( + prompt_token_ids=None, + output_token_ids=None, + num_cached_tokens=0, + finished=True, +): + """Helper function to create a mock RequestOutput object for testing.""" + outputs = [] + token_ids = output_token_ids if output_token_ids is not None else [] + outputs = [ + CompletionOutput( + index=0, + text="Test output", + token_ids=token_ids, + cumulative_logprob=0.0, + logprobs=None, + finish_reason=None, + stop_reason=None, + ) + ] + + return RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + outputs=outputs, + finished=finished, + num_cached_tokens=num_cached_tokens, + ) + + +async def generate_mock_outputs(num_turns, + prompt_token_counts, + output_token_counts, + cached_token_counts=None): + """Generate a sequence of mock RequestOutput objects to simulate multiple + turns.""" + if cached_token_counts is None: + cached_token_counts = [0] * num_turns + + for i in range(num_turns): + # Create mock prompt token IDs and output token IDs + prompt_token_ids = list(range(1, prompt_token_counts[i] + 1)) + output_token_ids = list(range(1, output_token_counts[i] + 1)) + + # Create and yield the RequestOutput + yield create_mock_request_output( + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, + num_cached_tokens=cached_token_counts[i], + ) + + +@pytest.fixture +def mock_parser(): + """Set up a mock parser for tests.""" + with patch("vllm.entrypoints.context.get_streamable_parser_for_assistant" + ) as mock_parser_factory: + # Create a mock parser object + parser = MagicMock() + parser.messages = [] + parser.current_channel = None + parser.state = StreamState.EXPECT_START + mock_parser_factory.return_value = parser + yield parser + + +def test_single_turn_token_counting(): + """Test token counting behavior for a single turn.""" + # Create a context + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a mock RequestOutput with specific token counts + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3, 4, 5], # 5 prompt tokens + output_token_ids=[6, 7, 8], # 3 output tokens + num_cached_tokens=2, # 2 cached tokens + ) + + # Append the output to the context + context.append_output(mock_output) + + # Verify the token counts + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 + assert context.num_cached_tokens == 2 + assert context.num_tool_output_tokens == 0 # No tool tokens in first turn + + # Verify internal state tracking + assert not context.is_first_turn + assert context.previous_turn.input_tokens == 5 + assert context.previous_turn.output_tokens == 3 + + +@pytest.mark.asyncio +async def test_multi_turn_token_counting(): + """Test token counting behavior across multiple turns with tool output.""" + # Create a context + context = HarmonyContext(messages=[], available_tools=["browser"]) + + # Simulate a conversation with 3 turns + # Turn 1: prefill 5, decode 3, tool 7 + # Turn 2: prefill 15, cached 5, decode 4, tool 1 + # Turn 3: prefill 20, cached 15, decode 5 + prompt_token_counts = [5, 15, 20] + output_token_counts = [3, 4, 5] + cached_token_counts = [0, 5, 15] + mock_generator = generate_mock_outputs(3, prompt_token_counts, + output_token_counts, + cached_token_counts) + + # First turn - initial prompt and response + mock_output1 = await async_next(mock_generator) + context.append_output(mock_output1) + + # At this point, we should have 5 prompt tokens and 3 output tokens + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 + assert context.num_tool_output_tokens == 0 + + # Second turn - after tool output + mock_output2 = await async_next(mock_generator) + context.append_output(mock_output2) + # Current prompt tokens (15) - last_turn_input_tokens (5) - + # last_turn_output_tokens (3) = 7 + expected_tool_output = 7 + + assert context.num_prompt_tokens == 5 + 15 + assert context.num_output_tokens == 3 + 4 + assert context.num_tool_output_tokens == expected_tool_output + assert context.num_cached_tokens == 5 + + # Third turn - final response + mock_output3 = await async_next(mock_generator) + context.append_output(mock_output3) + # Additional tool output tokens from third turn: + # Current prompt (20) - last_turn_input_tokens (15) - + # last_turn_output_tokens (4) = 1 + expected_tool_output = 7 + 1 + + assert context.num_prompt_tokens == 5 + 15 + 20 + assert context.num_output_tokens == 3 + 4 + 5 + assert context.num_tool_output_tokens == expected_tool_output + assert context.num_cached_tokens == 5 + 15 + + +def test_empty_output_tokens(): + """Test behavior when RequestOutput has empty output tokens.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a RequestOutput with empty output tokens + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3], # 3 prompt tokens + output_token_ids=[], # Empty output tokens list + num_cached_tokens=1, + ) + + context.append_output(mock_output) + + # Should handle empty outputs gracefully + assert context.num_prompt_tokens == 3 + assert context.num_output_tokens == 0 # No output tokens + assert context.num_cached_tokens == 1 + assert context.num_tool_output_tokens == 0 + + +def test_missing_prompt_token_ids(): + """Test behavior when RequestOutput has None prompt_token_ids.""" + context = HarmonyContext(messages=[], available_tools=[]) + + mock_output = create_mock_request_output( + prompt_token_ids=None, # No prompt token IDs + output_token_ids=[1, 2], # 2 output tokens + num_cached_tokens=0, + ) + + # Logger.error will be called, but we don't need to check for warnings + # here Just ensure it doesn't raise an exception + context.append_output(mock_output) + + # Should handle missing prompt tokens gracefully + assert context.num_prompt_tokens == 0 + assert context.num_output_tokens == 2 + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 + + +def test_reasoning_tokens_counting(mock_parser): + """Test that reasoning tokens are counted correctly.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Mock parser to simulate reasoning channel + mock_parser.current_channel = "analysis" # Reasoning channel + + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3], + output_token_ids=[4, 5, 6, 7], # 4 tokens, all in reasoning + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # All output tokens should be counted as reasoning + assert context.num_reasoning_tokens == 4 + assert context.num_output_tokens == 4 + + +def test_zero_tokens_edge_case(): + """Test behavior with all zero token counts.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a request with empty lists (not None) for both prompt and + # output tokens + mock_output = create_mock_request_output( + prompt_token_ids=[], # Empty prompt tokens + output_token_ids=[], # Empty output tokens + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # All counts should be zero + assert context.num_prompt_tokens == 0 + assert context.num_output_tokens == 0 + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 + assert context.num_reasoning_tokens == 0 + + +@pytest.mark.asyncio +async def test_single_turn_no_tool_output(): + """Test that first turn never generates tool output tokens.""" + context = HarmonyContext( + messages=[], + available_tools=["browser"] # Tools available + ) + + # Even with large prompt in first turn, no tool tokens should be counted + mock_output = create_mock_request_output( + prompt_token_ids=list(range(100)), # 100 tokens + output_token_ids=[1, 2, 3], + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # First turn should never have tool output tokens + assert context.num_tool_output_tokens == 0 + assert context.is_first_turn is False # Should be updated after first turn + + +@pytest.mark.asyncio +async def test_negative_tool_tokens_edge_case(): + """Test edge case where calculation could result in negative tool + tokens. We should log an error and clamp the value to 0.""" + # Use patch to check if logger.error was called + with patch("vllm.entrypoints.context.logger.error") as mock_log: + context = HarmonyContext(messages=[], available_tools=["browser"]) + + # First turn + mock_output1 = create_mock_request_output( + prompt_token_ids=list(range(10)), # 10 tokens + output_token_ids=[1, 2, 3, 4, 5], # 5 tokens + ) + context.append_output(mock_output1) + + # Second turn with fewer new tokens than previous output + # This could happen in edge cases with aggressive caching + mock_output2 = create_mock_request_output( + prompt_token_ids=list(range(12)), # 12 tokens (only 2 new) + output_token_ids=[6, 7], # 2 tokens + ) + context.append_output(mock_output2) + + # Calculated negative tool tokens (12 - 10 - 5 = -3) should be clamped + # to 0 and an error should be logged + assert context.num_tool_output_tokens == 0 + assert context.num_prompt_tokens == 10 + 12 + assert context.num_output_tokens == 5 + 2 + + # Verify the error was logged properly + mock_log.assert_called_once() + + # Extract the actual log message and arguments from the call + args, _ = mock_log.call_args + log_message = args[0] + + # Check for key parts of the message + assert "Negative tool output tokens calculated" in log_message + assert "-3" in str(args) # Check that -3 is in the arguments + + +@pytest.mark.asyncio +async def test_streaming_multi_turn_token_counting(mock_parser): + """Test token counting for streaming multi-turn conversations. + + This test focuses on how StreamingHarmonyContext counts tokens in a + multi-turn conversation with streaming (token-by-token) outputs and + message boundaries. + """ + # Create a streaming context + context = StreamingHarmonyContext(messages=[], available_tools=["browser"]) + + # Simulate three turns of conversation: + # Turn 1: stream tokens one by one, then finish the message + # Turn 2: new prompt, stream more tokens with a reasoning segment + # Turn 3: new prompt with tool output and cached tokens + + # First turn: 3 tokens streamed one by one + # First token of first turn + context.append_output( + create_mock_request_output( + prompt_token_ids=[1, 2, 3], # 3 prompt tokens + output_token_ids=[101], # Single token + num_cached_tokens=0, + finished=False, # Not end of message yet + )) + + # Second token of first turn + context.append_output( + create_mock_request_output( + output_token_ids=[102], + finished=False, + )) + + # Last token of first turn (finished=True signals end of message) + context.append_output( + create_mock_request_output( + output_token_ids=[103], + finished=True, # End of message + )) + + # Check token counts after first turn + assert context.num_prompt_tokens == 3 # Initial prompt tokens + assert context.num_output_tokens == 3 # Three output tokens + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 # No tool output in first turn + assert context.first_tok_of_message is True # Ready for next message + + # Second turn: reasoning tokens in analysis channel + mock_parser.current_channel = "analysis" # Set to reasoning channel + + # First token of second turn + context.append_output( + create_mock_request_output( + prompt_token_ids=[1, 2, 3, 101, 102, 103, 4, + 5], # 8 tokens (includes previous) + output_token_ids=[201], + num_cached_tokens=3, # Some tokens cached + finished=False, + )) + + # More tokens in reasoning channel + context.append_output( + create_mock_request_output( + output_token_ids=[202], + finished=False, + )) + + context.append_output( + create_mock_request_output( + output_token_ids=[203], + finished=True, # End of reasoning message + )) + + # Check counts after second turn (reasoning message) + assert context.num_prompt_tokens == 3 + 8 # Initial + second prompt + assert context.num_output_tokens == 3 + 3 # First turn + second turn + assert context.num_reasoning_tokens == 3 # All tokens in analysis channel + assert context.num_cached_tokens == 3 # Cached tokens from second turn + + # Formula: this turn prompt tokens - last turn prompt - last turn output + expected_tool_tokens = 8 - 3 - 3 # = 2 + assert context.num_tool_output_tokens == expected_tool_tokens + + # Third turn: regular output channel + mock_parser.current_channel = "final" # Switch back to regular channel + + # Third turn (with more cached tokens) + context.append_output( + create_mock_request_output( + prompt_token_ids=[ + 1, 2, 3, 101, 102, 103, 4, 5, 201, 202, 203, 6, 7 + ], # 13 tokens + output_token_ids=[301], + num_cached_tokens=8, # More cached tokens + finished=False, + )) + + context.append_output( + create_mock_request_output( + output_token_ids=[302], + finished=True, + )) + + # Final token counts check + assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts + assert context.num_output_tokens == 3 + 3 + 2 # All outputs + assert context.num_reasoning_tokens == 3 # Unchanged from second turn + assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens + + # Additional tool tokens from third turn + # Formula: this turn prompt - last turn prompt - last turn output + additional_tool_tokens = 13 - 8 - 3 # = 2 + assert context.num_tool_output_tokens == expected_tool_tokens \ + + additional_tool_tokens diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index e4f2e800f94a..7723c5d5cbcf 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -3,7 +3,6 @@ import json import logging from abc import ABC, abstractmethod -from collections.abc import Sequence from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Optional, Union @@ -21,6 +20,23 @@ logger = logging.getLogger(__name__) +class TurnTokens: + """Tracks token counts for a single conversation turn.""" + + def __init__(self, input_tokens=0, output_tokens=0): + self.input_tokens = input_tokens + self.output_tokens = output_tokens + + def reset(self): + """Reset counters for a new turn.""" + self.input_tokens = 0 + self.output_tokens = 0 + + def copy(self): + """Create a copy of this turn's token counts.""" + return TurnTokens(self.input_tokens, self.output_tokens) + + class ConversationContext(ABC): @abstractmethod @@ -92,52 +108,124 @@ def __init__( self.num_init_messages = len(messages) self.num_prompt_tokens = 0 self.num_output_tokens = 0 - # TODO(woosuk): Implement the following fields. self.num_cached_tokens = 0 self.num_reasoning_tokens = 0 + self.num_tool_output_tokens = 0 - def _update_num_prompt_tokens(self, output: RequestOutput): - if output.prompt_token_ids and len(output.prompt_token_ids) > 0: - # NOTE: with built-in tools, there might be multiple rounds in - # the conversation, with the full conversation being resent - # as new prompt each time. Hence the sum. - self.num_prompt_tokens += len(output.prompt_token_ids) - - def _update_num_cached_tokens(self, output: RequestOutput): - if output.num_cached_tokens is not None: - #Similar to num_prompt_tokens - self.num_cached_tokens += output.num_cached_tokens - - def _update_num_output_tokens(self, token_ids: Sequence[int]): - self.num_output_tokens += len(token_ids) + # Turn tracking - replaces multiple individual tracking variables + self.current_turn = TurnTokens() + self.previous_turn = TurnTokens() + self.is_first_turn = True + self.first_tok_of_message = True # For streaming support - def _update_num_reasoning_tokens(self, token_ids: Sequence[int]): - # Count tokens that are part of reasoning content (analysis channel - # or tool-directed messages like python/browser calls) - is_analysis = self.parser.current_channel == "analysis" - is_tool_call = (self.parser.current_recipient is not None and - (self.parser.current_recipient.startswith("python") or - self.parser.current_recipient.startswith("browser."))) - if is_analysis or is_tool_call: - self.num_reasoning_tokens += len(token_ids) + def _update_num_reasoning_tokens(self): + # Count all analysis and commentary channels as reasoning tokens + if self.parser.current_channel in {"analysis", "commentary"}: + self.num_reasoning_tokens += 1 def append_output(self, output) -> None: if isinstance(output, RequestOutput): - self._update_num_prompt_tokens(output) - self._update_num_cached_tokens(output) output_token_ids = output.outputs[0].token_ids - self._update_num_output_tokens(output_token_ids) self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids: self.parser.process(token_id) # Check if the current token is part of reasoning content - self._update_num_reasoning_tokens([token_id]) + self._update_num_reasoning_tokens() + self._update_prefill_token_usage(output) + # Reset current turn output tokens for this turn + self.current_turn.output_tokens = 0 + self._update_decode_token_usage(output) + # Move current turn to previous turn for next turn's calculations + self.previous_turn = self.current_turn.copy() output_msgs = self.parser.messages else: # Tool output. output_msgs = output self._messages.extend(output_msgs) + def _update_prefill_token_usage(self, output: RequestOutput) -> None: + """Update token usage statistics for the prefill phase of generation. + + The prefill phase processes the input prompt tokens. This method: + 1. Counts the prompt tokens for this turn + 2. Calculates tool output tokens for multi-turn conversations + 3. Updates cached token counts + 4. Tracks state for next turn calculations + + Tool output tokens are calculated as: + current_prompt_tokens - last_turn_prompt_tokens - + last_turn_output_tokens + This represents tokens added between turns (typically tool responses). + + Args: + output: The RequestOutput containing prompt token information + """ + if output.prompt_token_ids is not None: + this_turn_input_tokens = len(output.prompt_token_ids) + else: + this_turn_input_tokens = 0 + logger.error( + "RequestOutput appended contains no prompt_token_ids.") + + # Update current turn input tokens + self.current_turn.input_tokens = this_turn_input_tokens + self.num_prompt_tokens += this_turn_input_tokens + + # Calculate tool tokens (except on first turn) + if self.is_first_turn: + self.is_first_turn = False + else: + # start counting tool after first turn + # tool tokens = this turn prefill - last turn prefill - + # last turn decode + this_turn_tool_tokens = (self.current_turn.input_tokens - + self.previous_turn.input_tokens - + self.previous_turn.output_tokens) + + # Handle negative tool token counts (shouldn't happen in normal + # cases) + if this_turn_tool_tokens < 0: + logger.error( + "Negative tool output tokens calculated: %d " + "(current_input=%d, previous_input=%d, " + "previous_output=%d). Setting to 0.", + this_turn_tool_tokens, self.current_turn.input_tokens, + self.previous_turn.input_tokens, + self.previous_turn.output_tokens) + this_turn_tool_tokens = 0 + + self.num_tool_output_tokens += this_turn_tool_tokens + + # Update cached tokens + if output.num_cached_tokens is not None: + self.num_cached_tokens += output.num_cached_tokens + + def _update_decode_token_usage(self, output: RequestOutput) -> int: + """Update token usage statistics for the decode phase of generation. + + The decode phase processes the generated output tokens. This method: + 1. Counts output tokens from all completion outputs + 2. Updates the total output token count + 3. Tracks tokens generated in the current turn + + In streaming mode, this is called for each token generated. + In non-streaming mode, this is called once with all output tokens. + + Args: + output: The RequestOutput containing generated token information + + Returns: + int: Number of output tokens processed in this call + """ + updated_output_token_count = 0 + if output.outputs: + for completion_output in output.outputs: + # only keep last round + updated_output_token_count += len(completion_output.token_ids) + self.num_output_tokens += updated_output_token_count + self.current_turn.output_tokens += updated_output_token_count + return updated_output_token_count + @property def messages(self) -> list: return self._messages @@ -231,8 +319,8 @@ def append_output(self, output) -> None: # append_output is called for each output token in streaming case, # so we only want to add the prompt tokens once for each message. if self.first_tok_of_message: - self._update_num_prompt_tokens(output) - self._update_num_cached_tokens(output) + self._update_prefill_token_usage(output) + self.current_turn.output_tokens = 0 # Reset self.first_tok_of_message if needed: # if the current token is the last one of the current message # (finished=True), then the next token processed will mark the @@ -240,9 +328,13 @@ def append_output(self, output) -> None: self.first_tok_of_message = output.finished for tok in output.outputs[0].token_ids: self.parser.process(tok) - self._update_num_output_tokens(output.outputs[0].token_ids) + self._update_decode_token_usage(output) + + # For streaming, update previous turn when message is complete + if output.finished: + self.previous_turn = self.current_turn.copy() # Check if the current token is part of reasoning content - self._update_num_reasoning_tokens(output.outputs[0].token_ids) + self._update_num_reasoning_tokens() self.last_tok = tok else: # Handle the case of tool output in direct message format diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 413e1dd8d633..c56c68cf7644 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1841,7 +1841,8 @@ class InputTokensDetails(OpenAIBaseModel): class OutputTokensDetails(OpenAIBaseModel): - reasoning_tokens: int + reasoning_tokens: int = 0 + tool_output_tokens: int = 0 class ResponseUsage(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index d49724b0439c..a102d4a4a5e6 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -460,7 +460,7 @@ async def responses_full_generator( if self.use_harmony: assert isinstance(context, HarmonyContext) output = self._make_response_output_items_with_harmony(context) - # TODO: these are all 0 for now! + num_tool_output_tokens = context.num_tool_output_tokens else: assert isinstance(context, SimpleContext) final_res = context.last_output @@ -473,6 +473,8 @@ async def responses_full_generator( # Calculate usage. assert final_res.prompt_token_ids is not None + num_tool_output_tokens = 0 + assert isinstance(context, (SimpleContext, HarmonyContext)) num_prompt_tokens = context.num_prompt_tokens num_generated_tokens = context.num_output_tokens @@ -486,7 +488,8 @@ async def responses_full_generator( input_tokens_details=InputTokensDetails( cached_tokens=num_cached_tokens), output_tokens_details=OutputTokensDetails( - reasoning_tokens=num_reasoning_tokens), + reasoning_tokens=num_reasoning_tokens, + tool_output_tokens=num_tool_output_tokens), ) response = ResponsesResponse.from_request( request,