MultiLoRA Inference

MultiLoRA Inference#

Source vllm-project/vllm.

  1"""
  2This example shows how to use the multi-LoRA functionality
  3for offline inference.
  4
  5Requires HuggingFace credentials for access to Llama2.
  6"""
  7
  8from typing import List, Optional, Tuple
  9
 10from huggingface_hub import snapshot_download
 11
 12from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
 13from vllm.lora.request import LoRARequest
 14
 15
 16def create_test_prompts(
 17        lora_path: str
 18) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
 19    """Create a list of test prompts with their sampling parameters.
 20
 21    2 requests for base model, 4 requests for the LoRA. We define 2
 22    different LoRA adapters (using the same model for demo purposes).
 23    Since we also set `max_loras=1`, the expectation is that the requests
 24    with the second LoRA adapter will be ran after all requests with the
 25    first adapter have finished.
 26    """
 27    return [
 28        ("A robot may not injure a human being",
 29         SamplingParams(temperature=0.0,
 30                        logprobs=1,
 31                        prompt_logprobs=1,
 32                        max_tokens=128), None),
 33        ("To be or not to be,",
 34         SamplingParams(temperature=0.8,
 35                        top_k=5,
 36                        presence_penalty=0.2,
 37                        max_tokens=128), None),
 38        (
 39            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
 40            SamplingParams(temperature=0.0,
 41                           logprobs=1,
 42                           prompt_logprobs=1,
 43                           max_tokens=128,
 44                           stop_token_ids=[32003]),
 45            LoRARequest("sql-lora", 1, lora_path)),
 46        (
 47            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
 48            SamplingParams(temperature=0.0,
 49                           logprobs=1,
 50                           prompt_logprobs=1,
 51                           max_tokens=128,
 52                           stop_token_ids=[32003]),
 53            LoRARequest("sql-lora2", 2, lora_path)),
 54    ]
 55
 56
 57def process_requests(engine: LLMEngine,
 58                     test_prompts: List[Tuple[str, SamplingParams,
 59                                              Optional[LoRARequest]]]):
 60    """Continuously process a list of prompts and handle the outputs."""
 61    request_id = 0
 62
 63    while test_prompts or engine.has_unfinished_requests():
 64        if test_prompts:
 65            prompt, sampling_params, lora_request = test_prompts.pop(0)
 66            engine.add_request(str(request_id),
 67                               prompt,
 68                               sampling_params,
 69                               lora_request=lora_request)
 70            request_id += 1
 71
 72        request_outputs: List[RequestOutput] = engine.step()
 73
 74        for request_output in request_outputs:
 75            if request_output.finished:
 76                print(request_output)
 77
 78
 79def initialize_engine() -> LLMEngine:
 80    """Initialize the LLMEngine."""
 81    # max_loras: controls the number of LoRAs that can be used in the same
 82    #   batch. Larger numbers will cause higher memory usage, as each LoRA
 83    #   slot requires its own preallocated tensor.
 84    # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
 85    #   numbers will cause higher memory usage. If you know that all LoRAs will
 86    #   use the same rank, it is recommended to set this as low as possible.
 87    # max_cpu_loras: controls the size of the CPU LoRA cache.
 88    engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
 89                             enable_lora=True,
 90                             max_loras=1,
 91                             max_lora_rank=8,
 92                             max_cpu_loras=2,
 93                             max_num_seqs=256)
 94    return LLMEngine.from_engine_args(engine_args)
 95
 96
 97def main():
 98    """Main function that sets up and runs the prompt processing."""
 99    engine = initialize_engine()
100    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
101    test_prompts = create_test_prompts(lora_path)
102    process_requests(engine, test_prompts)
103
104
105if __name__ == '__main__':
106    main()