Skip to content

vllm.model_executor.models.musicflamingo

Classes:

MusicFlamingoForConditionalGeneration

Bases: AudioFlamingo3ForConditionalGeneration

vLLM MusicFlamingo model aligned with HF modular_musicflamingo.

Source code in vllm/model_executor/models/musicflamingo.py
@MULTIMODAL_REGISTRY.register_processor(
    MusicFlamingoMultiModalProcessor,
    info=MusicFlamingoProcessingInfo,
    dummy_inputs=MusicFlamingoDummyInputsBuilder,
)
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
    """vLLM MusicFlamingo model aligned with HF modular_musicflamingo."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        self.audio_tower = MusicFlamingoEncoder(self.config.audio_config)
        self.multi_modal_projector = MusicFlamingoMultiModalProjector(self.config)
        self.pos_emb = MusicFlamingoRotaryEmbedding(self.config)

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> MusicFlamingoInputs | None:
        rote_timestamps = kwargs.pop("rote_timestamps", None)
        audio_input = super()._parse_and_validate_audio_input(**kwargs)
        if audio_input is None or audio_input["type"] == "audio_embeds":
            return audio_input

        return MusicFlamingoFeatureInputs(
            type="audio_features",
            input_features=audio_input["input_features"],
            feature_attention_mask=audio_input["feature_attention_mask"],
            chunk_counts=audio_input["chunk_counts"],
            rote_timestamps=rote_timestamps,
        )

    def _build_audio_timestamps(
        self,
        chunk_counts: list[int],
        seq_len: int,
        device: torch.device,
    ) -> torch.Tensor:
        audio_embed_frame_step = self.config.audio_frame_step * 4
        frame_offsets = (
            torch.arange(seq_len, device=device, dtype=torch.float32)
            * audio_embed_frame_step
        )

        if not chunk_counts:
            return frame_offsets.new_empty((0, seq_len))

        window_indices = torch.cat(
            [
                torch.arange(count, device=device, dtype=torch.float32)
                for count in chunk_counts
            ]
        )
        return (
            window_indices.unsqueeze(1) * seq_len * audio_embed_frame_step
            + frame_offsets
        )

    def _process_audio_input(
        self, audio_input: MusicFlamingoInputs
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
        if audio_input["type"] == "audio_embeds":
            return super()._process_audio_input(audio_input)

        rote_timestamps = audio_input["rote_timestamps"]
        (
            input_features,
            feature_attention_mask,
            chunk_counts,
        ) = self._normalize_audio_feature_inputs(audio_input)
        hidden_states = self._encode_audio_features(
            input_features,
            feature_attention_mask,
        )
        if rote_timestamps is None:
            rote_timestamps = self._build_audio_timestamps(
                chunk_counts,
                seq_len=hidden_states.shape[-2],
                device=hidden_states.device,
            )
        elif isinstance(rote_timestamps, list):
            rote_timestamps = torch.cat(rote_timestamps, dim=0)

        cos, sin = self.pos_emb(
            rote_timestamps.to(hidden_states.device),
            seq_len=hidden_states.shape[-2],
        )
        hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
        audio_features = self.multi_modal_projector(hidden_states)

        return self._group_audio_embeddings(
            audio_features,
            feature_attention_mask,
            chunk_counts,
        )