Skip to content

Commit 26e8e58

Browse files
authored
[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#4176)
### What this PR does / why we need it? Support Encoder separation for Encode-Prefill-Decode Disaggregation - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: amy-why-3459 <[email protected]>
1 parent 6ece666 commit 26e8e58

File tree

4 files changed

+73
-11
lines changed

4 files changed

+73
-11
lines changed

vllm_ascend/patch/platform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import vllm_ascend.patch.platform.patch_config # noqa
2020
import vllm_ascend.patch.platform.patch_distributed # noqa
21+
import vllm_ascend.patch.platform.patch_ec_connector # noqa
2122
import vllm_ascend.patch.platform.patch_mamba_config # noqa
2223
import vllm_ascend.patch.platform.patch_sched_yield # noqa
2324

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector
2+
from safetensors.torch import load_file
3+
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
4+
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
5+
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
6+
from vllm.logger import logger
7+
8+
9+
class AscendECSharedStorageConnector(ECSharedStorageConnector):
10+
11+
def start_load_caches(self, encoder_cache, **kwargs) -> None:
12+
metadata: ECConnectorMetadata = self._get_connector_metadata()
13+
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
14+
assert encoder_cache is not None
15+
if metadata is None:
16+
logger.warning((
17+
"In connector.start_load_caches, ",
18+
"but the connector metadata is None",
19+
))
20+
return
21+
# Load the EC for each mm data
22+
for mm_data in metadata.mm_datas:
23+
if mm_data.mm_hash in encoder_cache:
24+
continue
25+
filename = self._generate_filename_debug(mm_data.mm_hash)
26+
ec_cache = load_file(filename)["ec_cache"].npu()
27+
encoder_cache[mm_data.mm_hash] = ec_cache
28+
logger.debug("Success load encoder cache for hash %s",
29+
mm_data.mm_hash)
30+
31+
32+
vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector

vllm_ascend/worker/model_runner_v1.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
4848
get_layers_from_vllm_config)
4949
from vllm.distributed import tensor_model_parallel_all_gather
50+
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
5051
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
5152
has_kv_transfer_group)
5253
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -91,13 +92,16 @@
9192
UniformTypeKVCacheSpecs)
9293
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
9394
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
94-
PoolerOutput)
95+
PoolerOutput,
96+
make_empty_encoder_model_runner_output)
9597
from vllm.v1.pool.metadata import PoolingMetadata
9698
from vllm.v1.sample.metadata import SamplingMetadata
9799
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
98100
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
99101
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
100102
from vllm.v1.utils import CpuGpuBuffer
103+
from vllm.v1.worker.ec_connector_model_runner_mixin import \
104+
ECConnectorModelRunnerMixin
101105
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
102106
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
103107
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
@@ -268,7 +272,7 @@ class ExecuteModelState(NamedTuple):
268272
positions: torch.Tensor
269273

270274

271-
class NPUModelRunner(LoRAModelRunnerMixin):
275+
class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
272276

273277
def __init__(self, vllm_config: VllmConfig, device: torch.device):
274278
self.vllm_config = vllm_config
@@ -791,6 +795,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
791795

792796
req_ids_to_add.append(req_id)
793797

798+
# If this rank is an EC transfer producer,
799+
# skip updating the states of KV cache blocks.
800+
if has_ec_transfer() and get_ec_transfer().is_producer:
801+
return
802+
794803
# Update the states of the running/resumed requests.
795804
is_last_rank = get_pp_group().is_last_rank
796805
req_data = scheduler_output.scheduled_cached_reqs
@@ -1072,6 +1081,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
10721081
output,
10731082
is_embed=pos_info.is_embed,
10741083
)
1084+
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
10751085

10761086
def _batch_mm_kwargs_from_scheduler(
10771087
self,
@@ -1597,15 +1607,19 @@ def _prepare_inputs(
15971607
# _prepare_inputs may reorder the batch, so we must gather
15981608
# multi-modal outputs after that to ensure the correct order
15991609
if self.is_multimodal_model:
1600-
# Run the multimodal encoder if any.
1601-
self._execute_mm_encoder(scheduler_output)
1602-
1603-
# NOTE(woosuk): To unify token ids and soft tokens (vision
1604-
# embeddings), we always use embeddings (rather than token ids)
1605-
# as input to the multimodal model, even when the input is text.
1606-
input_ids = self.input_ids[:total_num_scheduled_tokens]
1607-
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
1608-
scheduler_output)
1610+
with self.maybe_get_ec_connector_output(
1611+
scheduler_output,
1612+
encoder_cache=self.encoder_cache,
1613+
):
1614+
# Run the multimodal encoder if any.
1615+
self._execute_mm_encoder(scheduler_output)
1616+
1617+
# NOTE(woosuk): To unify token ids and soft tokens (vision
1618+
# embeddings), we always use embeddings (rather than token ids)
1619+
# as input to the multimodal model, even when the input is text.
1620+
input_ids = self.input_ids[:total_num_scheduled_tokens]
1621+
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
1622+
scheduler_output)
16091623

16101624
inputs_embeds = self.model.embed_input_ids(
16111625
input_ids,
@@ -2248,6 +2262,15 @@ def execute_model(
22482262

22492263
with ProfileExecuteDuration().capture_async("prepare input"):
22502264
self._update_states(scheduler_output)
2265+
if has_ec_transfer() and get_ec_transfer().is_producer:
2266+
with self.maybe_get_ec_connector_output(
2267+
scheduler_output,
2268+
encoder_cache=self.encoder_cache,
2269+
):
2270+
self._execute_mm_encoder(scheduler_output)
2271+
return make_empty_encoder_model_runner_output(
2272+
scheduler_output)
2273+
22512274
if not scheduler_output.total_num_scheduled_tokens:
22522275
if not has_kv_transfer_group():
22532276
logger.debug(
@@ -3741,6 +3764,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
37413764
KVCacheSpec: A dictionary mapping layer names to their KV cache
37423765
format. Layers that do not need KV cache are not included.
37433766
"""
3767+
3768+
if has_ec_transfer() and get_ec_transfer().is_producer:
3769+
return {}
3770+
37443771
block_size = self.vllm_config.cache_config.block_size
37453772
use_mla = self.vllm_config.model_config.use_mla
37463773
kv_cache_spec: dict[str, KVCacheSpec] = {}

vllm_ascend/worker/worker_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.config import VllmConfig
3131
from vllm.distributed import (ensure_model_parallel_initialized,
3232
init_distributed_environment)
33+
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
3334
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
3435
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
3536
from vllm.logger import logger
@@ -417,6 +418,7 @@ def _init_worker_distributed_environment(self) -> None:
417418
self.parallel_config.decode_context_parallel_size)
418419
init_ascend_model_parallel(self.parallel_config)
419420
ensure_kv_transfer_initialized(self.vllm_config)
421+
ensure_ec_transfer_initialized(self.vllm_config)
420422

421423
def _init_profiler(self):
422424
# Torch profiler. Enabled and configured through env vars:

0 commit comments

Comments
 (0)