|
46 | 46 | from vllm.compilation.monitor import set_cudagraph_capturing_enabled |
47 | 47 | from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, |
48 | 48 | get_layers_from_vllm_config) |
| 49 | +from vllm.distributed.ec_transfer import (get_ec_transfer, |
| 50 | + has_ec_transfer) |
49 | 51 | from vllm.distributed import tensor_model_parallel_all_gather |
50 | 52 | from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
51 | 53 | has_kv_transfer_group) |
|
89 | 91 | KVCacheGroupSpec, KVCacheSpec, |
90 | 92 | MambaSpec, MLAAttentionSpec, |
91 | 93 | UniformTypeKVCacheSpecs) |
92 | | -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, |
93 | | - DraftTokenIds, LogprobsTensors, ModelRunnerOutput, |
94 | | - PoolerOutput) |
| 94 | +# yapf: enable |
| 95 | +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, |
| 96 | + AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors, |
| 97 | + ModelRunnerOutput, |
| 98 | + PoolerOutput, make_empty_encoder_model_runner_output) |
95 | 99 | from vllm.v1.pool.metadata import PoolingMetadata |
96 | 100 | from vllm.v1.sample.metadata import SamplingMetadata |
97 | 101 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
|
103 | 107 | gather_mm_placeholders, |
104 | 108 | sanity_check_mm_encoder_outputs, |
105 | 109 | scatter_mm_placeholders) |
| 110 | +from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin |
106 | 111 |
|
107 | 112 | import vllm_ascend.envs as envs_ascend |
108 | 113 | from vllm_ascend.ascend_config import get_ascend_config |
@@ -792,6 +797,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: |
792 | 797 |
|
793 | 798 | req_ids_to_add.append(req_id) |
794 | 799 |
|
| 800 | + # If this rank is an EC transfer producer, |
| 801 | + # skip updating the states of KV cache blocks. |
| 802 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 803 | + return |
| 804 | + |
795 | 805 | # Update the states of the running/resumed requests. |
796 | 806 | is_last_rank = get_pp_group().is_last_rank |
797 | 807 | req_data = scheduler_output.scheduled_cached_reqs |
@@ -1620,8 +1630,12 @@ def _prepare_inputs( |
1620 | 1630 | # _prepare_inputs may reorder the batch, so we must gather |
1621 | 1631 | # multi-modal outputs after that to ensure the correct order |
1622 | 1632 | if self.is_multimodal_model: |
1623 | | - # Run the multimodal encoder if any. |
1624 | | - self._execute_mm_encoder(scheduler_output) |
| 1633 | + with self.maybe_get_ec_connector_output( |
| 1634 | + scheduler_output, |
| 1635 | + encoder_cache=self.encoder_cache, |
| 1636 | + ) as ec_connector_output: |
| 1637 | + # Run the multimodal encoder if any. |
| 1638 | + self._execute_mm_encoder(scheduler_output) |
1625 | 1639 |
|
1626 | 1640 | # NOTE(woosuk): To unify token ids and soft tokens (vision |
1627 | 1641 | # embeddings), we always use embeddings (rather than token ids) |
@@ -2272,6 +2286,14 @@ def execute_model( |
2272 | 2286 |
|
2273 | 2287 | with ProfileExecuteDuration().capture_async("prepare input"): |
2274 | 2288 | self._update_states(scheduler_output) |
| 2289 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 2290 | + with self.maybe_get_ec_connector_output( |
| 2291 | + scheduler_output, |
| 2292 | + encoder_cache=self.encoder_cache, |
| 2293 | + ) as ec_connector_output: |
| 2294 | + self._execute_mm_encoder(scheduler_output) |
| 2295 | + return make_empty_encoder_model_runner_output(scheduler_output) |
| 2296 | + |
2275 | 2297 | if not scheduler_output.total_num_scheduled_tokens: |
2276 | 2298 | if not has_kv_transfer_group(): |
2277 | 2299 | logger.debug( |
|
0 commit comments