MultiLoRA Inference#
Source vllm-project/vllm.
1"""
2This example shows how to use the multi-LoRA functionality
3for offline inference.
4
5Requires HuggingFace credentials for access to Llama2.
6"""
7
8from typing import List, Optional, Tuple
9
10from huggingface_hub import snapshot_download
11
12from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
13from vllm.lora.request import LoRARequest
14
15
16def create_test_prompts(
17 lora_path: str
18) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
19 """Create a list of test prompts with their sampling parameters.
20
21 2 requests for base model, 4 requests for the LoRA. We define 2
22 different LoRA adapters (using the same model for demo purposes).
23 Since we also set `max_loras=1`, the expectation is that the requests
24 with the second LoRA adapter will be ran after all requests with the
25 first adapter have finished.
26 """
27 return [
28 ("A robot may not injure a human being",
29 SamplingParams(temperature=0.0,
30 logprobs=1,
31 prompt_logprobs=1,
32 max_tokens=128), None),
33 ("To be or not to be,",
34 SamplingParams(temperature=0.8,
35 top_k=5,
36 presence_penalty=0.2,
37 max_tokens=128), None),
38 (
39 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
40 SamplingParams(temperature=0.0,
41 logprobs=1,
42 prompt_logprobs=1,
43 max_tokens=128,
44 stop_token_ids=[32003]),
45 LoRARequest("sql-lora", 1, lora_path)),
46 (
47 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
48 SamplingParams(n=3,
49 best_of=3,
50 use_beam_search=True,
51 temperature=0,
52 max_tokens=128,
53 stop_token_ids=[32003]),
54 LoRARequest("sql-lora", 1, lora_path)),
55 (
56 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
57 SamplingParams(temperature=0.0,
58 logprobs=1,
59 prompt_logprobs=1,
60 max_tokens=128,
61 stop_token_ids=[32003]),
62 LoRARequest("sql-lora2", 2, lora_path)),
63 (
64 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
65 SamplingParams(n=3,
66 best_of=3,
67 use_beam_search=True,
68 temperature=0,
69 max_tokens=128,
70 stop_token_ids=[32003]),
71 LoRARequest("sql-lora", 1, lora_path)),
72 ]
73
74
75def process_requests(engine: LLMEngine,
76 test_prompts: List[Tuple[str, SamplingParams,
77 Optional[LoRARequest]]]):
78 """Continuously process a list of prompts and handle the outputs."""
79 request_id = 0
80
81 while test_prompts or engine.has_unfinished_requests():
82 if test_prompts:
83 prompt, sampling_params, lora_request = test_prompts.pop(0)
84 engine.add_request(str(request_id),
85 prompt,
86 sampling_params,
87 lora_request=lora_request)
88 request_id += 1
89
90 request_outputs: List[RequestOutput] = engine.step()
91
92 for request_output in request_outputs:
93 if request_output.finished:
94 print(request_output)
95
96
97def initialize_engine() -> LLMEngine:
98 """Initialize the LLMEngine."""
99 # max_loras: controls the number of LoRAs that can be used in the same
100 # batch. Larger numbers will cause higher memory usage, as each LoRA
101 # slot requires its own preallocated tensor.
102 # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
103 # numbers will cause higher memory usage. If you know that all LoRAs will
104 # use the same rank, it is recommended to set this as low as possible.
105 # max_cpu_loras: controls the size of the CPU LoRA cache.
106 engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
107 enable_lora=True,
108 max_loras=1,
109 max_lora_rank=8,
110 max_cpu_loras=2,
111 max_num_seqs=256)
112 return LLMEngine.from_engine_args(engine_args)
113
114
115def main():
116 """Main function that sets up and runs the prompt processing."""
117 engine = initialize_engine()
118 lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
119 test_prompts = create_test_prompts(lora_path)
120 process_requests(engine, test_prompts)
121
122
123if __name__ == '__main__':
124 main()