MultiLoRA Inference#
Source vllm-project/vllm.
1"""
2This example shows how to use the multi-LoRA functionality
3for offline inference.
4
5Requires HuggingFace credentials for access to Llama2.
6"""
7
8from typing import List, Optional, Tuple
9
10from huggingface_hub import snapshot_download
11
12from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
13from vllm.lora.request import LoRARequest
14
15
16def create_test_prompts(
17 lora_path: str
18) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
19 """Create a list of test prompts with their sampling parameters.
20
21 2 requests for base model, 4 requests for the LoRA. We define 2
22 different LoRA adapters (using the same model for demo purposes).
23 Since we also set `max_loras=1`, the expectation is that the requests
24 with the second LoRA adapter will be ran after all requests with the
25 first adapter have finished.
26 """
27 return [
28 ("A robot may not injure a human being",
29 SamplingParams(temperature=0.0,
30 logprobs=1,
31 prompt_logprobs=1,
32 max_tokens=128), None),
33 ("To be or not to be,",
34 SamplingParams(temperature=0.8,
35 top_k=5,
36 presence_penalty=0.2,
37 max_tokens=128), None),
38 (
39 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
40 SamplingParams(temperature=0.0,
41 logprobs=1,
42 prompt_logprobs=1,
43 max_tokens=128,
44 stop_token_ids=[32003]),
45 LoRARequest("sql-lora", 1, lora_path)),
46 (
47 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
48 SamplingParams(temperature=0.0,
49 logprobs=1,
50 prompt_logprobs=1,
51 max_tokens=128,
52 stop_token_ids=[32003]),
53 LoRARequest("sql-lora2", 2, lora_path)),
54 ]
55
56
57def process_requests(engine: LLMEngine,
58 test_prompts: List[Tuple[str, SamplingParams,
59 Optional[LoRARequest]]]):
60 """Continuously process a list of prompts and handle the outputs."""
61 request_id = 0
62
63 while test_prompts or engine.has_unfinished_requests():
64 if test_prompts:
65 prompt, sampling_params, lora_request = test_prompts.pop(0)
66 engine.add_request(str(request_id),
67 prompt,
68 sampling_params,
69 lora_request=lora_request)
70 request_id += 1
71
72 request_outputs: List[RequestOutput] = engine.step()
73
74 for request_output in request_outputs:
75 if request_output.finished:
76 print(request_output)
77
78
79def initialize_engine() -> LLMEngine:
80 """Initialize the LLMEngine."""
81 # max_loras: controls the number of LoRAs that can be used in the same
82 # batch. Larger numbers will cause higher memory usage, as each LoRA
83 # slot requires its own preallocated tensor.
84 # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
85 # numbers will cause higher memory usage. If you know that all LoRAs will
86 # use the same rank, it is recommended to set this as low as possible.
87 # max_cpu_loras: controls the size of the CPU LoRA cache.
88 engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
89 enable_lora=True,
90 max_loras=1,
91 max_lora_rank=8,
92 max_cpu_loras=2,
93 max_num_seqs=256)
94 return LLMEngine.from_engine_args(engine_args)
95
96
97def main():
98 """Main function that sets up and runs the prompt processing."""
99 engine = initialize_engine()
100 lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
101 test_prompts = create_test_prompts(lora_path)
102 process_requests(engine, test_prompts)
103
104
105if __name__ == '__main__':
106 main()