Skip to content

vllm_omni.model_executor.models.fish_speech.fish_speech_slow_ar

Fish Speech S2 Pro -- Slow AR model (Stage 0).

Uses vLLM's Qwen3Model as the transformer backbone. Adds: - Multi-codebook input embedding (text + summed codebook embeddings at semantic-token positions). - Semantic logit masking. - Nested Fast AR for residual codebook prediction (talker_mtp). - preprocess / postprocess hooks for vLLM-omni's AR scheduler.

Analogous to Qwen3TTSTalkerForConditionalGeneration in qwen3_tts.

logger module-attribute

logger = init_logger(__name__)

FishSpeechSlowARForConditionalGeneration

Bases: Module

vLLM-AR Slow AR model for Fish Speech S2 Pro.

Stage 0: text → semantic tokens (+ residual codebook codes via Fast AR).

codebook_embeddings instance-attribute

codebook_embeddings = Embedding(
    _codebook_size * _num_codebooks, hidden_size
)

config instance-attribute

config = config

fast_ar instance-attribute

fast_ar = FishSpeechFastAR(
    vllm_config=_fast_ar_vllm_config,
    config=fast_ar_config,
    slow_ar_config=text_config,
    prefix="fast_ar",
)

fast_ar_config instance-attribute

fast_ar_config: FishSpeechFastARConfig = (
    audio_decoder_config
)

gpu_resident_buffer_keys instance-attribute

gpu_resident_buffer_keys: set[tuple[str, str]] = {
    ("hidden_states", "last")
}

has_postprocess instance-attribute

has_postprocess = True

has_preprocess instance-attribute

has_preprocess = True

have_multimodal_outputs instance-attribute

have_multimodal_outputs = True

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size,
    hidden_size,
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "lm_head"),
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = Qwen3Model(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

model_path instance-attribute

model_path = model

mtp_hidden_size instance-attribute

mtp_hidden_size = int(hidden_size)

talker_mtp_graph_safe instance-attribute

talker_mtp_graph_safe = True

talker_mtp_output_key instance-attribute

talker_mtp_output_key = ('codes', 'audio')

text_config instance-attribute

text_config: FishSpeechSlowARConfig = text_config

vllm_config instance-attribute

vllm_config = vllm_config

compute_logits

compute_logits(
    hidden_states: Tensor | OmniOutput,
    sampling_metadata: Any = None,
) -> Tensor | None

embed_input_ids

embed_input_ids(input_ids: Tensor, **_: Any) -> Tensor

estimate_prompt_len_from_additional_information staticmethod

estimate_prompt_len_from_additional_information(
    additional_information: dict[str, Any] | None,
    **kwargs: Any,
) -> int

Estimate prompt length for placeholder allocation.

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **_: Any,
) -> Tensor | IntermediateTensors

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights with Fish Speech → Qwen3 format transformation.

Transforms weight names (wqkv → q/k/v split, w1/w2/w3 → gate/up/down, etc.) and routes to the correct sub-modules.

make_omni_output

make_omni_output(
    model_outputs: Tensor | OmniOutput, **kwargs: Any
) -> OmniOutput

postprocess

postprocess(
    hidden_states: Tensor, **_: Any
) -> dict[str, Any]

preprocess

preprocess(
    input_ids: Tensor,
    input_embeds: Tensor | None,
    **info_dict: Any,
) -> tuple[Tensor, Tensor, dict[str, Any]]

talker_mtp

talker_mtp(
    input_ids: Tensor,
    input_embeds: Tensor,
    last_talker_hidden: Tensor,
    text_step: Tensor,
    seed: int | None = None,
    **kwargs: Any,
) -> tuple[Tensor, Tensor]

GPU fast-path: run Fast AR to predict residual codebook codes.

Returns (inputs_embeds, audio_codes).

The embedding is: text_embed(token) + sum(codebook_embed(code_i + i * codebook_size)) where codes come from FastAR(last_talker_hidden).

This matches the reference Fish Speech inference flow: - At step t, the Slow AR embedding includes codes from FastAR(hidden_{t-1}) - last_talker_hidden IS hidden_{t-1} (from postprocess of the previous step) - Preprocess provides the plain text embed; we add codebook embeddings here