1+ import unittest
2+
3+ from vllm_ascend .distributed .mooncake .config_data import _parse_global_segment_size , _convert_to_bytes
4+
5+
6+ class TestParseGlobalSegmentSize (unittest .TestCase ):
7+ def test_int_input (self ):
8+ self .assertEqual (_parse_global_segment_size (1024 ), 1024 )
9+ self .assertEqual (_parse_global_segment_size (0 ), 0 )
10+
11+ def test_gb_unit (self ):
12+ self .assertEqual (_parse_global_segment_size ("2GB" ), 2 * 1024 ** 3 )
13+ self .assertEqual (_parse_global_segment_size ("1.5GB" ), int (1.5 * 1024 ** 3 ))
14+ self .assertEqual (_parse_global_segment_size (" 2 GB " ), 2 * 1024 ** 3 )
15+
16+ def test_gb_unit_edge_cases (self ):
17+ with self .assertRaises (ValueError ):
18+ _parse_global_segment_size ("GB" )
19+ with self .assertRaises (ValueError ):
20+ _parse_global_segment_size ("abcGB" )
21+
22+ def test_mb_unit (self ):
23+ self .assertEqual (_parse_global_segment_size ("512MB" ), 512 * 1024 ** 2 )
24+ self .assertEqual (_parse_global_segment_size ("0.5MB" ), int (0.5 * 1024 ** 2 ))
25+ self .assertEqual (_parse_global_segment_size ("1024MB" ), 1024 * 1024 ** 2 )
26+
27+ def test_kb_unit (self ):
28+ self .assertEqual (_parse_global_segment_size ("256KB" ), 256 * 1024 )
29+ self .assertEqual (_parse_global_segment_size ("1.25KB" ), int (1.25 * 1024 ))
30+
31+ def test_b_unit (self ):
32+ self .assertEqual (_parse_global_segment_size ("4096B" ), 4096 )
33+ self .assertEqual (_parse_global_segment_size ("1024b" ), 1024 )
34+
35+ def test_no_unit (self ):
36+ self .assertEqual (_parse_global_segment_size ("2048" ), 2048 )
37+ self .assertEqual (_parse_global_segment_size ("0" ), 0 )
38+
39+ def test_non_string_non_int_input (self ):
40+ self .assertEqual (_parse_global_segment_size (2048.0 ), 2048 )
41+ self .assertEqual (_parse_global_segment_size (True ), 1 )
42+
43+ with self .assertRaises (TypeError ):
44+ _parse_global_segment_size (None )
45+
46+ with self .assertRaises (TypeError ):
47+ _parse_global_segment_size ({"size" : 1024 })
48+
49+
50+ class TestConvertToBytes (unittest .TestCase ):
51+ def test_valid_conversion (self ):
52+ self .assertEqual (_convert_to_bytes ("10" , 1 , "10" ), 10 )
53+ self .assertEqual (_convert_to_bytes ("1.5" , 1024 , "1.5KB" ), int (1.5 * 1024 ))
54+ self .assertEqual (_convert_to_bytes ("0" , 1024 ** 3 , "0GB" ), 0 )
55+
56+ def test_invalid_numbers (self ):
57+ with self .assertRaises (ValueError ):
58+ _convert_to_bytes ("abc" , 1 , "abc" )
59+
60+ with self .assertRaises (ValueError ):
61+ _convert_to_bytes ("1.2.3" , 1024 , "1.2.3KB" )
62+
63+
64+ if __name__ == '__main__' :
65+ unittest .main ()
0 commit comments