vllm_omni.model_executor.models.indextts2.indextts2_talker ¶
IndexTTS2 Stage 0: GPT-2 AR Talker with vLLM-native PagedAttention.
Predicts mel codes autoregressively and collects hidden_states as latent for Stage 1 (S2Mel + BigVGAN).
IndexTTS2TalkerForConditionalGeneration ¶
Bases: Module
vLLM-native GPT-2 AR talker for IndexTTS2.
Stage 0 of the two-stage pipeline. Predicts mel codes (8194 vocab) and accumulates hidden_states as latent for Stage 1 S2Mel decoder.
cond_mask_pad instance-attribute ¶
condition_num_latent instance-attribute ¶
conditioning_encoder instance-attribute ¶
conditioning_encoder = ConformerEncoder(
input_size=1024,
output_size=cond_cfg.get("output_size", 512),
linear_units=cond_cfg.get("linear_units", 2048),
attention_heads=cond_cfg.get("attention_heads", 8),
num_blocks=cond_cfg.get("num_blocks", 6),
input_layer=cond_cfg.get("input_layer", "conv2d2"),
)
emo_conditioning_encoder instance-attribute ¶
emo_conditioning_encoder = ConformerEncoder(
input_size=1024,
output_size=emo_cond_cfg.get("output_size", 512),
linear_units=emo_cond_cfg.get("linear_units", 1024),
attention_heads=emo_cond_cfg.get("attention_heads", 4),
num_blocks=emo_cond_cfg.get("num_blocks", 4),
input_layer=emo_cond_cfg.get("input_layer", "conv2d2"),
)
emo_perceiver_encoder instance-attribute ¶
emo_perceiver_encoder = PerceiverResampler(
1024,
dim_context=emo_cond_cfg.get("output_size", 512),
ff_mult=emo_cond_cfg.get("perceiver_mult", 2),
heads=emo_cond_cfg.get("attention_heads", 4),
num_latents=1,
)
enable_update_additional_information instance-attribute ¶
final_norm instance-attribute ¶
gpu_resident_buffer_keys instance-attribute ¶
gpu_resident_buffer_keys: set[tuple[str, str]] = {
("codes", "mel"),
("hidden_states", "latent"),
("meta", "mel_start_offset"),
("meta", "latent_acc"),
("meta", "mel_code_count"),
}
make_empty_intermediate_tensors instance-attribute ¶
make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states"], self.model_dim
)
)
mel_embedding instance-attribute ¶
mel_head instance-attribute ¶
mel_head = ParallelLMHead(
self.number_mel_codes,
self.model_dim,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "mel_head"),
)
mel_pos_embedding instance-attribute ¶
mel_pos_embedding = LearnedPositionEmbeddings(
self.max_mel_tokens + 2 + 1, self.model_dim
)
number_text_tokens instance-attribute ¶
perceiver_encoder instance-attribute ¶
perceiver_encoder = PerceiverResampler(
self.model_dim,
dim_context=cond_cfg.get("output_size", 512),
ff_mult=cond_cfg.get("perceiver_mult", 2),
heads=cond_cfg.get("attention_heads", 8),
num_latents=self.condition_num_latent,
)
text_embedding instance-attribute ¶
text_pos_embedding instance-attribute ¶
text_pos_embedding = LearnedPositionEmbeddings(
self.max_text_tokens + 2, self.model_dim
)
compute_logits ¶
compute_logits(
hidden_states: Tensor | OmniOutput,
sampling_metadata: Any = None,
) -> Tensor | None
forward ¶
forward(
input_ids: Tensor,
positions: Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: Tensor | None = None,
**kwargs: Any,
) -> Tensor | IntermediateTensors
AR transformer forward.
During prefill: inputs_embeds is set by preprocess(), input_ids ignored. During decode: inputs_embeds is mel_embedding(input_ids) + mel_pos.
load_weights ¶
Load weights from IndexTTS2 checkpoint (gpt.pth).
Bypasses vllm's default .pt iterator because the model directory contains raw-tensor files (feat1.pt, feat2.pt) that are not state dicts and would crash pt_weights_iterator.
Weight mapping from gpt.pth checkpoint → vLLM model params: gpt.h.{i}. → h.{i}. (strip gpt. prefix) gpt.ln_f. → ln_f. final_norm. → final_norm. text_head.* → (skipped)
make_omni_output ¶
make_omni_output(
model_outputs: Tensor | OmniOutput, **kwargs: Any
) -> OmniOutput
Collect mel_codes and hidden_states from intermediate buffer.
postprocess ¶
postprocess(
hidden_states: Tensor,
multimodal_outputs: Any = None,
**kwargs: Any,
) -> dict[str, Any]
Store current decode-step hidden state for Stage 1.
Only the current row is kept; the connector's full-payload accumulator reconstructs the complete latent sequence from per-step deltas emitted by make_omni_output.
preprocess ¶
preprocess(
input_ids: Tensor,
input_embeds: Tensor | None,
**info_dict: Any,
) -> tuple[Tensor, Tensor, dict[str, Any]]
Build prompt embeddings for prefill; compute mel embeddings for decode.
Prefill layout::
[conds(32) + emo_vec(1) + duration(2)] [text_emb + text_pos] [start_mel + mel_pos(0)]
Decode: mel_embedding(token) + mel_pos_embedding(step).