|
47 | 47 | from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, |
48 | 48 | get_layers_from_vllm_config) |
49 | 49 | from vllm.distributed import tensor_model_parallel_all_gather |
| 50 | +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer |
50 | 51 | from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
51 | 52 | has_kv_transfer_group) |
52 | 53 | from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 |
|
91 | 92 | UniformTypeKVCacheSpecs) |
92 | 93 | from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, |
93 | 94 | DraftTokenIds, LogprobsTensors, ModelRunnerOutput, |
94 | | - PoolerOutput) |
| 95 | + PoolerOutput, |
| 96 | + make_empty_encoder_model_runner_output) |
95 | 97 | from vllm.v1.pool.metadata import PoolingMetadata |
96 | 98 | from vllm.v1.sample.metadata import SamplingMetadata |
97 | 99 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
98 | 100 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
99 | 101 | from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer |
100 | 102 | from vllm.v1.utils import CpuGpuBuffer |
| 103 | +from vllm.v1.worker.ec_connector_model_runner_mixin import \ |
| 104 | + ECConnectorModelRunnerMixin |
101 | 105 | from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput |
102 | 106 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
103 | 107 | from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, |
@@ -268,7 +272,7 @@ class ExecuteModelState(NamedTuple): |
268 | 272 | positions: torch.Tensor |
269 | 273 |
|
270 | 274 |
|
271 | | -class NPUModelRunner(LoRAModelRunnerMixin): |
| 275 | +class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): |
272 | 276 |
|
273 | 277 | def __init__(self, vllm_config: VllmConfig, device: torch.device): |
274 | 278 | self.vllm_config = vllm_config |
@@ -791,6 +795,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: |
791 | 795 |
|
792 | 796 | req_ids_to_add.append(req_id) |
793 | 797 |
|
| 798 | + # If this rank is an EC transfer producer, |
| 799 | + # skip updating the states of KV cache blocks. |
| 800 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 801 | + return |
| 802 | + |
794 | 803 | # Update the states of the running/resumed requests. |
795 | 804 | is_last_rank = get_pp_group().is_last_rank |
796 | 805 | req_data = scheduler_output.scheduled_cached_reqs |
@@ -1072,6 +1081,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): |
1072 | 1081 | output, |
1073 | 1082 | is_embed=pos_info.is_embed, |
1074 | 1083 | ) |
| 1084 | + self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) |
1075 | 1085 |
|
1076 | 1086 | def _batch_mm_kwargs_from_scheduler( |
1077 | 1087 | self, |
@@ -1597,15 +1607,19 @@ def _prepare_inputs( |
1597 | 1607 | # _prepare_inputs may reorder the batch, so we must gather |
1598 | 1608 | # multi-modal outputs after that to ensure the correct order |
1599 | 1609 | if self.is_multimodal_model: |
1600 | | - # Run the multimodal encoder if any. |
1601 | | - self._execute_mm_encoder(scheduler_output) |
1602 | | - |
1603 | | - # NOTE(woosuk): To unify token ids and soft tokens (vision |
1604 | | - # embeddings), we always use embeddings (rather than token ids) |
1605 | | - # as input to the multimodal model, even when the input is text. |
1606 | | - input_ids = self.input_ids[:total_num_scheduled_tokens] |
1607 | | - mm_embeds, is_mm_embed = self._gather_mm_embeddings( |
1608 | | - scheduler_output) |
| 1610 | + with self.maybe_get_ec_connector_output( |
| 1611 | + scheduler_output, |
| 1612 | + encoder_cache=self.encoder_cache, |
| 1613 | + ): |
| 1614 | + # Run the multimodal encoder if any. |
| 1615 | + self._execute_mm_encoder(scheduler_output) |
| 1616 | + |
| 1617 | + # NOTE(woosuk): To unify token ids and soft tokens (vision |
| 1618 | + # embeddings), we always use embeddings (rather than token ids) |
| 1619 | + # as input to the multimodal model, even when the input is text. |
| 1620 | + input_ids = self.input_ids[:total_num_scheduled_tokens] |
| 1621 | + mm_embeds, is_mm_embed = self._gather_mm_embeddings( |
| 1622 | + scheduler_output) |
1609 | 1623 |
|
1610 | 1624 | inputs_embeds = self.model.embed_input_ids( |
1611 | 1625 | input_ids, |
@@ -2248,6 +2262,15 @@ def execute_model( |
2248 | 2262 |
|
2249 | 2263 | with ProfileExecuteDuration().capture_async("prepare input"): |
2250 | 2264 | self._update_states(scheduler_output) |
| 2265 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 2266 | + with self.maybe_get_ec_connector_output( |
| 2267 | + scheduler_output, |
| 2268 | + encoder_cache=self.encoder_cache, |
| 2269 | + ): |
| 2270 | + self._execute_mm_encoder(scheduler_output) |
| 2271 | + return make_empty_encoder_model_runner_output( |
| 2272 | + scheduler_output) |
| 2273 | + |
2251 | 2274 | if not scheduler_output.total_num_scheduled_tokens: |
2252 | 2275 | if not has_kv_transfer_group(): |
2253 | 2276 | logger.debug( |
@@ -3741,6 +3764,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: |
3741 | 3764 | KVCacheSpec: A dictionary mapping layer names to their KV cache |
3742 | 3765 | format. Layers that do not need KV cache are not included. |
3743 | 3766 | """ |
| 3767 | + |
| 3768 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 3769 | + return {} |
| 3770 | + |
3744 | 3771 | block_size = self.vllm_config.cache_config.block_size |
3745 | 3772 | use_mla = self.vllm_config.model_config.use_mla |
3746 | 3773 | kv_cache_spec: dict[str, KVCacheSpec] = {} |
|
0 commit comments