AudioX offline inference¶
Source https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/audiox.
Generate audio with the AudioX MMDiT diffusion pipeline (AudioXPipeline). Six tasks: t2a, t2m, v2a, v2m, tv2a, tv2m.
Prerequisites¶
Download a vLLM-Omni weight bundle (component-sharded safetensors):
The Hugging Face id zhangj1an/AudioX also works directly without prefetching.
Usage¶
# Text-to-audio only (default uses zhangj1an/AudioX from the Hub):
python end2end.py --tasks t2a
# All six tasks against a local bundle and a sample video for v2*/tv2*:
python end2end.py \
--model ./audiox_weights \
--video https://zeyuet.github.io/AudioX/static/samples/V2M/1XeBotOFqHA.mp4
# Subset of tasks, custom seed and steps:
python end2end.py --tasks t2a tv2a --num-inference-steps 100 --seed 0
Arguments¶
--model: HF id or local bundle path (default:zhangj1an/AudioX).--tasks: any subset oft2a t2m v2a v2m tv2a tv2m(default: all).--video: video file/URL — required forv2*andtv2*.--reference-audio: optional audio prompt (audio-conditioned generation).--num-inference-steps,--guidance-scale,--seed,--seconds-total,--sample-rate,--output-dir: generation knobs.
Outputs land in <output-dir>/<task>.wav as 16-bit stereo WAV.
Example materials¶
end2end.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""End-to-end AudioX offline example covering the 6 t2*/v2*/tv2* tasks.
Provide a directory with the **vLLM-Omni AudioX safetensors bundle** (e.g. from
``zhangj1an/AudioX`` on Hugging Face)::
huggingface-cli download zhangj1an/AudioX --local-dir ./audiox_weights
python end2end.py --model ./audiox_weights
python end2end.py --model ./audiox_weights --tasks t2a tv2a
"""
from __future__ import annotations
import argparse
import time
from pathlib import Path
import soundfile
import torch
import torchaudio.functional as TF
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
ROOT = Path(__file__).resolve().parent
SAMPLE_PROMPTS: dict[str, str] = {
"t2a": "Fireworks burst twice, followed by a period of silence before a clock begins ticking.",
"t2m": "Uplifting ukulele tune for a travel vlog",
"v2a": "",
"v2m": "",
"tv2a": "drum beating sound and human talking",
"tv2m": "uplifting music matching the scene",
}
ALL_TASKS = ("t2a", "t2m", "v2a", "v2m", "tv2a", "tv2m")
VIDEO_TASKS = frozenset({"v2a", "v2m", "tv2a", "tv2m"})
TEXT_TASKS = frozenset({"t2a", "t2m", "tv2a", "tv2m"})
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="AudioX offline end-to-end (6 t2*/v2*/tv2* tasks).")
p.add_argument("--model", default="zhangj1an/AudioX", help="HF id or local AudioX bundle path.")
p.add_argument("--tasks", nargs="+", default=list(ALL_TASKS), choices=ALL_TASKS)
p.add_argument("--video", default="", help="Video path / URL (required for v2*/tv2*).")
p.add_argument("--reference-audio", default="", help="Optional audio prompt for audio-conditioned generation.")
p.add_argument("--output-dir", default=str(ROOT / "audiox_task_outputs"))
p.add_argument("--num-inference-steps", type=int, default=250)
p.add_argument("--seconds-total", type=float, default=10.0)
p.add_argument("--guidance-scale", type=float, default=6.0)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--sample-rate", type=int, default=48000, help="Output WAV rate (resampled if != model rate).")
return p.parse_args()
def save_wav(audio: torch.Tensor, path: Path, sample_rate: int) -> None:
"""Write 16-bit PCM WAV. ``audio`` is ``[channels, samples]`` float in [-1, 1]."""
path.parent.mkdir(parents=True, exist_ok=True)
soundfile.write(str(path), audio.clamp(-1.0, 1.0).cpu().T.numpy(), sample_rate, subtype="PCM_16")
def main() -> None:
args = parse_args()
omni = Omni(model=args.model, model_class_name="AudioXPipeline")
for task in args.tasks:
if task in VIDEO_TASKS and not args.video:
raise SystemExit(f"task={task!r} requires --video")
prompt = SAMPLE_PROMPTS[task] if task in TEXT_TASKS else ""
extra: dict = {"audiox_task": task, "seconds_start": 0.0, "seconds_total": float(args.seconds_total)}
if task in VIDEO_TASKS:
extra["video_path"] = args.video
if args.reference_audio:
extra["audio_path"] = args.reference_audio
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed)
t0 = time.perf_counter()
outputs = omni.generate(
prompt,
OmniDiffusionSamplingParams(
generator=generator,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
seed=args.seed,
extra_args=extra,
),
)
audio = outputs[0].request_output.multimodal_output.get("audio")
if audio is None:
raise RuntimeError(f"No audio produced for task {task!r}")
audio = torch.as_tensor(audio).detach().cpu().float()
if audio.ndim == 3:
audio = audio[0]
model_sr = int(outputs[0].request_output.multimodal_output.get("audio_sample_rate") or 44100)
if model_sr != args.sample_rate:
audio = TF.resample(audio, model_sr, args.sample_rate)
out_path = Path(args.output_dir) / f"{task}.wav"
save_wav(audio, out_path, args.sample_rate)
print(f"[{task}] saved {out_path} ({time.perf_counter() - t0:.2f}s)")
omni.close()
if __name__ == "__main__":
main()