Skip to content

Commit e8e0851

Browse files
committed
Supports multiple input suffixes for global_segment_size
1 parent 292e213 commit e8e0851

File tree

3 files changed

+154
-5
lines changed

3 files changed

+154
-5
lines changed

examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
2424
"protocol": "ascend",
2525
"device_name": "",
2626
"master_server_address": "xx.xx.xx.xx:50088",
27-
"global_segment_size": 30000000000
27+
"global_segment_size": "1GB/1024MB/1048576KB/1073741824B" or 1073741824
2828
}
2929
```
3030

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import unittest
2+
3+
from vllm_ascend.distributed.mooncake.config_data import (
4+
_parse_global_segment_size, _convert_to_bytes)
5+
6+
7+
class TestParseGlobalSegmentSize(unittest.TestCase):
8+
def test_int_input(self):
9+
self.assertEqual(_parse_global_segment_size(1024), 1024)
10+
self.assertEqual(_parse_global_segment_size(0), 0)
11+
12+
def test_gb_unit(self):
13+
self.assertEqual(_parse_global_segment_size("2GB"), 2 * 1024**3)
14+
self.assertEqual(_parse_global_segment_size("1.5GB"), int(1.5 * 1024**3))
15+
self.assertEqual(_parse_global_segment_size(" 2 GB "), 2 * 1024**3)
16+
17+
def test_gb_unit_edge_cases(self):
18+
with self.assertRaises(ValueError):
19+
_parse_global_segment_size("GB")
20+
with self.assertRaises(ValueError):
21+
_parse_global_segment_size("abcGB")
22+
23+
def test_mb_unit(self):
24+
self.assertEqual(_parse_global_segment_size("512MB"), 512 * 1024**2)
25+
self.assertEqual(_parse_global_segment_size("0.5MB"), int(0.5 * 1024**2))
26+
self.assertEqual(_parse_global_segment_size("1024MB"), 1024 * 1024**2)
27+
28+
def test_kb_unit(self):
29+
self.assertEqual(_parse_global_segment_size("256KB"), 256 * 1024)
30+
self.assertEqual(_parse_global_segment_size("1.25KB"), int(1.25 * 1024))
31+
32+
def test_b_unit(self):
33+
self.assertEqual(_parse_global_segment_size("4096B"), 4096)
34+
self.assertEqual(_parse_global_segment_size("1024b"), 1024)
35+
36+
def test_no_unit(self):
37+
self.assertEqual(_parse_global_segment_size("2048"), 2048)
38+
self.assertEqual(_parse_global_segment_size("0"), 0)
39+
40+
def test_non_string_non_int_input(self):
41+
self.assertEqual(_parse_global_segment_size(2048.0), 2048)
42+
self.assertEqual(_parse_global_segment_size(True), 1)
43+
44+
with self.assertRaises(TypeError):
45+
_parse_global_segment_size(None)
46+
47+
with self.assertRaises(TypeError):
48+
_parse_global_segment_size({"size": 1024})
49+
50+
51+
class TestConvertToBytes(unittest.TestCase):
52+
def test_valid_conversion(self):
53+
self.assertEqual(_convert_to_bytes("10", 1, "10"), 10)
54+
self.assertEqual(_convert_to_bytes("1.5", 1024, "1.5KB"), int(1.5 * 1024))
55+
self.assertEqual(_convert_to_bytes("0", 1024**3, "0GB"), 0)
56+
57+
def test_invalid_numbers(self):
58+
with self.assertRaises(ValueError):
59+
_convert_to_bytes("abc", 1, "abc")
60+
61+
with self.assertRaises(ValueError):
62+
_convert_to_bytes("1.2.3", 1024, "1.2.3KB")
63+
64+
65+
if __name__ == '__main__':
66+
unittest.main()

vllm_ascend/distributed/mooncake/config_data.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import hashlib
33
import json
44
import os
5+
from typing import Union
56
from dataclasses import dataclass
67
from typing import Iterable, List, Optional, Tuple, Union
78

@@ -11,6 +12,8 @@
1112
from vllm.utils import cdiv, logger
1213
from vllm.v1.core.sched.output import NewRequestData
1314

15+
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
16+
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
1417

1518
@dataclass
1619
class MooncakeEngineMetadata:
@@ -419,7 +422,7 @@ class LasyerMultiBlockReqMeta:
419422
class MooncakeStoreConfig:
420423
local_hostname: str
421424
metadata_server: str
422-
global_segment_size: int
425+
global_segment_size: Union[int, str]
423426
local_buffer_size: int
424427
protocol: str
425428
device_name: str
@@ -433,8 +436,10 @@ def from_file(file_path: str) -> "MooncakeStoreConfig":
433436
return MooncakeStoreConfig(
434437
local_hostname=config.get("local_hostname"),
435438
metadata_server=config.get("metadata_server"),
436-
global_segment_size=config.get("global_segment_size", 3355443200),
437-
local_buffer_size=config.get("local_buffer_size", 1073741824),
439+
global_segment_size=_parse_global_segment_size(
440+
config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
441+
),
442+
local_buffer_size=config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE),
438443
protocol=config.get("protocol", "tcp"),
439444
device_name=config.get("device_name", ""),
440445
master_server_address=config.get("master_server_address"),
@@ -446,4 +451,82 @@ def load_from_env() -> "MooncakeStoreConfig":
446451
if not config_path:
447452
raise ValueError(
448453
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
449-
return MooncakeStoreConfig.from_file(config_path)
454+
return MooncakeStoreConfig.from_file(config_path)
455+
456+
457+
def _parse_global_segment_size(value) -> int:
458+
"""
459+
Parse storage size strings with support for units: GB, MB, KB, B
460+
461+
Args:
462+
value: Input value (int, str, or other convertible types)
463+
464+
Returns:
465+
int: Size in bytes
466+
467+
Raises:
468+
ValueError: For invalid format, missing number, or negative values
469+
TypeError: For unsupported input types
470+
"""
471+
472+
if isinstance(value, int):
473+
return value
474+
elif not isinstance(value, str):
475+
try:
476+
return int(value)
477+
except (TypeError, ValueError) as e:
478+
raise TypeError(
479+
f"Unsupported type for global_segment_size: {type(value)}"
480+
) from e
481+
# Clean input string
482+
cleaned_input = value.strip().lower()
483+
if not cleaned_input:
484+
raise ValueError("global segment size cannot be empty.")
485+
486+
UNIT_MULTIPLIERS = {
487+
'gb': 1024 ** 3, # 1 GB = 1024^3 bytes
488+
'mb': 1024 ** 2, # 1 MB = 1024^2 bytes
489+
'kb': 1024, # 1 KB = 1024 bytes
490+
'b': 1 # 1 B = 1 byte
491+
}
492+
# Check for unit suffixes
493+
for unit, multiplier in UNIT_MULTIPLIERS.items():
494+
if cleaned_input.endswith(unit):
495+
number_part = cleaned_input[:-len(unit)].strip()
496+
if not number_part:
497+
raise ValueError(
498+
f"Missing numeric value before unit '{unit}' in: '{value}'"
499+
)
500+
return _convert_to_bytes(number_part, multiplier, value)
501+
# Handle unit-less input (bytes)
502+
return _convert_to_bytes(cleaned_input, 1, value)
503+
504+
505+
def _convert_to_bytes(number_str: str, multiplier: int,
506+
original_input: str) -> int:
507+
"""
508+
Convert numeric string to byte count
509+
510+
Args:
511+
number_str: Numeric portion of input
512+
multiplier: Unit conversion factor
513+
original_input: Original input string (for error messages)
514+
515+
Returns:
516+
int: Byte count
517+
518+
Raises:
519+
ValueError: For invalid numbers or negative results
520+
"""
521+
try:
522+
numeric_value = float(number_str)
523+
except ValueError:
524+
raise ValueError(
525+
f"Invalid numeric value '{number_str}' in: '{original_input}'"
526+
)
527+
# Calculate byte count
528+
try:
529+
byte_count = int(numeric_value * multiplier)
530+
except OverflowError:
531+
raise ValueError(f"Storage size too large: '{original_input}'")
532+
return byte_count

0 commit comments

Comments
 (0)