2525from vllm .config import VllmConfig
2626from vllm .distributed .kv_transfer .kv_connector .v1 .base import (
2727 KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
28- from vllm .distributed .parallel_state import (get_tensor_model_parallel_rank ,
29- get_tp_group , get_world_group )
28+ from vllm .distributed .parallel_state import (
29+ get_decode_context_model_parallel_rank , get_pcp_group ,
30+ get_tensor_model_parallel_rank , get_tp_group , get_world_group )
31+
32+ from vllm .distributed import (get_prefill_context_model_parallel_rank ,
33+ get_prefill_context_model_parallel_world_size )
3034from vllm .logger import logger
3135from vllm .utils .network_utils import get_ip , make_zmq_path , make_zmq_socket
3236from vllm .v1 .core .sched .output import SchedulerOutput
@@ -65,6 +69,8 @@ class ReqMeta:
6569 remote_te_rpc_port : Optional [int ]
6670 remote_kv_caches_base_addr : Optional [list [int ]]
6771 metaserver : Optional [str ]
72+ remote_pcp_size : int
73+ remote_dcp_size : int
6874
6975
7076@dataclass
@@ -294,7 +300,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
294300 def __init__ (self , tp_rank : int , side_channel_port : int , tp_size : int ,
295301 pd_head_ratio : int , local_engine_id : str ,
296302 metadata : MooncakeAgentMetadata ,
297- ready_event : threading .Event ):
303+ ready_event : threading .Event , pcp_rank : int ):
298304 super ().__init__ (daemon = True , name = "KVCacheRecvingLayerThread" )
299305 self .tp_rank = tp_rank
300306 self .tp_size = tp_size
@@ -307,6 +313,7 @@ def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int,
307313 self .task_tracker = dict [str , int ]()
308314 self .ready_event = ready_event
309315 self .metadata = metadata
316+ self .pcp_rank = pcp_rank
310317
311318 def get_and_clear_finished_requests (self ) -> set [str ]:
312319 """
@@ -328,7 +335,8 @@ def update_task(self, req_id):
328335
329336 def run (self ):
330337 """Run the thread to handle KV cache transfer requests."""
331- handshake_port = self .side_channel_port + self .tp_rank
338+ handshake_port = self .side_channel_port + self .pcp_rank * self .tp_size \
339+ + self .tp_rank
332340 path = make_zmq_path ("tcp" , self .side_channel_host , handshake_port )
333341 logger .info ("Starting listening on path: %s" , path )
334342 encoder = msgspec .msgpack .Encoder ()
@@ -389,6 +397,8 @@ def add_new_req(self,
389397 remote_kv_caches_base_addr = kv_transfer_params .get (
390398 "remote_kv_caches_base_addr" , None ),
391399 metaserver = kv_transfer_params .get ("metaserver" , None ),
400+ remote_pcp_size = kv_transfer_params ["remote_pcp_size" ],
401+ remote_dcp_size = kv_transfer_params ["remote_dcp_size" ],
392402 )
393403
394404
@@ -497,12 +507,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
497507 logger .info ("Initializing Mooncake Scheduler %s" , engine_id )
498508
499509 self .side_channel_host = get_ip ()
510+ self .pcp_size = vllm_config .parallel_config .prefill_context_parallel_size
511+ self .dcp_size = vllm_config .parallel_config .decode_context_parallel_size
500512
501513 # Handshake base port
502514 self .side_channel_port = (
503515 vllm_config .kv_transfer_config .kv_port +
504516 vllm_config .parallel_config .data_parallel_rank *
505- vllm_config .parallel_config .tensor_parallel_size )
517+ vllm_config .parallel_config .tensor_parallel_size * self . pcp_size )
506518
507519 # Requests that need to start recv.
508520 # New requests are added by update_state_after_alloc in
@@ -673,6 +685,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
673685 self .tp_group = get_tp_group ()
674686 self .kv_caches : dict [str , torch .Tensor ] = {}
675687 self .side_channel_host = get_ip ()
688+ self .pcp_size = get_pcp_group ().world_size
689+ self .pcp_rank = get_pcp_group (
690+ ).rank_in_group if self .pcp_size > 1 else 0
691+ self .dcp_size = get_prefill_context_model_parallel_world_size ()
692+ self .dcp_rank = get_prefill_context_model_parallel_rank (
693+ ) if self .dcp_size > 1 else 0
694+
676695 self .total_layers = vllm_config .model_config .get_num_layers (
677696 vllm_config .parallel_config )
678697
@@ -686,7 +705,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
686705 vllm_config .kv_transfer_config .kv_port +
687706 vllm_config .parallel_config .data_parallel_rank *
688707 vllm_config .parallel_config .tensor_parallel_size )
689- self .handshake_port = self .side_channel_port + self .tp_rank
708+ self .handshake_port = self .side_channel_port + self .pcp_rank * self . tp_size + self . tp_rank
690709 self .sockets : dict = {}
691710 logger .info ("Initializing Mooncake work %s" , engine_id )
692711 self .engine = global_te .get_transfer_engine (self .side_channel_host ,
@@ -720,6 +739,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
720739 deque )
721740 self .remote_poller = zmq .Poller () # type: ignore
722741 self .timeout = 1.0 # seconds
742+ self .local_remote_block_port_mapping = None
743+ self .num_key_value_heads = self .vllm_config .model_config .hf_config .num_key_value_heads
723744
724745 def _get_prefill_decode_size (self , vllm_config : VllmConfig ):
725746 # get prefill tp and dp size from extra config
@@ -830,7 +851,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
830851 ready_event = threading .Event ()
831852 self .kv_recv_layer_thread = KVCacheRecvingLayerThread (
832853 self .tp_rank , self .side_channel_port , self .tp_size ,
833- self .pd_head_ratio , self .engine_id , metadata , ready_event )
854+ self .pd_head_ratio , self .engine_id , metadata , ready_event , self . pcp_rank )
834855 self .kv_recv_layer_thread .start ()
835856 ready_event .wait ()
836857
@@ -860,6 +881,164 @@ def get_finished(self) -> tuple[set[str], set[str]]:
860881 "requests: %d" , 0 , len (done_recving ))
861882 return set (), done_recving
862883
884+ def _get_kv_split_metadata (
885+ self ,
886+ req_id : str ,
887+ meta : ReqMeta ,
888+ ) -> tuple [list [list [int ]], list [list [int ]], list [list [int ]]]:
889+ """
890+ In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
891+ Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
892+ """
893+ if meta .remote_pcp_size * meta .remote_dcp_size * self .pcp_size * self .dcp_size == 1 :
894+ choosen_rank_list = self .tp_rank
895+ remote_handshake_port_list = [[
896+ x + meta .remote_port for x in choosen_rank_list
897+ ]]
898+ local_block_ids_list , remote_block_ids_list = [
899+ meta .local_block_ids
900+ ], [meta .remote_block_ids ]
901+ return remote_handshake_port_list , local_block_ids_list , remote_block_ids_list
902+
903+ ## TODO: add context_parallel_parameters_check
904+
905+ def get_kv_head_groups (tp_size ):
906+ if self .use_mla :
907+ kv_head_groups = []
908+ kv_head_ids = [0 ]
909+ kv_head_groups .append (tuple (kv_head_ids ))
910+ return kv_head_groups
911+ if self .num_key_value_heads // tp_size >= 1 :
912+ kv_head_groups = []
913+ for tp_rank in range (tp_size ):
914+ kv_head_ids = [head_idx + tp_rank * (self .num_key_value_heads // tp_size ) \
915+ for head_idx in range (self .num_key_value_heads // tp_size )]
916+ kv_head_groups .append (tuple (kv_head_ids ))
917+ return kv_head_groups
918+ if tp_size // self .num_key_value_heads > 1 :
919+ kv_head_groups = []
920+ for kv_head_ids in range (self .num_key_value_heads ):
921+ kv_head_groups .append (tuple ([kv_head_ids ]))
922+ return kv_head_groups
923+
924+ def get_cp_group_meta (tp_size , pcp_size , dcp_size , port_base ):
925+ # key is kv_head_group, value is cp_groups and which cp_groups to select
926+ cp_group_meta = {}
927+ kv_head_groups = get_kv_head_groups (tp_size )
928+ dcp_repeat_num = tp_size // len (kv_head_groups ) // dcp_size
929+
930+ for kv_head_group_idx , kv_head_group in enumerate (kv_head_groups ):
931+ if kv_head_group not in cp_group_meta :
932+ cp_group_meta [kv_head_group ] = {}
933+ cp_group_meta [kv_head_group ]['cp_groups' ] = []
934+ cp_group_meta [kv_head_group ]['select_cp_groups_id' ] = 0
935+ kv_head_group_offset = tp_size // len (kv_head_groups ) * kv_head_group_idx
936+ for dcp_repeat_idx in range (dcp_repeat_num ):
937+ # len(cp_group) == pcp_size * dcp_size
938+ cp_group = []
939+ dcp_repeat_offset = dcp_size * dcp_repeat_idx
940+ for pcp_rank in range (pcp_size ):
941+ pcp_rank_offset = tp_size * pcp_rank
942+ for dcp_rank in range (dcp_size ):
943+ cp_group .append (dcp_rank + port_base + pcp_rank_offset +
944+ dcp_repeat_offset + kv_head_group_offset )
945+ cp_group_meta [kv_head_group ]['cp_groups' ].append (cp_group )
946+
947+ return cp_group_meta
948+
949+ def get_local_remote_block_port_mapping ():
950+ p_node_cp_group_meta = get_cp_group_meta (self .tp_size , self .pcp_size ,
951+ self .dcp_size , self .side_channel_port )
952+ d_node_cp_group_meta = get_cp_group_meta (self ._decode_tp_size , meta .remote_pcp_size ,
953+ meta .remote_dcp_size , meta .remote_port )
954+ local_remote_block_port_mappings = {}
955+ for p_node_head_key in p_node_cp_group_meta .keys ():
956+ for d_node_head_key in d_node_cp_group_meta .keys ():
957+ if not set (d_node_head_key ).issubset (set (p_node_head_key )):
958+ continue
959+ d_node_head_group = d_node_cp_group_meta [d_node_head_key ]
960+ p_node_head_group = p_node_cp_group_meta [p_node_head_key ]
961+ for p_cp_group in p_node_head_group ['cp_groups' ]:
962+ select_cp_groups_id = d_node_head_group ['select_cp_groups_id' ]
963+ d_cp_groups = d_node_head_group ['cp_groups' ]
964+ d_cp_group = d_cp_groups [select_cp_groups_id ]
965+ d_node_head_group ['select_cp_groups_id' ] = select_cp_groups_id + 1 \
966+ if select_cp_groups_id + 1 < len (d_cp_groups ) else 0
967+ for p_idx , p_port in enumerate (p_cp_group ):
968+ if p_port not in local_remote_block_port_mappings :
969+ local_remote_block_port_mappings [p_port ] = []
970+ d_port_remote_list = []
971+ for d_idx , d_port in enumerate (d_cp_group ):
972+ if d_idx % len (d_cp_group ) == p_idx :
973+ d_port_remote_list .append (d_port )
974+ local_remote_block_port_mappings [p_port ].append (d_port_remote_list )
975+ local_remote_block_port_mapping = local_remote_block_port_mappings [self .handshake_port ]
976+
977+ logger .info (
978+ "p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. "
979+ "local_remote_block_port_mappings is:: %s. " , p_node_cp_group_meta ,
980+ d_node_cp_group_meta , local_remote_block_port_mappings )
981+
982+ return local_remote_block_port_mapping
983+
984+ if self .local_remote_block_port_mapping is None :
985+ self .local_remote_block_port_mapping = get_local_remote_block_port_mapping ()
986+
987+
988+ local_cp_size = self .pcp_size * self .dcp_size
989+ prompt_block_nums = len (meta .local_block_ids )
990+ local_block_nums_all = [prompt_block_nums // local_cp_size ] * local_cp_size
991+ num_remain_blocks = prompt_block_nums % local_cp_size
992+ for i in range (num_remain_blocks ):
993+ local_block_nums_all [i ] += 1
994+ num_remain_blocks = (num_remain_blocks + local_cp_size - 1 ) % local_cp_size
995+
996+ # make sure the last block (which may be unfull) of P nodes is put to the last block of D node
997+ local_block_nums = []
998+ final_block_idx = None
999+
1000+ remote_dcp_rank = get_decode_context_model_parallel_rank (
1001+ ) if meta .remote_dcp_size > 1 else 0
1002+ remote_pcp_rank = get_pcp_group (
1003+ ).rank_in_group if meta .remote_pcp_size else 0
1004+ remote_cp_rank = remote_dcp_rank + remote_pcp_rank * meta .remote_dcp_size
1005+ remote_cp_size = meta .remote_pcp_size * meta .remote_dcp_size
1006+ for cp_rank , block_num in enumerate (local_block_nums_all ):
1007+ if cp_rank % remote_cp_size == remote_cp_rank :
1008+ if num_remain_blocks == cp_rank :
1009+ final_block_idx = len (local_block_nums )
1010+ local_block_nums .append (block_num )
1011+
1012+ if final_block_idx is not None :
1013+ final_block_num = local_block_nums .pop (final_block_idx )
1014+ local_block_nums .append (final_block_num )
1015+ for mapping in self .local_remote_block_port_mapping :
1016+ final_block_port = mapping .pop (final_block_idx )
1017+ mapping .append (final_block_port )
1018+
1019+ remote_handshake_port_list , local_block_ids_list , remote_block_ids_list = [], [], []
1020+ for idx in range (len (self .local_remote_block_port_mapping [0 ])):
1021+ mapping_list = []
1022+ for mapping in self .local_remote_block_port_mapping :
1023+ mapping_list .append (mapping [idx ])
1024+ remote_handshake_port_list .append (mapping_list )
1025+
1026+ # the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
1027+ # such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
1028+ # remote_handshake_port_list[[30000],[30001],[30004],[30005]]
1029+ # D rank will get remote block 1 in port 30004 and save it in local block 5
1030+ remote_block_offset = 0
1031+ for local_kv_id in range (len (remote_handshake_port_list )):
1032+ num_blocks_to_push = local_block_nums [local_kv_id ]
1033+ local_block_ids_list .append (
1034+ meta .local_block_ids [:num_blocks_to_push ])
1035+ remote_block_ids_list .append (
1036+ meta .remote_block_ids [remote_block_offset :remote_block_offset +
1037+ num_blocks_to_push ])
1038+ remote_block_offset += num_blocks_to_push
1039+
1040+ return remote_handshake_port_list , local_block_ids_list , remote_block_ids_list
1041+
8631042 def start_load_kv (self , metadata : MooncakeLayerwiseConnectorMetadata ):
8641043 """Start loading KV blocks from remote engine."""
8651044 self .current_layer = 0
@@ -964,8 +1143,13 @@ def sort_kv_cache(input_kv: list[list[int]]):
9641143 f"Add request { req_id } to kv send layer thread. { req_meta_update = } "
9651144 )
9661145 assert self .kv_send_layer_thread is not None
967- self .kv_send_layer_thread .send_queue .put (
968- (req_id , req_meta_update , self .current_layer , key , value ))
1146+ remote_handshake_port_list , local_block_ids_list , remote_block_ids_list = self ._get_kv_split_metadata (
1147+ req_id , req_meta_update )
1148+ for pcp_dcp_rank in range (len (remote_handshake_port_list )):
1149+ req_meta_update .local_block_ids = local_block_ids_list [pcp_dcp_rank ]
1150+ req_meta_update .remote_block_ids = remote_block_ids_list [pcp_dcp_rank ]
1151+ self .kv_send_layer_thread .send_queue .put (
1152+ (req_id , req_meta_update , self .current_layer , key , value ))
9691153 self .current_layer += 1
9701154
9711155 def _get_remote_socket (
0 commit comments