Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 214 additions & 10 deletions vllm_ascend/distributed/mooncake_layerwise_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
get_tp_group, get_world_group)
from vllm.distributed.parallel_state import (
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size, get_pcp_group,
get_tensor_model_parallel_rank, get_tp_group, get_world_group)

from vllm.logger import logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -65,6 +68,10 @@
remote_te_rpc_port: Optional[int]
remote_kv_caches_base_addr: Optional[list[int]]
metaserver: Optional[str]
remote_pcp_size: int
remote_dcp_size: int
remote_pcp_rank: int
remote_dcp_rank: int


@dataclass
Expand Down Expand Up @@ -294,7 +301,7 @@
def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int,
pd_head_ratio: int, local_engine_id: str,
metadata: MooncakeAgentMetadata,
ready_event: threading.Event):
ready_event: threading.Event, pcp_rank: int):
super().__init__(daemon=True, name="KVCacheRecvingLayerThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
Expand All @@ -307,6 +314,7 @@
self.task_tracker = dict[str, int]()
self.ready_event = ready_event
self.metadata = metadata
self.pcp_rank = pcp_rank

def get_and_clear_finished_requests(self) -> set[str]:
"""
Expand All @@ -328,7 +336,8 @@

def run(self):
"""Run the thread to handle KV cache transfer requests."""
handshake_port = self.side_channel_port + self.tp_rank
handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size \
+ self.tp_rank
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
logger.info("Starting listening on path: %s", path)
encoder = msgspec.msgpack.Encoder()
Expand Down Expand Up @@ -389,6 +398,10 @@
remote_kv_caches_base_addr=kv_transfer_params.get(
"remote_kv_caches_base_addr", None),
metaserver=kv_transfer_params.get("metaserver", None),
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
remote_pcp_rank=kv_transfer_params.get("remote_pcp_rank", 1),
remote_dcp_rank=kv_transfer_params.get("remote_dcp_rank", 1),
)


Expand Down Expand Up @@ -497,12 +510,14 @@
logger.info("Initializing Mooncake Scheduler %s", engine_id)

self.side_channel_host = get_ip()
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size

# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)

# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
Expand Down Expand Up @@ -673,6 +688,13 @@
self.tp_group = get_tp_group()
self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip()
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0

self.total_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)

Expand All @@ -685,8 +707,8 @@
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
self.handshake_port = self.side_channel_port + self.tp_rank
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
self.sockets: dict = {}
logger.info("Initializing Mooncake work %s", engine_id)
self.engine = global_te.get_transfer_engine(self.side_channel_host,
Expand Down Expand Up @@ -720,6 +742,8 @@
deque)
self.remote_poller = zmq.Poller() # type: ignore
self.timeout = 1.0 # seconds
self.local_remote_block_port_mapping = None
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads

def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
Expand Down Expand Up @@ -830,7 +854,7 @@
ready_event = threading.Event()
self.kv_recv_layer_thread = KVCacheRecvingLayerThread(
self.tp_rank, self.side_channel_port, self.tp_size,
self.pd_head_ratio, self.engine_id, metadata, ready_event)
self.pd_head_ratio, self.engine_id, metadata, ready_event, self.pcp_rank)
self.kv_recv_layer_thread.start()
ready_event.wait()

Expand Down Expand Up @@ -860,6 +884,178 @@
"requests: %d", 0, len(done_recving))
return set(), done_recving

def _get_remote_tp_rank(self, req_id: str) -> List[int]:
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]

def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]:
if self._prefill_tp_size == self._decode_tp_size:
result = list(map(lambda x: [x], range(self._decode_tp_size)))
return result

sampled_nums = []
ori_data = np.arange(self._decode_tp_size)

group_size = self._decode_tp_size // self._prefill_tp_size
for i in range(self._prefill_tp_size):
ori_data_slice = ori_data[i * group_size:(i + 1) * group_size]
sampled_nums.append(ori_data_slice.tolist())
return sampled_nums

def _get_kv_split_metadata(
self,
req_id: str,
meta: ReqMeta,
) -> tuple[list[list[int]], list[list[int]], list[list[int]]]:
"""
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
choosen_rank_list = self._get_remote_tp_rank(req_id)
remote_handshake_port_list = [[
x + meta.remote_port for x in choosen_rank_list

Check failure on line 916 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for + ("int" and "None") [operator]

Check failure on line 916 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for + ("int" and "None") [operator]
]]
local_block_ids_list, remote_block_ids_list = [
meta.local_block_ids
], [meta.remote_block_ids]
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list

## TODO: add context_parallel_parameters_check

def get_kv_head_groups(tp_size):
if self.use_mla:
kv_head_groups = []
kv_head_ids = [0]
kv_head_groups.append(tuple(kv_head_ids))
return kv_head_groups
if self.num_key_value_heads // tp_size >= 1:
kv_head_groups = []
for tp_rank in range(tp_size):
kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \
for head_idx in range(self.num_key_value_heads // tp_size)]
kv_head_groups.append(tuple(kv_head_ids))
return kv_head_groups
if tp_size // self.num_key_value_heads > 1:
kv_head_groups = []
for kv_head_id in range(self.num_key_value_heads):
kv_head_groups.append(tuple([kv_head_id]))
return kv_head_groups

def get_cp_group_meta(tp_size, pcp_size, dcp_size, port_base):
# key is kv_head_group, value is cp_groups and which cp_groups to select
cp_group_meta = {}

Check failure on line 946 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "cp_group_meta" (hint: "cp_group_meta: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 946 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "cp_group_meta" (hint: "cp_group_meta: dict[<type>, <type>] = ...") [var-annotated]
kv_head_groups = get_kv_head_groups(tp_size)
dcp_repeat_num = tp_size // len(kv_head_groups) // dcp_size

for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups):
if kv_head_group not in cp_group_meta:
cp_group_meta[kv_head_group] = {}
cp_group_meta[kv_head_group]['cp_groups'] = []
cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0
kv_head_group_offset = tp_size // len(kv_head_groups) * kv_head_group_idx
for dcp_repeat_idx in range(dcp_repeat_num):
# len(cp_group) == pcp_size * dcp_size
cp_group = []
dcp_repeat_offset = dcp_size * dcp_repeat_idx
for pcp_rank in range(pcp_size):
pcp_rank_offset = tp_size * pcp_rank
for dcp_rank in range(dcp_size):
cp_group.append(dcp_rank + port_base + pcp_rank_offset +
dcp_repeat_offset + kv_head_group_offset)
cp_group_meta[kv_head_group]['cp_groups'].append(cp_group)

return cp_group_meta

def get_local_remote_block_port_mapping():
p_node_cp_group_meta = get_cp_group_meta(self.tp_size, self.pcp_size,
self.dcp_size, self.side_channel_port)
d_node_cp_group_meta = get_cp_group_meta(self._decode_tp_size, meta.remote_pcp_size,
meta.remote_dcp_size, meta.remote_port)
local_remote_block_port_mappings = {}
for p_node_head_key in p_node_cp_group_meta.keys():
for d_node_head_key in d_node_cp_group_meta.keys():
if not set(d_node_head_key).issubset(set(p_node_head_key)):
continue
d_node_head_group = d_node_cp_group_meta[d_node_head_key]
p_node_head_group = p_node_cp_group_meta[p_node_head_key]
for p_cp_group in p_node_head_group['cp_groups']:

Check failure on line 981 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "local_remote_block_port_mappings" (hint: "local_remote_block_port_mappings: dict[<type>, <type>] = ...") [var-annotated]
select_cp_groups_id = d_node_head_group['select_cp_groups_id']
d_cp_groups = d_node_head_group['cp_groups']
d_cp_group = d_cp_groups[select_cp_groups_id]
d_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \
if select_cp_groups_id + 1 < len(d_cp_groups) else 0
for p_idx, p_port in enumerate(p_cp_group):
if p_port not in local_remote_block_port_mappings:
local_remote_block_port_mappings[p_port] = []
d_port_remote_list = []
for d_idx, d_port in enumerate(d_cp_group):
if p_idx % len(d_cp_group) == d_idx:
d_port_remote_list.append(d_port)
local_remote_block_port_mappings[p_port].append(d_port_remote_list)
local_remote_block_port_mapping = local_remote_block_port_mappings[self.handshake_port]

logger.info(
"p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. "
"local_remote_block_port_mappings is:: %s. ", p_node_cp_group_meta,
d_node_cp_group_meta, local_remote_block_port_mappings)

return local_remote_block_port_mapping

if self.local_remote_block_port_mapping is None:
self.local_remote_block_port_mapping = get_local_remote_block_port_mapping()


local_cp_size = self.pcp_size * self.dcp_size
prompt_block_nums = len(meta.local_block_ids)
local_block_nums_all = [prompt_block_nums // local_cp_size] * local_cp_size
num_remain_blocks = prompt_block_nums % local_cp_size
for i in range(num_remain_blocks):
local_block_nums_all[i] += 1
num_remain_blocks = (num_remain_blocks + local_cp_size - 1) % local_cp_size

# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
local_block_nums = []
final_block_idx = None

remote_dcp_rank = meta.remote_dcp_rank
remote_pcp_rank = meta.remote_pcp_rank
remote_cp_rank = remote_dcp_rank + remote_pcp_rank * meta.remote_dcp_size
remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size
for cp_rank, block_num in enumerate(local_block_nums_all):
if cp_rank % remote_cp_size == remote_cp_rank:
if num_remain_blocks == cp_rank:
final_block_idx = len(local_block_nums)
local_block_nums.append(block_num)

if final_block_idx is not None:

Check failure on line 1030 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "local_block_nums" (hint: "local_block_nums: list[<type>] = ...") [var-annotated]
final_block_num = local_block_nums.pop(final_block_idx)
local_block_nums.append(final_block_num)
final_block_mapping = self.local_remote_block_port_mapping.pop(final_block_idx)
self.local_remote_block_port_mapping.append(final_block_mapping)

remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = [], [], []
for idx in range(len(self.local_remote_block_port_mapping[0])):
mapping_list = []
for mapping in self.local_remote_block_port_mapping:
mapping_list.append(mapping[idx])
remote_handshake_port_list.append(mapping_list)

# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
# D rank will get remote block 1 in port 30004 and save it in local block 5

Check failure on line 1046 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "pop" [attr-defined]
remote_block_offset = 0
for local_kv_id in range(len(remote_handshake_port_list)):

Check failure on line 1048 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "append" [attr-defined]
num_blocks_to_push = local_block_nums[local_kv_id]
local_block_ids_list.append(
meta.local_block_ids[:num_blocks_to_push])

Check failure on line 1051 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]
remote_block_ids_list.append(
meta.remote_block_ids[remote_block_offset:remote_block_offset+

Check failure on line 1053 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "__iter__" (not iterable) [attr-defined]
num_blocks_to_push])
remote_block_offset += num_blocks_to_push
Comment on lines +1047 to +1055
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

切分 meta.local_block_ids 的逻辑不正确。在每次循环迭代中,它都重复地从列表的开头进行切片(meta.local_block_ids[:num_blocks_to_push])。这将导致所有切分都使用相同的初始本地块集合,从而导致不正确的数据传输。您应该使用一个偏移量来在每次迭代中切分 meta.local_block_ids 的正确部分,类似于 remote_block_offset 用于 meta.remote_block_ids 的方式。

        remote_block_offset = 0
        local_block_offset = 0
        for local_kv_id in range(len(remote_handshake_port_list)):
            num_blocks_to_push = local_block_nums[local_kv_id]
            local_block_ids_list.append(
                meta.local_block_ids[local_block_offset:local_block_offset +
                                     num_blocks_to_push])
            remote_block_ids_list.append(
                meta.remote_block_ids[remote_block_offset:remote_block_offset +
                                      num_blocks_to_push])
            remote_block_offset += num_blocks_to_push
            local_block_offset += num_blocks_to_push


return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list

def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata):
"""Start loading KV blocks from remote engine."""
self.current_layer = 0
Expand All @@ -880,6 +1076,8 @@
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
remote_pcp_rank=self.pcp_rank,
remote_dcp_rank=self.dcp_rank,
)
future = self.executor.submit(
self._access_metaserver,
Expand Down Expand Up @@ -964,8 +1162,14 @@
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
)
assert self.kv_send_layer_thread is not None
self.kv_send_layer_thread.send_queue.put(
(req_id, req_meta_update, self.current_layer, key, value))
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
req_id, req_meta_update)
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
req_meta_for_queue = copy.copy(req_meta_update)
req_meta_for_queue.local_block_ids = local_block_ids_list[pcp_dcp_rank]
req_meta_for_queue.remote_block_ids = remote_block_ids_list[pcp_dcp_rank]
self.kv_send_layer_thread.send_queue.put(
(req_id, req_meta_for_queue, self.current_layer, key, value))
self.current_layer += 1

def _get_remote_socket(
Expand Down
Loading