Offline Inference Audio Language#
Source vllm-project/vllm.
1"""
2This example shows how to use vLLM for running offline inference
3with the correct prompt format on audio language models.
4
5For most models, the prompt format should follow corresponding examples
6on HuggingFace model repository.
7"""
8from transformers import AutoTokenizer
9
10from vllm import LLM, SamplingParams
11from vllm.assets.audio import AudioAsset
12from vllm.utils import FlexibleArgumentParser
13
14# Input audio and question
15audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
16question = "What is recited in the audio?"
17
18
19# Ultravox 0.3
20def run_ultravox(question):
21 model_name = "fixie-ai/ultravox-v0_3"
22
23 tokenizer = AutoTokenizer.from_pretrained(model_name)
24 messages = [{
25 'role': 'user',
26 'content': f"<|reserved_special_token_0|>\n{question}"
27 }]
28 prompt = tokenizer.apply_chat_template(messages,
29 tokenize=False,
30 add_generation_prompt=True)
31
32 llm = LLM(model=model_name)
33 stop_token_ids = None
34 return llm, prompt, stop_token_ids
35
36
37model_example_map = {
38 "ultravox": run_ultravox,
39}
40
41
42def main(args):
43 model = args.model_type
44 if model not in model_example_map:
45 raise ValueError(f"Model type {model} is not supported.")
46
47 llm, prompt, stop_token_ids = model_example_map[model](question)
48
49 # We set temperature to 0.2 so that outputs can be different
50 # even when all prompts are identical when running batch inference.
51 sampling_params = SamplingParams(temperature=0.2,
52 max_tokens=64,
53 stop_token_ids=stop_token_ids)
54
55 assert args.num_prompts > 0
56 if args.num_prompts == 1:
57 # Single inference
58 inputs = {
59 "prompt": prompt,
60 "multi_modal_data": {
61 "audio": audio_and_sample_rate
62 },
63 }
64
65 else:
66 # Batch inference
67 inputs = [{
68 "prompt": prompt,
69 "multi_modal_data": {
70 "audio": audio_and_sample_rate
71 },
72 } for _ in range(args.num_prompts)]
73
74 outputs = llm.generate(inputs, sampling_params=sampling_params)
75
76 for o in outputs:
77 generated_text = o.outputs[0].text
78 print(generated_text)
79
80
81if __name__ == "__main__":
82 parser = FlexibleArgumentParser(
83 description='Demo on using vLLM for offline inference with '
84 'audio language models')
85 parser.add_argument('--model-type',
86 '-m',
87 type=str,
88 default="ultravox",
89 choices=model_example_map.keys(),
90 help='Huggingface "model_type".')
91 parser.add_argument('--num-prompts',
92 type=int,
93 default=1,
94 help='Number of prompts to run.')
95
96 args = parser.parse_args()
97 main(args)