Skip to content

vllm.model_executor.models.moss_audio

Inference-only MOSS-Audio model compatible with HuggingFace weights.

Classes:

MossAudioAudioInputs

Bases: TensorSchema

Dimensions
  • b: Batch size
  • nmb: Number of mel bins
  • t: Time frames
Source code in vllm/model_executor/models/moss_audio.py
class MossAudioAudioInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - nmb: Number of mel bins
        - t: Time frames
    """

    audio_data: Annotated[torch.Tensor, TensorShape("b", "nmb", "t")]
    audio_data_seqlens: Annotated[torch.Tensor, TensorShape("b")]

MossAudioModel

Bases: Module, SupportsMultiModal, SupportsPP, SupportsLoRA

Source code in vllm/model_executor/models/moss_audio.py
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
@MULTIMODAL_REGISTRY.register_processor(
    MossAudioMultiModalProcessor,
    info=MossAudioProcessingInfo,
    dummy_inputs=MossAudioDummyInputsBuilder,
)
class MossAudioModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "language_model.embed_tokens.": "language_model.model.embed_tokens.",
            "language_model.layers.": "language_model.model.layers.",
            "language_model.norm.": "language_model.model.norm.",
        }
    )

    def get_mm_mapping(self) -> MultiModelKeys:
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
        )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("audio"):
            return MOSS_AUDIO_PLACEHOLDER
        raise ValueError("Only audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        self.vllm_config = vllm_config
        config = vllm_config.model_config.hf_config
        if not isinstance(config, MossAudioConfig):
            config = MossAudioConfig(
                audio_config=getattr(config, "audio_config", None),
                language_config=getattr(config, "language_config", None),
                adapter_hidden_size=getattr(config, "adapter_hidden_size", 8192),
                ignore_index=getattr(config, "ignore_index", -100),
                deepstack_num_inject_layers=getattr(
                    config, "deepstack_num_inject_layers", None
                ),
            )
        self.config = config
        self.quant_config = vllm_config.quant_config
        self.multimodal_config = vllm_config.model_config.multimodal_config

        parallel_config = vllm_config.parallel_config
        tp_size = parallel_config.tensor_parallel_size
        if self.config.adapter_hidden_size % tp_size != 0:
            raise ValueError(
                "MOSS-Audio adapter_hidden_size must be divisible by tensor "
                f"parallel size. Got adapter_hidden_size="
                f"{self.config.adapter_hidden_size} and tensor_parallel_size="
                f"{tp_size}."
            )

        audio_config = MossAudioEncoderConfig.from_config(self.config.audio_config)
        if audio_config.encoder_attention_heads % tp_size != 0:
            raise ValueError(
                "MOSS-Audio encoder_attention_heads must be divisible by "
                "tensor parallel size. Got encoder_attention_heads="
                f"{audio_config.encoder_attention_heads} and "
                f"tensor_parallel_size={tp_size}."
            )
        language_config = self.config.language_config
        self.audio_token_id = MOSS_AUDIO_TOKEN_ID
        self.deepstack_input_embeds: IntermediateTensors | None = None

        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_encoder = MossAudioEncoder(
                audio_config,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "audio_encoder"),
            )
            self.audio_adapter = GatedMLP(
                input_size=audio_config.output_dim,
                hidden_size=self.config.adapter_hidden_size,
                output_size=language_config.hidden_size,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "audio_adapter"),
            )

            deepstack_k = len(audio_config.deepstack_encoder_layer_indexes or [])
            if self.config.deepstack_num_inject_layers is not None:
                deepstack_k = min(
                    deepstack_k,
                    int(self.config.deepstack_num_inject_layers),
                )
            self.deepstack_audio_merger_list = nn.ModuleList(
                [
                    GatedMLP(
                        input_size=audio_config.output_dim,
                        hidden_size=self.config.adapter_hidden_size,
                        output_size=language_config.hidden_size,
                        quant_config=self.quant_config,
                        prefix=maybe_prefix(
                            prefix,
                            f"deepstack_audio_merger_list.{layer_idx}",
                        ),
                    )
                    for layer_idx in range(deepstack_k)
                ]
            )

        with self._mark_language_model(vllm_config):
            self.language_model = MossQwen3ForCausalLM(
                vllm_config=vllm_config.with_hf_config(
                    language_config, architectures=["Qwen3ForCausalLM"]
                ),
                prefix=maybe_prefix(prefix, "language_model"),
            )
            self.language_model.deepstack_inject_layer_indices = range(deepstack_k)
            self.language_model.model.deepstack_inject_layer_indices = range(
                deepstack_k
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    @staticmethod
    def _validate_audio_batch_size(
        audio_batch_size: int, audio_data_seqlens: torch.Tensor
    ) -> None:
        if audio_batch_size != audio_data_seqlens.numel():
            raise ValueError(
                "audio_data batch size does not match audio_data_seqlens: "
                f"{audio_batch_size} != {audio_data_seqlens.numel()}."
            )

    @staticmethod
    def _pad_audio_data_list(
        audio_data: list[torch.Tensor],
        audio_data_seqlens: torch.Tensor,
    ) -> torch.Tensor:
        if len(audio_data) == 0:
            raise ValueError("audio_data list must not be empty.")
        MossAudioModel._validate_audio_batch_size(len(audio_data), audio_data_seqlens)

        # pad_sequence needs every item to share the same trailing feature
        # layout, so validate the mel-major audio tensors before transposing.
        first = audio_data[0]
        if not isinstance(first, torch.Tensor):
            raise TypeError("audio_data list items must be torch.Tensor.")
        if first.ndim != 2:
            raise ValueError("audio_data list items must have shape [mel_dim, time].")

        mel_dim = first.shape[0]
        dtype = first.dtype
        device = first.device
        for item in audio_data[1:]:
            if not isinstance(item, torch.Tensor):
                raise TypeError("audio_data list items must be torch.Tensor.")
            if item.ndim != 2:
                raise ValueError(
                    "audio_data list items must have shape [mel_dim, time]."
                )
            if item.shape[0] != mel_dim:
                raise ValueError("audio_data list items must have the same mel_dim.")
            if item.dtype != dtype:
                raise TypeError("audio_data list items must have the same dtype.")
            if item.device != device:
                raise ValueError("audio_data list items must be on the same device.")

        # Each item arrives as [mel_dim, time]. pad_sequence pads along dim 1
        # after converting to [time, mel_dim], then we restore [batch, mel, time].
        time_major = [item.transpose(0, 1) for item in audio_data]
        padded = torch.nn.utils.rnn.pad_sequence(time_major, batch_first=True)
        return padded.transpose(1, 2).contiguous()

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> MossAudioAudioInputs | None:
        """Normalize and validate model-side audio kwargs.

        If audio_data is provided, this checks that audio_data_seqlens is also
        present, flattens sequence lengths to a long tensor, pads list inputs
        to [batch, mel_dim, time], validates batch-size/sequence-length
        agreement, and rejects empty, non-positive, or downsampled-zero audio
        lengths.
        """
        audio_data = kwargs.pop("audio_data", None)
        audio_data_seqlens = kwargs.pop("audio_data_seqlens", None)
        if audio_data is None:
            return None
        if audio_data_seqlens is None:
            raise ValueError(
                "audio_data_seqlens is required when audio_data is provided."
            )
        if not isinstance(audio_data_seqlens, torch.Tensor):
            audio_data_seqlens = torch.tensor(audio_data_seqlens, dtype=torch.long)
        audio_data_seqlens = audio_data_seqlens.to(dtype=torch.long).reshape(-1)

        if isinstance(audio_data, list):
            audio_data = self._pad_audio_data_list(audio_data, audio_data_seqlens)
        elif isinstance(audio_data, torch.Tensor):
            if audio_data.ndim == 3:
                self._validate_audio_batch_size(audio_data.shape[0], audio_data_seqlens)
        else:
            raise TypeError("audio_data must be a torch.Tensor or list[torch.Tensor].")

        audio_token_lens = MossAudioEncoder._compute_downsampled_length(
            audio_data_seqlens
        )
        if (
            audio_data_seqlens.numel() == 0
            or torch.any(audio_data_seqlens <= 0).item()
            or torch.any(audio_token_lens <= 0).item()
        ):
            raise ValueError("The audio is too short to be represented.")
        return MossAudioAudioInputs(
            audio_data=audio_data,
            audio_data_seqlens=audio_data_seqlens,
        )

    def _process_audio_input(
        self,
        audio_input: MossAudioAudioInputs,
    ) -> tuple[torch.Tensor, ...]:
        """Run the audio encoder and return one embedding tensor per audio.

        Example:
            audio_data=[2, 128, 1200], audio_data_seqlens=[800, 1200]
            -> returns (audio0_embeds, audio1_embeds), split by token length
            -> DeepStack packs each item as [main, layer0, ...] on dim -1
        """
        audio_data = audio_input["audio_data"]
        audio_data_seqlens = audio_input["audio_data_seqlens"]
        last_hidden_state, deepstack = self.audio_encoder(
            audio_data.to(self.audio_encoder.dtype),
            feature_lens=audio_data_seqlens,
            output_deepstack_hidden_states=len(self.deepstack_audio_merger_list) > 0,
        )
        audio_embeds = self.audio_adapter(last_hidden_state)
        audio_lengths = MossAudioEncoder._compute_downsampled_length(
            audio_data_seqlens.to(device=audio_embeds.device, dtype=torch.long)
        ).tolist()
        main_embeddings = tuple(audio_embeds.squeeze(0).split(audio_lengths, dim=0))

        deepstack_embeddings: list[tuple[torch.Tensor, ...]] = []
        if deepstack is not None:
            if len(deepstack) < len(self.deepstack_audio_merger_list):
                raise RuntimeError(
                    "DeepStack output count does not match configured audio "
                    "merger count."
                )
            for idx, hidden_states in enumerate(
                deepstack[: len(self.deepstack_audio_merger_list)]
            ):
                ds_embeds = self.deepstack_audio_merger_list[idx](hidden_states)
                deepstack_embeddings.append(
                    tuple(ds_embeds.squeeze(0).split(audio_lengths, dim=0))
                )

        if not deepstack_embeddings:
            return main_embeddings

        return tuple(
            torch.cat(
                [
                    main_embedding,
                    *(
                        layer_embeddings[item_idx]
                        for layer_embeddings in deepstack_embeddings
                    ),
                ],
                dim=-1,
            )
            for item_idx, main_embedding in enumerate(main_embeddings)
        )

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return ()
        return self._process_audio_input(audio_input)

    def _split_multimodal_embeddings(
        self,
        multimodal_embeddings: MultiModalEmbeddings,
        hidden_size: int,
    ) -> tuple[tuple[torch.Tensor, ...], tuple[tuple[torch.Tensor, ...], ...]]:
        """Unpack audio embeddings before merging them into token embeddings.

        embed_input_ids calls this on the output of embed_multimodal. Plain
        audio embeddings already have width hidden_size and are returned as the
        main embeddings for _merge_multimodal_embeddings. When DeepStack is
        enabled, _process_audio_input packs each audio item as
        [main, layer0, layer1, ...] along the last dimension so the standard
        multimodal path can carry a single embedding object. This method splits
        that packed layout back into main embeddings plus per-layer DeepStack
        embeddings, which _cache_deepstack_input_embeds scatters and forward
        passes into MossQwen3Model for layer injection.
        """
        if isinstance(multimodal_embeddings, torch.Tensor):
            embeddings = tuple(multimodal_embeddings.unbind(0))
        else:
            embeddings = tuple(multimodal_embeddings)

        if len(embeddings) == 0:
            return (), ()

        deepstack_count = len(self.deepstack_audio_merger_list)
        if all(embedding.shape[-1] == hidden_size for embedding in embeddings):
            return embeddings, ()

        packed_hidden_size = hidden_size * (deepstack_count + 1)
        if deepstack_count == 0 or any(
            embedding.shape[-1] != packed_hidden_size for embedding in embeddings
        ):
            got = [int(embedding.shape[-1]) for embedding in embeddings]
            raise ValueError(
                "MOSS-Audio multimodal embedding width mismatch: expected "
                f"{hidden_size} or {packed_hidden_size}, got {got}."
            )

        split_by_item = [
            torch.split(embedding, hidden_size, dim=-1) for embedding in embeddings
        ]
        main_embeddings = tuple(parts[0] for parts in split_by_item)
        deepstack_embeddings = tuple(
            tuple(parts[layer_idx + 1] for parts in split_by_item)
            for layer_idx in range(deepstack_count)
        )
        return main_embeddings, deepstack_embeddings

    def _cache_deepstack_input_embeds(
        self,
        inputs_embeds: torch.Tensor,
        deepstack_embeddings: tuple[tuple[torch.Tensor, ...], ...],
        is_multimodal: torch.Tensor,
    ) -> None:
        if len(deepstack_embeddings) == 0:
            self.deepstack_input_embeds = None
            return
        flat_by_layer = [
            torch.cat(layer_embeds, dim=0).to(
                device=inputs_embeds.device, dtype=inputs_embeds.dtype
            )
            for layer_embeds in deepstack_embeddings
        ]
        num_mm_tokens = int(is_multimodal.sum().item())
        if any(layer.shape[0] != num_mm_tokens for layer in flat_by_layer):
            got = [int(layer.shape[0]) for layer in flat_by_layer]
            raise ValueError(
                "DeepStack audio token count mismatch: "
                f"expected {num_mm_tokens}, got {got}."
            )
        data = {}
        for layer_idx, layer_embeds in enumerate(flat_by_layer):
            scattered = inputs_embeds.new_zeros(inputs_embeds.shape)
            scattered[is_multimodal] = layer_embeds
            data[f"deepstack_input_embeds_{layer_idx}"] = scattered
        self.deepstack_input_embeds = IntermediateTensors(data)

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
    ) -> torch.Tensor:
        inputs_embeds = self._embed_text_input_ids(
            input_ids,
            self.language_model.embed_input_ids,
            is_multimodal=is_multimodal,
        )

        self.deepstack_input_embeds = None
        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds
        is_multimodal = _require_is_multimodal(is_multimodal)
        multimodal_embeddings, deepstack_embeddings = self._split_multimodal_embeddings(
            multimodal_embeddings,
            hidden_size=int(inputs_embeds.shape[-1]),
        )

        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
        self._cache_deepstack_input_embeds(
            inputs_embeds,
            deepstack_embeddings,
            is_multimodal,
        )
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is None:
            deepstack_input_embeds = self.deepstack_input_embeds
        else:
            # Non-first PP ranks consume hidden states from intermediate_tensors.
            # The executor may still pass dummy inputs_embeds during profiling.
            inputs_embeds = None
            deepstack_input_embeds = intermediate_tensors
        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            deepstack_input_embeds=deepstack_input_embeds,
        )
        self.deepstack_input_embeds = None
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["audio_encoder.embed_positions"],
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

_parse_and_validate_audio_input(**kwargs)

Normalize and validate model-side audio kwargs.

If audio_data is provided, this checks that audio_data_seqlens is also present, flattens sequence lengths to a long tensor, pads list inputs to [batch, mel_dim, time], validates batch-size/sequence-length agreement, and rejects empty, non-positive, or downsampled-zero audio lengths.

Source code in vllm/model_executor/models/moss_audio.py
def _parse_and_validate_audio_input(
    self, **kwargs: object
) -> MossAudioAudioInputs | None:
    """Normalize and validate model-side audio kwargs.

    If audio_data is provided, this checks that audio_data_seqlens is also
    present, flattens sequence lengths to a long tensor, pads list inputs
    to [batch, mel_dim, time], validates batch-size/sequence-length
    agreement, and rejects empty, non-positive, or downsampled-zero audio
    lengths.
    """
    audio_data = kwargs.pop("audio_data", None)
    audio_data_seqlens = kwargs.pop("audio_data_seqlens", None)
    if audio_data is None:
        return None
    if audio_data_seqlens is None:
        raise ValueError(
            "audio_data_seqlens is required when audio_data is provided."
        )
    if not isinstance(audio_data_seqlens, torch.Tensor):
        audio_data_seqlens = torch.tensor(audio_data_seqlens, dtype=torch.long)
    audio_data_seqlens = audio_data_seqlens.to(dtype=torch.long).reshape(-1)

    if isinstance(audio_data, list):
        audio_data = self._pad_audio_data_list(audio_data, audio_data_seqlens)
    elif isinstance(audio_data, torch.Tensor):
        if audio_data.ndim == 3:
            self._validate_audio_batch_size(audio_data.shape[0], audio_data_seqlens)
    else:
        raise TypeError("audio_data must be a torch.Tensor or list[torch.Tensor].")

    audio_token_lens = MossAudioEncoder._compute_downsampled_length(
        audio_data_seqlens
    )
    if (
        audio_data_seqlens.numel() == 0
        or torch.any(audio_data_seqlens <= 0).item()
        or torch.any(audio_token_lens <= 0).item()
    ):
        raise ValueError("The audio is too short to be represented.")
    return MossAudioAudioInputs(
        audio_data=audio_data,
        audio_data_seqlens=audio_data_seqlens,
    )

_process_audio_input(audio_input)

Run the audio encoder and return one embedding tensor per audio.

Example

audio_data=[2, 128, 1200], audio_data_seqlens=[800, 1200] -> returns (audio0_embeds, audio1_embeds), split by token length -> DeepStack packs each item as [main, layer0, ...] on dim -1

Source code in vllm/model_executor/models/moss_audio.py
def _process_audio_input(
    self,
    audio_input: MossAudioAudioInputs,
) -> tuple[torch.Tensor, ...]:
    """Run the audio encoder and return one embedding tensor per audio.

    Example:
        audio_data=[2, 128, 1200], audio_data_seqlens=[800, 1200]
        -> returns (audio0_embeds, audio1_embeds), split by token length
        -> DeepStack packs each item as [main, layer0, ...] on dim -1
    """
    audio_data = audio_input["audio_data"]
    audio_data_seqlens = audio_input["audio_data_seqlens"]
    last_hidden_state, deepstack = self.audio_encoder(
        audio_data.to(self.audio_encoder.dtype),
        feature_lens=audio_data_seqlens,
        output_deepstack_hidden_states=len(self.deepstack_audio_merger_list) > 0,
    )
    audio_embeds = self.audio_adapter(last_hidden_state)
    audio_lengths = MossAudioEncoder._compute_downsampled_length(
        audio_data_seqlens.to(device=audio_embeds.device, dtype=torch.long)
    ).tolist()
    main_embeddings = tuple(audio_embeds.squeeze(0).split(audio_lengths, dim=0))

    deepstack_embeddings: list[tuple[torch.Tensor, ...]] = []
    if deepstack is not None:
        if len(deepstack) < len(self.deepstack_audio_merger_list):
            raise RuntimeError(
                "DeepStack output count does not match configured audio "
                "merger count."
            )
        for idx, hidden_states in enumerate(
            deepstack[: len(self.deepstack_audio_merger_list)]
        ):
            ds_embeds = self.deepstack_audio_merger_list[idx](hidden_states)
            deepstack_embeddings.append(
                tuple(ds_embeds.squeeze(0).split(audio_lengths, dim=0))
            )

    if not deepstack_embeddings:
        return main_embeddings

    return tuple(
        torch.cat(
            [
                main_embedding,
                *(
                    layer_embeddings[item_idx]
                    for layer_embeddings in deepstack_embeddings
                ),
            ],
            dim=-1,
        )
        for item_idx, main_embedding in enumerate(main_embeddings)
    )

_split_multimodal_embeddings(multimodal_embeddings, hidden_size)

Unpack audio embeddings before merging them into token embeddings.

embed_input_ids calls this on the output of embed_multimodal. Plain audio embeddings already have width hidden_size and are returned as the main embeddings for _merge_multimodal_embeddings. When DeepStack is enabled, _process_audio_input packs each audio item as [main, layer0, layer1, ...] along the last dimension so the standard multimodal path can carry a single embedding object. This method splits that packed layout back into main embeddings plus per-layer DeepStack embeddings, which _cache_deepstack_input_embeds scatters and forward passes into MossQwen3Model for layer injection.

Source code in vllm/model_executor/models/moss_audio.py
def _split_multimodal_embeddings(
    self,
    multimodal_embeddings: MultiModalEmbeddings,
    hidden_size: int,
) -> tuple[tuple[torch.Tensor, ...], tuple[tuple[torch.Tensor, ...], ...]]:
    """Unpack audio embeddings before merging them into token embeddings.

    embed_input_ids calls this on the output of embed_multimodal. Plain
    audio embeddings already have width hidden_size and are returned as the
    main embeddings for _merge_multimodal_embeddings. When DeepStack is
    enabled, _process_audio_input packs each audio item as
    [main, layer0, layer1, ...] along the last dimension so the standard
    multimodal path can carry a single embedding object. This method splits
    that packed layout back into main embeddings plus per-layer DeepStack
    embeddings, which _cache_deepstack_input_embeds scatters and forward
    passes into MossQwen3Model for layer injection.
    """
    if isinstance(multimodal_embeddings, torch.Tensor):
        embeddings = tuple(multimodal_embeddings.unbind(0))
    else:
        embeddings = tuple(multimodal_embeddings)

    if len(embeddings) == 0:
        return (), ()

    deepstack_count = len(self.deepstack_audio_merger_list)
    if all(embedding.shape[-1] == hidden_size for embedding in embeddings):
        return embeddings, ()

    packed_hidden_size = hidden_size * (deepstack_count + 1)
    if deepstack_count == 0 or any(
        embedding.shape[-1] != packed_hidden_size for embedding in embeddings
    ):
        got = [int(embedding.shape[-1]) for embedding in embeddings]
        raise ValueError(
            "MOSS-Audio multimodal embedding width mismatch: expected "
            f"{hidden_size} or {packed_hidden_size}, got {got}."
        )

    split_by_item = [
        torch.split(embedding, hidden_size, dim=-1) for embedding in embeddings
    ]
    main_embeddings = tuple(parts[0] for parts in split_by_item)
    deepstack_embeddings = tuple(
        tuple(parts[layer_idx + 1] for parts in split_by_item)
        for layer_idx in range(deepstack_count)
    )
    return main_embeddings, deepstack_embeddings

MossAudioProcessor

Methods:

  • __call__

    Build text tokens and audio tensors for one MossAudio prompt.

Source code in vllm/model_executor/models/moss_audio.py
class MossAudioProcessor:
    model_input_names = [
        "input_ids",
        "attention_mask",
        "audio_data",
        "audio_data_seqlens",
    ]

    def __init__(
        self,
        tokenizer: object,
        *,
        audio_token_id: int = MOSS_AUDIO_TOKEN_ID,
        audio_start_id: int = MOSS_AUDIO_BOS_TOKEN_ID,
        audio_end_id: int = MOSS_AUDIO_EOS_TOKEN_ID,
        enable_time_marker: bool = False,
        mel_config: Mapping[str, object] | None = None,
    ) -> None:
        self.tokenizer = tokenizer
        self.audio_token_id = int(audio_token_id)
        self.audio_start_id = int(audio_start_id)
        self.audio_end_id = int(audio_end_id)
        self.enable_time_marker = bool(enable_time_marker)
        self.mel_config = _normalize_moss_audio_mel_config(mel_config)
        self.feature_extractor = WhisperFeatureExtractor(
            feature_size=self.mel_config["mel_dim"],
            sampling_rate=self.mel_config["mel_sr"],
            hop_length=self.mel_config["mel_hop_length"],
            n_fft=self.mel_config["mel_n_fft"],
        )
        self.audio_tokens_per_second = self.mel_config["mel_sr"] / (
            self.mel_config["mel_hop_length"] * 8
        )
        self.time_marker_every_seconds = 2
        self.time_marker_every_audio_tokens = int(
            self.audio_tokens_per_second * self.time_marker_every_seconds
        )
        self._digit_token_ids = {
            "0": 15,
            "1": 16,
            "2": 17,
            "3": 18,
            "4": 19,
            "5": 20,
            "6": 21,
            "7": 22,
            "8": 23,
            "9": 24,
        }

    @staticmethod
    def conv3_downsample_len(raw_mel_len: int) -> int:
        return MossAudioEncoder.compute_num_audio_tokens(raw_mel_len)

    def _extract_mel(self, audio: np.ndarray | torch.Tensor) -> torch.Tensor:
        if isinstance(audio, torch.Tensor):
            wav = audio.detach().to("cpu", dtype=torch.float32).numpy()
        else:
            wav = np.asarray(audio, dtype=np.float32)
        if wav.size == 0:
            raise ValueError("The audio is too short to be represented.")
        if wav.ndim == 2:
            wav = wav[0]
        feats = self.feature_extractor._np_extract_fbank_features(
            wav[None, ...], device="cpu"
        )
        return torch.from_numpy(feats[0])

    def _get_default_audio_prompt(self) -> str:
        return MOSS_AUDIO_PLACEHOLDER

    def _ensure_audio_placeholders(
        self,
        prompt_text: str,
        num_audios: int,
    ) -> str:
        if num_audios == 0 or MOSS_AUDIO_SPAN_RE.search(prompt_text):
            return prompt_text

        audio_prompt = self._get_default_audio_prompt() * num_audios
        if prompt_text:
            return f"{audio_prompt}\n{prompt_text}"
        return audio_prompt

    def _build_audio_tokens_with_time_markers(self, audio_seq_len: int) -> list[int]:
        total_duration_seconds = audio_seq_len / self.audio_tokens_per_second
        num_full_seconds = int(total_duration_seconds)
        token_ids: list[int] = []
        audio_tokens_consumed = 0
        for second in range(
            self.time_marker_every_seconds,
            num_full_seconds + 1,
            self.time_marker_every_seconds,
        ):
            marker_pos = (
                second // self.time_marker_every_seconds
            ) * self.time_marker_every_audio_tokens
            audio_segment_len = marker_pos - audio_tokens_consumed
            if audio_segment_len > 0:
                token_ids.extend([self.audio_token_id] * audio_segment_len)
                audio_tokens_consumed += audio_segment_len
            token_ids.extend(self._digit_token_ids[digit] for digit in str(second))

        remaining = audio_seq_len - audio_tokens_consumed
        if remaining > 0:
            token_ids.extend([self.audio_token_id] * remaining)
        return token_ids

    def build_audio_placeholder_ids(self, num_audio_tokens: int) -> list[int]:
        if self.enable_time_marker:
            return self._build_audio_tokens_with_time_markers(num_audio_tokens)
        return [self.audio_token_id] * num_audio_tokens

    def __call__(
        self,
        text: str | Sequence[str] | None = None,
        audios: Sequence[np.ndarray | torch.Tensor] | None = None,
        audio: Sequence[np.ndarray | torch.Tensor] | None = None,
        return_tensors: str = "pt",
        **kwargs: object,
    ) -> BatchFeature:
        """Build text tokens and audio tensors for one MossAudio prompt.

        Example:
            text="Describe this.", audio=[waveform]
            -> input_ids contains audio_start, N audio tokens, audio_end
            -> audio_data has shape [1, mel_dim, max_time]
            -> mel_dim is the number of mel filter-bank bins, 128 by default
            -> audio_data_seqlens stores the unpadded mel length
        """
        del kwargs

        # Step 1. Normalize text input; this processor handles one prompt.
        if isinstance(text, (list, tuple)):
            if len(text) != 1:
                raise ValueError(f"Expected text batch size 1, got {len(text)}")
            prompt_text = text[0]
        elif text is None:
            prompt_text = ""
        else:
            prompt_text = text

        # Step 2. Accept either `audios` or `audio` and normalize to a list.
        audio_list = audios if audios is not None else audio
        audio_list = [] if audio_list is None else list(audio_list)

        # Step 3. Convert waveforms to [mel_dim, time] mel features and token
        # counts. mel_dim is the number of mel filter-bank bins.
        mels: list[torch.Tensor] = []
        raw_lengths: list[int] = []
        token_lens: list[int] = []
        for one_audio in audio_list:
            mel = self._extract_mel(one_audio)
            raw_len = int(mel.shape[-1])
            num_tokens = self.conv3_downsample_len(raw_len)
            if raw_len <= 0 or num_tokens <= 0:
                raise ValueError("The audio is too short to be represented.")
            mels.append(mel)
            raw_lengths.append(raw_len)
            token_lens.append(num_tokens)

        # Step 4. Pad variable-length mel features into a batch tensor.
        if mels:
            max_length = max(raw_lengths)
            audio_batch = torch.zeros(
                (len(mels), self.mel_config["mel_dim"], max_length),
                dtype=torch.float32,
            )
            for index, mel in enumerate(mels):
                audio_batch[index, :, : mel.shape[-1]] = mel
            audio_data_seqlens = torch.tensor(raw_lengths, dtype=torch.long)
        else:
            audio_batch = None
            audio_data_seqlens = None

        # Step 5. Ensure each audio item has a placeholder span in the prompt.
        prompt_text = self._ensure_audio_placeholders(prompt_text, len(audio_list))
        input_ids = []
        cursor = 0

        # Step 6. Text-only path: tokenize and preserve placeholder spans.
        if not audio_list:
            for match in MOSS_AUDIO_SPAN_RE.finditer(prompt_text):
                prefix = prompt_text[cursor : match.start()]
                input_ids.extend(
                    self.tokenizer.encode(prefix, add_special_tokens=False)
                )
                input_ids.extend(
                    [self.audio_start_id, self.audio_token_id, self.audio_end_id]
                )
                cursor = match.end()
            suffix = prompt_text[cursor:]
            input_ids.extend(self.tokenizer.encode(suffix, add_special_tokens=False))
            data: dict[str, torch.Tensor] = {
                "input_ids": torch.tensor([input_ids], dtype=torch.long),
                "attention_mask": torch.ones((1, len(input_ids)), dtype=torch.long),
            }
            return BatchFeature(data=data, tensor_type=return_tensors)

        # Step 7. Audio path: expand each placeholder to its audio-token count.
        span_iter = iter(MOSS_AUDIO_SPAN_RE.finditer(prompt_text))
        for item_idx, _ in enumerate(audio_list):
            match = next(span_iter, None)
            if match is None:
                raise ValueError(
                    "Audio placeholder count mismatch: expected one "
                    f"{MOSS_AUDIO_PLACEHOLDER!r} span per audio item."
                )
            prefix = prompt_text[cursor : match.start()]
            input_ids.extend(self.tokenizer.encode(prefix, add_special_tokens=False))
            input_ids.append(self.audio_start_id)
            input_ids.extend(self.build_audio_placeholder_ids(token_lens[item_idx]))
            input_ids.append(self.audio_end_id)
            cursor = match.end()

        # Step 8. Reject extra placeholder spans after all audio items are used.
        suffix = prompt_text[cursor:]
        if MOSS_AUDIO_SPAN_RE.search(suffix):
            raise ValueError(
                "Audio placeholder count mismatch: found more placeholder spans "
                "than audio items."
            )
        input_ids.extend(self.tokenizer.encode(suffix, add_special_tokens=False))

        # Step 9. Return tokenizer output plus audio tensors for embed_multimodal.
        data = {
            "input_ids": torch.tensor([input_ids], dtype=torch.long),
            "attention_mask": torch.ones((1, len(input_ids)), dtype=torch.long),
        }
        if audio_batch is not None and audio_data_seqlens is not None:
            data["audio_data"] = audio_batch
            data["audio_data_seqlens"] = audio_data_seqlens
        return BatchFeature(data=data, tensor_type=return_tensors)

    def decode(self, *args: object, **kwargs: object) -> str:
        return self.tokenizer.decode(*args, **kwargs)

    def batch_decode(self, *args: object, **kwargs: object) -> list[str]:
        return self.tokenizer.batch_decode(*args, **kwargs)

__call__(text=None, audios=None, audio=None, return_tensors='pt', **kwargs)

Build text tokens and audio tensors for one MossAudio prompt.

Example

text="Describe this.", audio=[waveform] -> input_ids contains audio_start, N audio tokens, audio_end -> audio_data has shape [1, mel_dim, max_time] -> mel_dim is the number of mel filter-bank bins, 128 by default -> audio_data_seqlens stores the unpadded mel length

Source code in vllm/model_executor/models/moss_audio.py
def __call__(
    self,
    text: str | Sequence[str] | None = None,
    audios: Sequence[np.ndarray | torch.Tensor] | None = None,
    audio: Sequence[np.ndarray | torch.Tensor] | None = None,
    return_tensors: str = "pt",
    **kwargs: object,
) -> BatchFeature:
    """Build text tokens and audio tensors for one MossAudio prompt.

    Example:
        text="Describe this.", audio=[waveform]
        -> input_ids contains audio_start, N audio tokens, audio_end
        -> audio_data has shape [1, mel_dim, max_time]
        -> mel_dim is the number of mel filter-bank bins, 128 by default
        -> audio_data_seqlens stores the unpadded mel length
    """
    del kwargs

    # Step 1. Normalize text input; this processor handles one prompt.
    if isinstance(text, (list, tuple)):
        if len(text) != 1:
            raise ValueError(f"Expected text batch size 1, got {len(text)}")
        prompt_text = text[0]
    elif text is None:
        prompt_text = ""
    else:
        prompt_text = text

    # Step 2. Accept either `audios` or `audio` and normalize to a list.
    audio_list = audios if audios is not None else audio
    audio_list = [] if audio_list is None else list(audio_list)

    # Step 3. Convert waveforms to [mel_dim, time] mel features and token
    # counts. mel_dim is the number of mel filter-bank bins.
    mels: list[torch.Tensor] = []
    raw_lengths: list[int] = []
    token_lens: list[int] = []
    for one_audio in audio_list:
        mel = self._extract_mel(one_audio)
        raw_len = int(mel.shape[-1])
        num_tokens = self.conv3_downsample_len(raw_len)
        if raw_len <= 0 or num_tokens <= 0:
            raise ValueError("The audio is too short to be represented.")
        mels.append(mel)
        raw_lengths.append(raw_len)
        token_lens.append(num_tokens)

    # Step 4. Pad variable-length mel features into a batch tensor.
    if mels:
        max_length = max(raw_lengths)
        audio_batch = torch.zeros(
            (len(mels), self.mel_config["mel_dim"], max_length),
            dtype=torch.float32,
        )
        for index, mel in enumerate(mels):
            audio_batch[index, :, : mel.shape[-1]] = mel
        audio_data_seqlens = torch.tensor(raw_lengths, dtype=torch.long)
    else:
        audio_batch = None
        audio_data_seqlens = None

    # Step 5. Ensure each audio item has a placeholder span in the prompt.
    prompt_text = self._ensure_audio_placeholders(prompt_text, len(audio_list))
    input_ids = []
    cursor = 0

    # Step 6. Text-only path: tokenize and preserve placeholder spans.
    if not audio_list:
        for match in MOSS_AUDIO_SPAN_RE.finditer(prompt_text):
            prefix = prompt_text[cursor : match.start()]
            input_ids.extend(
                self.tokenizer.encode(prefix, add_special_tokens=False)
            )
            input_ids.extend(
                [self.audio_start_id, self.audio_token_id, self.audio_end_id]
            )
            cursor = match.end()
        suffix = prompt_text[cursor:]
        input_ids.extend(self.tokenizer.encode(suffix, add_special_tokens=False))
        data: dict[str, torch.Tensor] = {
            "input_ids": torch.tensor([input_ids], dtype=torch.long),
            "attention_mask": torch.ones((1, len(input_ids)), dtype=torch.long),
        }
        return BatchFeature(data=data, tensor_type=return_tensors)

    # Step 7. Audio path: expand each placeholder to its audio-token count.
    span_iter = iter(MOSS_AUDIO_SPAN_RE.finditer(prompt_text))
    for item_idx, _ in enumerate(audio_list):
        match = next(span_iter, None)
        if match is None:
            raise ValueError(
                "Audio placeholder count mismatch: expected one "
                f"{MOSS_AUDIO_PLACEHOLDER!r} span per audio item."
            )
        prefix = prompt_text[cursor : match.start()]
        input_ids.extend(self.tokenizer.encode(prefix, add_special_tokens=False))
        input_ids.append(self.audio_start_id)
        input_ids.extend(self.build_audio_placeholder_ids(token_lens[item_idx]))
        input_ids.append(self.audio_end_id)
        cursor = match.end()

    # Step 8. Reject extra placeholder spans after all audio items are used.
    suffix = prompt_text[cursor:]
    if MOSS_AUDIO_SPAN_RE.search(suffix):
        raise ValueError(
            "Audio placeholder count mismatch: found more placeholder spans "
            "than audio items."
        )
    input_ids.extend(self.tokenizer.encode(suffix, add_special_tokens=False))

    # Step 9. Return tokenizer output plus audio tensors for embed_multimodal.
    data = {
        "input_ids": torch.tensor([input_ids], dtype=torch.long),
        "attention_mask": torch.ones((1, len(input_ids)), dtype=torch.long),
    }
    if audio_batch is not None and audio_data_seqlens is not None:
        data["audio_data"] = audio_batch
        data["audio_data_seqlens"] = audio_data_seqlens
    return BatchFeature(data=data, tensor_type=return_tensors)