Skip to content

Commit 29f7d97

Browse files
authored
Improve parse_raw_prompt test cases for invalid input .v2 (#30512)
Signed-off-by: Kayvan Mivehnejad <[email protected]>
1 parent dc7fb5b commit 29f7d97

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

tests/test_inputs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434
]
3535

3636

37+
# Test that a nested mixed-type list of lists raises a TypeError.
38+
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
39+
def test_invalid_input_raise_type_error(invalid_input):
40+
with pytest.raises(TypeError):
41+
parse_raw_prompts(invalid_input)
42+
43+
3744
def test_parse_raw_single_batch_empty():
3845
with pytest.raises(ValueError, match="at least one prompt"):
3946
parse_raw_prompts([])

vllm/inputs/parse.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,31 @@ def parse_raw_prompts(
3333
if len(prompt) == 0:
3434
raise ValueError("please provide at least one prompt")
3535

36+
# case 2: array of strings
3637
if is_list_of(prompt, str):
37-
# case 2: array of strings
3838
prompt = cast(list[str], prompt)
3939
return [TextPrompt(prompt=elem) for elem in prompt]
40+
41+
# case 3: array of tokens
4042
if is_list_of(prompt, int):
41-
# case 3: array of tokens
4243
prompt = cast(list[int], prompt)
4344
return [TokensPrompt(prompt_token_ids=prompt)]
45+
46+
# case 4: array of token arrays
4447
if is_list_of(prompt, list):
45-
prompt = cast(list[list[int]], prompt)
46-
if len(prompt[0]) == 0:
47-
raise ValueError("please provide at least one prompt")
48+
first = prompt[0]
49+
if not isinstance(first, list):
50+
raise ValueError("prompt expected to be a list of lists")
4851

49-
if is_list_of(prompt[0], int):
50-
# case 4: array of token arrays
51-
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
52+
if len(first) == 0:
53+
raise ValueError("Please provide at least one prompt")
54+
55+
# strict validation: every nested list must be list[int]
56+
if not all(is_list_of(elem, int) for elem in prompt):
57+
raise TypeError("Nested lists must contain only integers")
58+
59+
prompt = cast(list[list[int]], prompt)
60+
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
5261

5362
raise TypeError(
5463
"prompt must be a string, array of strings, "

0 commit comments

Comments
 (0)