Offline Inference Audio Language#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference 
  3with the correct prompt format on audio language models.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8from transformers import AutoTokenizer
  9
 10from vllm import LLM, SamplingParams
 11from vllm.assets.audio import AudioAsset
 12from vllm.utils import FlexibleArgumentParser
 13
 14audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
 15question_per_audio_count = [
 16    "What is recited in the audio?",
 17    "What sport and what nursery rhyme are referenced?"
 18]
 19
 20
 21# Ultravox 0.3
 22def run_ultravox(question, audio_count):
 23    model_name = "fixie-ai/ultravox-v0_3"
 24
 25    tokenizer = AutoTokenizer.from_pretrained(model_name)
 26    messages = [{
 27        'role':
 28        'user',
 29        'content':
 30        "<|reserved_special_token_0|>\n" * audio_count + question
 31    }]
 32    prompt = tokenizer.apply_chat_template(messages,
 33                                           tokenize=False,
 34                                           add_generation_prompt=True)
 35
 36    llm = LLM(model=model_name,
 37              enforce_eager=True,
 38              enable_chunked_prefill=False,
 39              max_model_len=8192,
 40              limit_mm_per_prompt={"audio": audio_count})
 41    stop_token_ids = None
 42    return llm, prompt, stop_token_ids
 43
 44
 45model_example_map = {
 46    "ultravox": run_ultravox,
 47}
 48
 49
 50def main(args):
 51    model = args.model_type
 52    if model not in model_example_map:
 53        raise ValueError(f"Model type {model} is not supported.")
 54
 55    audio_count = args.num_audios
 56    llm, prompt, stop_token_ids = model_example_map[model](
 57        question_per_audio_count[audio_count - 1], audio_count)
 58
 59    # We set temperature to 0.2 so that outputs can be different
 60    # even when all prompts are identical when running batch inference.
 61    sampling_params = SamplingParams(temperature=0.2,
 62                                     max_tokens=64,
 63                                     stop_token_ids=stop_token_ids)
 64
 65    assert args.num_prompts > 0
 66    inputs = {
 67        "prompt": prompt,
 68        "multi_modal_data": {
 69            "audio": [
 70                asset.audio_and_sample_rate
 71                for asset in audio_assets[:audio_count]
 72            ]
 73        },
 74    }
 75    if args.num_prompts > 1:
 76        # Batch inference
 77        inputs = [inputs] * args.num_prompts
 78
 79    outputs = llm.generate(inputs, sampling_params=sampling_params)
 80
 81    for o in outputs:
 82        generated_text = o.outputs[0].text
 83        print(generated_text)
 84
 85
 86if __name__ == "__main__":
 87    parser = FlexibleArgumentParser(
 88        description='Demo on using vLLM for offline inference with '
 89        'audio language models')
 90    parser.add_argument('--model-type',
 91                        '-m',
 92                        type=str,
 93                        default="ultravox",
 94                        choices=model_example_map.keys(),
 95                        help='Huggingface "model_type".')
 96    parser.add_argument('--num-prompts',
 97                        type=int,
 98                        default=1,
 99                        help='Number of prompts to run.')
100    parser.add_argument("--num-audios",
101                        type=int,
102                        default=1,
103                        choices=[1, 2],
104                        help="Number of audio items per prompt.")
105
106    args = parser.parse_args()
107    main(args)