Skip to content
46 changes: 46 additions & 0 deletions tests/ut/worker/test_model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import AscendSocVersion
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
from vllm.platforms import current_platform
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig, set_current_vllm_config)

DEVICE = current_platform.device_type
BLOCK_SIZE = 16


# yapf: disable
Expand Down Expand Up @@ -109,3 +115,43 @@ def test_select_moe_comm_method_unsupported_soc():
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):

NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)


def get_vllm_config():
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
)
model_config = ModelConfig(
model="facebook/opt-125m",
dtype="float16",
seed=42,
)
cache_config = CacheConfig(
block_size=BLOCK_SIZE,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
parallel_config = ParallelConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)
return vllm_config

@pytest.fixture
def model_runner():
vllm_config = get_vllm_config()
return NPUModelRunner(vllm_config, DEVICE)

def test_update_config(model_runner):
# Simple update
model_runner.update_config({"load_config": {"load_format": "dummy"}})
assert model_runner.load_config.load_format == "dummy"
# Raise error on non-existing config
with pytest.raises(AssertionError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To align with the recommended change of raising a ValueError for invalid inputs in NPUModelRunner.update_config, this test should be updated to expect a ValueError.

Suggested change
with pytest.raises(AssertionError):
with pytest.raises(ValueError):

model_runner.update_config({"do_not_exist_config": "dummy"})
12 changes: 11 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, update_config
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
Expand Down Expand Up @@ -4592,3 +4592,13 @@ def _generate_pcp_mtp_input(
self.input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full],
non_blocking=True,
)

def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
Comment on lines +4488 to +4490
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using assert for input validation is not recommended as assertions can be disabled with the -O Python flag, which would lead to this check being skipped. It's more robust to raise a ValueError for invalid configuration keys.

Additionally, sorting the allowed_config_names in the error message will ensure deterministic output, which is good practice.

Note that this change will require updating test_update_config to expect a ValueError.

Suggested change
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
if config_name not in allowed_config_names:
allowed = sorted(list(allowed_config_names))
raise ValueError(f"Config `{config_name}` not supported. Allowed: {allowed}")

config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
5 changes: 4 additions & 1 deletion vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#

import copy
from typing import Optional, Union
from typing import Optional, Union, Any

import torch
import torch.nn as nn
Expand Down Expand Up @@ -461,3 +461,6 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":

def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()

def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
Loading