Offline Inference Cli#
Source vllm-project/vllm.
1from dataclasses import asdict
2
3from vllm import LLM, SamplingParams
4from vllm.engine.arg_utils import EngineArgs
5from vllm.utils import FlexibleArgumentParser
6
7
8def get_prompts(num_prompts: int):
9 # The default sample prompts.
10 prompts = [
11 "Hello, my name is",
12 "The president of the United States is",
13 "The capital of France is",
14 "The future of AI is",
15 ]
16
17 if num_prompts != len(prompts):
18 prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts]
19
20 return prompts
21
22
23def main(args):
24 # Create prompts
25 prompts = get_prompts(args.num_prompts)
26
27 # Create a sampling params object.
28 sampling_params = SamplingParams(n=args.n,
29 temperature=args.temperature,
30 top_p=args.top_p,
31 top_k=args.top_k,
32 max_tokens=args.max_tokens)
33
34 # Create an LLM.
35 # The default model is 'facebook/opt-125m'
36 engine_args = EngineArgs.from_cli_args(args)
37 llm = LLM(**asdict(engine_args))
38
39 # Generate texts from the prompts.
40 # The output is a list of RequestOutput objects
41 # that contain the prompt, generated text, and other information.
42 outputs = llm.generate(prompts, sampling_params)
43 # Print the outputs.
44 for output in outputs:
45 prompt = output.prompt
46 generated_text = output.outputs[0].text
47 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
48
49
50if __name__ == '__main__':
51 parser = FlexibleArgumentParser()
52 parser = EngineArgs.add_cli_args(parser)
53 group = parser.add_argument_group("SamplingParams options")
54 group.add_argument("--num-prompts",
55 type=int,
56 default=4,
57 help="Number of prompts used for inference")
58 group.add_argument("--max-tokens",
59 type=int,
60 default=16,
61 help="Generated output length for sampling")
62 group.add_argument('--n',
63 type=int,
64 default=1,
65 help='Number of generated sequences per prompt')
66 group.add_argument('--temperature',
67 type=float,
68 default=0.8,
69 help='Temperature for text generation')
70 group.add_argument('--top-p',
71 type=float,
72 default=0.95,
73 help='top_p for text generation')
74 group.add_argument('--top-k',
75 type=int,
76 default=-1,
77 help='top_k for text generation')
78
79 args = parser.parse_args()
80 main(args)