Tensorize vLLM Model#

Source vllm-project/vllm.

  1import argparse
  2import dataclasses
  3import os
  4import time
  5import uuid
  6from functools import partial
  7from typing import Type
  8
  9import torch
 10import torch.nn as nn
 11from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
 12                        TensorSerializer, stream_io)
 13from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
 14from transformers import AutoConfig, PretrainedConfig
 15
 16from vllm.distributed import initialize_model_parallel
 17from vllm.engine.arg_utils import EngineArgs
 18from vllm.engine.llm_engine import LLMEngine
 19from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
 20from vllm.model_executor.models import ModelRegistry
 21
 22# yapf conflicts with isort for this docstring
 23# yapf: disable
 24"""
 25tensorize_vllm_model.py is a script that can be used to serialize and 
 26deserialize vLLM models. These models can be loaded using tensorizer 
 27to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
 28or locally. Tensor encryption and decryption is also supported, although 
 29libsodium must be installed to use it. Install vllm with tensorizer support 
 30using `pip install vllm[tensorizer]`.
 31
 32To serialize a model, install vLLM from source, then run something 
 33like this from the root level of this repository:
 34
 35python -m examples.tensorize_vllm_model \
 36   --model EleutherAI/gpt-j-6B \
 37   --dtype float16 \
 38   serialize \
 39   --serialized-directory s3://my-bucket/ \
 40   --suffix vllm
 41   
 42Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
 43and saves it to your S3 bucket. A local directory can also be used. This
 44assumes your S3 credentials are specified as environment variables
 45in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
 46To provide S3 credentials directly, you can provide `--s3-access-key-id` and 
 47`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this 
 48script.
 49
 50You can also encrypt the model weights with a randomly-generated key by 
 51providing a `--keyfile` argument.
 52
 53To deserialize a model, you can run something like this from the root 
 54level of this repository:
 55
 56python -m examples.tensorize_vllm_model \
 57   --model EleutherAI/gpt-j-6B \
 58   --dtype float16 \
 59   deserialize \
 60   --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
 61
 62Which downloads the model tensors from your S3 bucket and deserializes them.
 63
 64You can also provide a `--keyfile` argument to decrypt the model weights if 
 65they were serialized with encryption.
 66
 67For more information on the available arguments for serializing, run 
 68`python -m examples.tensorize_vllm_model serialize --help`.
 69
 70Or for deserializing:
 71
 72`python -m examples.tensorize_vllm_model deserialize --help`.
 73
 74Once a model is serialized, it can be used to load the model when running the
 75OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
 76the `--tensorizer-uri` CLI argument that is functionally the same as the
 77`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
 78signify that the model to be deserialized is a vLLM model, rather than a 
 79HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
 80in the same inference server, albeit without the speed optimizations. To
 81deserialize an encrypted file, the `--encryption-keyfile` argument can be used
 82to provide the path to the keyfile used to encrypt the model weights. For
 83information on all the arguments that can be used to configure tensorizer's
 84deserialization, check out the tensorizer options argument group in the
 85`vllm/entrypoints/openai/api_server.py` script with `--help`.
 86
 87Tensorizer can also be invoked with the `LLM` class directly to load models:
 88
 89    llm = LLM(model="facebook/opt-125m",
 90              load_format="tensorizer",
 91              tensorizer_uri=path_to_opt_tensors,
 92              num_readers=3,
 93              vllm_tensorized=True)
 94"""
 95
 96
 97def parse_args():
 98    parser = argparse.ArgumentParser(
 99        description="An example script that can be used to serialize and "
100        "deserialize vLLM models. These models "
101        "can be loaded using tensorizer directly to the GPU "
102        "extremely quickly. Tensor encryption and decryption is "
103        "also supported, although libsodium must be installed to "
104        "use it.")
105    parser = EngineArgs.add_cli_args(parser)
106    subparsers = parser.add_subparsers(dest='command')
107
108    serialize_parser = subparsers.add_parser(
109        'serialize', help="Serialize a model to `--serialized-directory`")
110
111    serialize_parser.add_argument(
112        "--suffix",
113        type=str,
114        required=False,
115        help=(
116            "The suffix to append to the serialized model directory, which is "
117            "used to construct the location of the serialized model tensors, "
118            "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
119            "`--suffix` is `v1`, the serialized model tensors will be "
120            "saved to "
121            "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
122            "If none is provided, a random UUID will be used."))
123    serialize_parser.add_argument(
124        "--serialized-directory",
125        type=str,
126        required=True,
127        help="The directory to serialize the model to. "
128        "This can be a local directory or S3 URI. The path to where the "
129        "tensors are saved is a combination of the supplied `dir` and model "
130        "reference ID. For instance, if `dir` is the serialized directory, "
131        "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
132        "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
133        "where `suffix` is given by `--suffix` or a random UUID if not "
134        "provided.")
135
136    serialize_parser.add_argument(
137        "--keyfile",
138        type=str,
139        required=False,
140        help=("Encrypt the model weights with a randomly-generated binary key,"
141              " and save the key at this path"))
142
143    deserialize_parser = subparsers.add_parser(
144        'deserialize',
145        help=("Deserialize a model from `--path-to-tensors`"
146              " to verify it can be loaded and used."))
147
148    deserialize_parser.add_argument(
149        "--path-to-tensors",
150        type=str,
151        required=True,
152        help="The local path or S3 URI to the model tensors to deserialize. ")
153
154    deserialize_parser.add_argument(
155        "--keyfile",
156        type=str,
157        required=False,
158        help=("Path to a binary key to use to decrypt the model weights,"
159              " if the model was serialized with encryption"))
160
161    return parser.parse_args()
162
163
164def make_model_contiguous(model):
165    # Ensure tensors are saved in memory contiguously
166    for param in model.parameters():
167        param.data = param.data.contiguous()
168
169
170def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
171    architectures = getattr(config, "architectures", [])
172    for arch in architectures:
173        model_cls = ModelRegistry.load_model_cls(arch)
174        if model_cls is not None:
175            return model_cls
176    raise ValueError(
177        f"Model architectures {architectures} are not supported for now. "
178        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
179
180
181def serialize():
182
183    eng_args_dict = {f.name: getattr(args, f.name) for f in
184                     dataclasses.fields(EngineArgs)}
185    engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
186    engine = LLMEngine.from_engine_args(engine_args)
187
188    model = (engine.model_executor.driver_worker.
189             model_runner.model)
190
191    encryption_params = EncryptionParams.random() if keyfile else None
192    if keyfile:
193        with _write_stream(keyfile) as stream:
194            stream.write(encryption_params.key)
195
196    with _write_stream(model_path) as stream:
197        serializer = TensorSerializer(stream, encryption=encryption_params)
198        serializer.write_module(model)
199        serializer.close()
200
201    print("Serialization complete. Model tensors saved to", model_path)
202    if keyfile:
203        print("Key saved to", keyfile)
204
205
206def deserialize():
207    config = AutoConfig.from_pretrained(model_ref)
208
209    with no_init_or_tensor():
210        model_class = _get_vllm_model_architecture(config)
211        model = model_class(config)
212
213    before_mem = get_mem_usage()
214    start = time.time()
215
216    if keyfile:
217        with _read_stream(keyfile) as stream:
218            key = stream.read()
219            decryption_params = DecryptionParams.from_key(key)
220            tensorizer_args.deserializer_params['encryption'] = \
221                decryption_params
222
223    with (_read_stream(model_path)) as stream, TensorDeserializer(
224            stream, **tensorizer_args.deserializer_params) as deserializer:
225        deserializer.load_into_module(model)
226        end = time.time()
227
228    # Brag about how fast we are.
229    total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
230    duration = end - start
231    per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
232    after_mem = get_mem_usage()
233    print(
234        f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
235    )
236    print(f"Memory usage before: {before_mem}")
237    print(f"Memory usage after: {after_mem}")
238
239    return model
240
241
242args = parse_args()
243
244s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
245                    or None)
246s3_secret_access_key = (args.s3_secret_access_key
247                        or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
248
249s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
250
251_read_stream, _write_stream = (partial(
252    stream_io.open_stream,
253    mode=mode,
254    s3_access_key_id=s3_access_key_id,
255    s3_secret_access_key=s3_secret_access_key,
256    s3_endpoint=s3_endpoint,
257) for mode in ("rb", "wb+"))
258
259model_ref = args.model
260
261model_name = model_ref.split("/")[1]
262
263os.environ["MASTER_ADDR"] = "127.0.0.1"
264os.environ["MASTER_PORT"] = "8080"
265
266torch.distributed.init_process_group(world_size=1, rank=0)
267initialize_model_parallel()
268
269keyfile = args.keyfile if args.keyfile else None
270
271if args.command == "serialize":
272    input_dir = args.serialized_directory.rstrip('/')
273    suffix = args.suffix if args.suffix else uuid.uuid4().hex
274    base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
275    model_path = f"{base_path}/model.tensors"
276    serialize()
277elif args.command == "deserialize":
278    tensorizer_args = TensorizerArgs.from_cli_args(args)
279    model_path = args.path_to_tensors
280    deserialize()
281else:
282    raise ValueError("Either serialize or deserialize must be specified.")