Offline Inference Audio Language

Offline Inference Audio Language#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference 
  3with the correct prompt format on audio language models.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8from transformers import AutoTokenizer
  9
 10from vllm import LLM, SamplingParams
 11from vllm.assets.audio import AudioAsset
 12from vllm.utils import FlexibleArgumentParser
 13
 14audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
 15question_per_audio_count = {
 16    0: "What is 1+1?",
 17    1: "What is recited in the audio?",
 18    2: "What sport and what nursery rhyme are referenced?"
 19}
 20
 21
 22# Ultravox 0.3
 23def run_ultravox(question: str, audio_count: int):
 24    model_name = "fixie-ai/ultravox-v0_3"
 25
 26    tokenizer = AutoTokenizer.from_pretrained(model_name)
 27    messages = [{
 28        'role':
 29        'user',
 30        'content':
 31        "<|reserved_special_token_0|>\n" * audio_count + question
 32    }]
 33    prompt = tokenizer.apply_chat_template(messages,
 34                                           tokenize=False,
 35                                           add_generation_prompt=True)
 36
 37    llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
 38    stop_token_ids = None
 39    return llm, prompt, stop_token_ids
 40
 41
 42# Qwen2-Audio
 43def run_qwen2_audio(question: str, audio_count: int):
 44    model_name = "Qwen/Qwen2-Audio-7B-Instruct"
 45
 46    llm = LLM(model=model_name,
 47              max_model_len=4096,
 48              max_num_seqs=5,
 49              limit_mm_per_prompt={"audio": audio_count})
 50
 51    audio_in_prompt = "".join([
 52        f"Audio {idx+1}: "
 53        f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
 54    ])
 55
 56    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
 57              "<|im_start|>user\n"
 58              f"{audio_in_prompt}{question}<|im_end|>\n"
 59              "<|im_start|>assistant\n")
 60    stop_token_ids = None
 61    return llm, prompt, stop_token_ids
 62
 63
 64model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio}
 65
 66
 67def main(args):
 68    model = args.model_type
 69    if model not in model_example_map:
 70        raise ValueError(f"Model type {model} is not supported.")
 71
 72    audio_count = args.num_audios
 73    llm, prompt, stop_token_ids = model_example_map[model](
 74        question_per_audio_count[audio_count], audio_count)
 75
 76    # We set temperature to 0.2 so that outputs can be different
 77    # even when all prompts are identical when running batch inference.
 78    sampling_params = SamplingParams(temperature=0.2,
 79                                     max_tokens=64,
 80                                     stop_token_ids=stop_token_ids)
 81
 82    mm_data = {}
 83    if audio_count > 0:
 84        mm_data = {
 85            "audio": [
 86                asset.audio_and_sample_rate
 87                for asset in audio_assets[:audio_count]
 88            ]
 89        }
 90
 91    assert args.num_prompts > 0
 92    inputs = {"prompt": prompt, "multi_modal_data": mm_data}
 93    if args.num_prompts > 1:
 94        # Batch inference
 95        inputs = [inputs] * args.num_prompts
 96
 97    outputs = llm.generate(inputs, sampling_params=sampling_params)
 98
 99    for o in outputs:
100        generated_text = o.outputs[0].text
101        print(generated_text)
102
103
104if __name__ == "__main__":
105    parser = FlexibleArgumentParser(
106        description='Demo on using vLLM for offline inference with '
107        'audio language models')
108    parser.add_argument('--model-type',
109                        '-m',
110                        type=str,
111                        default="ultravox",
112                        choices=model_example_map.keys(),
113                        help='Huggingface "model_type".')
114    parser.add_argument('--num-prompts',
115                        type=int,
116                        default=1,
117                        help='Number of prompts to run.')
118    parser.add_argument("--num-audios",
119                        type=int,
120                        default=1,
121                        choices=[0, 1, 2],
122                        help="Number of audio items per prompt.")
123
124    args = parser.parse_args()
125    main(args)