Speculative Decoding#

Warning

Please note that speculative decoding in vLLM is not yet optimized and does not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work to optimize it is ongoing and can be followed here: Issue #4630

Warning

Currently, speculative decoding in vLLM is not compatible with pipeline parallelism.

This document shows how to use Speculative Decoding with vLLM. Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.

Speculating with a draft model#

The following code configures vLLM in an offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time.

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    speculative_model="facebook/opt-125m",
    num_speculative_tokens=5,
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

To perform the same with an online mode launch the server:

python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
    --seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \
    --num_speculative_tokens 5 --gpu_memory_utilization 0.8

Then use a client:

from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    # defaults to os.environ.get("OPENAI_API_KEY")
    api_key=openai_api_key,
    base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

# Completion API
stream = False
completion = client.completions.create(
    model=model,
    prompt="The future of AI is",
    echo=False,
    n=1,
    stream=stream,
)

print("Completion results:")
if stream:
    for c in completion:
        print(c)
else:
    print(completion)

Speculating by matching n-grams in the prompt#

The following code configures vLLM to use speculative decoding where proposals are generated by matching n-grams in the prompt. For more information read this thread.

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="facebook/opt-6.7b",
    tensor_parallel_size=1,
    speculative_model="[ngram]",
    num_speculative_tokens=5,
    ngram_prompt_lookup_max=4,
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Speculating using MLP speculators#

The following code configures vLLM to use speculative decoding where proposals are generated by draft models that conditioning draft predictions on both context vectors and sampled tokens. For more information see this blog or this technical report.

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="meta-llama/Meta-Llama-3.1-70B-Instruct",
    tensor_parallel_size=4,
    speculative_model="ibm-fms/llama3-70b-accelerator",
    speculative_draft_tensor_parallel_size=1,
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Note that these speculative models currently need to be run without tensor parallelism, although it is possible to run the main model using tensor parallelism (see example above). Since the speculative models are relatively small, we still see significant speedups. However, this limitation will be fixed in a future release.

A variety of speculative models of this type are available on HF hub:

Speculating using EAGLE based draft models#

The following code configures vLLM to use speculative decoding where proposals are generated by an EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) based draft model.

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    tensor_parallel_size=4,
    speculative_model="path/to/modified/eagle/model",
    speculative_draft_tensor_parallel_size=1,
)

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

A few important things to consider when using the EAGLE based draft models:

  1. The EAGLE draft models available in the HF repository for EAGLE models cannot be used directly with vLLM due to differences in the expected layer names and model definition. To use these models with vLLM, use the following script to convert them. Note that this script does not modify the model’s weights.

    In the above example, use the script to first convert the yuhuili/EAGLE-LLaMA3-Instruct-8B model and then use the converted checkpoint as the draft model in vLLM.

  2. The EAGLE based draft models need to be run without tensor parallelism (i.e. speculative_draft_tensor_parallel_size is set to 1), although it is possible to run the main model using tensor parallelism (see example above).

  3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is reported in the reference implementation here. This issue is under investigation and tracked here: vllm-project/vllm#9565.

A variety of EAGLE draft models are available on the Hugging Face hub:

Base Model

EAGLE on Hugging Face

# EAGLE Parameters

Vicuna-7B-v1.3

yuhuili/EAGLE-Vicuna-7B-v1.3

0.24B

Vicuna-13B-v1.3

yuhuili/EAGLE-Vicuna-13B-v1.3

0.37B

Vicuna-33B-v1.3

yuhuili/EAGLE-Vicuna-33B-v1.3

0.56B

LLaMA2-Chat 7B

yuhuili/EAGLE-llama2-chat-7B

0.24B

LLaMA2-Chat 13B

yuhuili/EAGLE-llama2-chat-13B

0.37B

LLaMA2-Chat 70B

yuhuili/EAGLE-llama2-chat-70B

0.99B

Mixtral-8x7B-Instruct-v0.1

yuhuili/EAGLE-mixtral-instruct-8x7B

0.28B

LLaMA3-Instruct 8B

yuhuili/EAGLE-LLaMA3-Instruct-8B

0.25B

LLaMA3-Instruct 70B

yuhuili/EAGLE-LLaMA3-Instruct-70B

0.99B

Qwen2-7B-Instruct

yuhuili/EAGLE-Qwen2-7B-Instruct

0.26B

Qwen2-72B-Instruct

yuhuili/EAGLE-Qwen2-72B-Instruct

1.05B

Lossless guarantees of Speculative Decoding#

In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of speculative decoding, breaking down the guarantees into three key areas:

  1. Theoretical Losslessness - Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might cause slight variations in output distributions, as discussed in Accelerating Large Language Model Decoding with Speculative Sampling

  2. Algorithmic Losslessness - vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:

    • Rejection Sampler Convergence: Ensures that samples from vLLM’s rejection sampler align with the target distribution. View Test Code

    • Greedy Sampling Equality: Confirms that greedy sampling with speculative decoding matches greedy sampling without it. This verifies that vLLM’s speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, provides a lossless guarantee. Almost all of the tests in tests/spec_decode/e2e. verify this property using this assertion implementation

  3. vLLM Logprob Stability - vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the same request across runs. For more details, see the FAQ section titled Can the output of a prompt vary across runs in vLLM? in the FAQs.

While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding can occur due to following factors:

  • Floating-Point Precision: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution.

  • Batch Size and Numerical Stability: Changes in batch size may cause variations in logprobs and output probabilities, potentially due to non-deterministic behavior in batched operations or numerical instability.

For mitigation strategies, please refer to the FAQ entry Can the output of a prompt vary across runs in vLLM? in the FAQs.

Resources for vLLM contributors#