Skip to content

Commit 1324a5b

Browse files
author
kongsiyuan
committed
PCP/DCP 适配mooncake layerwise connector
1 parent 18221c0 commit 1324a5b

File tree

1 file changed

+215
-10
lines changed

1 file changed

+215
-10
lines changed

vllm_ascend/distributed/mooncake_layerwise_connector.py

Lines changed: 215 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@
2525
from vllm.config import VllmConfig
2626
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
2727
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
28-
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
29-
get_tp_group, get_world_group)
28+
from vllm.distributed.parallel_state import (
29+
get_decode_context_model_parallel_rank,
30+
get_decode_context_model_parallel_world_size, get_pcp_group,
31+
get_tensor_model_parallel_rank, get_tp_group, get_world_group)
32+
3033
from vllm.logger import logger
3134
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
3235
from vllm.v1.core.sched.output import SchedulerOutput
@@ -65,6 +68,10 @@ class ReqMeta:
6568
remote_te_rpc_port: Optional[int]
6669
remote_kv_caches_base_addr: Optional[list[int]]
6770
metaserver: Optional[str]
71+
remote_pcp_size: int
72+
remote_dcp_size: int
73+
remote_pcp_rank: int
74+
remote_dcp_rank: int
6875

6976

7077
@dataclass
@@ -294,7 +301,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
294301
def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int,
295302
pd_head_ratio: int, local_engine_id: str,
296303
metadata: MooncakeAgentMetadata,
297-
ready_event: threading.Event):
304+
ready_event: threading.Event, pcp_rank: int):
298305
super().__init__(daemon=True, name="KVCacheRecvingLayerThread")
299306
self.tp_rank = tp_rank
300307
self.tp_size = tp_size
@@ -307,6 +314,7 @@ def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int,
307314
self.task_tracker = dict[str, int]()
308315
self.ready_event = ready_event
309316
self.metadata = metadata
317+
self.pcp_rank = pcp_rank
310318

311319
def get_and_clear_finished_requests(self) -> set[str]:
312320
"""
@@ -328,7 +336,8 @@ def update_task(self, req_id):
328336

329337
def run(self):
330338
"""Run the thread to handle KV cache transfer requests."""
331-
handshake_port = self.side_channel_port + self.tp_rank
339+
handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size \
340+
+ self.tp_rank
332341
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
333342
logger.info("Starting listening on path: %s", path)
334343
encoder = msgspec.msgpack.Encoder()
@@ -389,6 +398,10 @@ def add_new_req(self,
389398
remote_kv_caches_base_addr=kv_transfer_params.get(
390399
"remote_kv_caches_base_addr", None),
391400
metaserver=kv_transfer_params.get("metaserver", None),
401+
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
402+
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
403+
remote_pcp_rank=kv_transfer_params.get("remote_pcp_rank", 1),
404+
remote_dcp_rank=kv_transfer_params.get("remote_dcp_rank", 1),
392405
)
393406

394407

@@ -497,12 +510,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
497510
logger.info("Initializing Mooncake Scheduler %s", engine_id)
498511

499512
self.side_channel_host = get_ip()
513+
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
514+
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
500515

501516
# Handshake base port
502517
self.side_channel_port = (
503518
vllm_config.kv_transfer_config.kv_port +
504519
vllm_config.parallel_config.data_parallel_rank *
505-
vllm_config.parallel_config.tensor_parallel_size)
520+
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
506521

507522
# Requests that need to start recv.
508523
# New requests are added by update_state_after_alloc in
@@ -673,6 +688,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
673688
self.tp_group = get_tp_group()
674689
self.kv_caches: dict[str, torch.Tensor] = {}
675690
self.side_channel_host = get_ip()
691+
self.pcp_size = get_pcp_group().world_size
692+
self.pcp_rank = get_pcp_group(
693+
).rank_in_group if self.pcp_size > 1 else 0
694+
self.dcp_size = get_decode_context_model_parallel_world_size()
695+
self.dcp_rank = get_decode_context_model_parallel_rank(
696+
) if self.dcp_size > 1 else 0
697+
676698
self.total_layers = vllm_config.model_config.get_num_layers(
677699
vllm_config.parallel_config)
678700

@@ -685,8 +707,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
685707
self.side_channel_port = (
686708
vllm_config.kv_transfer_config.kv_port +
687709
vllm_config.parallel_config.data_parallel_rank *
688-
vllm_config.parallel_config.tensor_parallel_size)
689-
self.handshake_port = self.side_channel_port + self.tp_rank
710+
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
711+
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
690712
self.sockets: dict = {}
691713
logger.info("Initializing Mooncake work %s", engine_id)
692714
self.engine = global_te.get_transfer_engine(self.side_channel_host,
@@ -720,6 +742,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
720742
deque)
721743
self.remote_poller = zmq.Poller() # type: ignore
722744
self.timeout = 1.0 # seconds
745+
self.local_remote_block_port_mapping = None
746+
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
723747

724748
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
725749
# get prefill tp and dp size from extra config
@@ -830,7 +854,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
830854
ready_event = threading.Event()
831855
self.kv_recv_layer_thread = KVCacheRecvingLayerThread(
832856
self.tp_rank, self.side_channel_port, self.tp_size,
833-
self.pd_head_ratio, self.engine_id, metadata, ready_event)
857+
self.pd_head_ratio, self.engine_id, metadata, ready_event, self.pcp_rank)
834858
self.kv_recv_layer_thread.start()
835859
ready_event.wait()
836860

@@ -860,6 +884,179 @@ def get_finished(self) -> tuple[set[str], set[str]]:
860884
"requests: %d", 0, len(done_recving))
861885
return set(), done_recving
862886

887+
def _get_remote_tp_rank(self, req_id: str) -> List[int]:
888+
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
889+
890+
def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]:
891+
if self._prefill_tp_size == self._decode_tp_size:
892+
result = list(map(lambda x: [x], range(self._decode_tp_size)))
893+
return result
894+
895+
sampled_nums = []
896+
ori_data = np.arange(self._decode_tp_size)
897+
898+
group_size = self._decode_tp_size // self._prefill_tp_size
899+
for i in range(self._prefill_tp_size):
900+
ori_data_slice = ori_data[i * group_size:(i + 1) * group_size]
901+
sampled_nums.append(ori_data_slice.tolist())
902+
return sampled_nums
903+
904+
def _get_kv_split_metadata(
905+
self,
906+
req_id: str,
907+
meta: ReqMeta,
908+
) -> tuple[list[list[int]], list[list[int]], list[list[int]]]:
909+
"""
910+
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
911+
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
912+
"""
913+
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
914+
choosen_rank_list = self._get_remote_tp_rank(req_id)
915+
remote_handshake_port_list = [[
916+
x + meta.remote_port for x in choosen_rank_list
917+
]]
918+
local_block_ids_list, remote_block_ids_list = [
919+
meta.local_block_ids
920+
], [meta.remote_block_ids]
921+
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
922+
923+
## TODO: add context_parallel_parameters_check
924+
925+
def get_kv_head_groups(tp_size):
926+
if self.use_mla:
927+
kv_head_groups = []
928+
kv_head_ids = [0]
929+
kv_head_groups.append(tuple(kv_head_ids))
930+
return kv_head_groups
931+
if self.num_key_value_heads // tp_size >= 1:
932+
kv_head_groups = []
933+
for tp_rank in range(tp_size):
934+
kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \
935+
for head_idx in range(self.num_key_value_heads // tp_size)]
936+
kv_head_groups.append(tuple(kv_head_ids))
937+
return kv_head_groups
938+
if tp_size // self.num_key_value_heads > 1:
939+
kv_head_groups = []
940+
for kv_head_id in range(self.num_key_value_heads):
941+
kv_head_groups.append(tuple([kv_head_id]))
942+
return kv_head_groups
943+
944+
def get_cp_group_meta(tp_size, pcp_size, dcp_size, port_base):
945+
# key is kv_head_group, value is cp_groups and which cp_groups to select
946+
cp_group_meta = {}
947+
kv_head_groups = get_kv_head_groups(tp_size)
948+
dcp_repeat_num = tp_size // len(kv_head_groups) // dcp_size
949+
950+
for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups):
951+
if kv_head_group not in cp_group_meta:
952+
cp_group_meta[kv_head_group] = {}
953+
cp_group_meta[kv_head_group]['cp_groups'] = []
954+
cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0
955+
kv_head_group_offset = tp_size // len(kv_head_groups) * kv_head_group_idx
956+
for dcp_repeat_idx in range(dcp_repeat_num):
957+
# len(cp_group) == pcp_size * dcp_size
958+
cp_group = []
959+
dcp_repeat_offset = dcp_size * dcp_repeat_idx
960+
for pcp_rank in range(pcp_size):
961+
pcp_rank_offset = tp_size * pcp_rank
962+
for dcp_rank in range(dcp_size):
963+
cp_group.append(dcp_rank + port_base + pcp_rank_offset +
964+
dcp_repeat_offset + kv_head_group_offset)
965+
cp_group_meta[kv_head_group]['cp_groups'].append(cp_group)
966+
967+
return cp_group_meta
968+
969+
def get_local_remote_block_port_mapping():
970+
p_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size,
971+
self.dcp_size, self.side_channel_port)
972+
d_node_cp_group_meta = get_cp_group_meta(self._decode_tp_size, meta.remote_pcp_size,
973+
meta.remote_dcp_size, meta.remote_port)
974+
local_remote_block_port_mappings = {}
975+
for p_node_head_key in p_node_cp_group_meta.keys():
976+
for d_node_head_key in d_node_cp_group_meta.keys():
977+
if not set(d_node_head_key).issubset(set(p_node_head_key)):
978+
continue
979+
d_node_head_group = d_node_cp_group_meta[d_node_head_key]
980+
p_node_head_group = p_node_cp_group_meta[p_node_head_key]
981+
for p_cp_group in p_node_head_group['cp_groups']:
982+
select_cp_groups_id = d_node_head_group['select_cp_groups_id']
983+
d_cp_groups = d_node_head_group['cp_groups']
984+
d_cp_group = d_cp_groups[select_cp_groups_id]
985+
d_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \
986+
if select_cp_groups_id + 1 < len(d_cp_groups) else 0
987+
for p_idx, p_port in enumerate(p_cp_group):
988+
if p_port not in local_remote_block_port_mappings:
989+
local_remote_block_port_mappings[p_port] = []
990+
d_port_remote_list = []
991+
for d_idx, d_port in enumerate(d_cp_group):
992+
if p_idx % len(d_cp_group) == d_idx:
993+
d_port_remote_list.append(d_port)
994+
local_remote_block_port_mappings[p_port].append(d_port_remote_list)
995+
local_remote_block_port_mapping = local_remote_block_port_mappings[self.handshake_port]
996+
997+
logger.info(
998+
"p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. "
999+
"local_remote_block_port_mappings is:: %s. ", p_node_cp_group_meta,
1000+
d_node_cp_group_meta, local_remote_block_port_mappings)
1001+
1002+
return local_remote_block_port_mapping
1003+
1004+
if self.local_remote_block_port_mapping is None:
1005+
self.local_remote_block_port_mapping = get_local_remote_block_port_mapping()
1006+
1007+
1008+
local_cp_size = self.pcp_size * self.dcp_size
1009+
prompt_block_nums = len(meta.local_block_ids)
1010+
local_block_nums_all = [prompt_block_nums // local_cp_size] * local_cp_size
1011+
num_remain_blocks = prompt_block_nums % local_cp_size
1012+
for i in range(num_remain_blocks):
1013+
local_block_nums_all[i] += 1
1014+
num_remain_blocks = (num_remain_blocks + local_cp_size - 1) % local_cp_size
1015+
1016+
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
1017+
local_block_nums = []
1018+
final_block_idx = None
1019+
1020+
remote_dcp_rank = meta.remote_dcp_rank
1021+
remote_pcp_rank = meta.remote_pcp_rank
1022+
remote_cp_rank = remote_dcp_rank + remote_pcp_rank * meta.remote_dcp_size
1023+
remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size
1024+
for cp_rank, block_num in enumerate(local_block_nums_all):
1025+
if cp_rank % remote_cp_size == remote_cp_rank:
1026+
if num_remain_blocks == cp_rank:
1027+
final_block_idx = len(local_block_nums)
1028+
local_block_nums.append(block_num)
1029+
1030+
if final_block_idx is not None:
1031+
final_block_num = local_block_nums.pop(final_block_idx)
1032+
local_block_nums.append(final_block_num)
1033+
for mapping in self.local_remote_block_port_mapping:
1034+
final_block_port = mapping.pop(final_block_idx)
1035+
mapping.append(final_block_port)
1036+
1037+
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = [], [], []
1038+
for idx in range(len(self.local_remote_block_port_mapping[0])):
1039+
mapping_list = []
1040+
for mapping in self.local_remote_block_port_mapping:
1041+
mapping_list.append(mapping[idx])
1042+
remote_handshake_port_list.append(mapping_list)
1043+
1044+
# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
1045+
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
1046+
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
1047+
# D rank will get remote block 1 in port 30004 and save it in local block 5
1048+
remote_block_offset = 0
1049+
for local_kv_id in range(len(remote_handshake_port_list)):
1050+
num_blocks_to_push = local_block_nums[local_kv_id]
1051+
local_block_ids_list.append(
1052+
meta.local_block_ids[:num_blocks_to_push])
1053+
remote_block_ids_list.append(
1054+
meta.remote_block_ids[remote_block_offset:remote_block_offset+
1055+
num_blocks_to_push])
1056+
remote_block_offset += num_blocks_to_push
1057+
1058+
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
1059+
8631060
def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata):
8641061
"""Start loading KV blocks from remote engine."""
8651062
self.current_layer = 0
@@ -880,6 +1077,8 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata):
8801077
remote_engine_id=self.engine_id,
8811078
remote_host=self.side_channel_host,
8821079
remote_port=self.side_channel_port,
1080+
remote_pcp_rank=self.pcp_rank,
1081+
remote_dcp_rank=self.dcp_rank,
8831082
)
8841083
future = self.executor.submit(
8851084
self._access_metaserver,
@@ -964,8 +1163,14 @@ def sort_kv_cache(input_kv: list[list[int]]):
9641163
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
9651164
)
9661165
assert self.kv_send_layer_thread is not None
967-
self.kv_send_layer_thread.send_queue.put(
968-
(req_id, req_meta_update, self.current_layer, key, value))
1166+
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
1167+
req_id, req_meta_update)
1168+
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
1169+
req_meta_for_queue = copy.copy(req_meta_update)
1170+
req_meta_for_queue.local_block_ids = local_block_ids_list[pcp_dcp_rank]
1171+
req_meta_for_queue.remote_block_ids = remote_block_ids_list[pcp_dcp_rank]
1172+
self.kv_send_layer_thread.send_queue.put(
1173+
(req_id, req_meta_for_queue, self.current_layer, key, value))
9691174
self.current_layer += 1
9701175

9711176
def _get_remote_socket(

0 commit comments

Comments
 (0)