Offline Inference Mlpspeculator#
Source vllm-project/vllm.
1import gc
2import time
3from typing import List
4
5from vllm import LLM, SamplingParams
6
7
8def time_generation(llm: LLM, prompts: List[str],
9 sampling_params: SamplingParams):
10 # Generate texts from the prompts. The output is a list of RequestOutput
11 # objects that contain the prompt, generated text, and other information.
12 # Warmup first
13 llm.generate(prompts, sampling_params)
14 llm.generate(prompts, sampling_params)
15 start = time.time()
16 outputs = llm.generate(prompts, sampling_params)
17 end = time.time()
18 print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
19 # Print the outputs.
20 for output in outputs:
21 generated_text = output.outputs[0].text
22 print(f"text: {generated_text!r}")
23
24
25if __name__ == "__main__":
26
27 template = (
28 "Below is an instruction that describes a task. Write a response "
29 "that appropriately completes the request.\n\n### Instruction:\n{}"
30 "\n\n### Response:\n")
31
32 # Sample prompts.
33 prompts = [
34 "Write about the president of the United States.",
35 ]
36 prompts = [template.format(prompt) for prompt in prompts]
37 # Create a sampling params object.
38 sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
39
40 # Create an LLM without spec decoding
41 llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
42
43 print("Without speculation")
44 time_generation(llm, prompts, sampling_params)
45
46 del llm
47 gc.collect()
48
49 # Create an LLM with spec decoding
50 llm = LLM(
51 model="meta-llama/Llama-2-13b-chat-hf",
52 speculative_model="ibm-fms/llama-13b-accelerator",
53 # These are currently required for MLPSpeculator decoding
54 use_v2_block_manager=True,
55 )
56
57 print("With speculation")
58 time_generation(llm, prompts, sampling_params)