diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index d1351049726..afbc67c35c2 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -25,8 +25,11 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, - get_tp_group, get_world_group) +from vllm.distributed.parallel_state import ( + get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, get_pcp_group, + get_tensor_model_parallel_rank, get_tp_group, get_world_group) + from vllm.logger import logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput @@ -65,6 +68,10 @@ class ReqMeta: remote_te_rpc_port: Optional[int] remote_kv_caches_base_addr: Optional[list[int]] metaserver: Optional[str] + remote_pcp_size: int + remote_dcp_size: int + remote_pcp_rank: int + remote_dcp_rank: int @dataclass @@ -294,7 +301,7 @@ class KVCacheRecvingLayerThread(threading.Thread): def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int, pd_head_ratio: int, local_engine_id: str, metadata: MooncakeAgentMetadata, - ready_event: threading.Event): + ready_event: threading.Event, pcp_rank: int): super().__init__(daemon=True, name="KVCacheRecvingLayerThread") self.tp_rank = tp_rank self.tp_size = tp_size @@ -307,6 +314,7 @@ def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int, self.task_tracker = dict[str, int]() self.ready_event = ready_event self.metadata = metadata + self.pcp_rank = pcp_rank def get_and_clear_finished_requests(self) -> set[str]: """ @@ -328,7 +336,8 @@ def update_task(self, req_id): def run(self): """Run the thread to handle KV cache transfer requests.""" - handshake_port = self.side_channel_port + self.tp_rank + handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size \ + + self.tp_rank path = make_zmq_path("tcp", self.side_channel_host, handshake_port) logger.info("Starting listening on path: %s", path) encoder = msgspec.msgpack.Encoder() @@ -389,6 +398,10 @@ def add_new_req(self, remote_kv_caches_base_addr=kv_transfer_params.get( "remote_kv_caches_base_addr", None), metaserver=kv_transfer_params.get("metaserver", None), + remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1), + remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1), + remote_pcp_rank=kv_transfer_params.get("remote_pcp_rank", 1), + remote_dcp_rank=kv_transfer_params.get("remote_dcp_rank", 1), ) @@ -497,12 +510,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.info("Initializing Mooncake Scheduler %s", engine_id) self.side_channel_host = get_ip() + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) + vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) # Requests that need to start recv. # New requests are added by update_state_after_alloc in @@ -673,6 +688,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.tp_group = get_tp_group() self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + self.total_layers = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) @@ -685,8 +707,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) - self.handshake_port = self.side_channel_port + self.tp_rank + vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) + self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank self.sockets: dict = {} logger.info("Initializing Mooncake work %s", engine_id) self.engine = global_te.get_transfer_engine(self.side_channel_host, @@ -720,6 +742,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): deque) self.remote_poller = zmq.Poller() # type: ignore self.timeout = 1.0 # seconds + self.local_remote_block_port_mapping = None + self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -830,7 +854,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event = threading.Event() self.kv_recv_layer_thread = KVCacheRecvingLayerThread( self.tp_rank, self.side_channel_port, self.tp_size, - self.pd_head_ratio, self.engine_id, metadata, ready_event) + self.pd_head_ratio, self.engine_id, metadata, ready_event, self.pcp_rank) self.kv_recv_layer_thread.start() ready_event.wait() @@ -860,6 +884,178 @@ def get_finished(self) -> tuple[set[str], set[str]]: "requests: %d", 0, len(done_recving)) return set(), done_recving + def _get_remote_tp_rank(self, req_id: str) -> List[int]: + return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] + + def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]: + if self._prefill_tp_size == self._decode_tp_size: + result = list(map(lambda x: [x], range(self._decode_tp_size))) + return result + + sampled_nums = [] + ori_data = np.arange(self._decode_tp_size) + + group_size = self._decode_tp_size // self._prefill_tp_size + for i in range(self._prefill_tp_size): + ori_data_slice = ori_data[i * group_size:(i + 1) * group_size] + sampled_nums.append(ori_data_slice.tolist()) + return sampled_nums + + def _get_kv_split_metadata( + self, + req_id: str, + meta: ReqMeta, + ) -> tuple[list[list[int]], list[list[int]], list[list[int]]]: + """ + In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node. + Use this function to calculate remote port and remote block number of each remote P node that we need to pull. + """ + if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1: + choosen_rank_list = self._get_remote_tp_rank(req_id) + remote_handshake_port_list = [[ + x + meta.remote_port for x in choosen_rank_list + ]] + local_block_ids_list, remote_block_ids_list = [ + meta.local_block_ids + ], [meta.remote_block_ids] + return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list + + ## TODO: add context_parallel_parameters_check + + def get_kv_head_groups(tp_size): + if self.use_mla: + kv_head_groups = [] + kv_head_ids = [0] + kv_head_groups.append(tuple(kv_head_ids)) + return kv_head_groups + if self.num_key_value_heads // tp_size >= 1: + kv_head_groups = [] + for tp_rank in range(tp_size): + kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \ + for head_idx in range(self.num_key_value_heads // tp_size)] + kv_head_groups.append(tuple(kv_head_ids)) + return kv_head_groups + if tp_size // self.num_key_value_heads > 1: + kv_head_groups = [] + for kv_head_id in range(self.num_key_value_heads): + kv_head_groups.append(tuple([kv_head_id])) + return kv_head_groups + + def get_cp_group_meta(tp_size, pcp_size, dcp_size, port_base): + # key is kv_head_group, value is cp_groups and which cp_groups to select + cp_group_meta = {} + kv_head_groups = get_kv_head_groups(tp_size) + dcp_repeat_num = tp_size // len(kv_head_groups) // dcp_size + + for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups): + if kv_head_group not in cp_group_meta: + cp_group_meta[kv_head_group] = {} + cp_group_meta[kv_head_group]['cp_groups'] = [] + cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0 + kv_head_group_offset = tp_size // len(kv_head_groups) * kv_head_group_idx + for dcp_repeat_idx in range(dcp_repeat_num): + # len(cp_group) == pcp_size * dcp_size + cp_group = [] + dcp_repeat_offset = dcp_size * dcp_repeat_idx + for pcp_rank in range(pcp_size): + pcp_rank_offset = tp_size * pcp_rank + for dcp_rank in range(dcp_size): + cp_group.append(dcp_rank + port_base + pcp_rank_offset + + dcp_repeat_offset + kv_head_group_offset) + cp_group_meta[kv_head_group]['cp_groups'].append(cp_group) + + return cp_group_meta + + def get_local_remote_block_port_mapping(): + p_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, + self.dcp_size, self.side_channel_port) + d_node_cp_group_meta = get_cp_group_meta(self._decode_tp_size, meta.remote_pcp_size, + meta.remote_dcp_size, meta.remote_port) + local_remote_block_port_mappings = {} + for p_node_head_key in p_node_cp_group_meta.keys(): + for d_node_head_key in d_node_cp_group_meta.keys(): + if not set(d_node_head_key).issubset(set(p_node_head_key)): + continue + d_node_head_group = d_node_cp_group_meta[d_node_head_key] + p_node_head_group = p_node_cp_group_meta[p_node_head_key] + for p_cp_group in p_node_head_group['cp_groups']: + select_cp_groups_id = d_node_head_group['select_cp_groups_id'] + d_cp_groups = d_node_head_group['cp_groups'] + d_cp_group = d_cp_groups[select_cp_groups_id] + d_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \ + if select_cp_groups_id + 1 < len(d_cp_groups) else 0 + for p_idx, p_port in enumerate(p_cp_group): + if p_port not in local_remote_block_port_mappings: + local_remote_block_port_mappings[p_port] = [] + d_port_remote_list = [] + for d_idx, d_port in enumerate(d_cp_group): + if p_idx % len(d_cp_group) == d_idx: + d_port_remote_list.append(d_port) + local_remote_block_port_mappings[p_port].append(d_port_remote_list) + local_remote_block_port_mapping = local_remote_block_port_mappings[self.handshake_port] + + logger.info( + "p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. " + "local_remote_block_port_mappings is:: %s. ", p_node_cp_group_meta, + d_node_cp_group_meta, local_remote_block_port_mappings) + + return local_remote_block_port_mapping + + if self.local_remote_block_port_mapping is None: + self.local_remote_block_port_mapping = get_local_remote_block_port_mapping() + + + local_cp_size = self.pcp_size * self.dcp_size + prompt_block_nums = len(meta.local_block_ids) + local_block_nums_all = [prompt_block_nums // local_cp_size] * local_cp_size + num_remain_blocks = prompt_block_nums % local_cp_size + for i in range(num_remain_blocks): + local_block_nums_all[i] += 1 + num_remain_blocks = (num_remain_blocks + local_cp_size - 1) % local_cp_size + + # make sure the last block (which may be unfull) of P nodes is put to the last block of D node + local_block_nums = [] + final_block_idx = None + + remote_dcp_rank = meta.remote_dcp_rank + remote_pcp_rank = meta.remote_pcp_rank + remote_cp_rank = remote_dcp_rank + remote_pcp_rank * meta.remote_dcp_size + remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size + for cp_rank, block_num in enumerate(local_block_nums_all): + if cp_rank % remote_cp_size == remote_cp_rank: + if num_remain_blocks == cp_rank: + final_block_idx = len(local_block_nums) + local_block_nums.append(block_num) + + if final_block_idx is not None: + final_block_num = local_block_nums.pop(final_block_idx) + local_block_nums.append(final_block_num) + final_block_mapping = self.local_remote_block_port_mapping.pop(final_block_idx) + self.local_remote_block_port_mapping.append(final_block_mapping) + + remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = [], [], [] + for idx in range(len(self.local_remote_block_port_mapping[0])): + mapping_list = [] + for mapping in self.local_remote_block_port_mapping: + mapping_list.append(mapping[idx]) + remote_handshake_port_list.append(mapping_list) + + # the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list + # such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]], + # remote_handshake_port_list[[30000],[30001],[30004],[30005]] + # D rank will get remote block 1 in port 30004 and save it in local block 5 + remote_block_offset = 0 + for local_kv_id in range(len(remote_handshake_port_list)): + num_blocks_to_push = local_block_nums[local_kv_id] + local_block_ids_list.append( + meta.local_block_ids[:num_blocks_to_push]) + remote_block_ids_list.append( + meta.remote_block_ids[remote_block_offset:remote_block_offset+ + num_blocks_to_push]) + remote_block_offset += num_blocks_to_push + + return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list + def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): """Start loading KV blocks from remote engine.""" self.current_layer = 0 @@ -880,6 +1076,8 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, + remote_pcp_rank=self.pcp_rank, + remote_dcp_rank=self.dcp_rank, ) future = self.executor.submit( self._access_metaserver, @@ -964,8 +1162,14 @@ def sort_kv_cache(input_kv: list[list[int]]): f"Add request {req_id} to kv send layer thread. {req_meta_update=}" ) assert self.kv_send_layer_thread is not None - self.kv_send_layer_thread.send_queue.put( - (req_id, req_meta_update, self.current_layer, key, value)) + remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( + req_id, req_meta_update) + for pcp_dcp_rank in range(len(remote_handshake_port_list)): + req_meta_for_queue = copy.copy(req_meta_update) + req_meta_for_queue.local_block_ids = local_block_ids_list[pcp_dcp_rank] + req_meta_for_queue.remote_block_ids = remote_block_ids_list[pcp_dcp_rank] + self.kv_send_layer_thread.send_queue.put( + (req_id, req_meta_for_queue, self.current_layer, key, value)) self.current_layer += 1 def _get_remote_socket(