Skip to content

Commit ad5aaa2

Browse files
committed
Encoder separation for Encode-Prefill-Decode Disaggregation
1 parent bc69d7c commit ad5aaa2

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
4747
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
4848
get_layers_from_vllm_config)
49+
from vllm.distributed.ec_transfer import (get_ec_transfer,
50+
has_ec_transfer)
4951
from vllm.distributed import tensor_model_parallel_all_gather
5052
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
5153
has_kv_transfer_group)
@@ -89,9 +91,11 @@
8991
KVCacheGroupSpec, KVCacheSpec,
9092
MambaSpec, MLAAttentionSpec,
9193
UniformTypeKVCacheSpecs)
92-
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
93-
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
94-
PoolerOutput)
94+
# yapf: enable
95+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT,
96+
AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors,
97+
ModelRunnerOutput,
98+
PoolerOutput, make_empty_encoder_model_runner_output)
9599
from vllm.v1.pool.metadata import PoolingMetadata
96100
from vllm.v1.sample.metadata import SamplingMetadata
97101
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -103,6 +107,7 @@
103107
gather_mm_placeholders,
104108
sanity_check_mm_encoder_outputs,
105109
scatter_mm_placeholders)
110+
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
106111

107112
import vllm_ascend.envs as envs_ascend
108113
from vllm_ascend.ascend_config import get_ascend_config
@@ -792,6 +797,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
792797

793798
req_ids_to_add.append(req_id)
794799

800+
# If this rank is an EC transfer producer,
801+
# skip updating the states of KV cache blocks.
802+
if has_ec_transfer() and get_ec_transfer().is_producer:
803+
return
804+
795805
# Update the states of the running/resumed requests.
796806
is_last_rank = get_pp_group().is_last_rank
797807
req_data = scheduler_output.scheduled_cached_reqs
@@ -1620,8 +1630,12 @@ def _prepare_inputs(
16201630
# _prepare_inputs may reorder the batch, so we must gather
16211631
# multi-modal outputs after that to ensure the correct order
16221632
if self.is_multimodal_model:
1623-
# Run the multimodal encoder if any.
1624-
self._execute_mm_encoder(scheduler_output)
1633+
with self.maybe_get_ec_connector_output(
1634+
scheduler_output,
1635+
encoder_cache=self.encoder_cache,
1636+
) as ec_connector_output:
1637+
# Run the multimodal encoder if any.
1638+
self._execute_mm_encoder(scheduler_output)
16251639

16261640
# NOTE(woosuk): To unify token ids and soft tokens (vision
16271641
# embeddings), we always use embeddings (rather than token ids)
@@ -2272,6 +2286,14 @@ def execute_model(
22722286

22732287
with ProfileExecuteDuration().capture_async("prepare input"):
22742288
self._update_states(scheduler_output)
2289+
if has_ec_transfer() and get_ec_transfer().is_producer:
2290+
with self.maybe_get_ec_connector_output(
2291+
scheduler_output,
2292+
encoder_cache=self.encoder_cache,
2293+
) as ec_connector_output:
2294+
self._execute_mm_encoder(scheduler_output)
2295+
return make_empty_encoder_model_runner_output(scheduler_output)
2296+
22752297
if not scheduler_output.total_num_scheduled_tokens:
22762298
if not has_kv_transfer_group():
22772299
logger.debug(

vllm_ascend/worker/worker_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.config import VllmConfig
3131
from vllm.distributed import (ensure_model_parallel_initialized,
3232
init_distributed_environment)
33+
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
3334
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
3435
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
3536
from vllm.logger import logger
@@ -420,6 +421,7 @@ def _init_worker_distributed_environment(self) -> None:
420421
self.parallel_config.decode_context_parallel_size)
421422
init_ascend_model_parallel(self.parallel_config)
422423
ensure_kv_transfer_initialized(self.vllm_config)
424+
ensure_ec_transfer_initialized(self.vllm_config)
423425

424426
def _init_profiler(self):
425427
# Torch profiler. Enabled and configured through env vars:

0 commit comments

Comments
 (0)