Skip to content

Commit 19f49ec

Browse files
[0.11.0][Bugfix]fix_mulit_connector_bug (#3332) (#3882)
### What this PR does / why we need it? When using multi connector, the multi connector does not define get_finished_count, which will cause the kv cache to be released ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@83f478b Signed-off-by: baxingpiaochong <[email protected]> Co-authored-by: baxingpiaochong <[email protected]>
1 parent e5b938c commit 19f49ec

File tree

2 files changed

+9
-41
lines changed

2 files changed

+9
-41
lines changed

tests/ut/kv_connector/test_mooncake_connector.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -667,10 +667,6 @@ def test_build_connector_meta(self):
667667
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
668668
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
669669

670-
def test_get_finished_count(self):
671-
count = self.scheduler.get_finished_count()
672-
self.assertEqual(count, 2)
673-
674670

675671
class TestHelperFunctions(unittest.TestCase):
676672

vllm_ascend/distributed/mooncake_connector.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def __init__(self):
7474
self.record_finished_requests: set[str] = set()
7575
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
7676

77+
def add_not_transfer_request(self, request_id: str):
78+
with self.done_task_lock:
79+
self.finished_requests.add(request_id)
80+
7781
def update_done_task_count(self, request_id: str):
7882
with self.done_task_lock:
7983
self.finished_requests.add(request_id)
@@ -151,6 +155,9 @@ def get_and_clear_finished_requests(self) -> set[str]:
151155
"""
152156
return self.task_tracker.get_and_clear_finished_requests()
153157

158+
def add_not_transfer_request(self, request_id: str):
159+
self.task_tracker.add_not_transfer_request(request_id)
160+
154161
def add_delayed_request(self, request_id: str, delay_start_time: float):
155162
return self.task_tracker.add_delayed_request(request_id,
156163
delay_start_time)
@@ -652,10 +659,6 @@ def request_finished(
652659
assert self.connector_scheduler is not None
653660
return self.connector_scheduler.request_finished(request, block_ids)
654661

655-
def get_finished_count(self) -> Optional[int]:
656-
assert self.connector_scheduler is not None
657-
return self.connector_scheduler.get_finished_count()
658-
659662
############################################################
660663
# Worker Side Methods
661664
############################################################
@@ -840,39 +843,6 @@ def request_finished(
840843
last_token_id=request.output_token_ids[-1],
841844
)
842845

843-
def get_finished_count(self) -> Optional[int]:
844-
prefill_parallel_config: dict[
845-
str,
846-
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
847-
"prefill", {})
848-
849-
assert "tp_size" in prefill_parallel_config.keys()
850-
self._prefill_tp_size = prefill_parallel_config["tp_size"]
851-
decode_parallel_config: dict[
852-
str,
853-
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
854-
"decode", {})
855-
assert "tp_size" in decode_parallel_config.keys()
856-
self._decode_tp_size = decode_parallel_config["tp_size"]
857-
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
858-
if self.vllm_config.model_config.use_mla or hasattr(
859-
self.vllm_config.model_config.hf_config, "index_topk"):
860-
num_need_pulls = 1
861-
else:
862-
num_p_block_heads = max(
863-
1, num_key_value_heads // self._prefill_tp_size)
864-
num_d_block_heads = max(
865-
1, num_key_value_heads // self._decode_tp_size)
866-
num_need_pulls = num_d_block_heads // num_p_block_heads
867-
kv_role = self.vllm_config.kv_transfer_config.kv_role
868-
logger.debug(
869-
"get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d",
870-
kv_role, num_need_pulls, self._decode_tp_size)
871-
if kv_role == 'kv_producer':
872-
return num_need_pulls * self._decode_tp_size
873-
else:
874-
return self._decode_tp_size
875-
876846

877847
class MooncakeConnectorWorker:
878848
"""Implementation of Worker side methods"""
@@ -1144,6 +1114,8 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
11441114
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
11451115
self.kv_send_thread.add_delayed_request(
11461116
req_id, delay_start_time)
1117+
else:
1118+
self.kv_send_thread.add_not_transfer_request(req_id)
11471119

11481120
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
11491121
return sum(self._get_remote_tp_ranks_for_req(req_id), [])

0 commit comments

Comments
 (0)