diff --git a/tests/test_inputs.py b/tests/test_inputs.py index c4339827de8b..0821de3f803a 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -34,6 +34,13 @@ ] +# Test that a nested mixed-type list of lists raises a TypeError. +@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]]) +def test_invalid_input_raise_type_error(invalid_input): + with pytest.raises(TypeError): + parse_raw_prompts(invalid_input) + + def test_parse_raw_single_batch_empty(): with pytest.raises(ValueError, match="at least one prompt"): parse_raw_prompts([]) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 211551be8e60..71289277eb98 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -33,22 +33,31 @@ def parse_raw_prompts( if len(prompt) == 0: raise ValueError("please provide at least one prompt") + # case 2: array of strings if is_list_of(prompt, str): - # case 2: array of strings prompt = cast(list[str], prompt) return [TextPrompt(prompt=elem) for elem in prompt] + + # case 3: array of tokens if is_list_of(prompt, int): - # case 3: array of tokens prompt = cast(list[int], prompt) return [TokensPrompt(prompt_token_ids=prompt)] + + # case 4: array of token arrays if is_list_of(prompt, list): - prompt = cast(list[list[int]], prompt) - if len(prompt[0]) == 0: - raise ValueError("please provide at least one prompt") + first = prompt[0] + if not isinstance(first, list): + raise ValueError("prompt expected to be a list of lists") - if is_list_of(prompt[0], int): - # case 4: array of token arrays - return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] + if len(first) == 0: + raise ValueError("Please provide at least one prompt") + + # strict validation: every nested list must be list[int] + if not all(is_list_of(elem, int) for elem in prompt): + raise TypeError("Nested lists must contain only integers") + + prompt = cast(list[list[int]], prompt) + return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] raise TypeError( "prompt must be a string, array of strings, "