vllm_omni.model_executor.models.fish_speech.fish_speech_slow_ar ¶
Fish Speech S2 Pro -- Slow AR model (Stage 0).
Uses vLLM's Qwen3Model as the transformer backbone. Adds: - Multi-codebook input embedding (text + summed codebook embeddings at semantic-token positions). - Semantic logit masking. - Nested Fast AR for residual codebook prediction (talker_mtp). - preprocess / postprocess hooks for vLLM-omni's AR scheduler.
Analogous to Qwen3TTSTalkerForConditionalGeneration in qwen3_tts.
FishSpeechSlowARForConditionalGeneration ¶
Bases: Module
vLLM-AR Slow AR model for Fish Speech S2 Pro.
Stage 0: text → semantic tokens (+ residual codebook codes via Fast AR).
codebook_embeddings instance-attribute ¶
fast_ar instance-attribute ¶
fast_ar = FishSpeechFastAR(
vllm_config=_fast_ar_vllm_config,
config=fast_ar_config,
slow_ar_config=text_config,
prefix="fast_ar",
)
fast_ar_config instance-attribute ¶
fast_ar_config: FishSpeechFastARConfig = (
audio_decoder_config
)
gpu_resident_buffer_keys instance-attribute ¶
lm_head instance-attribute ¶
lm_head = ParallelLMHead(
vocab_size,
hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
make_empty_intermediate_tensors instance-attribute ¶
model instance-attribute ¶
compute_logits ¶
compute_logits(
hidden_states: Tensor | OmniOutput,
sampling_metadata: Any = None,
) -> Tensor | None
estimate_prompt_len_from_additional_information staticmethod ¶
estimate_prompt_len_from_additional_information(
additional_information: dict[str, Any] | None,
**kwargs: Any,
) -> int
Estimate prompt length for placeholder allocation.
forward ¶
forward(
input_ids: Tensor,
positions: Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: Tensor | None = None,
**_: Any,
) -> Tensor | IntermediateTensors
load_weights ¶
Load weights with Fish Speech → Qwen3 format transformation.
Transforms weight names (wqkv → q/k/v split, w1/w2/w3 → gate/up/down, etc.) and routes to the correct sub-modules.
make_omni_output ¶
make_omni_output(
model_outputs: Tensor | OmniOutput, **kwargs: Any
) -> OmniOutput
preprocess ¶
preprocess(
input_ids: Tensor,
input_embeds: Tensor | None,
**info_dict: Any,
) -> tuple[Tensor, Tensor, dict[str, Any]]
talker_mtp ¶
talker_mtp(
input_ids: Tensor,
input_embeds: Tensor,
last_talker_hidden: Tensor,
text_step: Tensor,
seed: int | None = None,
**kwargs: Any,
) -> tuple[Tensor, Tensor]
GPU fast-path: run Fast AR to predict residual codebook codes.
Returns (inputs_embeds, audio_codes).
The embedding is: text_embed(token) + sum(codebook_embed(code_i + i * codebook_size)) where codes come from FastAR(last_talker_hidden).
This matches the reference Fish Speech inference flow: - At step t, the Slow AR embedding includes codes from FastAR(hidden_{t-1}) - last_talker_hidden IS hidden_{t-1} (from postprocess of the previous step) - Preprocess provides the plain text embed; we add codebook embeddings here