Skip to content

Commit 63971b1

Browse files
committed
Encoder separation for Encode-Prefill-Decode Disaggregation
1 parent ae068a3 commit 63971b1

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
4747
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
4848
get_layers_from_vllm_config)
49+
from vllm.distributed.ec_transfer import (get_ec_transfer,
50+
has_ec_transfer)
4951
from vllm.distributed import tensor_model_parallel_all_gather
5052
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
5153
has_kv_transfer_group)
@@ -89,9 +91,11 @@
8991
KVCacheGroupSpec, KVCacheSpec,
9092
MambaSpec, MLAAttentionSpec,
9193
UniformTypeKVCacheSpecs)
92-
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
93-
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
94-
PoolerOutput)
94+
# yapf: enable
95+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT,
96+
AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors,
97+
ModelRunnerOutput,
98+
PoolerOutput, make_empty_encoder_model_runner_output)
9599
from vllm.v1.pool.metadata import PoolingMetadata
96100
from vllm.v1.sample.metadata import SamplingMetadata
97101
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -103,6 +107,7 @@
103107
gather_mm_placeholders,
104108
sanity_check_mm_encoder_outputs,
105109
scatter_mm_placeholders)
110+
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
106111

107112
import vllm_ascend.envs as envs_ascend
108113
from vllm_ascend.ascend_config import get_ascend_config
@@ -252,7 +257,7 @@ def get_output(self) -> ModelRunnerOutput:
252257
return output
253258

254259

255-
class NPUModelRunner(LoRAModelRunnerMixin):
260+
class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
256261

257262
def __init__(self, vllm_config: VllmConfig, device: torch.device):
258263
self.vllm_config = vllm_config
@@ -758,6 +763,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
758763

759764
req_ids_to_add.append(req_id)
760765

766+
# If this rank is an EC transfer producer,
767+
# skip updating the states of KV cache blocks.
768+
if has_ec_transfer() and get_ec_transfer().is_producer:
769+
return
770+
761771
# Update the states of the running/resumed requests.
762772
is_last_rank = get_pp_group().is_last_rank
763773
req_data = scheduler_output.scheduled_cached_reqs
@@ -1614,8 +1624,12 @@ def _prepare_inputs(
16141624
# _prepare_inputs may reorder the batch, so we must gather
16151625
# multi-modal outputs after that to ensure the correct order
16161626
if self.is_multimodal_model:
1617-
# Run the multimodal encoder if any.
1618-
self._execute_mm_encoder(scheduler_output)
1627+
with self.maybe_get_ec_connector_output(
1628+
scheduler_output,
1629+
encoder_cache=self.encoder_cache,
1630+
) as ec_connector_output:
1631+
# Run the multimodal encoder if any.
1632+
self._execute_mm_encoder(scheduler_output)
16191633

16201634
# NOTE(woosuk): To unify token ids and soft tokens (vision
16211635
# embeddings), we always use embeddings (rather than token ids)
@@ -2261,6 +2275,14 @@ def execute_model(
22612275
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
22622276
with ProfileExecuteDuration().capture_async("prepare input"):
22632277
self._update_states(scheduler_output)
2278+
if has_ec_transfer() and get_ec_transfer().is_producer:
2279+
with self.maybe_get_ec_connector_output(
2280+
scheduler_output,
2281+
encoder_cache=self.encoder_cache,
2282+
) as ec_connector_output:
2283+
self._execute_mm_encoder(scheduler_output)
2284+
return make_empty_encoder_model_runner_output(scheduler_output)
2285+
22642286
if not scheduler_output.total_num_scheduled_tokens:
22652287
if not has_kv_transfer_group():
22662288
logger.debug(

vllm_ascend/worker/worker_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from vllm.config import VllmConfig
3030
from vllm.distributed import (ensure_model_parallel_initialized,
3131
init_distributed_environment)
32+
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
3233
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
3334
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
3435
from vllm.logger import logger
@@ -405,6 +406,7 @@ def _init_worker_distributed_environment(self) -> None:
405406
self.parallel_config.decode_context_parallel_size)
406407
init_ascend_model_parallel(self.parallel_config)
407408
ensure_kv_transfer_initialized(self.vllm_config)
409+
ensure_ec_transfer_initialized(self.vllm_config)
408410

409411
def _init_profiler(self):
410412
# Torch profiler. Enabled and configured through env vars:

0 commit comments

Comments
 (0)