Skip to content

Commit 2a81478

Browse files
nwpu-zxrliziyu179
authored andcommitted
[P/D]Make kv-transfer env variable take effect & Fix load-balance proxy (vllm-project#3981)
### What this PR does / why we need it? Make kv-transfer env variable take effect and Fix load-balance proxy. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By CI. - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: liziyu <[email protected]> Signed-off-by: nwpu-zxr <[email protected]> Co-authored-by: liziyu <[email protected]> Signed-off-by: hwhaokun <[email protected]>
1 parent 4383da8 commit 2a81478

File tree

7 files changed

+33
-13
lines changed

7 files changed

+33
-13
lines changed

examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,12 @@ async def metaserver(request: Request):
561561
max_retries=global_args.max_retries,
562562
base_delay=global_args.retry_delay)
563563
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
564+
proxy_state.release_prefiller_kv(prefiller_idx,prefiller_score)
564565

565566
except Exception as e:
566567
logger.error(f"Post metaserver failed with: {str(e)}")
568+
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
569+
proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
567570

568571

569572
if __name__ == '__main__':

tests/ut/kv_connector/test_mooncake_connector.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -978,9 +978,6 @@ def __init__(self, *args, **kwargs):
978978
self.data_ptr = MagicMock(return_value=0x1000)
979979

980980

981-
mock_envs_ascend = MagicMock()
982-
mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
983-
984981
mock_logger = MagicMock()
985982

986983

@@ -1017,14 +1014,15 @@ def mock_string_to_int64_hash(s):
10171014
class TestMooncakeConnectorWorker(unittest.TestCase):
10181015

10191016
def setUp(self):
1020-
self.envs_ascend_mock = MockEnvsAscend()
10211017
self.mock_transfer_engine = MagicMock()
10221018
self.mock_transfer_engine.get_rpc_port.return_value = 9090
10231019
self.mock_transfer_engine.initialize.return_value = 0
10241020
self.mock_transfer_engine.register_memory.return_value = 0
10251021

10261022
self.patches = [
1027-
patch('os.getenv', return_value="10,11"),
1023+
patch(
1024+
'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
1025+
'10,11'),
10281026
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
10291027
patch('torch.Tensor.element_size', return_value=4),
10301028
patch('torch.Tensor.data_ptr', return_value=0x1000),
@@ -1053,8 +1051,6 @@ def setUp(self):
10531051
MagicMock()),
10541052
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
10551053
MagicMock()),
1056-
patch.dict('sys.modules',
1057-
{'vllm_ascend.envs': self.envs_ascend_mock}),
10581054
]
10591055

10601056
for p in self.patches:

tests/ut/kv_connector/test_mooncake_layerwise_connector.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -792,15 +792,15 @@ def test_request_finished(self, mock_method):
792792
class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
793793

794794
def setUp(self):
795-
self.envs_ascend_mock = type("MockEnvsAscend", (),
796-
{"PHYSICAL_DEVICES": "10,11"})()
797795
self.mock_transfer_engine = MagicMock()
798796
self.mock_transfer_engine.get_rpc_port.return_value = 9090
799797
self.mock_transfer_engine.initialize.return_value = 0
800798
self.mock_transfer_engine.register_memory.return_value = 0
801799

802800
self.patches = [
803-
patch('os.getenv', return_value="10,11"),
801+
patch(
802+
'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
803+
'10,11'),
804804
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
805805
patch('torch.Tensor.element_size', return_value=4),
806806
patch('torch.Tensor.data_ptr', return_value=0x1000),
@@ -833,8 +833,6 @@ def setUp(self):
833833
patch(
834834
'vllm_ascend.distributed.mooncake_layerwise_connector.threading.Event',
835835
MagicMock()),
836-
patch.dict('sys.modules',
837-
{'vllm_ascend.envs': self.envs_ascend_mock}),
838836
patch(
839837
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config',
840838
return_value=SimpleNamespace(pd_tp_ratio=1,

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.v1.request import Request, RequestStatus
3232

3333
import vllm_ascend.envs as envs_ascend
34+
from vllm_ascend.distributed.utils import get_transfer_timeout_value
3435
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
3536
prefill_context_parallel_enable,
3637
vllm_version_is)
@@ -438,7 +439,7 @@ def init_llm_datadist(self):
438439
assert self.local_agent_metadata is not None
439440
llm_config = LLMConfig()
440441
llm_config.device_id = self.local_rank
441-
llm_config.sync_kv_timeout = 20000
442+
llm_config.sync_kv_timeout = get_transfer_timeout_value()
442443
llm_config.enable_switch_role = True
443444
llm_config.enable_cache_manager = True
444445
llm_config.enable_remote_cache_accessible = True

vllm_ascend/distributed/mooncake_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextlib
33
import hashlib
44
import math
5+
import os
56
import queue
67
import random
78
import struct
@@ -33,6 +34,7 @@
3334
import vllm_ascend.envs as envs_ascend
3435
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
3536
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
37+
from vllm_ascend.distributed.utils import get_transfer_timeout_value
3638
from vllm_ascend.utils import vllm_version_is
3739

3840
if vllm_version_is("0.11.0"):
@@ -855,6 +857,8 @@ class MooncakeConnectorWorker:
855857

856858
def __init__(self, vllm_config: VllmConfig, engine_id: str):
857859
self._get_prefill_decode_size(vllm_config)
860+
os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(
861+
get_transfer_timeout_value())
858862
if self._prefill_tp_size < self._decode_tp_size:
859863
raise ValueError(
860864
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"

vllm_ascend/distributed/mooncake_layerwise_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import hashlib
55
import math
6+
import os
67
import queue
78
import struct
89
import threading
@@ -31,6 +32,7 @@
3132
import vllm_ascend.envs as envs_ascend
3233
from vllm_ascend.ascend_config import get_ascend_config
3334
from vllm_ascend.distributed.utils import (align_memory,
35+
get_transfer_timeout_value,
3436
kv_alltoall_and_rearrange)
3537
from vllm_ascend.utils import vllm_version_is
3638

@@ -602,6 +604,8 @@ class MooncakeLayerwiseConnectorWorker:
602604

603605
def __init__(self, vllm_config: VllmConfig, engine_id: str):
604606
self._get_prefill_decode_size(vllm_config)
607+
os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(
608+
get_transfer_timeout_value())
605609
if self._prefill_tp_size < self._decode_tp_size:
606610
raise ValueError(
607611
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"

vllm_ascend/distributed/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import torch
24
import torch.distributed as dist
35

@@ -45,3 +47,15 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
4547
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
4648
offset = (aligned_addr - data_ptr) // tensor.element_size()
4749
return tensor[int(offset):]
50+
51+
52+
def get_transfer_timeout_value():
53+
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
54+
if len(ascend_transfer_timeout) > 0:
55+
return int(ascend_transfer_timeout)
56+
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
57+
'20')) # type: ignore
58+
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
59+
'7')) # type: ignore
60+
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
61+
3000)

0 commit comments

Comments
 (0)