Adding a New Model#

This document provides a high-level guide on integrating a HuggingFace Transformers model into vLLM.

Note

The complexity of adding a new model depends heavily on the model’s architecture. The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.

Tip

If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our GitHub repository. We will be happy to help you out!

0. Fork the vLLM repository#

Start by forking our GitHub repository and then build it from source. This gives you the ability to modify the codebase and test your model.

Tip

If you don’t want to fork the repository and modify vLLM’s codebase, please refer to the “Out-of-Tree Model Integration” section below.

1. Bring your model code#

Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the vllm/model_executor/models directory. For instance, vLLM’s OPT model was adapted from the HuggingFace’s modeling_opt.py file.

Warning

When copying the model code, make sure to review and adhere to the code’s copyright and licensing terms.

2. Rewrite the forward methods#

Next, you need to rewrite the forward methods of your model by following these steps:

  1. Remove any unnecessary code, such as the code only used for training.

  2. Change the input parameters:

def forward(
    self,
    input_ids: torch.Tensor,
-    attention_mask: Optional[torch.Tensor] = None,
-    position_ids: Optional[torch.LongTensor] = None,
-    past_key_values: Optional[List[torch.FloatTensor]] = None,
-    inputs_embeds: Optional[torch.FloatTensor] = None,
-    labels: Optional[torch.LongTensor] = None,
-    use_cache: Optional[bool] = None,
-    output_attentions: Optional[bool] = None,
-    output_hidden_states: Optional[bool] = None,
-    return_dict: Optional[bool] = None,
-) -> Union[Tuple, CausalLMOutputWithPast]:
+    positions: torch.Tensor,
+    kv_caches: List[torch.Tensor],
+    attn_metadata: AttentionMetadata,
+) -> Optional[SamplerOutput]:
  1. Update the code by considering that input_ids and positions are now flattened tensors.

  2. Replace the attention operation with either PagedAttention, PagedAttentionWithRoPE, or PagedAttentionWithALiBi depending on the model’s architecture.

Note

Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.

3. (Optional) Implement tensor parallelism and quantization support#

If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it. To do this, substitute your model’s linear and embedding layers with their tensor-parallel versions. For the embedding layer, you can simply replace nn.Embedding with VocabParallelEmbedding. For the output LM head, you can use ParallelLMHead. When it comes to the linear layers, we provide the following options to parallelize them:

  • ReplicatedLinear: Replicates the inputs and weights across multiple GPUs. No memory saving.

  • RowParallelLinear: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An all-reduce operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.

  • ColumnParallelLinear: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.

  • MergedColumnParallelLinear: Column-parallel linear that merges multiple ColumnParallelLinear operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.

  • QKVParallelLinear: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.

Note that all the linear layers above take linear_method as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.

4. Implement the weight loading logic#

You now need to implement the load_weights method in your *ForCausalLM class. This method should load the weights from the HuggingFace’s checkpoint file and assign them to the corresponding layers in your model. Specifically, for MergedColumnParallelLinear and QKVParallelLinear layers, if the original model has separated weight matrices, you need to load the different parts separately.

5. Register your model#

Finally, register your *ForCausalLM class to the _MODELS in vllm/model_executor/models/__init__.py.

6. Out-of-Tree Model Integration#

We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5.

Just add the following lines in your code:

from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)

If you are running api server with python -m vllm.entrypoints.openai.api_server args, you can wrap the entrypoint with the following code:

from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')

Save the above code in a file and run it with python your_file.py args.