11# SPDX-License-Identifier: Apache-2.0
2+ """
3+ This file demonstrates the example usage of guided decoding
4+ to generate structured outputs using vLLM. It shows how to apply
5+ different guided decoding techniques such as Choice, Regex, JSON schema,
6+ and Grammar to produce structured and formatted results
7+ based on specific prompts.
8+ """
29
310from enum import Enum
411
714from vllm import LLM , SamplingParams
815from vllm .sampling_params import GuidedDecodingParams
916
10- llm = LLM (model = "Qwen/Qwen2.5-3B-Instruct" , max_model_len = 100 )
11-
1217# Guided decoding by Choice (list of possible options)
13- guided_decoding_params = GuidedDecodingParams (choice = ["Positive" , "Negative" ])
14- sampling_params = SamplingParams (guided_decoding = guided_decoding_params )
15- outputs = llm .generate (
16- prompts = "Classify this sentiment: vLLM is wonderful!" ,
17- sampling_params = sampling_params ,
18- )
19- print (outputs [0 ].outputs [0 ].text )
18+ guided_decoding_params_choice = GuidedDecodingParams (
19+ choice = ["Positive" , "Negative" ])
20+ sampling_params_choice = SamplingParams (
21+ guided_decoding = guided_decoding_params_choice )
22+ prompt_choice = "Classify this sentiment: vLLM is wonderful!"
2023
2124# Guided decoding by Regex
22- guided_decoding_params = GuidedDecodingParams (regex = r"\w+@\w+\.com\n" )
23- sampling_params = SamplingParams (guided_decoding = guided_decoding_params ,
24- stop = ["\n " ])
25- prompt = ("Generate an email address for Alan Turing, who works in Enigma."
26- "End in .com and new line. Example result:"
27- 28- outputs = llm .generate (prompts = prompt , sampling_params = sampling_params )
29- print (outputs [0 ].outputs [0 ].text )
25+ guided_decoding_params_regex = GuidedDecodingParams (regex = r"\w+@\w+\.com\n" )
26+ sampling_params_regex = SamplingParams (
27+ guided_decoding = guided_decoding_params_regex , stop = ["\n " ])
28+ prompt_regex = (
29+ "Generate an email address for Alan Turing, who works in Enigma."
30+ "End in .com and new line. Example result:"
31+ 3032
3133
3234# Guided decoding by JSON using Pydantic schema
@@ -44,16 +46,11 @@ class CarDescription(BaseModel):
4446
4547
4648json_schema = CarDescription .model_json_schema ()
47-
48- guided_decoding_params = GuidedDecodingParams (json = json_schema )
49- sampling_params = SamplingParams (guided_decoding = guided_decoding_params )
50- prompt = ("Generate a JSON with the brand, model and car_type of"
51- "the most iconic car from the 90's" )
52- outputs = llm .generate (
53- prompts = prompt ,
54- sampling_params = sampling_params ,
55- )
56- print (outputs [0 ].outputs [0 ].text )
49+ guided_decoding_params_json = GuidedDecodingParams (json = json_schema )
50+ sampling_params_json = SamplingParams (
51+ guided_decoding = guided_decoding_params_json )
52+ prompt_json = ("Generate a JSON with the brand, model and car_type of"
53+ "the most iconic car from the 90's" )
5754
5855# Guided decoding by Grammar
5956simplified_sql_grammar = """
@@ -64,12 +61,39 @@ class CarDescription(BaseModel):
6461condition ::= column "= " number
6562number ::= "1 " | "2 "
6663"""
67- guided_decoding_params = GuidedDecodingParams (grammar = simplified_sql_grammar )
68- sampling_params = SamplingParams (guided_decoding = guided_decoding_params )
69- prompt = ("Generate an SQL query to show the 'username' and 'email'"
70- "from the 'users' table." )
71- outputs = llm .generate (
72- prompts = prompt ,
73- sampling_params = sampling_params ,
74- )
75- print (outputs [0 ].outputs [0 ].text )
64+ guided_decoding_params_grammar = GuidedDecodingParams (
65+ grammar = simplified_sql_grammar )
66+ sampling_params_grammar = SamplingParams (
67+ guided_decoding = guided_decoding_params_grammar )
68+ prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'"
69+ "from the 'users' table." )
70+
71+
72+ def format_output (title : str , output : str ):
73+ print (f"{ '-' * 50 } \n { title } : { output } \n { '-' * 50 } " )
74+
75+
76+ def generate_output (prompt : str , sampling_params : SamplingParams , llm : LLM ):
77+ outputs = llm .generate (prompts = prompt , sampling_params = sampling_params )
78+ return outputs [0 ].outputs [0 ].text
79+
80+
81+ def main ():
82+ llm = LLM (model = "Qwen/Qwen2.5-3B-Instruct" , max_model_len = 100 )
83+
84+ choice_output = generate_output (prompt_choice , sampling_params_choice , llm )
85+ format_output ("Guided decoding by Choice" , choice_output )
86+
87+ regex_output = generate_output (prompt_regex , sampling_params_regex , llm )
88+ format_output ("Guided decoding by Regex" , regex_output )
89+
90+ json_output = generate_output (prompt_json , sampling_params_json , llm )
91+ format_output ("Guided decoding by JSON" , json_output )
92+
93+ grammar_output = generate_output (prompt_grammar , sampling_params_grammar ,
94+ llm )
95+ format_output ("Guided decoding by Grammar" , grammar_output )
96+
97+
98+ if __name__ == "__main__" :
99+ main ()
0 commit comments