-
Notifications
You must be signed in to change notification settings - Fork 663
PCP/DCP 适配mooncake layerwise connector #4924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,8 +25,11 @@ | |
| from vllm.config import VllmConfig | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.base import ( | ||
| KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) | ||
| from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, | ||
| get_tp_group, get_world_group) | ||
| from vllm.distributed.parallel_state import ( | ||
| get_decode_context_model_parallel_rank, | ||
| get_decode_context_model_parallel_world_size, get_pcp_group, | ||
| get_tensor_model_parallel_rank, get_tp_group, get_world_group) | ||
|
|
||
| from vllm.logger import logger | ||
| from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket | ||
| from vllm.v1.core.sched.output import SchedulerOutput | ||
|
|
@@ -65,6 +68,10 @@ | |
| remote_te_rpc_port: Optional[int] | ||
| remote_kv_caches_base_addr: Optional[list[int]] | ||
| metaserver: Optional[str] | ||
| remote_pcp_size: int | ||
| remote_dcp_size: int | ||
| remote_pcp_rank: int | ||
| remote_dcp_rank: int | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -294,7 +301,7 @@ | |
| def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int, | ||
| pd_head_ratio: int, local_engine_id: str, | ||
| metadata: MooncakeAgentMetadata, | ||
| ready_event: threading.Event): | ||
| ready_event: threading.Event, pcp_rank: int): | ||
| super().__init__(daemon=True, name="KVCacheRecvingLayerThread") | ||
| self.tp_rank = tp_rank | ||
| self.tp_size = tp_size | ||
|
|
@@ -307,6 +314,7 @@ | |
| self.task_tracker = dict[str, int]() | ||
| self.ready_event = ready_event | ||
| self.metadata = metadata | ||
| self.pcp_rank = pcp_rank | ||
|
|
||
| def get_and_clear_finished_requests(self) -> set[str]: | ||
| """ | ||
|
|
@@ -328,7 +336,8 @@ | |
|
|
||
| def run(self): | ||
| """Run the thread to handle KV cache transfer requests.""" | ||
| handshake_port = self.side_channel_port + self.tp_rank | ||
| handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size \ | ||
| + self.tp_rank | ||
| path = make_zmq_path("tcp", self.side_channel_host, handshake_port) | ||
| logger.info("Starting listening on path: %s", path) | ||
| encoder = msgspec.msgpack.Encoder() | ||
|
|
@@ -389,6 +398,10 @@ | |
| remote_kv_caches_base_addr=kv_transfer_params.get( | ||
| "remote_kv_caches_base_addr", None), | ||
| metaserver=kv_transfer_params.get("metaserver", None), | ||
| remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1), | ||
| remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1), | ||
| remote_pcp_rank=kv_transfer_params.get("remote_pcp_rank", 1), | ||
| remote_dcp_rank=kv_transfer_params.get("remote_dcp_rank", 1), | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -497,12 +510,14 @@ | |
| logger.info("Initializing Mooncake Scheduler %s", engine_id) | ||
|
|
||
| self.side_channel_host = get_ip() | ||
| self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size | ||
| self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size | ||
|
|
||
| # Handshake base port | ||
| self.side_channel_port = ( | ||
| vllm_config.kv_transfer_config.kv_port + | ||
| vllm_config.parallel_config.data_parallel_rank * | ||
| vllm_config.parallel_config.tensor_parallel_size) | ||
| vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) | ||
|
|
||
| # Requests that need to start recv. | ||
| # New requests are added by update_state_after_alloc in | ||
|
|
@@ -673,6 +688,13 @@ | |
| self.tp_group = get_tp_group() | ||
| self.kv_caches: dict[str, torch.Tensor] = {} | ||
| self.side_channel_host = get_ip() | ||
| self.pcp_size = get_pcp_group().world_size | ||
| self.pcp_rank = get_pcp_group( | ||
| ).rank_in_group if self.pcp_size > 1 else 0 | ||
| self.dcp_size = get_decode_context_model_parallel_world_size() | ||
| self.dcp_rank = get_decode_context_model_parallel_rank( | ||
| ) if self.dcp_size > 1 else 0 | ||
|
|
||
| self.total_layers = vllm_config.model_config.get_num_layers( | ||
| vllm_config.parallel_config) | ||
|
|
||
|
|
@@ -685,8 +707,8 @@ | |
| self.side_channel_port = ( | ||
| vllm_config.kv_transfer_config.kv_port + | ||
| vllm_config.parallel_config.data_parallel_rank * | ||
| vllm_config.parallel_config.tensor_parallel_size) | ||
| self.handshake_port = self.side_channel_port + self.tp_rank | ||
| vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) | ||
| self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank | ||
| self.sockets: dict = {} | ||
| logger.info("Initializing Mooncake work %s", engine_id) | ||
| self.engine = global_te.get_transfer_engine(self.side_channel_host, | ||
|
|
@@ -720,6 +742,8 @@ | |
| deque) | ||
| self.remote_poller = zmq.Poller() # type: ignore | ||
| self.timeout = 1.0 # seconds | ||
| self.local_remote_block_port_mapping = None | ||
| self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads | ||
|
|
||
| def _get_prefill_decode_size(self, vllm_config: VllmConfig): | ||
| # get prefill tp and dp size from extra config | ||
|
|
@@ -830,7 +854,7 @@ | |
| ready_event = threading.Event() | ||
| self.kv_recv_layer_thread = KVCacheRecvingLayerThread( | ||
| self.tp_rank, self.side_channel_port, self.tp_size, | ||
| self.pd_head_ratio, self.engine_id, metadata, ready_event) | ||
| self.pd_head_ratio, self.engine_id, metadata, ready_event, self.pcp_rank) | ||
| self.kv_recv_layer_thread.start() | ||
| ready_event.wait() | ||
|
|
||
|
|
@@ -860,6 +884,178 @@ | |
| "requests: %d", 0, len(done_recving)) | ||
| return set(), done_recving | ||
|
|
||
| def _get_remote_tp_rank(self, req_id: str) -> List[int]: | ||
| return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] | ||
|
|
||
| def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]: | ||
| if self._prefill_tp_size == self._decode_tp_size: | ||
| result = list(map(lambda x: [x], range(self._decode_tp_size))) | ||
| return result | ||
|
|
||
| sampled_nums = [] | ||
| ori_data = np.arange(self._decode_tp_size) | ||
|
|
||
| group_size = self._decode_tp_size // self._prefill_tp_size | ||
| for i in range(self._prefill_tp_size): | ||
| ori_data_slice = ori_data[i * group_size:(i + 1) * group_size] | ||
| sampled_nums.append(ori_data_slice.tolist()) | ||
| return sampled_nums | ||
|
|
||
| def _get_kv_split_metadata( | ||
| self, | ||
| req_id: str, | ||
| meta: ReqMeta, | ||
| ) -> tuple[list[list[int]], list[list[int]], list[list[int]]]: | ||
| """ | ||
| In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node. | ||
| Use this function to calculate remote port and remote block number of each remote P node that we need to pull. | ||
| """ | ||
| if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1: | ||
| choosen_rank_list = self._get_remote_tp_rank(req_id) | ||
| remote_handshake_port_list = [[ | ||
| x + meta.remote_port for x in choosen_rank_list | ||
|
Check failure on line 916 in vllm_ascend/distributed/mooncake_layerwise_connector.py
|
||
| ]] | ||
| local_block_ids_list, remote_block_ids_list = [ | ||
| meta.local_block_ids | ||
| ], [meta.remote_block_ids] | ||
| return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list | ||
|
|
||
| ## TODO: add context_parallel_parameters_check | ||
|
|
||
| def get_kv_head_groups(tp_size): | ||
| if self.use_mla: | ||
| kv_head_groups = [] | ||
| kv_head_ids = [0] | ||
| kv_head_groups.append(tuple(kv_head_ids)) | ||
| return kv_head_groups | ||
| if self.num_key_value_heads // tp_size >= 1: | ||
| kv_head_groups = [] | ||
| for tp_rank in range(tp_size): | ||
| kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \ | ||
| for head_idx in range(self.num_key_value_heads // tp_size)] | ||
| kv_head_groups.append(tuple(kv_head_ids)) | ||
| return kv_head_groups | ||
| if tp_size // self.num_key_value_heads > 1: | ||
| kv_head_groups = [] | ||
| for kv_head_id in range(self.num_key_value_heads): | ||
| kv_head_groups.append(tuple([kv_head_id])) | ||
| return kv_head_groups | ||
|
|
||
| def get_cp_group_meta(tp_size, pcp_size, dcp_size, port_base): | ||
| # key is kv_head_group, value is cp_groups and which cp_groups to select | ||
| cp_group_meta = {} | ||
|
Check failure on line 946 in vllm_ascend/distributed/mooncake_layerwise_connector.py
|
||
| kv_head_groups = get_kv_head_groups(tp_size) | ||
| dcp_repeat_num = tp_size // len(kv_head_groups) // dcp_size | ||
|
|
||
| for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups): | ||
| if kv_head_group not in cp_group_meta: | ||
| cp_group_meta[kv_head_group] = {} | ||
| cp_group_meta[kv_head_group]['cp_groups'] = [] | ||
| cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0 | ||
| kv_head_group_offset = tp_size // len(kv_head_groups) * kv_head_group_idx | ||
| for dcp_repeat_idx in range(dcp_repeat_num): | ||
| # len(cp_group) == pcp_size * dcp_size | ||
| cp_group = [] | ||
| dcp_repeat_offset = dcp_size * dcp_repeat_idx | ||
| for pcp_rank in range(pcp_size): | ||
| pcp_rank_offset = tp_size * pcp_rank | ||
| for dcp_rank in range(dcp_size): | ||
| cp_group.append(dcp_rank + port_base + pcp_rank_offset + | ||
| dcp_repeat_offset + kv_head_group_offset) | ||
| cp_group_meta[kv_head_group]['cp_groups'].append(cp_group) | ||
|
|
||
| return cp_group_meta | ||
|
|
||
| def get_local_remote_block_port_mapping(): | ||
| p_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size, | ||
| self.dcp_size, self.side_channel_port) | ||
| d_node_cp_group_meta = get_cp_group_meta(self._decode_tp_size, meta.remote_pcp_size, | ||
| meta.remote_dcp_size, meta.remote_port) | ||
| local_remote_block_port_mappings = {} | ||
| for p_node_head_key in p_node_cp_group_meta.keys(): | ||
| for d_node_head_key in d_node_cp_group_meta.keys(): | ||
| if not set(d_node_head_key).issubset(set(p_node_head_key)): | ||
| continue | ||
| d_node_head_group = d_node_cp_group_meta[d_node_head_key] | ||
| p_node_head_group = p_node_cp_group_meta[p_node_head_key] | ||
| for p_cp_group in p_node_head_group['cp_groups']: | ||
|
Check failure on line 981 in vllm_ascend/distributed/mooncake_layerwise_connector.py
|
||
| select_cp_groups_id = d_node_head_group['select_cp_groups_id'] | ||
| d_cp_groups = d_node_head_group['cp_groups'] | ||
| d_cp_group = d_cp_groups[select_cp_groups_id] | ||
| d_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \ | ||
| if select_cp_groups_id + 1 < len(d_cp_groups) else 0 | ||
| for p_idx, p_port in enumerate(p_cp_group): | ||
| if p_port not in local_remote_block_port_mappings: | ||
| local_remote_block_port_mappings[p_port] = [] | ||
| d_port_remote_list = [] | ||
| for d_idx, d_port in enumerate(d_cp_group): | ||
| if p_idx % len(d_cp_group) == d_idx: | ||
| d_port_remote_list.append(d_port) | ||
| local_remote_block_port_mappings[p_port].append(d_port_remote_list) | ||
| local_remote_block_port_mapping = local_remote_block_port_mappings[self.handshake_port] | ||
|
|
||
| logger.info( | ||
| "p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. " | ||
| "local_remote_block_port_mappings is:: %s. ", p_node_cp_group_meta, | ||
| d_node_cp_group_meta, local_remote_block_port_mappings) | ||
|
|
||
| return local_remote_block_port_mapping | ||
|
|
||
| if self.local_remote_block_port_mapping is None: | ||
| self.local_remote_block_port_mapping = get_local_remote_block_port_mapping() | ||
|
|
||
|
|
||
| local_cp_size = self.pcp_size * self.dcp_size | ||
| prompt_block_nums = len(meta.local_block_ids) | ||
| local_block_nums_all = [prompt_block_nums // local_cp_size] * local_cp_size | ||
| num_remain_blocks = prompt_block_nums % local_cp_size | ||
| for i in range(num_remain_blocks): | ||
| local_block_nums_all[i] += 1 | ||
| num_remain_blocks = (num_remain_blocks + local_cp_size - 1) % local_cp_size | ||
|
|
||
| # make sure the last block (which may be unfull) of P nodes is put to the last block of D node | ||
| local_block_nums = [] | ||
| final_block_idx = None | ||
|
|
||
| remote_dcp_rank = meta.remote_dcp_rank | ||
| remote_pcp_rank = meta.remote_pcp_rank | ||
| remote_cp_rank = remote_dcp_rank + remote_pcp_rank * meta.remote_dcp_size | ||
| remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size | ||
| for cp_rank, block_num in enumerate(local_block_nums_all): | ||
| if cp_rank % remote_cp_size == remote_cp_rank: | ||
| if num_remain_blocks == cp_rank: | ||
| final_block_idx = len(local_block_nums) | ||
| local_block_nums.append(block_num) | ||
|
|
||
| if final_block_idx is not None: | ||
| final_block_num = local_block_nums.pop(final_block_idx) | ||
| local_block_nums.append(final_block_num) | ||
| final_block_mapping = self.local_remote_block_port_mapping.pop(final_block_idx) | ||
| self.local_remote_block_port_mapping.append(final_block_mapping) | ||
|
|
||
| remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = [], [], [] | ||
| for idx in range(len(self.local_remote_block_port_mapping[0])): | ||
| mapping_list = [] | ||
| for mapping in self.local_remote_block_port_mapping: | ||
| mapping_list.append(mapping[idx]) | ||
| remote_handshake_port_list.append(mapping_list) | ||
|
|
||
| # the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list | ||
| # such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]], | ||
| # remote_handshake_port_list[[30000],[30001],[30004],[30005]] | ||
| # D rank will get remote block 1 in port 30004 and save it in local block 5 | ||
| remote_block_offset = 0 | ||
| for local_kv_id in range(len(remote_handshake_port_list)): | ||
| num_blocks_to_push = local_block_nums[local_kv_id] | ||
| local_block_ids_list.append( | ||
| meta.local_block_ids[:num_blocks_to_push]) | ||
| remote_block_ids_list.append( | ||
| meta.remote_block_ids[remote_block_offset:remote_block_offset+ | ||
| num_blocks_to_push]) | ||
| remote_block_offset += num_blocks_to_push | ||
|
Comment on lines
+1047
to
+1055
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 切分 remote_block_offset = 0
local_block_offset = 0
for local_kv_id in range(len(remote_handshake_port_list)):
num_blocks_to_push = local_block_nums[local_kv_id]
local_block_ids_list.append(
meta.local_block_ids[local_block_offset:local_block_offset +
num_blocks_to_push])
remote_block_ids_list.append(
meta.remote_block_ids[remote_block_offset:remote_block_offset +
num_blocks_to_push])
remote_block_offset += num_blocks_to_push
local_block_offset += num_blocks_to_push |
||
|
|
||
| return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list | ||
|
|
||
| def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): | ||
| """Start loading KV blocks from remote engine.""" | ||
| self.current_layer = 0 | ||
|
|
@@ -880,6 +1076,8 @@ | |
| remote_engine_id=self.engine_id, | ||
| remote_host=self.side_channel_host, | ||
| remote_port=self.side_channel_port, | ||
| remote_pcp_rank=self.pcp_rank, | ||
| remote_dcp_rank=self.dcp_rank, | ||
| ) | ||
| future = self.executor.submit( | ||
| self._access_metaserver, | ||
|
|
@@ -964,8 +1162,14 @@ | |
| f"Add request {req_id} to kv send layer thread. {req_meta_update=}" | ||
| ) | ||
| assert self.kv_send_layer_thread is not None | ||
| self.kv_send_layer_thread.send_queue.put( | ||
| (req_id, req_meta_update, self.current_layer, key, value)) | ||
| remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( | ||
| req_id, req_meta_update) | ||
| for pcp_dcp_rank in range(len(remote_handshake_port_list)): | ||
| req_meta_for_queue = copy.copy(req_meta_update) | ||
| req_meta_for_queue.local_block_ids = local_block_ids_list[pcp_dcp_rank] | ||
| req_meta_for_queue.remote_block_ids = remote_block_ids_list[pcp_dcp_rank] | ||
| self.kv_send_layer_thread.send_queue.put( | ||
| (req_id, req_meta_for_queue, self.current_layer, key, value)) | ||
| self.current_layer += 1 | ||
|
|
||
| def _get_remote_socket( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.