diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py index 0f27548a6f5..81897780095 100644 --- a/tests/ut/worker/test_model_runner_v1.py +++ b/tests/ut/worker/test_model_runner_v1.py @@ -14,11 +14,17 @@ from unittest.mock import MagicMock, patch import pytest +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VllmConfig) +from vllm.platforms import current_platform from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import AscendDeviceType from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +DEVICE = current_platform.device_type +BLOCK_SIZE = 16 + # yapf: disable @pytest.mark.parametrize( @@ -109,3 +115,46 @@ def test_select_moe_comm_method_unsupported_soc(): pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"): NPUModelRunner._select_moe_comm_method(mock_runner, 100) + + +def get_vllm_config(): + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + ) + model_config = ModelConfig( + model="facebook/opt-125m", + dtype="float16", + seed=42, + ) + cache_config = CacheConfig( + block_size=BLOCK_SIZE, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + ) + parallel_config = ParallelConfig() + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + ) + return vllm_config + + +@pytest.fixture +def model_runner(): + vllm_config = get_vllm_config() + return NPUModelRunner(vllm_config, DEVICE) + + +def test_update_config(model_runner): + """ + Tests test_update_config simple update model_runner's load_config, and raise error on non-existing config + """ + model_runner.update_config({"load_config": {"load_format": "dummy"}}) + assert model_runner.load_config.load_format == "dummy" + with pytest.raises(AssertionError): + model_runner.update_config({"do_not_exist_config": "dummy"}) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ff55d1d1897..f43f2b7be08 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -45,7 +45,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config) + get_layers_from_vllm_config, update_config) from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -4465,7 +4465,7 @@ def _generate_pcp_mtp_input( self.input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full], non_blocking=True, ) - + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. @@ -4480,3 +4480,14 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]: self.transfer_event.record() self.transfer_event.synchronize() return [row for row in pinned.numpy()] + + + def update_config(self, overrides: dict[str, Any]) -> None: + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" + config = getattr(self, config_name) + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index e9000eae38e..f9273c8b7e0 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -19,7 +19,7 @@ import copy from types import NoneType -from typing import Optional +from typing import Any, Optional import torch import torch.nn as nn @@ -468,3 +468,6 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": def take_draft_token_ids(self) -> Optional[DraftTokenIds]: return self.model_runner.take_draft_token_ids() + + def update_config(self, overrides: dict[str, Any]) -> None: + self.model_runner.update_config(overrides) \ No newline at end of file