Offline Profile#

Source vllm-project/vllm.

  1import inspect
  2import json
  3import os
  4import sys
  5from argparse import RawTextHelpFormatter
  6from dataclasses import asdict, dataclass
  7from typing import Optional
  8
  9import torch
 10
 11from vllm import LLM, SamplingParams
 12from vllm.engine.arg_utils import EngineArgs
 13from vllm.profiler import layerwise_profile
 14from vllm.utils import FlexibleArgumentParser
 15
 16BATCH_SIZE_DEFAULT = 1
 17PROMPT_LEN_DEFAULT = 256
 18OUTPUT_LEN_DEFAULT = 2
 19
 20
 21@dataclass
 22class ProfileContext:
 23    engine_args: EngineArgs
 24    prompt_len: int
 25    output_len: int
 26    batch_size: int
 27    save_chrome_traces_folder: Optional[str]
 28
 29
 30def get_dtype(dtype: str):
 31    if dtype == "torch.float":
 32        return torch.float
 33    else:
 34        return dtype
 35
 36
 37def run_profile(context: ProfileContext, csv_output: Optional[str],
 38                json_output: Optional[str]):
 39    print("Run profile with:")
 40    for key, value in asdict(context).items():
 41        print(f"  {key} = {value}")
 42
 43    # Create sampling params
 44    sampling_params = SamplingParams(temperature=0.8,
 45                                     top_p=0.95,
 46                                     max_tokens=args.output_len,
 47                                     ignore_eos=True)
 48
 49    # Create LLM
 50    llm = LLM(**asdict(context.engine_args))
 51    batch_size = context.batch_size
 52    prompt_len = context.prompt_len
 53    output_len = context.output_len
 54
 55    scheduler_config = llm.llm_engine.scheduler_config
 56    max_model_len = llm.llm_engine.model_config.max_model_len
 57    max_num_batched_tokens = scheduler_config.max_num_batched_tokens
 58    max_num_seqs = scheduler_config.max_num_seqs
 59
 60    if batch_size * prompt_len > max_num_batched_tokens:
 61        print(f"ERROR: chosen batch_size * prompt_len "
 62              f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is  "
 63              f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
 64              f"and therefore cannot be run in a single profile step, please "
 65              f"choose a smaller batch size or prompt length, or increase "
 66              f"--max-num-batched-tokens")
 67        sys.exit(-1)
 68    if batch_size >= max_num_seqs:
 69        print(
 70            f"ERROR: chosen batch_size ({batch_size}) is larger than "
 71            f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
 72            f"single profile step, please choose a smaller batch size")
 73        sys.exit(-1)
 74    print("llm.llm_engine.model_config.max_model_len: ",
 75          llm.llm_engine.model_config.max_model_len)
 76    if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
 77        print(
 78            f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
 79            f"{output_len} = {prompt_len + output_len}) is larger than the "
 80            f"model's max_model_len ({max_model_len}), please choose a smaller "
 81            f"prompt_len or output_len, or increase --max-model-len")
 82        sys.exit(-1)
 83
 84    def add_requests():
 85        for i in range(batch_size):
 86            prompt_token_ids = torch.randint(
 87                llm.llm_engine.model_config.get_vocab_size(),
 88                size=(prompt_len, )).tolist()
 89
 90            llm.llm_engine.add_request(
 91                request_id=f"seq{i}",
 92                prompt={'prompt_token_ids': prompt_token_ids},
 93                params=sampling_params)
 94
 95    def abort_requests():
 96        for i in range(batch_size):
 97            llm.llm_engine.abort_request(f"seq{i}")
 98
 99    # Warm up run
100    print("Warm up run ...")
101    add_requests()
102    llm.llm_engine.step()  # Prefill
103    llm.llm_engine.step()  # Decode
104    abort_requests()
105
106    print("Profile run ...")
107    add_requests()
108
109    with layerwise_profile() as prefill_prof:
110        llm.llm_engine.step()  # First step is prefill
111
112    decode_profs = []
113    for x in range(args.output_len - 1):
114        with layerwise_profile() as decode_prof:
115            llm.llm_engine.step()
116        decode_profs.append(decode_prof)
117
118    decode_results_list = [prof.results for prof in decode_profs]
119    prefill_results = prefill_prof.results
120    has_decode = len(decode_results_list) > 0
121
122    LINE_WIDTH = 80
123    print("=" * LINE_WIDTH)
124    print(f"= Prefill Model Table "
125          f"(prompt_len={prompt_len}, batch_size={batch_size})")
126    print("=" * LINE_WIDTH)
127    print()
128    prefill_results.print_model_table()
129
130    if has_decode:
131        print()
132        print("=" * LINE_WIDTH)
133        print(f"= First Decode Step Model Table "
134              f"(prompt_len={prompt_len}, batch_size={batch_size})")
135        print("=" * LINE_WIDTH)
136        print()
137        decode_results_list[0].print_model_table()
138
139    print()
140    print("=" * LINE_WIDTH)
141    print(f"= Prefill Summary Table "
142          f"(prompt_len={prompt_len}, batch_size={batch_size})")
143    print("=" * LINE_WIDTH)
144    print()
145    prefill_results.print_summary_table()
146
147    if has_decode:
148        print()
149        print("=" * LINE_WIDTH)
150        print(f"= First Decode Step Summary Table "
151              f"(prompt_len={prompt_len}, batch_size={batch_size})")
152        print("=" * LINE_WIDTH)
153        print()
154        decode_results_list[0].print_summary_table()
155
156    if csv_output:
157        csv_filename_base = csv_output.rstrip(".csv")
158        prefill_results.export_model_stats_table_csv(
159            csv_filename_base + "_prefill_model_table.csv")
160        prefill_results.export_summary_stats_table_csv(
161            csv_filename_base + "_prefill_summary_table.csv")
162
163        if has_decode:
164            decode_results_list[0].export_model_stats_table_csv(\
165                csv_filename_base + "_decode_model_table.csv")
166            decode_results_list[0].export_summary_stats_table_csv(
167                csv_filename_base + "_decode_summary_table.csv")
168
169    if json_output:
170        cuda_devices = [
171            torch.cuda.get_device_properties(dev_idx)
172            for dev_idx in range(torch.cuda.device_count())
173        ]
174
175        json_dict = {
176            "context": {
177                "python_version": f"{sys.version}",
178                "torch_version": f"{torch.__version__}",
179                "torch_cuda_version": f"{torch.version.cuda}",
180                "cuda_devices": f"{cuda_devices}",
181                **asdict(context)
182            },
183            "prefill": prefill_results.convert_stats_to_dict(),
184        }
185
186        if has_decode:
187            for idx, dr in enumerate(decode_results_list):
188                json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
189
190        for idx, dr in enumerate(decode_results_list[1:]):
191            json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
192
193        with open(json_output.rstrip(".json") + ".json", "w+") as f:
194            json.dump(json_dict, f, indent=2)
195        pass
196
197    if context.save_chrome_traces_folder is not None:
198        os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
199        prefill_prof.profiler.export_chrome_trace(
200            context.save_chrome_traces_folder + "/prefill.json")
201        for idx, decode_prof in enumerate(decode_profs):
202            decode_prof.profiler.export_chrome_trace(
203                context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
204        print("Traces saved as prefill.json and decode_1.json, etc."
205              f" in folder {context.save_chrome_traces_folder}")
206
207
208if __name__ == "__main__":
209    parser = FlexibleArgumentParser(description="""
210Profile a model
211
212    example:
213    ```
214    python examples/offline_profile.py \\
215        --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
216        --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
217        --enforce-eager
218    ```
219
220    then you can use various tools to analyze the json output
221    terminal ascii tables:
222        ```
223        python tools/profiler/print_layerwise_table.py \\
224            --json-trace Llama31-8b-FP8.json --phase prefill --table summary
225        ```
226    or create matplotlib stacked bar charts:
227        ```
228        python tools/profiler/visualize_layerwise_profile.py \\
229            --json-trace Llama31-8b-FP8.json \\
230            --output-directory profile_breakdown --plot-metric pct_cuda_time
231        ```
232""",
233                                    formatter_class=RawTextHelpFormatter)
234    parser.add_argument(
235        "--csv",
236        type=str,
237        default=None,
238        help="Export the results as multiple csv file. This should be the root "
239        "filename, will create <filename>_prefill_model_table.csv, "
240        "<filename>_prefill_summary_table.csv, "
241        "<filename>_decode_model_table.csv, and "
242        "<filename>_decode_summary_table.csv")
243    parser.add_argument(
244        "--json",
245        type=str,
246        default=None,
247        help="Export the results as a json file. This should be the filename")
248    parser.add_argument("--save-chrome-traces-folder",
249                        type=str,
250                        help="Save chrome traces for the prefill and decode "
251                        "will save traces as prefill.json and decode_1.json, "
252                        "etc. inside this folder")
253    parser.add_argument(
254        "--prompt-len",
255        type=int,
256        default=PROMPT_LEN_DEFAULT,
257        help=f"Length of the random prompt to use when profiling, all batched "
258        f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
259    parser.add_argument("--batch-size",
260                        type=int,
261                        default=BATCH_SIZE_DEFAULT,
262                        help=f"Number of requests to run as a single batch, "
263                        f"default={BATCH_SIZE_DEFAULT}")
264    parser.add_argument(
265        "--output-len",
266        type=int,
267        default=OUTPUT_LEN_DEFAULT,
268        help="Number of llm steps to run (includes prefill and decode) "
269        "- default={OUTPUT_LEN_DEFAULT}")
270
271    EngineArgs.add_cli_args(parser)
272
273    args = parser.parse_args()
274
275    context = ProfileContext(
276        engine_args=EngineArgs.from_cli_args(args),
277        **{
278            k: v
279            for k, v in vars(args).items()
280            if k in inspect.signature(ProfileContext).parameters
281        })
282    run_profile(context, csv_output=args.csv, json_output=args.json)