Lora With Quantization Inference#
Source vllm-project/vllm.
1"""
2This example shows how to use LoRA with different quantization techniques
3for offline inference.
4
5Requires HuggingFace credentials for access.
6"""
7
8import gc
9from typing import List, Optional, Tuple
10
11import torch
12from huggingface_hub import snapshot_download
13
14from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
15from vllm.lora.request import LoRARequest
16
17
18def create_test_prompts(
19 lora_path: str
20) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
21 return [
22 # this is an example of using quantization without LoRA
23 ("My name is",
24 SamplingParams(temperature=0.0,
25 logprobs=1,
26 prompt_logprobs=1,
27 max_tokens=128), None),
28 # the next three examples use quantization with LoRA
29 ("my name is",
30 SamplingParams(temperature=0.0,
31 logprobs=1,
32 prompt_logprobs=1,
33 max_tokens=128),
34 LoRARequest("lora-test-1", 1, lora_path)),
35 ("The capital of USA is",
36 SamplingParams(temperature=0.0,
37 logprobs=1,
38 prompt_logprobs=1,
39 max_tokens=128),
40 LoRARequest("lora-test-2", 1, lora_path)),
41 ("The capital of France is",
42 SamplingParams(temperature=0.0,
43 logprobs=1,
44 prompt_logprobs=1,
45 max_tokens=128),
46 LoRARequest("lora-test-3", 1, lora_path)),
47 ]
48
49
50def process_requests(engine: LLMEngine,
51 test_prompts: List[Tuple[str, SamplingParams,
52 Optional[LoRARequest]]]):
53 """Continuously process a list of prompts and handle the outputs."""
54 request_id = 0
55
56 while test_prompts or engine.has_unfinished_requests():
57 if test_prompts:
58 prompt, sampling_params, lora_request = test_prompts.pop(0)
59 engine.add_request(str(request_id),
60 prompt,
61 sampling_params,
62 lora_request=lora_request)
63 request_id += 1
64
65 request_outputs: List[RequestOutput] = engine.step()
66 for request_output in request_outputs:
67 if request_output.finished:
68 print("----------------------------------------------------")
69 print(f"Prompt: {request_output.prompt}")
70 print(f"Output: {request_output.outputs[0].text}")
71
72
73def initialize_engine(model: str, quantization: str,
74 lora_repo: Optional[str]) -> LLMEngine:
75 """Initialize the LLMEngine."""
76
77 if quantization == "bitsandbytes":
78 # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
79 # It quantizes the model when loading, with some config info from the
80 # LoRA adapter repo. So need to set the parameter of load_format and
81 # qlora_adapter_name_or_path as below.
82 engine_args = EngineArgs(
83 model=model,
84 quantization=quantization,
85 qlora_adapter_name_or_path=lora_repo,
86 load_format="bitsandbytes",
87 enable_lora=True,
88 max_lora_rank=64,
89 # set it only in GPUs of limited memory
90 enforce_eager=True)
91 else:
92 engine_args = EngineArgs(
93 model=model,
94 quantization=quantization,
95 enable_lora=True,
96 max_loras=4,
97 # set it only in GPUs of limited memory
98 enforce_eager=True)
99 return LLMEngine.from_engine_args(engine_args)
100
101
102def main():
103 """Main function that sets up and runs the prompt processing."""
104
105 test_configs = [{
106 "name": "qlora_inference_example",
107 'model': "huggyllama/llama-7b",
108 'quantization': "bitsandbytes",
109 'lora_repo': 'timdettmers/qlora-flan-7b'
110 }, {
111 "name": "AWQ_inference_with_lora_example",
112 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
113 'quantization': "awq",
114 'lora_repo': 'jashing/tinyllama-colorist-lora'
115 }, {
116 "name": "GPTQ_inference_with_lora_example",
117 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
118 'quantization': "gptq",
119 'lora_repo': 'jashing/tinyllama-colorist-lora'
120 }]
121
122 for test_config in test_configs:
123 print(
124 f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
125 )
126 engine = initialize_engine(test_config['model'],
127 test_config['quantization'],
128 test_config['lora_repo'])
129 lora_path = snapshot_download(repo_id=test_config['lora_repo'])
130 test_prompts = create_test_prompts(lora_path)
131 process_requests(engine, test_prompts)
132
133 # Clean up the GPU memory for the next test
134 del engine
135 gc.collect()
136 torch.cuda.empty_cache()
137
138
139if __name__ == '__main__':
140 main()