Lora With Quantization Inference#

Source vllm-project/vllm.

  1"""
  2This example shows how to use LoRA with different quantization techniques
  3for offline inference.
  4
  5Requires HuggingFace credentials for access.
  6"""
  7
  8import gc
  9from typing import List, Optional, Tuple
 10
 11import torch
 12from huggingface_hub import snapshot_download
 13
 14from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
 15from vllm.lora.request import LoRARequest
 16
 17
 18def create_test_prompts(
 19        lora_path: str
 20) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
 21    return [
 22        # this is an example of using quantization without LoRA
 23        ("My name is",
 24         SamplingParams(temperature=0.0,
 25                        logprobs=1,
 26                        prompt_logprobs=1,
 27                        max_tokens=128), None),
 28        # the next three examples use quantization with LoRA
 29        ("my name is",
 30         SamplingParams(temperature=0.0,
 31                        logprobs=1,
 32                        prompt_logprobs=1,
 33                        max_tokens=128),
 34         LoRARequest("lora-test-1", 1, lora_path)),
 35        ("The capital of USA is",
 36         SamplingParams(temperature=0.0,
 37                        logprobs=1,
 38                        prompt_logprobs=1,
 39                        max_tokens=128),
 40         LoRARequest("lora-test-2", 1, lora_path)),
 41        ("The capital of France is",
 42         SamplingParams(temperature=0.0,
 43                        logprobs=1,
 44                        prompt_logprobs=1,
 45                        max_tokens=128),
 46         LoRARequest("lora-test-3", 1, lora_path)),
 47    ]
 48
 49
 50def process_requests(engine: LLMEngine,
 51                     test_prompts: List[Tuple[str, SamplingParams,
 52                                              Optional[LoRARequest]]]):
 53    """Continuously process a list of prompts and handle the outputs."""
 54    request_id = 0
 55
 56    while test_prompts or engine.has_unfinished_requests():
 57        if test_prompts:
 58            prompt, sampling_params, lora_request = test_prompts.pop(0)
 59            engine.add_request(str(request_id),
 60                               prompt,
 61                               sampling_params,
 62                               lora_request=lora_request)
 63            request_id += 1
 64
 65        request_outputs: List[RequestOutput] = engine.step()
 66        for request_output in request_outputs:
 67            if request_output.finished:
 68                print("----------------------------------------------------")
 69                print(f"Prompt: {request_output.prompt}")
 70                print(f"Output: {request_output.outputs[0].text}")
 71
 72
 73def initialize_engine(model: str, quantization: str,
 74                      lora_repo: Optional[str]) -> LLMEngine:
 75    """Initialize the LLMEngine."""
 76
 77    if quantization == "bitsandbytes":
 78        # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
 79        # It quantizes the model when loading, with some config info from the
 80        # LoRA adapter repo. So need to set the parameter of load_format and
 81        # qlora_adapter_name_or_path as below.
 82        engine_args = EngineArgs(
 83            model=model,
 84            quantization=quantization,
 85            qlora_adapter_name_or_path=lora_repo,
 86            load_format="bitsandbytes",
 87            enable_lora=True,
 88            max_lora_rank=64,
 89            # set it only in GPUs of limited memory
 90            enforce_eager=True)
 91    else:
 92        engine_args = EngineArgs(
 93            model=model,
 94            quantization=quantization,
 95            enable_lora=True,
 96            max_loras=4,
 97            # set it only in GPUs of limited memory
 98            enforce_eager=True)
 99    return LLMEngine.from_engine_args(engine_args)
100
101
102def main():
103    """Main function that sets up and runs the prompt processing."""
104
105    test_configs = [{
106        "name": "qlora_inference_example",
107        'model': "huggyllama/llama-7b",
108        'quantization': "bitsandbytes",
109        'lora_repo': 'timdettmers/qlora-flan-7b'
110    }, {
111        "name": "AWQ_inference_with_lora_example",
112        'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
113        'quantization': "awq",
114        'lora_repo': 'jashing/tinyllama-colorist-lora'
115    }, {
116        "name": "GPTQ_inference_with_lora_example",
117        'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
118        'quantization': "gptq",
119        'lora_repo': 'jashing/tinyllama-colorist-lora'
120    }]
121
122    for test_config in test_configs:
123        print(
124            f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
125        )
126        engine = initialize_engine(test_config['model'],
127                                   test_config['quantization'],
128                                   test_config['lora_repo'])
129        lora_path = snapshot_download(repo_id=test_config['lora_repo'])
130        test_prompts = create_test_prompts(lora_path)
131        process_requests(engine, test_prompts)
132
133        # Clean up the GPU memory for the next test
134        del engine
135        gc.collect()
136        torch.cuda.empty_cache()
137
138
139if __name__ == '__main__':
140    main()