@MULTIMODAL_REGISTRY.register_processor(
MusicFlamingoMultiModalProcessor,
info=MusicFlamingoProcessingInfo,
dummy_inputs=MusicFlamingoDummyInputsBuilder,
)
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
"""vLLM MusicFlamingo model aligned with HF modular_musicflamingo."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.audio_tower = MusicFlamingoEncoder(self.config.audio_config)
self.multi_modal_projector = MusicFlamingoMultiModalProjector(self.config)
self.pos_emb = MusicFlamingoRotaryEmbedding(self.config)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> MusicFlamingoInputs | None:
rote_timestamps = kwargs.pop("rote_timestamps", None)
audio_input = super()._parse_and_validate_audio_input(**kwargs)
if audio_input is None or audio_input["type"] == "audio_embeds":
return audio_input
return MusicFlamingoFeatureInputs(
type="audio_features",
input_features=audio_input["input_features"],
feature_attention_mask=audio_input["feature_attention_mask"],
chunk_counts=audio_input["chunk_counts"],
rote_timestamps=rote_timestamps,
)
def _build_audio_timestamps(
self,
chunk_counts: list[int],
seq_len: int,
device: torch.device,
) -> torch.Tensor:
audio_embed_frame_step = self.config.audio_frame_step * 4
frame_offsets = (
torch.arange(seq_len, device=device, dtype=torch.float32)
* audio_embed_frame_step
)
if not chunk_counts:
return frame_offsets.new_empty((0, seq_len))
window_indices = torch.cat(
[
torch.arange(count, device=device, dtype=torch.float32)
for count in chunk_counts
]
)
return (
window_indices.unsqueeze(1) * seq_len * audio_embed_frame_step
+ frame_offsets
)
def _process_audio_input(
self, audio_input: MusicFlamingoInputs
) -> torch.Tensor | tuple[torch.Tensor, ...]:
if audio_input["type"] == "audio_embeds":
return super()._process_audio_input(audio_input)
rote_timestamps = audio_input["rote_timestamps"]
(
input_features,
feature_attention_mask,
chunk_counts,
) = self._normalize_audio_feature_inputs(audio_input)
hidden_states = self._encode_audio_features(
input_features,
feature_attention_mask,
)
if rote_timestamps is None:
rote_timestamps = self._build_audio_timestamps(
chunk_counts,
seq_len=hidden_states.shape[-2],
device=hidden_states.device,
)
elif isinstance(rote_timestamps, list):
rote_timestamps = torch.cat(rote_timestamps, dim=0)
cos, sin = self.pos_emb(
rote_timestamps.to(hidden_states.device),
seq_len=hidden_states.shape[-2],
)
hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
audio_features = self.multi_modal_projector(hidden_states)
return self._group_audio_embeddings(
audio_features,
feature_attention_mask,
chunk_counts,
)