Skip to content

Commit 4dbe4fd

Browse files
[feature]Pooling Features and PCP Adaptation (#4143)
This PR let pooling kv connector support pcp feature - vLLM version: v0.11.2 --------- Signed-off-by: fjw <[email protected]> Signed-off-by: SlightwindSec <[email protected]> Co-authored-by: SlightwindSec <[email protected]>
1 parent 1eb5295 commit 4dbe4fd

File tree

5 files changed

+89
-29
lines changed

5 files changed

+89
-29
lines changed

vllm_ascend/distributed/kvpool/ascend_store_connector.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ def __init__(self,
4343

4444
self.kv_caches: dict[str, torch.Tensor] = {}
4545

46-
self._block_size = vllm_config.cache_config.block_size
47-
4846
self.sended_but_unfinished_reqs: set[str] = set()
4947

5048
if role == KVConnectorRole.SCHEDULER:

vllm_ascend/distributed/kvpool/config_data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ class KeyMetadata:
1717
model_name: str
1818
""" worker id when running under a distributed setting """
1919
head_or_tp_rank: int
20+
""" Initialize the current prefill context model parallel rank """
21+
pcp_rank: int
22+
""" Initialize the current decode context model parallel rank """
23+
dcp_rank: int
2024

2125

2226
@dataclass(order=True)
@@ -28,12 +32,15 @@ def __hash__(self):
2832
return hash((
2933
self.key_metadata.model_name,
3034
self.key_metadata.head_or_tp_rank,
35+
self.key_metadata.pcp_rank,
36+
self.key_metadata.dcp_rank,
3137
self.chunk_hash,
3238
))
3339

3440
def to_string(self):
3541
return (
3642
f"{self.key_metadata.model_name}"
43+
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
3744
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
3845
)
3946

@@ -60,13 +67,16 @@ def __hash__(self):
6067
return hash((
6168
self.key_metadata.model_name,
6269
self.key_metadata.head_or_tp_rank,
70+
self.key_metadata.pcp_rank,
71+
self.key_metadata.dcp_rank,
6372
self.chunk_hash,
6473
self.layer_id,
6574
))
6675

6776
def to_string(self):
6877
return (
6978
f"{self.key_metadata.model_name}"
79+
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
7080
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
7181
)
7282

vllm_ascend/distributed/kvpool/kv_transfer.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
class KVTransferThread(threading.Thread):
2020

2121
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
22-
tp_rank: int, ready_event: threading.Event, name: str):
22+
tp_rank: int, dcp_size: int, ready_event: threading.Event,
23+
name: str):
2324
super().__init__(daemon=True, name=name)
2425
self.m_store = m_store
2526
self.ready_event = ready_event
2627
self.tp_rank = tp_rank
28+
self.dcp_size = dcp_size
2729
self.token_database = token_database
2830
self.done_task_lock = threading.Lock()
2931
self.request_queue: queue.Queue[Any] = queue.Queue()
@@ -87,10 +89,12 @@ def _handle_request(self, req_meta: dict[str, Any]):
8789
class KVCacheStoreSendingThread(KVTransferThread):
8890

8991
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
90-
tp_rank: int, put_step: int, ready_event: threading.Event):
92+
tp_rank: int, dcp_size: int, put_step: int,
93+
ready_event: threading.Event):
9194
super().__init__(m_store,
9295
token_database,
9396
tp_rank,
97+
dcp_size,
9498
ready_event,
9599
name="KVCacheSendingThread")
96100
self.put_step = put_step
@@ -112,12 +116,18 @@ def _handle_request(self, req_meta: dict[str, Any]):
112116
key_list.append(key.to_string())
113117
addr_list.append(addr)
114118
size_list.append(size)
115-
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
116-
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
117-
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
118-
if key_list_tp:
119+
if self.dcp_size > 1:
119120
torch.npu.current_stream().synchronize()
120-
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
121+
self.m_store.put(key_list, addr_list, size_list)
122+
else:
123+
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
124+
addr_list_tp = addr_list[self.tp_rank %
125+
self.put_step::self.put_step]
126+
size_list_tp = size_list[self.tp_rank %
127+
self.put_step::self.put_step]
128+
if key_list_tp:
129+
torch.npu.current_stream().synchronize()
130+
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
121131
if is_last_chunk:
122132
self.set_finished_request(req_id)
123133
self.request_queue.task_done()
@@ -126,10 +136,11 @@ def _handle_request(self, req_meta: dict[str, Any]):
126136
class KVCacheStoreRecvingThread(KVTransferThread):
127137

128138
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
129-
tp_rank: int, ready_event: threading.Event):
139+
tp_rank: int, dcp_size: int, ready_event: threading.Event):
130140
super().__init__(m_store,
131141
token_database,
132142
tp_rank,
143+
dcp_size,
133144
ready_event,
134145
name="KVCacheStoreRecvingThread")
135146

@@ -166,11 +177,12 @@ def _handle_request(self, req_meta: dict[str, Any]):
166177
class KVCacheStoreLayerSendingThread(KVTransferThread):
167178

168179
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
169-
tp_rank: int, put_step: int, ready_event: threading.Event,
170-
num_layers: int):
180+
tp_rank: int, dcp_size: int, put_step: int,
181+
ready_event: threading.Event, num_layers: int):
171182
super().__init__(m_store,
172183
token_database,
173184
tp_rank,
185+
dcp_size,
174186
ready_event,
175187
name="KVCacheStoreLayerSendingThread")
176188
self.final_layer_id = num_layers - 1
@@ -192,12 +204,18 @@ def _handle_request( # type: ignore[override]
192204
key_list.append(key.to_string())
193205
addr_list.append(addr)
194206
size_list.append(size)
195-
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
196-
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
197-
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
198-
if key_list_tp:
207+
if self.dcp_size > 1:
199208
torch.npu.current_stream().synchronize()
200-
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
209+
self.m_store.put(key_list, addr_list, size_list)
210+
else:
211+
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
212+
addr_list_tp = addr_list[self.tp_rank %
213+
self.put_step::self.put_step]
214+
size_list_tp = size_list[self.tp_rank %
215+
self.put_step::self.put_step]
216+
if key_list_tp:
217+
torch.npu.current_stream().synchronize()
218+
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
201219
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
202220
self.set_finished_request(req_meta.req_id)
203221
self.request_queue.task_done()
@@ -206,11 +224,12 @@ def _handle_request( # type: ignore[override]
206224
class KVCacheStoreLayerRecvingThread(KVTransferThread):
207225

208226
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
209-
tp_rank: int, ready_event: threading.Event,
227+
tp_rank: int, dcp_size: int, ready_event: threading.Event,
210228
get_event: threading.Event):
211229
super().__init__(m_store,
212230
token_database,
213231
tp_rank,
232+
dcp_size,
214233
ready_event,
215234
name="KVCacheStoreLayerRecvingThread")
216235
self.get_event = get_event

vllm_ascend/distributed/kvpool/pool_scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise):
2929
"load_async", False)
3030
# request_id -> (vllm cached tokes, kvpool cached tokens)
3131
self.load_specs: dict[str, LoadSpec] = {}
32+
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
33+
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
34+
3235
self._block_size = vllm_config.cache_config.block_size
36+
if self.pcp_size > 1:
37+
self._block_size *= self.pcp_size
38+
if self.dcp_size > 1:
39+
self._block_size *= self.dcp_size
3340
# request_id -> full_token_ids
3441
self._request_trackers: dict[str, RequestTracker] = {}
3542
# Whether to discard partial chunks

vllm_ascend/distributed/kvpool/pool_worker.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
# Standard
21
import math
32
import threading
43
from typing import Dict, Generator, Optional, Type
54

6-
# Third Party
75
import torch
86
from vllm.config import VllmConfig
7+
from vllm.distributed import (get_decode_context_model_parallel_rank,
8+
get_decode_context_model_parallel_world_size,
9+
get_tensor_model_parallel_rank,
10+
get_tensor_model_parallel_world_size)
911
from vllm.utils import logger
1012
from vllm.v1.core.kv_cache_utils import BlockHash
1113

@@ -20,6 +22,14 @@
2022
from vllm_ascend.distributed.kvpool.kv_transfer import (
2123
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
2224
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
25+
from vllm_ascend.utils import prefill_context_parallel_enable
26+
27+
if prefill_context_parallel_enable():
28+
# isort: off
29+
from vllm.distributed import (get_prefill_context_model_parallel_rank,
30+
get_prefill_context_model_parallel_world_size
31+
)
32+
# isort: on
2333

2434
backend_map: Dict[str, Type[Backend]] = {
2535
"mooncake": MooncakeBackend,
@@ -44,17 +54,30 @@ def __init__(
4454
and model_config.use_mla):
4555
self.use_mla = True
4656
self.use_layerwise = use_layerwize
47-
self.tp_rank = parallel_config.rank
48-
self.tp_size = parallel_config.tensor_parallel_size
57+
self.tp_rank = get_tensor_model_parallel_rank()
58+
self.tp_size = get_tensor_model_parallel_world_size()
59+
60+
self.pcp_size = get_prefill_context_model_parallel_world_size(
61+
) if prefill_context_parallel_enable() else 1
62+
self.pcp_rank = get_prefill_context_model_parallel_rank(
63+
) if self.pcp_size > 1 else 0
64+
self.dcp_size = get_decode_context_model_parallel_world_size()
65+
self.dcp_rank = get_decode_context_model_parallel_rank(
66+
) if self.dcp_size > 1 else 0
67+
4968
self.kv_role = vllm_config.kv_transfer_config.kv_role
5069
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
5170
"load_async", False)
5271
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
5372
"backend", "mooncake")
5473
self.block_size = vllm_config.cache_config.block_size
74+
75+
if self.pcp_size > 1:
76+
self.block_size *= self.pcp_size
77+
if self.dcp_size > 1:
78+
self.block_size *= self.dcp_size
5579
self.current_layer = 0
5680
self.num_layers = model_config.get_num_layers(parallel_config)
57-
self.block_size = vllm_config.cache_config.block_size
5881

5982
if self.use_mla:
6083
self.num_kv_head = 1
@@ -69,8 +92,10 @@ def __init__(
6992
self.put_step = 1
7093

7194
self.metadata = KeyMetadata(
72-
model_config.model,
95+
model_config.model.split('/')[-1],
7396
self.head_or_tp_rank,
97+
self.pcp_rank,
98+
self.dcp_rank,
7499
)
75100

76101
self.token_database = ChunkedTokenDatabase(self.metadata,
@@ -147,26 +172,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
147172
ready_event_sending = threading.Event()
148173
self.kv_send_thread = KVCacheStoreLayerSendingThread(
149174
self.m_store, self.token_database, self.tp_rank,
150-
self.put_step, ready_event_sending, self.num_layers)
175+
self.dcp_size, self.put_step, ready_event_sending,
176+
self.num_layers)
151177
self.kv_send_thread.start()
152178
ready_event = threading.Event()
153179
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
154-
self.m_store, self.token_database, self.tp_rank, ready_event,
155-
self.get_event)
180+
self.m_store, self.token_database, self.tp_rank, self.dcp_size,
181+
ready_event, self.get_event)
156182
self.kv_recv_thread.start()
157183
ready_event.wait()
158184
else:
159185
if self.kv_role in ['kv_producer', 'kv_both']:
160186
ready_event_sending = threading.Event()
161187
self.kv_send_thread = KVCacheStoreSendingThread(
162188
self.m_store, self.token_database, self.tp_rank,
163-
self.put_step, ready_event_sending)
189+
self.dcp_size, self.put_step, ready_event_sending)
164190
self.kv_send_thread.start()
165191
if self.load_async:
166192
ready_event = threading.Event()
167193
self.kv_recv_thread = KVCacheStoreRecvingThread(
168194
self.m_store, self.token_database, self.tp_rank,
169-
ready_event)
195+
self.dcp_size, ready_event)
170196
self.kv_recv_thread.start()
171197
ready_event.wait()
172198

0 commit comments

Comments
 (0)