1717from common .token_utils import num_tokens_from_string , total_token_count_from_response , truncate , encoder
1818import pytest
1919
20+
2021class TestNumTokensFromString :
2122 """Test cases for num_tokens_from_string function"""
2223
@@ -111,8 +112,6 @@ def test_consistency():
111112 assert first_result > 0
112113
113114
114- from unittest .mock import Mock
115-
116115class TestTotalTokenCountFromResponse :
117116 """Test cases for total_token_count_from_response function"""
118117
@@ -153,35 +152,6 @@ def test_dict_with_meta_tokens_input_output(self):
153152 result = total_token_count_from_response (resp_dict )
154153 assert result == 120
155154
156- def test_priority_order_usage_total_tokens_first (self ):
157- """Test that resp.usage.total_tokens takes priority over other formats"""
158- # Create a response that matches multiple conditions
159- mock_usage = Mock ()
160- mock_usage .total_tokens = 300
161-
162- mock_usage_metadata = Mock ()
163- mock_usage_metadata .total_tokens = 400
164-
165- mock_resp = Mock ()
166- mock_resp .usage = mock_usage
167- mock_resp .usage_metadata = mock_usage_metadata
168-
169- result = total_token_count_from_response (mock_resp )
170- assert result == 300 # Should use the first matching condition
171-
172- def test_priority_order_usage_metadata_second (self ):
173- """Test that resp.usage_metadata.total_tokens is second in priority"""
174- # Create a response without resp.usage but with resp.usage_metadata
175- mock_usage_metadata = Mock ()
176- mock_usage_metadata .total_tokens = 250
177-
178- mock_resp = Mock ()
179- delattr (mock_resp , 'usage' ) # Ensure no usage attribute
180- mock_resp .usage_metadata = mock_usage_metadata
181-
182- result = total_token_count_from_response (mock_resp )
183- assert result == 250
184-
185155 def test_priority_order_dict_usage_total_tokens_third (self ):
186156 """Test that dict['usage']['total_tokens'] is third in priority"""
187157 resp_dict = {
@@ -279,34 +249,6 @@ def test_invalid_response_type(self):
279249 # assert result == 0
280250
281251
282- # Parameterized tests for different response formats
283- @pytest .mark .parametrize ("response_data,expected_tokens" , [
284- # Object with usage.total_tokens
285- ({"usage" : Mock (total_tokens = 150 )}, 150 ),
286- # Dict with usage.total_tokens
287- ({"usage" : {"total_tokens" : 175 }}, 175 ),
288- # Dict with usage.input_tokens + output_tokens
289- ({"usage" : {"input_tokens" : 100 , "output_tokens" : 50 }}, 150 ),
290- # Dict with meta.tokens.input_tokens + output_tokens
291- ({"meta" : {"tokens" : {"input_tokens" : 80 , "output_tokens" : 40 }}}, 120 ),
292- # Empty dict
293- ({}, 0 ),
294- ])
295- def test_various_response_formats (response_data , expected_tokens ):
296- """Test various response formats using parameterized tests"""
297- if isinstance (response_data , dict ) and not any (isinstance (v , Mock ) for v in response_data .values ()):
298- # Regular dictionary
299- resp = response_data
300- else :
301- # Mock object
302- resp = Mock ()
303- for key , value in response_data .items ():
304- setattr (resp , key , value )
305-
306- result = total_token_count_from_response (resp )
307- assert result == expected_tokens
308-
309-
310252class TestTruncate :
311253 """Test cases for truncate function"""
312254
@@ -428,4 +370,4 @@ def test_numbers_and_symbols(self):
428370 max_len = 4
429371
430372 result = truncate (number_string , max_len )
431- assert len (encoder .encode (result )) == max_len
373+ assert len (encoder .encode (result )) == max_len
0 commit comments