Offline Inference Audio Language

Offline Inference Audio Language#

Source: examples/offline_inference_audio_language.py.

  1"""
  2This example shows how to use vLLM for running offline inference 
  3with the correct prompt format on audio language models.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8from transformers import AutoTokenizer
  9
 10from vllm import LLM, SamplingParams
 11from vllm.assets.audio import AudioAsset
 12from vllm.utils import FlexibleArgumentParser
 13
 14audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
 15question_per_audio_count = {
 16    0: "What is 1+1?",
 17    1: "What is recited in the audio?",
 18    2: "What sport and what nursery rhyme are referenced?"
 19}
 20
 21# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
 22# lower-end GPUs.
 23# Unless specified, these settings have been tested to work on a single L4.
 24
 25
 26# Ultravox 0.3
 27def run_ultravox(question: str, audio_count: int):
 28    model_name = "fixie-ai/ultravox-v0_3"
 29
 30    tokenizer = AutoTokenizer.from_pretrained(model_name)
 31    messages = [{
 32        'role': 'user',
 33        'content': "<|audio|>\n" * audio_count + question
 34    }]
 35    prompt = tokenizer.apply_chat_template(messages,
 36                                           tokenize=False,
 37                                           add_generation_prompt=True)
 38
 39    llm = LLM(model=model_name,
 40              max_model_len=4096,
 41              max_num_seqs=5,
 42              trust_remote_code=True,
 43              limit_mm_per_prompt={"audio": audio_count})
 44    stop_token_ids = None
 45    return llm, prompt, stop_token_ids
 46
 47
 48# Qwen2-Audio
 49def run_qwen2_audio(question: str, audio_count: int):
 50    model_name = "Qwen/Qwen2-Audio-7B-Instruct"
 51
 52    llm = LLM(model=model_name,
 53              max_model_len=4096,
 54              max_num_seqs=5,
 55              limit_mm_per_prompt={"audio": audio_count})
 56
 57    audio_in_prompt = "".join([
 58        f"Audio {idx+1}: "
 59        f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
 60    ])
 61
 62    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
 63              "<|im_start|>user\n"
 64              f"{audio_in_prompt}{question}<|im_end|>\n"
 65              "<|im_start|>assistant\n")
 66    stop_token_ids = None
 67    return llm, prompt, stop_token_ids
 68
 69
 70model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio}
 71
 72
 73def main(args):
 74    model = args.model_type
 75    if model not in model_example_map:
 76        raise ValueError(f"Model type {model} is not supported.")
 77
 78    audio_count = args.num_audios
 79    llm, prompt, stop_token_ids = model_example_map[model](
 80        question_per_audio_count[audio_count], audio_count)
 81
 82    # We set temperature to 0.2 so that outputs can be different
 83    # even when all prompts are identical when running batch inference.
 84    sampling_params = SamplingParams(temperature=0.2,
 85                                     max_tokens=64,
 86                                     stop_token_ids=stop_token_ids)
 87
 88    mm_data = {}
 89    if audio_count > 0:
 90        mm_data = {
 91            "audio": [
 92                asset.audio_and_sample_rate
 93                for asset in audio_assets[:audio_count]
 94            ]
 95        }
 96
 97    assert args.num_prompts > 0
 98    inputs = {"prompt": prompt, "multi_modal_data": mm_data}
 99    if args.num_prompts > 1:
100        # Batch inference
101        inputs = [inputs] * args.num_prompts
102
103    outputs = llm.generate(inputs, sampling_params=sampling_params)
104
105    for o in outputs:
106        generated_text = o.outputs[0].text
107        print(generated_text)
108
109
110if __name__ == "__main__":
111    parser = FlexibleArgumentParser(
112        description='Demo on using vLLM for offline inference with '
113        'audio language models')
114    parser.add_argument('--model-type',
115                        '-m',
116                        type=str,
117                        default="ultravox",
118                        choices=model_example_map.keys(),
119                        help='Huggingface "model_type".')
120    parser.add_argument('--num-prompts',
121                        type=int,
122                        default=1,
123                        help='Number of prompts to run.')
124    parser.add_argument("--num-audios",
125                        type=int,
126                        default=1,
127                        choices=[0, 1, 2],
128                        help="Number of audio items per prompt.")
129
130    args = parser.parse_args()
131    main(args)