MLPSpeculator

Source examples/offline_inference/mlpspeculator.py.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates the usage of text generation with an LLM model,
comparing the performance with and without speculative decoding.

Note that still not support `v1`:
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
"""

import gc
import time

from vllm import LLM, SamplingParams


def time_generation(
    llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
):
    # Generate texts from the prompts. The output is a list of RequestOutput
    # objects that contain the prompt, generated text, and other information.
    # Warmup first
    llm.generate(prompts, sampling_params)
    llm.generate(prompts, sampling_params)
    start = time.time()
    outputs = llm.generate(prompts, sampling_params)
    end = time.time()
    print("-" * 50)
    print(title)
    print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
    # Print the outputs.
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"text: {generated_text!r}")
        print("-" * 50)


def main():
    template = (
        "Below is an instruction that describes a task. Write a response "
        "that appropriately completes the request.\n\n### Instruction:\n{}"
        "\n\n### Response:\n"
    )

    # Sample prompts.
    prompts = [
        "Write about the president of the United States.",
    ]
    prompts = [template.format(prompt) for prompt in prompts]
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=200)

    # Create an LLM without spec decoding
    llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")

    time_generation(llm, prompts, sampling_params, "Without speculation")

    del llm
    gc.collect()

    # Create an LLM with spec decoding
    llm = LLM(
        model="meta-llama/Llama-2-13b-chat-hf",
        speculative_config={
            "model": "ibm-ai-platform/llama-13b-accelerator",
        },
    )

    time_generation(llm, prompts, sampling_params, "With speculation")


if __name__ == "__main__":
    main()