Offline Profile#
Source: examples/offline_profile.py.
1import inspect
2import json
3import os
4import sys
5from argparse import RawTextHelpFormatter
6from dataclasses import asdict, dataclass
7from typing import Any, Dict, Generator, List, Optional, TypeAlias
8
9import torch
10import tqdm
11
12from vllm import LLM, SamplingParams
13from vllm.engine.arg_utils import EngineArgs
14from vllm.profiler import layerwise_profile
15from vllm.utils import FlexibleArgumentParser
16
17BATCH_SIZE_DEFAULT = 1
18PROMPT_LEN_DEFAULT = 256
19
20
21@dataclass
22class ProfileContext:
23 engine_args: EngineArgs
24 prompt_len: int
25 batch_size: int
26
27 # The profiler can run in 2 modes,
28 # 1. Run profiler for user specified num_steps
29 num_steps: Optional[int] = None
30 # 2. Run profiler until all requests complete
31 complete_num_requests_per_step: Optional[int] = None
32
33 save_chrome_traces_folder: Optional[str] = None
34
35
36def get_dtype(dtype: str):
37 if dtype == "torch.float":
38 return torch.float
39 else:
40 return dtype
41
42
43OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
44def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
45 -> OutputLen_NumReqs_Map:
46 """
47 Given the number of requests, batch_size, and the number of requests
48 that each engine-step should process, step_requests, determine the
49 output lengths of the requests such that step_request is honoured.
50
51 Example:
52 if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
53 then return,
54 {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
55 32 requests should have output length 2,
56 32 requests should have output length 3,
57 32 requests should have output length 4,
58 31 requests should have output length 5,
59 1 request should have output length 6.
60
61 Args:
62 batch_size (int): Number of requests submitted for profile. This is
63 args.batch_size.
64 step_requests (List[int]): step_requests[i] is the number of requests
65 that the ith engine step should process.
66
67 Returns:
68 OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
69 number of requests required to have that output-length as values.
70 """
71 ol_nr: OutputLen_NumReqs_Map = {}
72
73 # Number of request that are assigned an output-length
74 num_reqs_assigned: int = 0
75 num_steps: int = len(step_requests)
76
77 # sanity check. The first step (prefill-step), must process all requests.
78 assert step_requests[0] == batch_size
79
80 # Begin assignments from the last step.
81 output_length: int = num_steps
82 for num_requests_at_step in reversed(step_requests):
83 if num_reqs_assigned == batch_size:
84 break
85
86 assert num_reqs_assigned < batch_size
87
88 # Remove the number of requests that have been determined
89 # to participate in this step and beyond.
90 num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
91 assert num_reqs_unassigned_at_step >= 0
92
93 if num_reqs_unassigned_at_step > 0:
94 ol_nr[output_length] = num_reqs_unassigned_at_step
95 num_reqs_assigned += num_reqs_unassigned_at_step
96
97 output_length -= 1
98
99 # sanity checks.
100 assert sum(ol_nr.values()) == batch_size, \
101 ("Number of requests in output-length assignment does not match "
102 f"batch-size.\n batch size {batch_size} - "
103 f"step requests {step_requests} - assignments {ol_nr}")
104
105 # Check that the output-length is in [1, num-steps]. Output length must be
106 # at least 1 as all requests must participate in the prefill-step.
107 assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
108 ("Output lengths of requests should be in range "
109 f"[1, num-engine-steps].\n batch size {batch_size} - "
110 f"step requests {step_requests} - assignments {ol_nr}")
111
112 return ol_nr
113
114
115def determine_requests_per_step(context: ProfileContext) -> List[int]:
116 """
117 Determine number of requests each engine step should process.
118 If context.num_steps is set, then all engine steps process the
119 same number of requests and the output list is of length
120 context.num_steps.
121
122 If context.complete_num_requests_per_step is set, then each decode step
123 processes fewer and fewer requests until there are no requests to process.
124 In this case, the output list is as big as the number of steps
125 required to process all requests.
126
127 Args:
128 context: ProfileContext object.
129
130 Returns:
131 List[int]: Number of requests to process for all engine-steps.
132 output[i], contains the number of requests that the ith step
133 should process.
134 """
135 if context.num_steps:
136 # All requests must run until num_engine_steps. This implies
137 # that their output lengths must be equal to num_engine_steps.
138 return [context.batch_size] * context.num_steps
139
140 assert context.complete_num_requests_per_step and \
141 context.complete_num_requests_per_step > 0, \
142 (f"Expected a positive complete_num_requests_per_step argument."
143 f"Instead got {context.complete_num_requests_per_step}")
144
145 # We start dropping after the first decode step.
146 step_requests = [
147 context.batch_size, # prefill
148 context.batch_size, # decode
149 ]
150
151 num_running_requests = context.batch_size
152 num_running_requests -= context.complete_num_requests_per_step
153 while num_running_requests > 0:
154 step_requests.append(num_running_requests)
155 num_running_requests -= context.complete_num_requests_per_step
156
157 if step_requests[-1] != 1:
158 # have 1 request running at the last step. This is often
159 # useful
160 step_requests.append(1)
161
162 return step_requests
163
164
165def run_profile(context: ProfileContext, csv_output: Optional[str],
166 json_output: Optional[str]):
167 print("Run profile with:")
168 for key, value in asdict(context).items():
169 print(f" {key} = {value}")
170
171 requests_per_step: List[int] = determine_requests_per_step(context)
172
173 ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
174 context.batch_size, requests_per_step)
175
176 num_steps_to_profile: int = len(requests_per_step)
177 max_output_len: int = max(ol_nr.keys())
178 assert max_output_len >= 1
179
180 # Create sampling params
181 sampling_params = SamplingParams(
182 temperature=0.8,
183 top_p=0.95,
184 # max_tokens is set on a per-request basis.
185 max_tokens=None,
186 ignore_eos=True)
187
188 # Create LLM
189 llm = LLM(**asdict(context.engine_args))
190 batch_size = context.batch_size
191 prompt_len = context.prompt_len
192
193 scheduler_config = llm.llm_engine.scheduler_config
194 max_model_len = llm.llm_engine.model_config.max_model_len
195 max_num_batched_tokens = scheduler_config.max_num_batched_tokens
196 max_num_seqs = scheduler_config.max_num_seqs
197
198 if batch_size * prompt_len > max_num_batched_tokens:
199 print(f"ERROR: chosen batch_size * prompt_len "
200 f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
201 f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
202 f"and therefore cannot be run in a single profile step, please "
203 f"choose a smaller batch size or prompt length, or increase "
204 f"--max-num-batched-tokens")
205 sys.exit(-1)
206 if batch_size > max_num_seqs:
207 print(
208 f"ERROR: chosen batch_size ({batch_size}) is larger than "
209 f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
210 f"single profile step, please choose a smaller batch size")
211 sys.exit(-1)
212 print("llm.llm_engine.model_config.max_model_len: ",
213 llm.llm_engine.model_config.max_model_len)
214 if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
215 print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
216 f"{max_output_len} = {prompt_len + max_output_len}) is larger "
217 f"than the model's max_model_len ({max_model_len}), please "
218 f"choose a smaller prompt_len or max_output_len, or increase "
219 f"--max-model-len")
220 sys.exit(-1)
221
222 def add_requests():
223
224 def get_output_len_generator() -> Generator[int, Any, Any]:
225 for output_len, num_reqs in ol_nr.items():
226 for _ in range(num_reqs):
227 yield output_len
228
229 output_len_generator = get_output_len_generator()
230 for i in range(batch_size):
231 sampling_params.max_tokens = next(output_len_generator)
232 assert isinstance(sampling_params.max_tokens, int)
233
234 prompt_token_ids = torch.randint(
235 llm.llm_engine.model_config.get_vocab_size(),
236 size=(prompt_len, )).tolist()
237
238 llm.llm_engine.add_request(
239 request_id=f"seq{i}",
240 prompt={'prompt_token_ids': prompt_token_ids},
241 params=sampling_params)
242
243 def abort_requests():
244 for i in range(batch_size):
245 llm.llm_engine.abort_request(f"seq{i}")
246
247 # Warm up run
248 print("Warm up run ...")
249 add_requests()
250 llm.llm_engine.step() # Prefill
251 llm.llm_engine.step() # Decode
252 abort_requests()
253
254 print("Profile run ...")
255 add_requests()
256
257 with layerwise_profile() as prefill_prof:
258 llm.llm_engine.step() # First step is prefill
259
260 decode_profs = []
261 for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
262 num_running_seqs = llm.llm_engine.scheduler[
263 0].get_num_unfinished_seq_groups()
264 with layerwise_profile(
265 num_running_seqs=num_running_seqs) as decode_prof:
266 llm.llm_engine.step()
267 decode_profs.append(decode_prof)
268
269 decode_results_list = [prof.results for prof in decode_profs]
270 prefill_results = prefill_prof.results
271 has_decode = len(decode_results_list) > 0
272
273 LINE_WIDTH = 80
274 print("=" * LINE_WIDTH)
275 print(f"= Prefill Model Table "
276 f"(prompt_len={prompt_len}, batch_size={batch_size})")
277 print("=" * LINE_WIDTH)
278 print()
279 prefill_results.print_model_table()
280
281 if has_decode:
282 print()
283 print("=" * LINE_WIDTH)
284 print(f"= First Decode Step Model Table "
285 f"(prompt_len={prompt_len}, batch_size={batch_size})")
286 print("=" * LINE_WIDTH)
287 print()
288 decode_results_list[0].print_model_table()
289
290 print()
291 print("=" * LINE_WIDTH)
292 print(f"= Prefill Summary Table "
293 f"(prompt_len={prompt_len}, batch_size={batch_size})")
294 print("=" * LINE_WIDTH)
295 print()
296 prefill_results.print_summary_table()
297
298 if has_decode:
299 print()
300 print("=" * LINE_WIDTH)
301 print(f"= First Decode Step Summary Table "
302 f"(prompt_len={prompt_len}, batch_size={batch_size})")
303 print("=" * LINE_WIDTH)
304 print()
305 decode_results_list[0].print_summary_table()
306
307 if csv_output:
308 csv_filename_base = csv_output[:-4] \
309 if csv_output.endswith('.csv') else csv_output
310 prefill_results.export_model_stats_table_csv(
311 csv_filename_base + "_prefill_model_table.csv")
312 prefill_results.export_summary_stats_table_csv(
313 csv_filename_base + "_prefill_summary_table.csv")
314
315 if has_decode:
316 decode_results_list[0].export_model_stats_table_csv(\
317 csv_filename_base + "_decode_model_table.csv")
318 decode_results_list[0].export_summary_stats_table_csv(
319 csv_filename_base + "_decode_summary_table.csv")
320
321 if json_output:
322 cuda_devices = [
323 torch.cuda.get_device_properties(dev_idx)
324 for dev_idx in range(torch.cuda.device_count())
325 ]
326
327 json_dict = {
328 "context": {
329 "python_version": f"{sys.version}",
330 "torch_version": f"{torch.__version__}",
331 "torch_cuda_version": f"{torch.version.cuda}",
332 "cuda_devices": f"{cuda_devices}",
333 **asdict(context)
334 },
335 "prefill": prefill_results.convert_stats_to_dict(),
336 }
337
338 if has_decode:
339 for idx, dr in enumerate(decode_results_list):
340 json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
341
342 # Add .json to json_output filename if it doesn't exist already.
343 json_output_file = json_output if json_output.endswith(
344 '.json') else json_output + '.json'
345 with open(json_output_file, "w+") as f:
346 json.dump(json_dict, f, indent=2)
347 pass
348
349 if context.save_chrome_traces_folder is not None:
350 os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
351 prefill_prof.profiler.export_chrome_trace(
352 context.save_chrome_traces_folder + "/prefill.json")
353 for idx, decode_prof in enumerate(decode_profs):
354 decode_prof.profiler.export_chrome_trace(
355 context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
356 print("Traces saved as prefill.json and decode_1.json, etc."
357 f" in folder {context.save_chrome_traces_folder}")
358
359
360if __name__ == "__main__":
361 parser = FlexibleArgumentParser(description="""
362Profile a model
363
364 example:
365 ```
366 python examples/offline_profile.py \\
367 --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
368 --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
369 --enforce-eager run_num_steps -n 2
370 ```
371
372 then you can use various tools to analyze the json output
373 terminal ascii tables:
374 ```
375 python tools/profiler/print_layerwise_table.py \\
376 --json-trace Llama31-8b-FP8.json --phase prefill --table summary
377 ```
378 or create matplotlib stacked bar charts:
379 ```
380 python tools/profiler/visualize_layerwise_profile.py \\
381 --json-trace Llama31-8b-FP8.json \\
382 --output-directory profile_breakdown --plot-metric pct_cuda_time
383 ```
384""",
385 formatter_class=RawTextHelpFormatter)
386 parser.add_argument(
387 "--csv",
388 type=str,
389 default=None,
390 help="Export the results as multiple csv file. This should be the root "
391 "filename, will create <filename>_prefill_model_table.csv, "
392 "<filename>_prefill_summary_table.csv, "
393 "<filename>_decode_model_table.csv, and "
394 "<filename>_decode_summary_table.csv")
395 parser.add_argument(
396 "--json",
397 type=str,
398 default=None,
399 help="Export the results as a json file. This should be the filename")
400 parser.add_argument("--save-chrome-traces-folder",
401 type=str,
402 help="Save chrome traces for the prefill and decode "
403 "will save traces as prefill.json and decode_1.json, "
404 "etc. inside this folder")
405 parser.add_argument(
406 "--prompt-len",
407 type=int,
408 default=PROMPT_LEN_DEFAULT,
409 help=f"Length of the random prompt to use when profiling, all batched "
410 f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
411 parser.add_argument("--batch-size",
412 type=int,
413 default=BATCH_SIZE_DEFAULT,
414 help=f"Number of requests to run as a single batch, "
415 f"default={BATCH_SIZE_DEFAULT}")
416
417 subparsers = parser.add_subparsers(dest="cmd")
418
419 run_num_steps_parser = subparsers.add_parser(
420 "run_num_steps",
421 help="This variation profiles n engine.step() invocations.")
422 run_num_steps_parser.add_argument(
423 '-n',
424 '--num-steps',
425 type=int,
426 help="Number of engine steps to profile.\n"
427 "Setting it to 1, profiles only the prefill step.\n"
428 "Setting it to 2, profiles the prefill and first decode step\n"
429 "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
430 "and so on ...")
431
432 run_to_completion_parser = subparsers.add_parser(
433 "run_to_completion",
434 help="This variation profiles all the engine.step() invocations"
435 "until the engine exhausts all submitted requests.")
436 run_to_completion_parser.add_argument(
437 '-n',
438 '--complete-num-requests-per-step',
439 type=int,
440 help=
441 "Complete complete_num_requests_per_step requests every decode step."
442 "For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
443 "the profiler is run for 6 engine steps, with the steps processing, "
444 "128, 128, 96, 64, 32, 1 requests respectively.\n"
445 "Note that we tack-on a one-request step at the end as it is often "
446 "useful.")
447
448 EngineArgs.add_cli_args(parser)
449
450 args = parser.parse_args()
451 context = ProfileContext(
452 engine_args=EngineArgs.from_cli_args(args),
453 **{
454 k: v
455 for k, v in vars(args).items()
456 if k in inspect.signature(ProfileContext).parameters
457 })
458 run_profile(context, csv_output=args.csv, json_output=args.json)