Skip to content

Commit 86c36ae

Browse files
author
kongsiyuan
committed
PCP/DCP 适配mooncake layerwise connector
1 parent 18221c0 commit 86c36ae

File tree

1 file changed

+193
-9
lines changed

1 file changed

+193
-9
lines changed

vllm_ascend/distributed/mooncake_layerwise_connector.py

Lines changed: 193 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525
from vllm.config import VllmConfig
2626
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
2727
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
28-
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
29-
get_tp_group, get_world_group)
28+
from vllm.distributed.parallel_state import (
29+
get_decode_context_model_parallel_rank, get_pcp_group,
30+
get_tensor_model_parallel_rank, get_tp_group, get_world_group)
31+
32+
from vllm.distributed import (get_prefill_context_model_parallel_rank,
33+
get_prefill_context_model_parallel_world_size)
3034
from vllm.logger import logger
3135
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
3236
from vllm.v1.core.sched.output import SchedulerOutput
@@ -65,6 +69,8 @@ class ReqMeta:
6569
remote_te_rpc_port: Optional[int]
6670
remote_kv_caches_base_addr: Optional[list[int]]
6771
metaserver: Optional[str]
72+
remote_pcp_size: int
73+
remote_dcp_size: int
6874

6975

7076
@dataclass
@@ -294,7 +300,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
294300
def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int,
295301
pd_head_ratio: int, local_engine_id: str,
296302
metadata: MooncakeAgentMetadata,
297-
ready_event: threading.Event):
303+
ready_event: threading.Event, pcp_rank: int):
298304
super().__init__(daemon=True, name="KVCacheRecvingLayerThread")
299305
self.tp_rank = tp_rank
300306
self.tp_size = tp_size
@@ -307,6 +313,7 @@ def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int,
307313
self.task_tracker = dict[str, int]()
308314
self.ready_event = ready_event
309315
self.metadata = metadata
316+
self.pcp_rank = pcp_rank
310317

311318
def get_and_clear_finished_requests(self) -> set[str]:
312319
"""
@@ -328,7 +335,8 @@ def update_task(self, req_id):
328335

329336
def run(self):
330337
"""Run the thread to handle KV cache transfer requests."""
331-
handshake_port = self.side_channel_port + self.tp_rank
338+
handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size \
339+
+ self.tp_rank
332340
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
333341
logger.info("Starting listening on path: %s", path)
334342
encoder = msgspec.msgpack.Encoder()
@@ -389,6 +397,8 @@ def add_new_req(self,
389397
remote_kv_caches_base_addr=kv_transfer_params.get(
390398
"remote_kv_caches_base_addr", None),
391399
metaserver=kv_transfer_params.get("metaserver", None),
400+
remote_pcp_size=kv_transfer_params["remote_pcp_size"],
401+
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
392402
)
393403

394404

@@ -497,12 +507,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
497507
logger.info("Initializing Mooncake Scheduler %s", engine_id)
498508

499509
self.side_channel_host = get_ip()
510+
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
511+
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
500512

501513
# Handshake base port
502514
self.side_channel_port = (
503515
vllm_config.kv_transfer_config.kv_port +
504516
vllm_config.parallel_config.data_parallel_rank *
505-
vllm_config.parallel_config.tensor_parallel_size)
517+
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
506518

507519
# Requests that need to start recv.
508520
# New requests are added by update_state_after_alloc in
@@ -673,6 +685,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
673685
self.tp_group = get_tp_group()
674686
self.kv_caches: dict[str, torch.Tensor] = {}
675687
self.side_channel_host = get_ip()
688+
self.pcp_size = get_pcp_group().world_size
689+
self.pcp_rank = get_pcp_group(
690+
).rank_in_group if self.pcp_size > 1 else 0
691+
self.dcp_size = get_prefill_context_model_parallel_world_size()
692+
self.dcp_rank = get_prefill_context_model_parallel_rank(
693+
) if self.dcp_size > 1 else 0
694+
676695
self.total_layers = vllm_config.model_config.get_num_layers(
677696
vllm_config.parallel_config)
678697

@@ -686,7 +705,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
686705
vllm_config.kv_transfer_config.kv_port +
687706
vllm_config.parallel_config.data_parallel_rank *
688707
vllm_config.parallel_config.tensor_parallel_size)
689-
self.handshake_port = self.side_channel_port + self.tp_rank
708+
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
690709
self.sockets: dict = {}
691710
logger.info("Initializing Mooncake work %s", engine_id)
692711
self.engine = global_te.get_transfer_engine(self.side_channel_host,
@@ -720,6 +739,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
720739
deque)
721740
self.remote_poller = zmq.Poller() # type: ignore
722741
self.timeout = 1.0 # seconds
742+
self.local_remote_block_port_mapping = None
743+
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
723744

724745
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
725746
# get prefill tp and dp size from extra config
@@ -830,7 +851,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
830851
ready_event = threading.Event()
831852
self.kv_recv_layer_thread = KVCacheRecvingLayerThread(
832853
self.tp_rank, self.side_channel_port, self.tp_size,
833-
self.pd_head_ratio, self.engine_id, metadata, ready_event)
854+
self.pd_head_ratio, self.engine_id, metadata, ready_event, self.pcp_rank)
834855
self.kv_recv_layer_thread.start()
835856
ready_event.wait()
836857

@@ -860,6 +881,164 @@ def get_finished(self) -> tuple[set[str], set[str]]:
860881
"requests: %d", 0, len(done_recving))
861882
return set(), done_recving
862883

884+
def _get_kv_split_metadata(
885+
self,
886+
req_id: str,
887+
meta: ReqMeta,
888+
) -> tuple[list[list[int]], list[list[int]], list[list[int]]]:
889+
"""
890+
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
891+
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
892+
"""
893+
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
894+
choosen_rank_list = self.tp_rank
895+
remote_handshake_port_list = [[
896+
x + meta.remote_port for x in choosen_rank_list
897+
]]
898+
local_block_ids_list, remote_block_ids_list = [
899+
meta.local_block_ids
900+
], [meta.remote_block_ids]
901+
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
902+
903+
## TODO: add context_parallel_parameters_check
904+
905+
def get_kv_head_groups(tp_size):
906+
if self.use_mla:
907+
kv_head_groups = []
908+
kv_head_ids = [0]
909+
kv_head_groups.append(tuple(kv_head_ids))
910+
return kv_head_groups
911+
if self.num_key_value_heads // tp_size >= 1:
912+
kv_head_groups = []
913+
for tp_rank in range(tp_size):
914+
kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \
915+
for head_idx in range(self.num_key_value_heads // tp_size)]
916+
kv_head_groups.append(tuple(kv_head_ids))
917+
return kv_head_groups
918+
if tp_size // self.num_key_value_heads > 1:
919+
kv_head_groups = []
920+
for kv_head_ids in range(self.num_key_value_heads):
921+
kv_head_groups.append(tuple([kv_head_ids]))
922+
return kv_head_groups
923+
924+
def get_cp_group_meta(tp_size, pcp_size, dcp_size, port_base):
925+
# key is kv_head_group, value is cp_groups and which cp_groups to select
926+
cp_group_meta = {}
927+
kv_head_groups = get_kv_head_groups(tp_size)
928+
dcp_repeat_num = tp_size // len(kv_head_groups) // dcp_size
929+
930+
for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups):
931+
if kv_head_group not in cp_group_meta:
932+
cp_group_meta[kv_head_group] = {}
933+
cp_group_meta[kv_head_group]['cp_groups'] = []
934+
cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0
935+
kv_head_group_offset = tp_size // len(kv_head_groups) * kv_head_group_idx
936+
for dcp_repeat_idx in range(dcp_repeat_num):
937+
# len(cp_group) == pcp_size * dcp_size
938+
cp_group = []
939+
dcp_repeat_offset = dcp_size * dcp_repeat_idx
940+
for pcp_rank in range(pcp_size):
941+
pcp_rank_offset = tp_size * pcp_rank
942+
for dcp_rank in range(dcp_size):
943+
cp_group.append(dcp_rank + port_base + pcp_rank_offset +
944+
dcp_repeat_offset + kv_head_group_offset)
945+
cp_group_meta[kv_head_group]['cp_groups'].append(cp_group)
946+
947+
return cp_group_meta
948+
949+
def get_local_remote_block_port_mapping():
950+
p_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size,
951+
self.dcp_size, self.side_channel_port)
952+
d_node_cp_group_meta = get_cp_group_meta(self._decode_tp_size, meta.remote_pcp_size,
953+
meta.remote_dcp_size, meta.remote_port)
954+
local_remote_block_port_mappings = {}
955+
for p_node_head_key in p_node_cp_group_meta.keys():
956+
for d_node_head_key in d_node_cp_group_meta.keys():
957+
if not set(d_node_head_key).issubset(set(p_node_head_key)):
958+
continue
959+
d_node_head_group = d_node_cp_group_meta[d_node_head_key]
960+
p_node_head_group = p_node_cp_group_meta[p_node_head_key]
961+
for p_cp_group in p_node_head_group['cp_groups']:
962+
select_cp_groups_id = d_node_head_group['select_cp_groups_id']
963+
d_cp_groups = d_node_head_group['cp_groups']
964+
d_cp_group = d_cp_groups[select_cp_groups_id]
965+
d_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \
966+
if select_cp_groups_id + 1 < len(d_cp_groups) else 0
967+
for p_idx, p_port in enumerate(p_cp_group):
968+
if p_port not in local_remote_block_port_mappings:
969+
local_remote_block_port_mappings[p_port] = []
970+
d_port_remote_list = []
971+
for d_idx, d_port in enumerate(d_cp_group):
972+
if d_idx % len(d_cp_group) == p_idx:
973+
d_port_remote_list.append(d_port)
974+
local_remote_block_port_mappings[p_port].append(d_port_remote_list)
975+
local_remote_block_port_mapping = local_remote_block_port_mappings[self.handshake_port]
976+
977+
logger.info(
978+
"p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. "
979+
"local_remote_block_port_mappings is:: %s. ", p_node_cp_group_meta,
980+
d_node_cp_group_meta, local_remote_block_port_mappings)
981+
982+
return local_remote_block_port_mapping
983+
984+
if self.local_remote_block_port_mapping is None:
985+
self.local_remote_block_port_mapping = get_local_remote_block_port_mapping()
986+
987+
988+
local_cp_size = self.pcp_size * self.dcp_size
989+
prompt_block_nums = len(meta.local_block_ids)
990+
local_block_nums_all = [prompt_block_nums // local_cp_size] * local_cp_size
991+
num_remain_blocks = prompt_block_nums % local_cp_size
992+
for i in range(num_remain_blocks):
993+
local_block_nums_all[i] += 1
994+
num_remain_blocks = (num_remain_blocks + local_cp_size - 1) % local_cp_size
995+
996+
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
997+
local_block_nums = []
998+
final_block_idx = None
999+
1000+
remote_dcp_rank = get_decode_context_model_parallel_rank(
1001+
) if meta.remote_dcp_size > 1 else 0
1002+
remote_pcp_rank = get_pcp_group(
1003+
).rank_in_group if meta.remote_pcp_size else 0
1004+
remote_cp_rank = remote_dcp_rank + remote_pcp_rank * meta.remote_dcp_size
1005+
remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size
1006+
for cp_rank, block_num in enumerate(local_block_nums_all):
1007+
if cp_rank % remote_cp_size == remote_cp_rank:
1008+
if num_remain_blocks == cp_rank:
1009+
final_block_idx = len(local_block_nums)
1010+
local_block_nums.append(block_num)
1011+
1012+
if final_block_idx is not None:
1013+
final_block_num = local_block_nums.pop(final_block_idx)
1014+
local_block_nums.append(final_block_num)
1015+
for mapping in self.local_remote_block_port_mapping:
1016+
final_block_port = mapping.pop(final_block_idx)
1017+
mapping.append(final_block_port)
1018+
1019+
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = [], [], []
1020+
for idx in range(len(self.local_remote_block_port_mapping[0])):
1021+
mapping_list = []
1022+
for mapping in self.local_remote_block_port_mapping:
1023+
mapping_list.append(mapping[idx])
1024+
remote_handshake_port_list.append(mapping_list)
1025+
1026+
# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
1027+
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
1028+
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
1029+
# D rank will get remote block 1 in port 30004 and save it in local block 5
1030+
remote_block_offset = 0
1031+
for local_kv_id in range(len(remote_handshake_port_list)):
1032+
num_blocks_to_push = local_block_nums[local_kv_id]
1033+
local_block_ids_list.append(
1034+
meta.local_block_ids[:num_blocks_to_push])
1035+
remote_block_ids_list.append(
1036+
meta.remote_block_ids[remote_block_offset:remote_block_offset+
1037+
num_blocks_to_push])
1038+
remote_block_offset += num_blocks_to_push
1039+
1040+
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
1041+
8631042
def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata):
8641043
"""Start loading KV blocks from remote engine."""
8651044
self.current_layer = 0
@@ -964,8 +1143,13 @@ def sort_kv_cache(input_kv: list[list[int]]):
9641143
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
9651144
)
9661145
assert self.kv_send_layer_thread is not None
967-
self.kv_send_layer_thread.send_queue.put(
968-
(req_id, req_meta_update, self.current_layer, key, value))
1146+
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
1147+
req_id, req_meta_update)
1148+
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
1149+
req_meta_update.local_block_ids = local_block_ids_list[pcp_dcp_rank]
1150+
req_meta_update.remote_block_ids = remote_block_ids_list[pcp_dcp_rank]
1151+
self.kv_send_layer_thread.send_queue.put(
1152+
(req_id, req_meta_update, self.current_layer, key, value))
9691153
self.current_layer += 1
9701154

9711155
def _get_remote_socket(

0 commit comments

Comments
 (0)