diff --git a/docs/source/user_guide/feature_guide/kv_pool_mooncake.md b/docs/source/user_guide/feature_guide/kv_pool_mooncake.md index a9c61cd1ad5..34ab0479079 100644 --- a/docs/source/user_guide/feature_guide/kv_pool_mooncake.md +++ b/docs/source/user_guide/feature_guide/kv_pool_mooncake.md @@ -5,7 +5,7 @@ * Software: * Python >= 3.9, < 3.12 * CANN >= 8.3.rc1 - * PyTorch == 2.7.1, torch-npu == 2.7.1 + * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 * vLLM:main branch * vLLM-Ascend:main branch * Mooncake:main branch @@ -41,7 +41,7 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path "use_ascend_direct": true, "alloc_in_same_node": true, "master_server_address": "xx.xx.xx.xx:50088", - "global_segment_size": 30000000000 + "global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824) } ``` diff --git a/tests/ut/distributed/mooncake/test_config_data.py b/tests/ut/distributed/mooncake/test_config_data.py new file mode 100644 index 00000000000..4408b41a825 --- /dev/null +++ b/tests/ut/distributed/mooncake/test_config_data.py @@ -0,0 +1,68 @@ +import unittest + +from vllm_ascend.distributed.mooncake.config_data import ( + _convert_to_bytes, _parse_global_segment_size) + + +class TestParseGlobalSegmentSize(unittest.TestCase): + + def test_int_input(self): + self.assertEqual(_parse_global_segment_size(1024), 1024) + self.assertEqual(_parse_global_segment_size(0), 0) + + def test_gb_unit(self): + self.assertEqual(_parse_global_segment_size("2GB"), 2 * 1024**3) + self.assertEqual(_parse_global_segment_size("1.5GB"), + int(1.5 * 1024**3)) + self.assertEqual(_parse_global_segment_size(" 2 GB "), 2 * 1024**3) + + def test_gb_unit_edge_cases(self): + with self.assertRaises(ValueError): + _parse_global_segment_size("GB") + with self.assertRaises(ValueError): + _parse_global_segment_size("abcGB") + + def test_mb_unit(self): + self.assertEqual(_parse_global_segment_size("512MB"), 512 * 1024**2) + self.assertEqual(_parse_global_segment_size("0.5MB"), + int(0.5 * 1024**2)) + self.assertEqual(_parse_global_segment_size("1024MB"), 1024 * 1024**2) + + def test_kb_unit(self): + self.assertEqual(_parse_global_segment_size("256KB"), 256 * 1024) + self.assertEqual(_parse_global_segment_size("1.25KB"), + int(1.25 * 1024)) + + def test_b_unit(self): + self.assertEqual(_parse_global_segment_size("4096B"), 4096) + self.assertEqual(_parse_global_segment_size("1024b"), 1024) + + def test_no_unit(self): + self.assertEqual(_parse_global_segment_size("2048"), 2048) + self.assertEqual(_parse_global_segment_size("0"), 0) + + def test_non_string_non_int_input(self): + self.assertEqual(_parse_global_segment_size(2048.0), 2048) + self.assertEqual(_parse_global_segment_size(True), 1) + + with self.assertRaises(TypeError): + _parse_global_segment_size(None) + + with self.assertRaises(TypeError): + _parse_global_segment_size({"size": 1024}) + + +class TestConvertToBytes(unittest.TestCase): + + def test_valid_conversion(self): + self.assertEqual(_convert_to_bytes("10", 1, "10"), 10) + self.assertEqual(_convert_to_bytes("1.5", 1024, "1.5KB"), + int(1.5 * 1024)) + self.assertEqual(_convert_to_bytes("0", 1024**3, "0GB"), 0) + + def test_invalid_numbers(self): + with self.assertRaises(ValueError): + _convert_to_bytes("abc", 1, "abc") + + with self.assertRaises(ValueError): + _convert_to_bytes("1.2.3", 1024, "1.2.3KB") diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index 745d91131fa..36c820b0890 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -2,6 +2,7 @@ import hashlib import json import os +import re from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple, Union @@ -11,6 +12,9 @@ from vllm.utils import cdiv, logger from vllm.v1.core.sched.output import NewRequestData +DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB + @dataclass class MooncakeEngineMetadata: @@ -419,7 +423,7 @@ class LasyerMultiBlockReqMeta: class MooncakeStoreConfig: local_hostname: str metadata_server: str - global_segment_size: int + global_segment_size: Union[int, str] local_buffer_size: int protocol: str device_name: str @@ -433,8 +437,11 @@ def from_file(file_path: str) -> "MooncakeStoreConfig": return MooncakeStoreConfig( local_hostname=config.get("local_hostname"), metadata_server=config.get("metadata_server"), - global_segment_size=config.get("global_segment_size", 3355443200), - local_buffer_size=config.get("local_buffer_size", 1073741824), + global_segment_size=_parse_global_segment_size( + config.get("global_segment_size", + DEFAULT_GLOBAL_SEGMENT_SIZE)), + local_buffer_size=(config.get("local_buffer_size", + DEFAULT_LOCAL_BUFFER_SIZE)), protocol=config.get("protocol", "tcp"), device_name=config.get("device_name", ""), master_server_address=config.get("master_server_address"), @@ -446,4 +453,81 @@ def load_from_env() -> "MooncakeStoreConfig": if not config_path: raise ValueError( "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") - return MooncakeStoreConfig.from_file(config_path) \ No newline at end of file + return MooncakeStoreConfig.from_file(config_path) + + +def _parse_global_segment_size(value) -> int: + """ + Parse storage size strings with support for units: GB, MB, KB, B + + Args: + value: Input value (int, str, or other convertible types) + + Returns: + int: Size in bytes + + Raises: + ValueError: For invalid format, missing number, or negative values + TypeError: For unsupported input types + """ + + if isinstance(value, int): + return value + elif not isinstance(value, str): + try: + return int(value) + except (TypeError, ValueError) as e: + raise TypeError( + f"Unsupported type for global_segment_size: {type(value)}" + ) from e + + cleaned_input = value.strip().lower() + if not cleaned_input: + raise ValueError("global segment size cannot be empty.") + + UNIT_MULTIPLIERS = { + 'gb': 1024**3, # 1 GB = 1024^3 bytes + 'mb': 1024**2, # 1 MB = 1024^2 bytes + 'kb': 1024, # 1 KB = 1024 bytes + 'b': 1 # 1 B = 1 byte + } + pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$' + match = re.match(pattern, cleaned_input) + + if not match: + raise ValueError(f"Invalid format: '{value}'") + + number_str = match.group(1) + unit = match.group(2) or 'b' + + multiplier = UNIT_MULTIPLIERS[unit] + return _convert_to_bytes(number_str, multiplier, value) + + +def _convert_to_bytes(number_str: str, multiplier: int, + original_input: str) -> int: + """ + Convert numeric string to byte count + + Args: + number_str: Numeric portion of input + multiplier: Unit conversion factor + original_input: Original input string (for error messages) + + Returns: + int: Byte count + + Raises: + ValueError: For invalid numbers or negative results + """ + try: + numeric_value = float(number_str) + except ValueError: + raise ValueError( + f"Invalid numeric value '{number_str}' in: '{original_input}'") + # Calculate byte count + try: + byte_count = int(numeric_value * multiplier) + except OverflowError: + raise ValueError(f"Storage size too large: '{original_input}'") + return byte_count