Skip to content

vllm_omni.model_executor.models.indextts2.indextts2_talker

IndexTTS2 Stage 0: GPT-2 AR Talker with vLLM-native PagedAttention.

Predicts mel codes autoregressively and collects hidden_states as latent for Stage 1 (S2Mel + BigVGAN).

logger module-attribute

logger = init_logger(__name__)

IndexTTS2TalkerForConditionalGeneration

Bases: Module

vLLM-native GPT-2 AR talker for IndexTTS2.

Stage 0 of the two-stage pipeline. Predicts mel codes (8194 vocab) and accumulates hidden_states as latent for Stage 1 S2Mel decoder.

cond_mask_pad instance-attribute

cond_mask_pad = nn.ConstantPad1d(
    (self.condition_num_latent, 0), True
)

condition_num_latent instance-attribute

condition_num_latent = gpt_cfg.get(
    "condition_num_latent", 32
)

conditioning_encoder instance-attribute

conditioning_encoder = ConformerEncoder(
    input_size=1024,
    output_size=cond_cfg.get("output_size", 512),
    linear_units=cond_cfg.get("linear_units", 2048),
    attention_heads=cond_cfg.get("attention_heads", 8),
    num_blocks=cond_cfg.get("num_blocks", 6),
    input_layer=cond_cfg.get("input_layer", "conv2d2"),
)

config instance-attribute

config: IndexTTS2Config = vllm_config.model_config.hf_config

emo_cond_mask_pad instance-attribute

emo_cond_mask_pad = nn.ConstantPad1d((1, 0), True)

emo_conditioning_encoder instance-attribute

emo_conditioning_encoder = ConformerEncoder(
    input_size=1024,
    output_size=emo_cond_cfg.get("output_size", 512),
    linear_units=emo_cond_cfg.get("linear_units", 1024),
    attention_heads=emo_cond_cfg.get("attention_heads", 4),
    num_blocks=emo_cond_cfg.get("num_blocks", 4),
    input_layer=emo_cond_cfg.get("input_layer", "conv2d2"),
)

emo_layer instance-attribute

emo_layer = nn.Linear(self.model_dim, self.model_dim)

emo_perceiver_encoder instance-attribute

emo_perceiver_encoder = PerceiverResampler(
    1024,
    dim_context=emo_cond_cfg.get("output_size", 512),
    ff_mult=emo_cond_cfg.get("perceiver_mult", 2),
    heads=emo_cond_cfg.get("attention_heads", 4),
    num_latents=1,
)

emovec_layer instance-attribute

emovec_layer = nn.Linear(1024, self.model_dim)

enable_update_additional_information instance-attribute

enable_update_additional_information = True

final_norm instance-attribute

final_norm = nn.LayerNorm(
    self.model_dim, eps=gpt2_config.layer_norm_epsilon
)

gpu_resident_buffer_keys instance-attribute

gpu_resident_buffer_keys: set[tuple[str, str]] = {
    ("codes", "mel"),
    ("hidden_states", "latent"),
    ("meta", "mel_start_offset"),
    ("meta", "latent_acc"),
    ("meta", "mel_code_count"),
}

has_postprocess instance-attribute

has_postprocess = True

has_preprocess instance-attribute

has_preprocess = True

have_multimodal_outputs instance-attribute

have_multimodal_outputs = True

ln_f instance-attribute

ln_f = nn.LayerNorm(
    self.model_dim, eps=gpt2_config.layer_norm_epsilon
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(self.number_mel_codes)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], self.model_dim
    )
)

max_mel_tokens instance-attribute

max_mel_tokens = gpt_cfg['max_mel_tokens']

max_text_tokens instance-attribute

max_text_tokens = gpt_cfg['max_text_tokens']

mel_embedding instance-attribute

mel_embedding = nn.Embedding(
    self.number_mel_codes, self.model_dim
)

mel_head instance-attribute

mel_head = ParallelLMHead(
    self.number_mel_codes,
    self.model_dim,
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "mel_head"),
)

mel_pos_embedding instance-attribute

mel_pos_embedding = LearnedPositionEmbeddings(
    self.max_mel_tokens + 2 + 1, self.model_dim
)

model_dim instance-attribute

model_dim = gpt_cfg['model_dim']

model_path instance-attribute

model_path = vllm_config.model_config.model

num_heads instance-attribute

num_heads = gpt_cfg['heads']

num_layers instance-attribute

num_layers = gpt_cfg['layers']

number_mel_codes instance-attribute

number_mel_codes = gpt_cfg['number_mel_codes']

number_text_tokens instance-attribute

number_text_tokens = gpt_cfg.get(
    "number_text_tokens", 12000
)

omni_payload_at_request_end instance-attribute

omni_payload_at_request_end = False

omni_request_end_token_ids instance-attribute

omni_request_end_token_ids = ()

perceiver_encoder instance-attribute

perceiver_encoder = PerceiverResampler(
    self.model_dim,
    dim_context=cond_cfg.get("output_size", 512),
    ff_mult=cond_cfg.get("perceiver_mult", 2),
    heads=cond_cfg.get("attention_heads", 8),
    num_latents=self.condition_num_latent,
)

requires_raw_input_tokens instance-attribute

requires_raw_input_tokens = True

speed_emb instance-attribute

speed_emb = nn.Embedding(2, self.model_dim)

start_mel_token instance-attribute

start_mel_token = gpt_cfg['start_mel_token']

stop_mel_token instance-attribute

stop_mel_token = gpt_cfg['stop_mel_token']

text_embedding instance-attribute

text_embedding = nn.Embedding(
    self.number_text_tokens + 1, self.model_dim
)

text_pos_embedding instance-attribute

text_pos_embedding = LearnedPositionEmbeddings(
    self.max_text_tokens + 2, self.model_dim
)

vllm_config instance-attribute

vllm_config = vllm_config

compute_logits

compute_logits(
    hidden_states: Tensor | OmniOutput,
    sampling_metadata: Any = None,
) -> Tensor | None

embed_input_ids

embed_input_ids(input_ids: Tensor, **_: Any) -> Tensor

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: Any,
) -> Tensor | IntermediateTensors

AR transformer forward.

During prefill: inputs_embeds is set by preprocess(), input_ids ignored. During decode: inputs_embeds is mel_embedding(input_ids) + mel_pos.

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights from IndexTTS2 checkpoint (gpt.pth).

Bypasses vllm's default .pt iterator because the model directory contains raw-tensor files (feat1.pt, feat2.pt) that are not state dicts and would crash pt_weights_iterator.

Weight mapping from gpt.pth checkpoint → vLLM model params: gpt.h.{i}. → h.{i}. (strip gpt. prefix) gpt.ln_f. → ln_f. final_norm. → final_norm. text_head.* → (skipped) → identity

make_omni_output

make_omni_output(
    model_outputs: Tensor | OmniOutput, **kwargs: Any
) -> OmniOutput

Collect mel_codes and hidden_states from intermediate buffer.

postprocess

postprocess(
    hidden_states: Tensor,
    multimodal_outputs: Any = None,
    **kwargs: Any,
) -> dict[str, Any]

Store current decode-step hidden state for Stage 1.

Only the current row is kept; the connector's full-payload accumulator reconstructs the complete latent sequence from per-step deltas emitted by make_omni_output.

preprocess

preprocess(
    input_ids: Tensor,
    input_embeds: Tensor | None,
    **info_dict: Any,
) -> tuple[Tensor, Tensor, dict[str, Any]]

Build prompt embeddings for prefill; compute mel embeddings for decode.

Prefill layout::

[conds(32) + emo_vec(1) + duration(2)] [text_emb + text_pos] [start_mel + mel_pos(0)]

Decode: mel_embedding(token) + mel_pos_embedding(step).