|
46 | 46 | from vllm.compilation.monitor import set_cudagraph_capturing_enabled |
47 | 47 | from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, |
48 | 48 | get_layers_from_vllm_config) |
| 49 | +from vllm.distributed.ec_transfer import (get_ec_transfer, |
| 50 | + has_ec_transfer) |
49 | 51 | from vllm.distributed import tensor_model_parallel_all_gather |
50 | 52 | from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
51 | 53 | has_kv_transfer_group) |
|
89 | 91 | KVCacheGroupSpec, KVCacheSpec, |
90 | 92 | MambaSpec, MLAAttentionSpec, |
91 | 93 | UniformTypeKVCacheSpecs) |
92 | | -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, |
93 | | - DraftTokenIds, LogprobsTensors, ModelRunnerOutput, |
94 | | - PoolerOutput) |
| 94 | +# yapf: enable |
| 95 | +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, |
| 96 | + AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors, |
| 97 | + ModelRunnerOutput, |
| 98 | + PoolerOutput, make_empty_encoder_model_runner_output) |
95 | 99 | from vllm.v1.pool.metadata import PoolingMetadata |
96 | 100 | from vllm.v1.sample.metadata import SamplingMetadata |
97 | 101 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
|
103 | 107 | gather_mm_placeholders, |
104 | 108 | sanity_check_mm_encoder_outputs, |
105 | 109 | scatter_mm_placeholders) |
| 110 | +from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin |
106 | 111 |
|
107 | 112 | import vllm_ascend.envs as envs_ascend |
108 | 113 | from vllm_ascend.ascend_config import get_ascend_config |
@@ -252,7 +257,7 @@ def get_output(self) -> ModelRunnerOutput: |
252 | 257 | return output |
253 | 258 |
|
254 | 259 |
|
255 | | -class NPUModelRunner(LoRAModelRunnerMixin): |
| 260 | +class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): |
256 | 261 |
|
257 | 262 | def __init__(self, vllm_config: VllmConfig, device: torch.device): |
258 | 263 | self.vllm_config = vllm_config |
@@ -758,6 +763,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: |
758 | 763 |
|
759 | 764 | req_ids_to_add.append(req_id) |
760 | 765 |
|
| 766 | + # If this rank is an EC transfer producer, |
| 767 | + # skip updating the states of KV cache blocks. |
| 768 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 769 | + return |
| 770 | + |
761 | 771 | # Update the states of the running/resumed requests. |
762 | 772 | is_last_rank = get_pp_group().is_last_rank |
763 | 773 | req_data = scheduler_output.scheduled_cached_reqs |
@@ -1614,8 +1624,12 @@ def _prepare_inputs( |
1614 | 1624 | # _prepare_inputs may reorder the batch, so we must gather |
1615 | 1625 | # multi-modal outputs after that to ensure the correct order |
1616 | 1626 | if self.is_multimodal_model: |
1617 | | - # Run the multimodal encoder if any. |
1618 | | - self._execute_mm_encoder(scheduler_output) |
| 1627 | + with self.maybe_get_ec_connector_output( |
| 1628 | + scheduler_output, |
| 1629 | + encoder_cache=self.encoder_cache, |
| 1630 | + ) as ec_connector_output: |
| 1631 | + # Run the multimodal encoder if any. |
| 1632 | + self._execute_mm_encoder(scheduler_output) |
1619 | 1633 |
|
1620 | 1634 | # NOTE(woosuk): To unify token ids and soft tokens (vision |
1621 | 1635 | # embeddings), we always use embeddings (rather than token ids) |
@@ -2261,6 +2275,14 @@ def execute_model( |
2261 | 2275 | ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: |
2262 | 2276 | with ProfileExecuteDuration().capture_async("prepare input"): |
2263 | 2277 | self._update_states(scheduler_output) |
| 2278 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 2279 | + with self.maybe_get_ec_connector_output( |
| 2280 | + scheduler_output, |
| 2281 | + encoder_cache=self.encoder_cache, |
| 2282 | + ) as ec_connector_output: |
| 2283 | + self._execute_mm_encoder(scheduler_output) |
| 2284 | + return make_empty_encoder_model_runner_output(scheduler_output) |
| 2285 | + |
2264 | 2286 | if not scheduler_output.total_num_scheduled_tokens: |
2265 | 2287 | if not has_kv_transfer_group(): |
2266 | 2288 | logger.debug( |
|
0 commit comments