@@ -33,22 +33,31 @@ def parse_raw_prompts(
3333 if len (prompt ) == 0 :
3434 raise ValueError ("please provide at least one prompt" )
3535
36+ # case 2: array of strings
3637 if is_list_of (prompt , str ):
37- # case 2: array of strings
3838 prompt = cast (list [str ], prompt )
3939 return [TextPrompt (prompt = elem ) for elem in prompt ]
40+
41+ # case 3: array of tokens
4042 if is_list_of (prompt , int ):
41- # case 3: array of tokens
4243 prompt = cast (list [int ], prompt )
4344 return [TokensPrompt (prompt_token_ids = prompt )]
45+
46+ # case 4: array of token arrays
4447 if is_list_of (prompt , list ):
45- prompt = cast ( list [ list [ int ]], prompt )
46- if len ( prompt [ 0 ]) == 0 :
47- raise ValueError ("please provide at least one prompt " )
48+ first = prompt [ 0 ]
49+ if not isinstance ( first , list ) :
50+ raise ValueError ("prompt expected to be a list of lists " )
4851
49- if is_list_of (prompt [0 ], int ):
50- # case 4: array of token arrays
51- return [TokensPrompt (prompt_token_ids = elem ) for elem in prompt ]
52+ if len (first ) == 0 :
53+ raise ValueError ("Please provide at least one prompt" )
54+
55+ # strict validation: every nested list must be list[int]
56+ if not all (is_list_of (elem , int ) for elem in prompt ):
57+ raise TypeError ("Nested lists must contain only integers" )
58+
59+ prompt = cast (list [list [int ]], prompt )
60+ return [TokensPrompt (prompt_token_ids = elem ) for elem in prompt ]
5261
5362 raise TypeError (
5463 "prompt must be a string, array of strings, "
0 commit comments