Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@
]


# Test that a nested mixed-type list of lists raises a TypeError.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test case should still be valid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test was meant to enforce thorough validation across all nested lists, which goes beyond the intended scope of parse_raw_prompts. That function performs lightweight structural checks, avoiding asymptotically expensive ‘check=all’ validation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still check the first item of each prompt (1 and foo in this case) which should still trigger the error

@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
def test_invalid_input_raise_type_error(invalid_input):
with pytest.raises(TypeError):
parse_raw_prompts(invalid_input)


def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([])
Expand Down
21 changes: 11 additions & 10 deletions vllm/inputs/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,17 @@ def parse_raw_prompts(

# case 4: array of token arrays
if is_list_of(prompt, list):
first = prompt[0]
if not isinstance(first, list):
raise ValueError("prompt expected to be a list of lists")

if len(first) == 0:
raise ValueError("Please provide at least one prompt")

# strict validation: every nested list must be list[int]
if not all(is_list_of(elem, int) for elem in prompt):
raise TypeError("Nested lists must contain only integers")
if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
for elem in prompt:
if not isinstance(elem, list):
raise TypeError(
"prompt must be a list of lists, but found a non-list element."
)
if not is_list_of(elem, int):
raise TypeError(
"Nested lists of tokens must contain only integers."
)

prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
Expand Down