Skip to content

Commit 1bff42c

Browse files
authored
[Misc] refactor Structured Outputs example (vllm-project#16322)
Signed-off-by: reidliu41 <[email protected]> Co-authored-by: reidliu41 <[email protected]>
1 parent cb391d8 commit 1bff42c

File tree

1 file changed

+60
-36
lines changed

1 file changed

+60
-36
lines changed
Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This file demonstrates the example usage of guided decoding
4+
to generate structured outputs using vLLM. It shows how to apply
5+
different guided decoding techniques such as Choice, Regex, JSON schema,
6+
and Grammar to produce structured and formatted results
7+
based on specific prompts.
8+
"""
29

310
from enum import Enum
411

@@ -7,26 +14,21 @@
714
from vllm import LLM, SamplingParams
815
from vllm.sampling_params import GuidedDecodingParams
916

10-
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
11-
1217
# Guided decoding by Choice (list of possible options)
13-
guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"])
14-
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
15-
outputs = llm.generate(
16-
prompts="Classify this sentiment: vLLM is wonderful!",
17-
sampling_params=sampling_params,
18-
)
19-
print(outputs[0].outputs[0].text)
18+
guided_decoding_params_choice = GuidedDecodingParams(
19+
choice=["Positive", "Negative"])
20+
sampling_params_choice = SamplingParams(
21+
guided_decoding=guided_decoding_params_choice)
22+
prompt_choice = "Classify this sentiment: vLLM is wonderful!"
2023

2124
# Guided decoding by Regex
22-
guided_decoding_params = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
23-
sampling_params = SamplingParams(guided_decoding=guided_decoding_params,
24-
stop=["\n"])
25-
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
26-
"End in .com and new line. Example result:"
27-
28-
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
29-
print(outputs[0].outputs[0].text)
25+
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
26+
sampling_params_regex = SamplingParams(
27+
guided_decoding=guided_decoding_params_regex, stop=["\n"])
28+
prompt_regex = (
29+
"Generate an email address for Alan Turing, who works in Enigma."
30+
"End in .com and new line. Example result:"
31+
3032

3133

3234
# Guided decoding by JSON using Pydantic schema
@@ -44,16 +46,11 @@ class CarDescription(BaseModel):
4446

4547

4648
json_schema = CarDescription.model_json_schema()
47-
48-
guided_decoding_params = GuidedDecodingParams(json=json_schema)
49-
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
50-
prompt = ("Generate a JSON with the brand, model and car_type of"
51-
"the most iconic car from the 90's")
52-
outputs = llm.generate(
53-
prompts=prompt,
54-
sampling_params=sampling_params,
55-
)
56-
print(outputs[0].outputs[0].text)
49+
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
50+
sampling_params_json = SamplingParams(
51+
guided_decoding=guided_decoding_params_json)
52+
prompt_json = ("Generate a JSON with the brand, model and car_type of"
53+
"the most iconic car from the 90's")
5754

5855
# Guided decoding by Grammar
5956
simplified_sql_grammar = """
@@ -64,12 +61,39 @@ class CarDescription(BaseModel):
6461
condition ::= column "= " number
6562
number ::= "1 " | "2 "
6663
"""
67-
guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar)
68-
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
69-
prompt = ("Generate an SQL query to show the 'username' and 'email'"
70-
"from the 'users' table.")
71-
outputs = llm.generate(
72-
prompts=prompt,
73-
sampling_params=sampling_params,
74-
)
75-
print(outputs[0].outputs[0].text)
64+
guided_decoding_params_grammar = GuidedDecodingParams(
65+
grammar=simplified_sql_grammar)
66+
sampling_params_grammar = SamplingParams(
67+
guided_decoding=guided_decoding_params_grammar)
68+
prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'"
69+
"from the 'users' table.")
70+
71+
72+
def format_output(title: str, output: str):
73+
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
74+
75+
76+
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
77+
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
78+
return outputs[0].outputs[0].text
79+
80+
81+
def main():
82+
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
83+
84+
choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
85+
format_output("Guided decoding by Choice", choice_output)
86+
87+
regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
88+
format_output("Guided decoding by Regex", regex_output)
89+
90+
json_output = generate_output(prompt_json, sampling_params_json, llm)
91+
format_output("Guided decoding by JSON", json_output)
92+
93+
grammar_output = generate_output(prompt_grammar, sampling_params_grammar,
94+
llm)
95+
format_output("Guided decoding by Grammar", grammar_output)
96+
97+
98+
if __name__ == "__main__":
99+
main()

0 commit comments

Comments
 (0)