2525from vllm .config import VllmConfig
2626from vllm .distributed .kv_transfer .kv_connector .v1 .base import (
2727 KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
28- from vllm .distributed .parallel_state import (get_tensor_model_parallel_rank ,
29- get_tp_group , get_world_group )
28+ from vllm .distributed .parallel_state import (
29+ get_decode_context_model_parallel_rank ,
30+ get_decode_context_model_parallel_world_size , get_pcp_group ,
31+ get_tensor_model_parallel_rank , get_tp_group , get_world_group )
32+
3033from vllm .logger import logger
3134from vllm .utils .network_utils import get_ip , make_zmq_path , make_zmq_socket
3235from vllm .v1 .core .sched .output import SchedulerOutput
@@ -65,6 +68,10 @@ class ReqMeta:
6568 remote_te_rpc_port : Optional [int ]
6669 remote_kv_caches_base_addr : Optional [list [int ]]
6770 metaserver : Optional [str ]
71+ remote_pcp_size : int
72+ remote_dcp_size : int
73+ remote_pcp_rank : int
74+ remote_dcp_rank : int
6875
6976
7077@dataclass
@@ -294,7 +301,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
294301 def __init__ (self , tp_rank : int , side_channel_port : int , tp_size : int ,
295302 pd_head_ratio : int , local_engine_id : str ,
296303 metadata : MooncakeAgentMetadata ,
297- ready_event : threading .Event ):
304+ ready_event : threading .Event , pcp_rank : int ):
298305 super ().__init__ (daemon = True , name = "KVCacheRecvingLayerThread" )
299306 self .tp_rank = tp_rank
300307 self .tp_size = tp_size
@@ -307,6 +314,7 @@ def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int,
307314 self .task_tracker = dict [str , int ]()
308315 self .ready_event = ready_event
309316 self .metadata = metadata
317+ self .pcp_rank = pcp_rank
310318
311319 def get_and_clear_finished_requests (self ) -> set [str ]:
312320 """
@@ -328,7 +336,8 @@ def update_task(self, req_id):
328336
329337 def run (self ):
330338 """Run the thread to handle KV cache transfer requests."""
331- handshake_port = self .side_channel_port + self .tp_rank
339+ handshake_port = self .side_channel_port + self .pcp_rank * self .tp_size \
340+ + self .tp_rank
332341 path = make_zmq_path ("tcp" , self .side_channel_host , handshake_port )
333342 logger .info ("Starting listening on path: %s" , path )
334343 encoder = msgspec .msgpack .Encoder ()
@@ -389,6 +398,10 @@ def add_new_req(self,
389398 remote_kv_caches_base_addr = kv_transfer_params .get (
390399 "remote_kv_caches_base_addr" , None ),
391400 metaserver = kv_transfer_params .get ("metaserver" , None ),
401+ remote_pcp_size = kv_transfer_params .get ("remote_pcp_size" , 1 ),
402+ remote_dcp_size = kv_transfer_params .get ("remote_dcp_size" , 1 ),
403+ remote_pcp_rank = kv_transfer_params .get ("remote_pcp_rank" , 1 ),
404+ remote_dcp_rank = kv_transfer_params .get ("remote_dcp_rank" , 1 ),
392405 )
393406
394407
@@ -497,12 +510,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
497510 logger .info ("Initializing Mooncake Scheduler %s" , engine_id )
498511
499512 self .side_channel_host = get_ip ()
513+ self .pcp_size = vllm_config .parallel_config .prefill_context_parallel_size
514+ self .dcp_size = vllm_config .parallel_config .decode_context_parallel_size
500515
501516 # Handshake base port
502517 self .side_channel_port = (
503518 vllm_config .kv_transfer_config .kv_port +
504519 vllm_config .parallel_config .data_parallel_rank *
505- vllm_config .parallel_config .tensor_parallel_size )
520+ vllm_config .parallel_config .tensor_parallel_size * self . pcp_size )
506521
507522 # Requests that need to start recv.
508523 # New requests are added by update_state_after_alloc in
@@ -673,6 +688,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
673688 self .tp_group = get_tp_group ()
674689 self .kv_caches : dict [str , torch .Tensor ] = {}
675690 self .side_channel_host = get_ip ()
691+ self .pcp_size = get_pcp_group ().world_size
692+ self .pcp_rank = get_pcp_group (
693+ ).rank_in_group if self .pcp_size > 1 else 0
694+ self .dcp_size = get_decode_context_model_parallel_world_size ()
695+ self .dcp_rank = get_decode_context_model_parallel_rank (
696+ ) if self .dcp_size > 1 else 0
697+
676698 self .total_layers = vllm_config .model_config .get_num_layers (
677699 vllm_config .parallel_config )
678700
@@ -685,8 +707,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
685707 self .side_channel_port = (
686708 vllm_config .kv_transfer_config .kv_port +
687709 vllm_config .parallel_config .data_parallel_rank *
688- vllm_config .parallel_config .tensor_parallel_size )
689- self .handshake_port = self .side_channel_port + self .tp_rank
710+ vllm_config .parallel_config .tensor_parallel_size * self . pcp_size )
711+ self .handshake_port = self .side_channel_port + self .pcp_rank * self . tp_size + self . tp_rank
690712 self .sockets : dict = {}
691713 logger .info ("Initializing Mooncake work %s" , engine_id )
692714 self .engine = global_te .get_transfer_engine (self .side_channel_host ,
@@ -720,6 +742,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
720742 deque )
721743 self .remote_poller = zmq .Poller () # type: ignore
722744 self .timeout = 1.0 # seconds
745+ self .local_remote_block_port_mapping = None
746+ self .num_key_value_heads = self .vllm_config .model_config .hf_config .num_key_value_heads
723747
724748 def _get_prefill_decode_size (self , vllm_config : VllmConfig ):
725749 # get prefill tp and dp size from extra config
@@ -830,7 +854,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
830854 ready_event = threading .Event ()
831855 self .kv_recv_layer_thread = KVCacheRecvingLayerThread (
832856 self .tp_rank , self .side_channel_port , self .tp_size ,
833- self .pd_head_ratio , self .engine_id , metadata , ready_event )
857+ self .pd_head_ratio , self .engine_id , metadata , ready_event , self . pcp_rank )
834858 self .kv_recv_layer_thread .start ()
835859 ready_event .wait ()
836860
@@ -860,6 +884,179 @@ def get_finished(self) -> tuple[set[str], set[str]]:
860884 "requests: %d" , 0 , len (done_recving ))
861885 return set (), done_recving
862886
887+ def _get_remote_tp_rank (self , req_id : str ) -> List [int ]:
888+ return self ._get_remote_tp_ranks_for_req (req_id )[self .tp_rank ]
889+
890+ def _get_remote_tp_ranks_for_req (self , req_id : str ) -> List [List [int ]]:
891+ if self ._prefill_tp_size == self ._decode_tp_size :
892+ result = list (map (lambda x : [x ], range (self ._decode_tp_size )))
893+ return result
894+
895+ sampled_nums = []
896+ ori_data = np .arange (self ._decode_tp_size )
897+
898+ group_size = self ._decode_tp_size // self ._prefill_tp_size
899+ for i in range (self ._prefill_tp_size ):
900+ ori_data_slice = ori_data [i * group_size :(i + 1 ) * group_size ]
901+ sampled_nums .append (ori_data_slice .tolist ())
902+ return sampled_nums
903+
904+ def _get_kv_split_metadata (
905+ self ,
906+ req_id : str ,
907+ meta : ReqMeta ,
908+ ) -> tuple [list [list [int ]], list [list [int ]], list [list [int ]]]:
909+ """
910+ In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
911+ Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
912+ """
913+ if meta .remote_pcp_size * meta .remote_dcp_size * self .pcp_size * self .dcp_size == 1 :
914+ choosen_rank_list = self ._get_remote_tp_rank (req_id )
915+ remote_handshake_port_list = [[
916+ x + meta .remote_port for x in choosen_rank_list
917+ ]]
918+ local_block_ids_list , remote_block_ids_list = [
919+ meta .local_block_ids
920+ ], [meta .remote_block_ids ]
921+ return remote_handshake_port_list , local_block_ids_list , remote_block_ids_list
922+
923+ ## TODO: add context_parallel_parameters_check
924+
925+ def get_kv_head_groups (tp_size ):
926+ if self .use_mla :
927+ kv_head_groups = []
928+ kv_head_ids = [0 ]
929+ kv_head_groups .append (tuple (kv_head_ids ))
930+ return kv_head_groups
931+ if self .num_key_value_heads // tp_size >= 1 :
932+ kv_head_groups = []
933+ for tp_rank in range (tp_size ):
934+ kv_head_ids = [head_idx + tp_rank * (self .num_key_value_heads // tp_size ) \
935+ for head_idx in range (self .num_key_value_heads // tp_size )]
936+ kv_head_groups .append (tuple (kv_head_ids ))
937+ return kv_head_groups
938+ if tp_size // self .num_key_value_heads > 1 :
939+ kv_head_groups = []
940+ for kv_head_id in range (self .num_key_value_heads ):
941+ kv_head_groups .append (tuple ([kv_head_id ]))
942+ return kv_head_groups
943+
944+ def get_cp_group_meta (tp_size , pcp_size , dcp_size , port_base ):
945+ # key is kv_head_group, value is cp_groups and which cp_groups to select
946+ cp_group_meta = {}
947+ kv_head_groups = get_kv_head_groups (tp_size )
948+ dcp_repeat_num = tp_size // len (kv_head_groups ) // dcp_size
949+
950+ for kv_head_group_idx , kv_head_group in enumerate (kv_head_groups ):
951+ if kv_head_group not in cp_group_meta :
952+ cp_group_meta [kv_head_group ] = {}
953+ cp_group_meta [kv_head_group ]['cp_groups' ] = []
954+ cp_group_meta [kv_head_group ]['select_cp_groups_id' ] = 0
955+ kv_head_group_offset = tp_size // len (kv_head_groups ) * kv_head_group_idx
956+ for dcp_repeat_idx in range (dcp_repeat_num ):
957+ # len(cp_group) == pcp_size * dcp_size
958+ cp_group = []
959+ dcp_repeat_offset = dcp_size * dcp_repeat_idx
960+ for pcp_rank in range (pcp_size ):
961+ pcp_rank_offset = tp_size * pcp_rank
962+ for dcp_rank in range (dcp_size ):
963+ cp_group .append (dcp_rank + port_base + pcp_rank_offset +
964+ dcp_repeat_offset + kv_head_group_offset )
965+ cp_group_meta [kv_head_group ]['cp_groups' ].append (cp_group )
966+
967+ return cp_group_meta
968+
969+ def get_local_remote_block_port_mapping ():
970+ p_node_cp_group_meta = get_cp_group_meta (self .tp_size , self .pcp_size ,
971+ self .dcp_size , self .side_channel_port )
972+ d_node_cp_group_meta = get_cp_group_meta (self ._decode_tp_size , meta .remote_pcp_size ,
973+ meta .remote_dcp_size , meta .remote_port )
974+ local_remote_block_port_mappings = {}
975+ for p_node_head_key in p_node_cp_group_meta .keys ():
976+ for d_node_head_key in d_node_cp_group_meta .keys ():
977+ if not set (d_node_head_key ).issubset (set (p_node_head_key )):
978+ continue
979+ d_node_head_group = d_node_cp_group_meta [d_node_head_key ]
980+ p_node_head_group = p_node_cp_group_meta [p_node_head_key ]
981+ for p_cp_group in p_node_head_group ['cp_groups' ]:
982+ select_cp_groups_id = d_node_head_group ['select_cp_groups_id' ]
983+ d_cp_groups = d_node_head_group ['cp_groups' ]
984+ d_cp_group = d_cp_groups [select_cp_groups_id ]
985+ d_node_head_group ['select_cp_groups_id' ] = select_cp_groups_id + 1 \
986+ if select_cp_groups_id + 1 < len (d_cp_groups ) else 0
987+ for p_idx , p_port in enumerate (p_cp_group ):
988+ if p_port not in local_remote_block_port_mappings :
989+ local_remote_block_port_mappings [p_port ] = []
990+ d_port_remote_list = []
991+ for d_idx , d_port in enumerate (d_cp_group ):
992+ if p_idx % len (d_cp_group ) == d_idx :
993+ d_port_remote_list .append (d_port )
994+ local_remote_block_port_mappings [p_port ].append (d_port_remote_list )
995+ local_remote_block_port_mapping = local_remote_block_port_mappings [self .handshake_port ]
996+
997+ logger .info (
998+ "p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. "
999+ "local_remote_block_port_mappings is:: %s. " , p_node_cp_group_meta ,
1000+ d_node_cp_group_meta , local_remote_block_port_mappings )
1001+
1002+ return local_remote_block_port_mapping
1003+
1004+ if self .local_remote_block_port_mapping is None :
1005+ self .local_remote_block_port_mapping = get_local_remote_block_port_mapping ()
1006+
1007+
1008+ local_cp_size = self .pcp_size * self .dcp_size
1009+ prompt_block_nums = len (meta .local_block_ids )
1010+ local_block_nums_all = [prompt_block_nums // local_cp_size ] * local_cp_size
1011+ num_remain_blocks = prompt_block_nums % local_cp_size
1012+ for i in range (num_remain_blocks ):
1013+ local_block_nums_all [i ] += 1
1014+ num_remain_blocks = (num_remain_blocks + local_cp_size - 1 ) % local_cp_size
1015+
1016+ # make sure the last block (which may be unfull) of P nodes is put to the last block of D node
1017+ local_block_nums = []
1018+ final_block_idx = None
1019+
1020+ remote_dcp_rank = meta .remote_dcp_rank
1021+ remote_pcp_rank = meta .remote_pcp_rank
1022+ remote_cp_rank = remote_dcp_rank + remote_pcp_rank * meta .remote_dcp_size
1023+ remote_cp_size = meta .remote_pcp_size * meta .remote_dcp_size
1024+ for cp_rank , block_num in enumerate (local_block_nums_all ):
1025+ if cp_rank % remote_cp_size == remote_cp_rank :
1026+ if num_remain_blocks == cp_rank :
1027+ final_block_idx = len (local_block_nums )
1028+ local_block_nums .append (block_num )
1029+
1030+ if final_block_idx is not None :
1031+ final_block_num = local_block_nums .pop (final_block_idx )
1032+ local_block_nums .append (final_block_num )
1033+ for mapping in self .local_remote_block_port_mapping :
1034+ final_block_port = mapping .pop (final_block_idx )
1035+ mapping .append (final_block_port )
1036+
1037+ remote_handshake_port_list , local_block_ids_list , remote_block_ids_list = [], [], []
1038+ for idx in range (len (self .local_remote_block_port_mapping [0 ])):
1039+ mapping_list = []
1040+ for mapping in self .local_remote_block_port_mapping :
1041+ mapping_list .append (mapping [idx ])
1042+ remote_handshake_port_list .append (mapping_list )
1043+
1044+ # the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
1045+ # such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
1046+ # remote_handshake_port_list[[30000],[30001],[30004],[30005]]
1047+ # D rank will get remote block 1 in port 30004 and save it in local block 5
1048+ remote_block_offset = 0
1049+ for local_kv_id in range (len (remote_handshake_port_list )):
1050+ num_blocks_to_push = local_block_nums [local_kv_id ]
1051+ local_block_ids_list .append (
1052+ meta .local_block_ids [:num_blocks_to_push ])
1053+ remote_block_ids_list .append (
1054+ meta .remote_block_ids [remote_block_offset :remote_block_offset +
1055+ num_blocks_to_push ])
1056+ remote_block_offset += num_blocks_to_push
1057+
1058+ return remote_handshake_port_list , local_block_ids_list , remote_block_ids_list
1059+
8631060 def start_load_kv (self , metadata : MooncakeLayerwiseConnectorMetadata ):
8641061 """Start loading KV blocks from remote engine."""
8651062 self .current_layer = 0
@@ -880,6 +1077,8 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata):
8801077 remote_engine_id = self .engine_id ,
8811078 remote_host = self .side_channel_host ,
8821079 remote_port = self .side_channel_port ,
1080+ remote_pcp_rank = self .pcp_rank ,
1081+ remote_dcp_rank = self .dcp_rank ,
8831082 )
8841083 future = self .executor .submit (
8851084 self ._access_metaserver ,
@@ -964,8 +1163,14 @@ def sort_kv_cache(input_kv: list[list[int]]):
9641163 f"Add request { req_id } to kv send layer thread. { req_meta_update = } "
9651164 )
9661165 assert self .kv_send_layer_thread is not None
967- self .kv_send_layer_thread .send_queue .put (
968- (req_id , req_meta_update , self .current_layer , key , value ))
1166+ remote_handshake_port_list , local_block_ids_list , remote_block_ids_list = self ._get_kv_split_metadata (
1167+ req_id , req_meta_update )
1168+ for pcp_dcp_rank in range (len (remote_handshake_port_list )):
1169+ req_meta_for_queue = copy .copy (req_meta_update )
1170+ req_meta_for_queue .local_block_ids = local_block_ids_list [pcp_dcp_rank ]
1171+ req_meta_for_queue .remote_block_ids = remote_block_ids_list [pcp_dcp_rank ]
1172+ self .kv_send_layer_thread .send_queue .put (
1173+ (req_id , req_meta_for_queue , self .current_layer , key , value ))
9691174 self .current_layer += 1
9701175
9711176 def _get_remote_socket (
0 commit comments