Offline Inference Cli

Offline Inference Cli#

Source vllm-project/vllm.

 1from dataclasses import asdict
 2
 3from vllm import LLM, SamplingParams
 4from vllm.engine.arg_utils import EngineArgs
 5from vllm.utils import FlexibleArgumentParser
 6
 7
 8def get_prompts(num_prompts: int):
 9    # The default sample prompts.
10    prompts = [
11        "Hello, my name is",
12        "The president of the United States is",
13        "The capital of France is",
14        "The future of AI is",
15    ]
16
17    if num_prompts != len(prompts):
18        prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts]
19
20    return prompts
21
22
23def main(args):
24    # Create prompts
25    prompts = get_prompts(args.num_prompts)
26
27    # Create a sampling params object.
28    sampling_params = SamplingParams(n=args.n,
29                                     temperature=args.temperature,
30                                     top_p=args.top_p,
31                                     top_k=args.top_k,
32                                     max_tokens=args.max_tokens)
33
34    # Create an LLM.
35    # The default model is 'facebook/opt-125m'
36    engine_args = EngineArgs.from_cli_args(args)
37    llm = LLM(**asdict(engine_args))
38
39    # Generate texts from the prompts.
40    # The output is a list of RequestOutput objects
41    # that contain the prompt, generated text, and other information.
42    outputs = llm.generate(prompts, sampling_params)
43    # Print the outputs.
44    for output in outputs:
45        prompt = output.prompt
46        generated_text = output.outputs[0].text
47        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
48
49
50if __name__ == '__main__':
51    parser = FlexibleArgumentParser()
52    parser = EngineArgs.add_cli_args(parser)
53    group = parser.add_argument_group("SamplingParams options")
54    group.add_argument("--num-prompts",
55                       type=int,
56                       default=4,
57                       help="Number of prompts used for inference")
58    group.add_argument("--max-tokens",
59                       type=int,
60                       default=16,
61                       help="Generated output length for sampling")
62    group.add_argument('--n',
63                       type=int,
64                       default=1,
65                       help='Number of generated sequences per prompt')
66    group.add_argument('--temperature',
67                       type=float,
68                       default=0.8,
69                       help='Temperature for text generation')
70    group.add_argument('--top-p',
71                       type=float,
72                       default=0.95,
73                       help='top_p for text generation')
74    group.add_argument('--top-k',
75                       type=int,
76                       default=-1,
77                       help='top_k for text generation')
78
79    args = parser.parse_args()
80    main(args)