@@ -74,6 +74,10 @@ def __init__(self):
7474 self .record_finished_requests : set [str ] = set ()
7575 self .delayed_free_requests : OrderedDict [str , float ] = OrderedDict ()
7676
77+ def add_not_transfer_request (self , request_id : str ):
78+ with self .done_task_lock :
79+ self .finished_requests .add (request_id )
80+
7781 def update_done_task_count (self , request_id : str ):
7882 with self .done_task_lock :
7983 self .finished_requests .add (request_id )
@@ -151,6 +155,9 @@ def get_and_clear_finished_requests(self) -> set[str]:
151155 """
152156 return self .task_tracker .get_and_clear_finished_requests ()
153157
158+ def add_not_transfer_request (self , request_id : str ):
159+ self .task_tracker .add_not_transfer_request (request_id )
160+
154161 def add_delayed_request (self , request_id : str , delay_start_time : float ):
155162 return self .task_tracker .add_delayed_request (request_id ,
156163 delay_start_time )
@@ -652,10 +659,6 @@ def request_finished(
652659 assert self .connector_scheduler is not None
653660 return self .connector_scheduler .request_finished (request , block_ids )
654661
655- def get_finished_count (self ) -> Optional [int ]:
656- assert self .connector_scheduler is not None
657- return self .connector_scheduler .get_finished_count ()
658-
659662 ############################################################
660663 # Worker Side Methods
661664 ############################################################
@@ -840,39 +843,6 @@ def request_finished(
840843 last_token_id = request .output_token_ids [- 1 ],
841844 )
842845
843- def get_finished_count (self ) -> Optional [int ]:
844- prefill_parallel_config : dict [
845- str ,
846- Any ] = self .vllm_config .kv_transfer_config .get_from_extra_config (
847- "prefill" , {})
848-
849- assert "tp_size" in prefill_parallel_config .keys ()
850- self ._prefill_tp_size = prefill_parallel_config ["tp_size" ]
851- decode_parallel_config : dict [
852- str ,
853- Any ] = self .vllm_config .kv_transfer_config .get_from_extra_config (
854- "decode" , {})
855- assert "tp_size" in decode_parallel_config .keys ()
856- self ._decode_tp_size = decode_parallel_config ["tp_size" ]
857- num_key_value_heads = self .vllm_config .model_config .hf_config .num_key_value_heads
858- if self .vllm_config .model_config .use_mla or hasattr (
859- self .vllm_config .model_config .hf_config , "index_topk" ):
860- num_need_pulls = 1
861- else :
862- num_p_block_heads = max (
863- 1 , num_key_value_heads // self ._prefill_tp_size )
864- num_d_block_heads = max (
865- 1 , num_key_value_heads // self ._decode_tp_size )
866- num_need_pulls = num_d_block_heads // num_p_block_heads
867- kv_role = self .vllm_config .kv_transfer_config .kv_role
868- logger .debug (
869- "get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d" ,
870- kv_role , num_need_pulls , self ._decode_tp_size )
871- if kv_role == 'kv_producer' :
872- return num_need_pulls * self ._decode_tp_size
873- else :
874- return self ._decode_tp_size
875-
876846
877847class MooncakeConnectorWorker :
878848 """Implementation of Worker side methods"""
@@ -1144,6 +1114,8 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
11441114 if self .tp_rank in self ._prefill_get_remote_tp_rank (req_id ):
11451115 self .kv_send_thread .add_delayed_request (
11461116 req_id , delay_start_time )
1117+ else :
1118+ self .kv_send_thread .add_not_transfer_request (req_id )
11471119
11481120 def _prefill_get_remote_tp_rank (self , req_id : str ) -> List [int ]:
11491121 return sum (self ._get_remote_tp_ranks_for_req (req_id ), [])
0 commit comments