4040from vllm_ascend .distributed .mooncake .transfer_engine import get_global_te
4141from vllm_ascend .distributed .utils import get_transfer_timeout_value
4242from vllm_ascend .utils import prefill_context_parallel_enable , vllm_version_is
43- from typing import List
4443
4544# isort: off
4645if prefill_context_parallel_enable ():
@@ -1289,21 +1288,22 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
12891288 self .kv_recv_thread .add_request (
12901289 request_id = req_id ,
12911290 local_block_ids = local_block_ids_list [pcp_dcp_rank ],
1292- remote_block_ids = remote_block_ids_list [pcp_dcp_rank ],
1291+ remote_block_ids = remote_block_ids_list [
1292+ pcp_dcp_rank ],
12931293 remote_engine_id = meta .remote_engine_id ,
12941294 remote_host = meta .remote_host ,
12951295 remote_handshake_port = remote_handshake_port_list [
12961296 pcp_dcp_rank ][i ],
12971297 offset = i ,
12981298 tp_num_need_pulls = self .tp_num_need_pulls ,
1299- all_task_done = (pcp_dcp_rank
1300- == len (remote_handshake_port_list ) - 1
1301- and i == self .tp_num_need_pulls - 1 ))
1302- else : #TODO: support prefill context parallel and pipeline parallel open at the same time
1299+ all_task_done = (
1300+ pcp_dcp_rank
1301+ == len (remote_handshake_port_list ) - 1
1302+ and i == self .tp_num_need_pulls - 1 ))
1303+ else : #TODO: support prefill context parallel and pipeline parallel open at the same time
13031304 choosen_rank_list = self ._get_remote_tp_rank (req_id )
1304- remote_handshake_port_list = [
1305- [x + meta .remote_port ] for x in choosen_rank_list
1306- ]
1305+ remote_handshake_port_list = [[x + meta .remote_port ]
1306+ for x in choosen_rank_list ]
13071307 for i in range (self .tp_num_need_pulls * self ._prefill_pp_size ):
13081308 assert self .kv_recv_thread is not None
13091309 self .kv_recv_thread .add_request (
@@ -1315,7 +1315,8 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
13151315 remote_handshake_port = remote_handshake_port_list [i ][0 ],
13161316 offset = i ,
13171317 tp_num_need_pulls = self .tp_num_need_pulls ,
1318- all_task_done = (i == self .tp_num_need_pulls * self ._prefill_pp_size - 1 ))
1318+ all_task_done = (i == self .tp_num_need_pulls *
1319+ self ._prefill_pp_size - 1 ))
13191320
13201321 if self .kv_send_thread is not None :
13211322 for req_id , delay_start_time in metadata .requests_to_send .items ():
0 commit comments