Skip to content

Commit 332b547

Browse files
[Bugfix] support mtp kv transfer and pp partition by hand in kv transfer (#4892)
### What this PR does / why we need it? Current mooncake connector has following problems with PP and MTP enabled: 1. MTP layer kv caches are not transfered, it may cause decreasing of accept ratio: This PR add MTP layer indices for last PP stage after calculating end_layer in transfer_kv_cache 2. While MTP enabled, PP layers divided by default may cause imbalance between stages, we need to use `VLLM_PP_LAYER_PARTITION` environment to make it balance by hand, but in mooncake connector kv transfer, decode doesn't know the partition of prefill node: This PR add config `pp_layer_partition` in `kv_connector_extra_config` to make decode node acquire the partition information of prefill node. ### Does this PR introduce _any_ user-facing change? When prefill using `VLLM_PP_LAYER_PARTITION` environment, add `pp_layer_partition` in `kv_connector_extra_config` like below: ``` export VLLM_PP_LAYER_PARTITION=33,28 "kv_connector_extra_config": { "use_ascend_direct": true, "prefill": { "dp_size": 1, "tp_size": 8, "pp_size": 2, "pp_layer_partition": "33,28" }, "decode": { "dp_size": 16, "tp_size": 1, "pp_size": 1 } } ``` ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: lidenghui <[email protected]>
1 parent a47aa4d commit 332b547

File tree

2 files changed

+71
-14
lines changed

2 files changed

+71
-14
lines changed

tests/ut/kv_connector/test_mooncake_connector.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ def setUp(self):
242242
block_len=[1024, 2048],
243243
ready_event=self.ready_event,
244244
vllm_config=self.vllm_config,
245-
kv_caches=self.kv_caches)
245+
kv_caches=self.kv_caches,
246+
prefill_pp_layer_partition=None)
246247

247248
def test_add_request(self):
248249
test_req = {
@@ -295,7 +296,8 @@ def setUp(self):
295296
block_len=[1024, 2048],
296297
ready_event=self.ready_event,
297298
vllm_config=self.vllm_config,
298-
kv_caches=self.kv_caches)
299+
kv_caches=self.kv_caches,
300+
prefill_pp_layer_partition=None)
299301
self.thread.remote_sockets = defaultdict(deque)
300302
self.thread.remote_poller = MagicMock()
301303

@@ -352,7 +354,8 @@ def setUp(self):
352354
block_len=[1024, 2048],
353355
ready_event=self.ready_event,
354356
vllm_config=self.vllm_config,
355-
kv_caches=self.kv_caches)
357+
kv_caches=self.kv_caches,
358+
prefill_pp_layer_partition=None)
356359
self.thread.request_queue = self.mock_queue
357360
self.test_req = {
358361
"request_id": "req1",
@@ -434,7 +437,8 @@ def setUp(self):
434437
block_len=[1024, 2048],
435438
ready_event=self.ready_event,
436439
vllm_config=self.vllm_config,
437-
kv_caches=self.kv_caches)
440+
kv_caches=self.kv_caches,
441+
prefill_pp_layer_partition=None)
438442
self.test_metadata = MooncakeAgentMetadata(
439443
engine_id="remote_engine",
440444
te_rpc_port=9090,
@@ -498,7 +502,8 @@ def setUp(self):
498502
block_len=[1024, 2048],
499503
ready_event=self.ready_event,
500504
vllm_config=self.vllm_config,
501-
kv_caches=self.kv_caches)
505+
kv_caches=self.kv_caches,
506+
prefill_pp_layer_partition=None)
502507
self.thread.request_queue = queue.Queue()
503508

504509
@patch.object(KVCacheRecvingThread, '_handle_request')
@@ -535,6 +540,7 @@ def __init__(self):
535540
self.parallel_config = MagicMock()
536541
self.cache_config = MagicMock()
537542
self.kv_transfer_config = MagicMock()
543+
self.speculative_config = MagicMock()
538544
self.model_config.use_mla = True
539545
self.parallel_config.tensor_parallel_size = 2
540546
self.parallel_config.data_parallel_rank = 0

vllm_ascend/distributed/mooncake_connector.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,19 @@ def run(self):
271271

272272
class KVCacheRecvingThread(threading.Thread):
273273

274-
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
275-
engine: TransferEngine, local_engine_id: str,
274+
def __init__(self,
275+
tp_rank: int,
276+
tp_size: int,
277+
_prefill_pp_size: int,
278+
engine: TransferEngine,
279+
local_engine_id: str,
276280
local_handshake_port: int,
277-
local_kv_caches_base_addr: list[int], block_len: list[int],
278-
ready_event: threading.Event, vllm_config: VllmConfig,
279-
kv_caches: dict[str, Any]):
281+
local_kv_caches_base_addr: list[int],
282+
block_len: list[int],
283+
ready_event: threading.Event,
284+
vllm_config: VllmConfig,
285+
kv_caches: dict[str, Any],
286+
prefill_pp_layer_partition: Optional[str] = None):
280287
super().__init__(daemon=True, name="KVCacheRecvingThread")
281288
self.tp_rank = tp_rank
282289
self.tp_size = tp_size
@@ -315,6 +322,14 @@ def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
315322
self.vllm_config = vllm_config
316323
self.model_config = self.vllm_config.model_config
317324
self.block_size = self.vllm_config.cache_config.block_size
325+
self.num_layers = self.model_config.hf_config.num_hidden_layers
326+
self.pp_layer_indices = {
327+
rank:
328+
get_prefill_pp_indices(self.num_layers, rank,
329+
self._prefill_pp_size,
330+
prefill_pp_layer_partition)
331+
for rank in range(self._prefill_pp_size)
332+
}
318333
if self.use_mla:
319334
self.k_head_dim = self.model_config.hf_config.kv_lora_rank
320335
self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim
@@ -435,9 +450,14 @@ def _transfer_kv_cache(self, req_meta: dict[str, Any]):
435450

436451
remote_kv_caches_base_addrs = \
437452
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
438-
num_layers = self.model_config.hf_config.num_hidden_layers
439-
first_layer_index, end_layer_index = get_pp_indices(
440-
num_layers, prefill_pp_rank, self._prefill_pp_size)
453+
first_layer_index, end_layer_index = self.pp_layer_indices[
454+
prefill_pp_rank]
455+
# support MTP layer kv transfer
456+
if self.vllm_config.speculative_config is not None:
457+
num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens
458+
num_speculative_tokens = 0 if num_speculative_tokens is None else num_speculative_tokens
459+
if prefill_pp_rank == self._prefill_pp_size - 1:
460+
end_layer_index = end_layer_index + num_speculative_tokens
441461
num_cache_per_layer = len(list(
442462
self.kv_caches.values())[0]) # Number of KV caches per layer
443463
local_kv_caches_base_addrs = \
@@ -1020,6 +1040,8 @@ def _get_prefill_decode_size(self, vllm_config: VllmConfig):
10201040
# get prefill pp size from extra config
10211041
self._decode_pp_size = decode_parallel_config.get("pp_size", 1)
10221042
assert self._decode_pp_size == 1, "decode pp size must be 1"
1043+
self._prefill_pp_layer_partition = prefill_parallel_config.get(
1044+
"pp_layer_partition", None)
10231045

10241046
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10251047
"""Register the KV Cache data."""
@@ -1126,7 +1148,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11261148
self.kv_recv_thread = KVCacheRecvingThread(
11271149
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
11281150
self.engine_id, self.handshake_port, kv_caches_base_addr,
1129-
self.block_len, ready_event, self.vllm_config, self.kv_caches)
1151+
self.block_len, ready_event, self.vllm_config, self.kv_caches,
1152+
self._prefill_pp_layer_partition)
11301153
self.kv_recv_thread.start()
11311154
ready_event.wait()
11321155

@@ -1455,3 +1478,31 @@ def ensure_zmq_recv(
14551478
raise RuntimeError(
14561479
f"Failed to receive data after {max_retries} "
14571480
f"retries: {e}")
1481+
1482+
1483+
# decode node should know pp_partition_layer in prefill node,
1484+
# it is configured in kv_transfer_config by partition_list_str,
1485+
# default using vllm layer split algorithm.
1486+
def get_prefill_pp_indices(
1487+
num_hidden_layers: int,
1488+
pp_rank: int,
1489+
pp_size: int,
1490+
partition_list_str: Optional[str] = None) -> tuple[int, int]:
1491+
if partition_list_str is None:
1492+
return get_pp_indices(num_hidden_layers, pp_rank, pp_size)
1493+
else:
1494+
try:
1495+
partitions = [
1496+
int(layer) for layer in partition_list_str.split(",")
1497+
]
1498+
except ValueError as err:
1499+
raise ValueError("Invalid partition string: {}".format(
1500+
partition_list_str)) from err
1501+
if len(partitions) != pp_size:
1502+
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
1503+
if sum(partitions) != num_hidden_layers:
1504+
raise ValueError(
1505+
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
1506+
start_layer = sum(partitions[:pp_rank])
1507+
end_layer = start_layer + partitions[pp_rank]
1508+
return (start_layer, end_layer)

0 commit comments

Comments
 (0)