Offline Inference Structured Outputs

Offline Inference Structured Outputs#

Source vllm-project/vllm.

 1from enum import Enum
 2
 3from pydantic import BaseModel
 4
 5from vllm import LLM, SamplingParams
 6from vllm.sampling_params import GuidedDecodingParams
 7
 8llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
 9
10# Guided decoding by Choice (list of possible options)
11guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"])
12sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
13outputs = llm.generate(
14    prompts="Classify this sentiment: vLLM is wonderful!",
15    sampling_params=sampling_params,
16)
17print(outputs[0].outputs[0].text)
18
19# Guided decoding by Regex
20guided_decoding_params = GuidedDecodingParams(regex="\w+@\w+\.com\n")
21sampling_params = SamplingParams(guided_decoding=guided_decoding_params,
22                                 stop=["\n"])
23prompt = ("Generate an email address for Alan Turing, who works in Enigma."
24          "End in .com and new line. Example result:"
25          "[email protected]\n")
26outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
27print(outputs[0].outputs[0].text)
28
29
30# Guided decoding by JSON using Pydantic schema
31class CarType(str, Enum):
32    sedan = "sedan"
33    suv = "SUV"
34    truck = "Truck"
35    coupe = "Coupe"
36
37
38class CarDescription(BaseModel):
39    brand: str
40    model: str
41    car_type: CarType
42
43
44json_schema = CarDescription.model_json_schema()
45
46guided_decoding_params = GuidedDecodingParams(json=json_schema)
47sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
48prompt = ("Generate a JSON with the brand, model and car_type of"
49          "the most iconic car from the 90's")
50outputs = llm.generate(
51    prompts=prompt,
52    sampling_params=sampling_params,
53)
54print(outputs[0].outputs[0].text)
55
56# Guided decoding by Grammar
57simplified_sql_grammar = """
58    ?start: select_statement
59
60    ?select_statement: "SELECT " column_list " FROM " table_name
61
62    ?column_list: column_name ("," column_name)*
63
64    ?table_name: identifier
65
66    ?column_name: identifier
67
68    ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
69"""
70guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar)
71sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
72prompt = ("Generate an SQL query to show the 'username' and 'email'"
73          "from the 'users' table.")
74outputs = llm.generate(
75    prompts=prompt,
76    sampling_params=sampling_params,
77)
78print(outputs[0].outputs[0].text)