Offline Profile

Offline Profile#

Source: examples/offline_profile.py.

  1import inspect
  2import json
  3import os
  4import sys
  5from argparse import RawTextHelpFormatter
  6from dataclasses import asdict, dataclass
  7from typing import Any, Dict, Generator, List, Optional, TypeAlias
  8
  9import torch
 10import tqdm
 11
 12from vllm import LLM, SamplingParams
 13from vllm.engine.arg_utils import EngineArgs
 14from vllm.profiler import layerwise_profile
 15from vllm.utils import FlexibleArgumentParser
 16
 17BATCH_SIZE_DEFAULT = 1
 18PROMPT_LEN_DEFAULT = 256
 19
 20
 21@dataclass
 22class ProfileContext:
 23    engine_args: EngineArgs
 24    prompt_len: int
 25    batch_size: int
 26
 27    # The profiler can run in 2 modes,
 28    # 1. Run profiler for user specified num_steps
 29    num_steps: Optional[int] = None
 30    # 2. Run profiler until all requests complete
 31    complete_num_requests_per_step: Optional[int] = None
 32
 33    save_chrome_traces_folder: Optional[str] = None
 34
 35
 36def get_dtype(dtype: str):
 37    if dtype == "torch.float":
 38        return torch.float
 39    else:
 40        return dtype
 41
 42
 43OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
 44def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
 45      -> OutputLen_NumReqs_Map:
 46    """
 47    Given the number of requests, batch_size, and the number of requests
 48    that each engine-step should process, step_requests, determine the
 49    output lengths of the requests such that step_request is honoured.
 50
 51    Example: 
 52    if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
 53    then return,
 54    {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
 55    32 requests should have output length 2,
 56    32 requests should have output length 3,
 57    32 requests should have output length 4,
 58    31 requests should have output length 5,
 59    1 request should have output length 6.
 60
 61    Args:
 62        batch_size (int): Number of requests submitted for profile. This is
 63            args.batch_size.
 64        step_requests (List[int]): step_requests[i] is the number of requests
 65            that the ith engine step should process.
 66
 67    Returns:
 68        OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
 69            number of requests required to have that output-length as values.
 70    """
 71    ol_nr: OutputLen_NumReqs_Map = {}
 72
 73    # Number of request that are assigned an output-length
 74    num_reqs_assigned: int = 0
 75    num_steps: int = len(step_requests)
 76
 77    # sanity check. The first step (prefill-step), must process all requests.
 78    assert step_requests[0] == batch_size
 79
 80    # Begin assignments from the last step.
 81    output_length: int = num_steps
 82    for num_requests_at_step in reversed(step_requests):
 83        if num_reqs_assigned == batch_size:
 84            break
 85
 86        assert num_reqs_assigned < batch_size
 87
 88        # Remove the number of requests that have been determined
 89        # to participate in this step and beyond.
 90        num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
 91        assert num_reqs_unassigned_at_step >= 0
 92
 93        if num_reqs_unassigned_at_step > 0:
 94            ol_nr[output_length] = num_reqs_unassigned_at_step
 95            num_reqs_assigned += num_reqs_unassigned_at_step
 96
 97        output_length -= 1
 98
 99    # sanity checks.
100    assert sum(ol_nr.values()) == batch_size, \
101            ("Number of requests in output-length assignment does not match "
102             f"batch-size.\n batch size {batch_size} - "
103             f"step requests {step_requests} - assignments {ol_nr}")
104
105    # Check that the output-length is in [1, num-steps]. Output length must be
106    # at least 1 as all requests must participate in the prefill-step.
107    assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
108            ("Output lengths of requests should be in range "
109             f"[1, num-engine-steps].\n batch size {batch_size} - "
110             f"step requests {step_requests} - assignments {ol_nr}")
111
112    return ol_nr
113
114
115def determine_requests_per_step(context: ProfileContext) -> List[int]:
116    """
117    Determine number of requests each engine step should process.
118    If context.num_steps is set, then all engine steps process the
119    same number of requests and the output list is of length
120    context.num_steps.
121
122    If context.complete_num_requests_per_step is set, then each decode step
123    processes fewer and fewer requests until there are no requests to process.
124    In this case, the output list is as big as the number of steps
125    required to process all requests.
126
127    Args:
128        context: ProfileContext object.
129
130    Returns:
131        List[int]: Number of requests to process for all engine-steps. 
132         output[i], contains the number of requests that the ith step
133         should process.
134    """
135    if context.num_steps:
136        # All requests must run until num_engine_steps. This implies
137        # that their output lengths must be equal to num_engine_steps.
138        return [context.batch_size] * context.num_steps
139
140    assert context.complete_num_requests_per_step and \
141                context.complete_num_requests_per_step > 0, \
142        (f"Expected a positive complete_num_requests_per_step argument."
143         f"Instead got {context.complete_num_requests_per_step}")
144
145    # We start dropping after the first decode step.
146    step_requests = [
147        context.batch_size,  # prefill
148        context.batch_size,  # decode
149    ]
150
151    num_running_requests = context.batch_size
152    num_running_requests -= context.complete_num_requests_per_step
153    while num_running_requests > 0:
154        step_requests.append(num_running_requests)
155        num_running_requests -= context.complete_num_requests_per_step
156
157    if step_requests[-1] != 1:
158        # have 1 request running at the last step. This is often
159        # useful
160        step_requests.append(1)
161
162    return step_requests
163
164
165def run_profile(context: ProfileContext, csv_output: Optional[str],
166                json_output: Optional[str]):
167    print("Run profile with:")
168    for key, value in asdict(context).items():
169        print(f"  {key} = {value}")
170
171    requests_per_step: List[int] = determine_requests_per_step(context)
172
173    ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
174        context.batch_size, requests_per_step)
175
176    num_steps_to_profile: int = len(requests_per_step)
177    max_output_len: int = max(ol_nr.keys())
178    assert max_output_len >= 1
179
180    # Create sampling params
181    sampling_params = SamplingParams(
182        temperature=0.8,
183        top_p=0.95,
184        # max_tokens is set on a per-request basis.
185        max_tokens=None,
186        ignore_eos=True)
187
188    # Create LLM
189    llm = LLM(**asdict(context.engine_args))
190    batch_size = context.batch_size
191    prompt_len = context.prompt_len
192
193    scheduler_config = llm.llm_engine.scheduler_config
194    max_model_len = llm.llm_engine.model_config.max_model_len
195    max_num_batched_tokens = scheduler_config.max_num_batched_tokens
196    max_num_seqs = scheduler_config.max_num_seqs
197
198    if batch_size * prompt_len > max_num_batched_tokens:
199        print(f"ERROR: chosen batch_size * prompt_len "
200              f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is  "
201              f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
202              f"and therefore cannot be run in a single profile step, please "
203              f"choose a smaller batch size or prompt length, or increase "
204              f"--max-num-batched-tokens")
205        sys.exit(-1)
206    if batch_size > max_num_seqs:
207        print(
208            f"ERROR: chosen batch_size ({batch_size}) is larger than "
209            f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
210            f"single profile step, please choose a smaller batch size")
211        sys.exit(-1)
212    print("llm.llm_engine.model_config.max_model_len: ",
213          llm.llm_engine.model_config.max_model_len)
214    if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
215        print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
216              f"{max_output_len} = {prompt_len + max_output_len}) is larger "
217              f"than the model's max_model_len ({max_model_len}), please "
218              f"choose a smaller prompt_len or max_output_len, or increase "
219              f"--max-model-len")
220        sys.exit(-1)
221
222    def add_requests():
223
224        def get_output_len_generator() -> Generator[int, Any, Any]:
225            for output_len, num_reqs in ol_nr.items():
226                for _ in range(num_reqs):
227                    yield output_len
228
229        output_len_generator = get_output_len_generator()
230        for i in range(batch_size):
231            sampling_params.max_tokens = next(output_len_generator)
232            assert isinstance(sampling_params.max_tokens, int)
233
234            prompt_token_ids = torch.randint(
235                llm.llm_engine.model_config.get_vocab_size(),
236                size=(prompt_len, )).tolist()
237
238            llm.llm_engine.add_request(
239                request_id=f"seq{i}",
240                prompt={'prompt_token_ids': prompt_token_ids},
241                params=sampling_params)
242
243    def abort_requests():
244        for i in range(batch_size):
245            llm.llm_engine.abort_request(f"seq{i}")
246
247    # Warm up run
248    print("Warm up run ...")
249    add_requests()
250    llm.llm_engine.step()  # Prefill
251    llm.llm_engine.step()  # Decode
252    abort_requests()
253
254    print("Profile run ...")
255    add_requests()
256
257    with layerwise_profile() as prefill_prof:
258        llm.llm_engine.step()  # First step is prefill
259
260    decode_profs = []
261    for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
262        num_running_seqs = llm.llm_engine.scheduler[
263            0].get_num_unfinished_seq_groups()
264        with layerwise_profile(
265                num_running_seqs=num_running_seqs) as decode_prof:
266            llm.llm_engine.step()
267        decode_profs.append(decode_prof)
268
269    decode_results_list = [prof.results for prof in decode_profs]
270    prefill_results = prefill_prof.results
271    has_decode = len(decode_results_list) > 0
272
273    LINE_WIDTH = 80
274    print("=" * LINE_WIDTH)
275    print(f"= Prefill Model Table "
276          f"(prompt_len={prompt_len}, batch_size={batch_size})")
277    print("=" * LINE_WIDTH)
278    print()
279    prefill_results.print_model_table()
280
281    if has_decode:
282        print()
283        print("=" * LINE_WIDTH)
284        print(f"= First Decode Step Model Table "
285              f"(prompt_len={prompt_len}, batch_size={batch_size})")
286        print("=" * LINE_WIDTH)
287        print()
288        decode_results_list[0].print_model_table()
289
290    print()
291    print("=" * LINE_WIDTH)
292    print(f"= Prefill Summary Table "
293          f"(prompt_len={prompt_len}, batch_size={batch_size})")
294    print("=" * LINE_WIDTH)
295    print()
296    prefill_results.print_summary_table()
297
298    if has_decode:
299        print()
300        print("=" * LINE_WIDTH)
301        print(f"= First Decode Step Summary Table "
302              f"(prompt_len={prompt_len}, batch_size={batch_size})")
303        print("=" * LINE_WIDTH)
304        print()
305        decode_results_list[0].print_summary_table()
306
307    if csv_output:
308        csv_filename_base = csv_output[:-4] \
309                if csv_output.endswith('.csv') else csv_output
310        prefill_results.export_model_stats_table_csv(
311            csv_filename_base + "_prefill_model_table.csv")
312        prefill_results.export_summary_stats_table_csv(
313            csv_filename_base + "_prefill_summary_table.csv")
314
315        if has_decode:
316            decode_results_list[0].export_model_stats_table_csv(\
317                csv_filename_base + "_decode_model_table.csv")
318            decode_results_list[0].export_summary_stats_table_csv(
319                csv_filename_base + "_decode_summary_table.csv")
320
321    if json_output:
322        cuda_devices = [
323            torch.cuda.get_device_properties(dev_idx)
324            for dev_idx in range(torch.cuda.device_count())
325        ]
326
327        json_dict = {
328            "context": {
329                "python_version": f"{sys.version}",
330                "torch_version": f"{torch.__version__}",
331                "torch_cuda_version": f"{torch.version.cuda}",
332                "cuda_devices": f"{cuda_devices}",
333                **asdict(context)
334            },
335            "prefill": prefill_results.convert_stats_to_dict(),
336        }
337
338        if has_decode:
339            for idx, dr in enumerate(decode_results_list):
340                json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
341
342        # Add .json to json_output filename if it doesn't exist already.
343        json_output_file = json_output if json_output.endswith(
344            '.json') else json_output + '.json'
345        with open(json_output_file, "w+") as f:
346            json.dump(json_dict, f, indent=2)
347        pass
348
349    if context.save_chrome_traces_folder is not None:
350        os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
351        prefill_prof.profiler.export_chrome_trace(
352            context.save_chrome_traces_folder + "/prefill.json")
353        for idx, decode_prof in enumerate(decode_profs):
354            decode_prof.profiler.export_chrome_trace(
355                context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
356        print("Traces saved as prefill.json and decode_1.json, etc."
357              f" in folder {context.save_chrome_traces_folder}")
358
359
360if __name__ == "__main__":
361    parser = FlexibleArgumentParser(description="""
362Profile a model
363
364    example:
365    ```
366    python examples/offline_profile.py \\
367        --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
368        --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
369        --enforce-eager run_num_steps -n 2
370    ```
371
372    then you can use various tools to analyze the json output
373    terminal ascii tables:
374        ```
375        python tools/profiler/print_layerwise_table.py \\
376            --json-trace Llama31-8b-FP8.json --phase prefill --table summary
377        ```
378    or create matplotlib stacked bar charts:
379        ```
380        python tools/profiler/visualize_layerwise_profile.py \\
381            --json-trace Llama31-8b-FP8.json \\
382            --output-directory profile_breakdown --plot-metric pct_cuda_time
383        ```
384""",
385                                    formatter_class=RawTextHelpFormatter)
386    parser.add_argument(
387        "--csv",
388        type=str,
389        default=None,
390        help="Export the results as multiple csv file. This should be the root "
391        "filename, will create <filename>_prefill_model_table.csv, "
392        "<filename>_prefill_summary_table.csv, "
393        "<filename>_decode_model_table.csv, and "
394        "<filename>_decode_summary_table.csv")
395    parser.add_argument(
396        "--json",
397        type=str,
398        default=None,
399        help="Export the results as a json file. This should be the filename")
400    parser.add_argument("--save-chrome-traces-folder",
401                        type=str,
402                        help="Save chrome traces for the prefill and decode "
403                        "will save traces as prefill.json and decode_1.json, "
404                        "etc. inside this folder")
405    parser.add_argument(
406        "--prompt-len",
407        type=int,
408        default=PROMPT_LEN_DEFAULT,
409        help=f"Length of the random prompt to use when profiling, all batched "
410        f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
411    parser.add_argument("--batch-size",
412                        type=int,
413                        default=BATCH_SIZE_DEFAULT,
414                        help=f"Number of requests to run as a single batch, "
415                        f"default={BATCH_SIZE_DEFAULT}")
416
417    subparsers = parser.add_subparsers(dest="cmd")
418
419    run_num_steps_parser = subparsers.add_parser(
420        "run_num_steps",
421        help="This variation profiles n engine.step() invocations.")
422    run_num_steps_parser.add_argument(
423        '-n',
424        '--num-steps',
425        type=int,
426        help="Number of engine steps to profile.\n"
427        "Setting it to 1, profiles only the prefill step.\n"
428        "Setting it to 2, profiles the prefill and first decode step\n"
429        "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
430        "and so on ...")
431
432    run_to_completion_parser = subparsers.add_parser(
433        "run_to_completion",
434        help="This variation profiles all the engine.step() invocations"
435        "until the engine exhausts all submitted requests.")
436    run_to_completion_parser.add_argument(
437        '-n',
438        '--complete-num-requests-per-step',
439        type=int,
440        help=
441        "Complete complete_num_requests_per_step requests every decode step."
442        "For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
443        "the profiler is run for 6 engine steps, with the steps processing, "
444        "128, 128, 96, 64, 32, 1 requests respectively.\n"
445        "Note that we tack-on a one-request step at the end as it is often "
446        "useful.")
447
448    EngineArgs.add_cli_args(parser)
449
450    args = parser.parse_args()
451    context = ProfileContext(
452        engine_args=EngineArgs.from_cli_args(args),
453        **{
454            k: v
455            for k, v in vars(args).items()
456            if k in inspect.signature(ProfileContext).parameters
457        })
458    run_profile(context, csv_output=args.csv, json_output=args.json)