Skip to content

vllm_omni.model_executor.models.common.whisper_vq

WhisperVQEncoder: HF WhisperEncoder + VQ codebook + inter-layer pooling.

Built on standard WhisperConfig. VQ-specific parameters are patched onto the config object after from_pretrained so that callers never need a separate config class. GLM-TTS (and future VQ-based TTS models) only need::

from transformers import WhisperConfig
from vllm_omni.model_executor.models.common.whisper_vq import WhisperVQEncoder

cfg = WhisperConfig.from_pretrained(checkpoint_dir)
cfg.pooling_kernel_size  = None       # passed through from checkpoint
cfg.pooling_type         = "max"      #       or set by caller
cfg.pooling_position     = 0
cfg.quantize_vocab_size  = 32768
cfg.quantize_position    = 16
cfg.quantize_encoder_only = True
model = WhisperVQEncoder(cfg)

QuantizedBaseModelOutput dataclass

Bases: BaseModelOutput

quantized_token_ids class-attribute instance-attribute

quantized_token_ids: LongTensor | None = None

WhisperVQEncoder

Bases: WhisperEncoder

HF Whisper encoder with optional VQ codebook and pooling.

Uses a standard WhisperConfig with the following VQ-specific attrs patched on by the caller (or loaded from the checkpoint's config.json):

  • pooling_kernel_size -- int | None
  • pooling_type -- "max" | "avg"
  • pooling_position -- int (0-based layer index)
  • quantize_vocab_size -- int | None
  • quantize_position -- int (0-based layer index)
  • quantize_encoder_only -- bool

codebook instance-attribute

codebook: Embedding | None = None

device property

device: device

dtype property

dtype: dtype

embed_positions2 instance-attribute

embed_positions2: Embedding | None = None

layer_norm instance-attribute

layer_norm = None

layers instance-attribute

layers = ModuleList(list(layers[:qpos]))

pooling_layer instance-attribute

pooling_layer: Module | None = None

forward

forward(
    input_features: Tensor,
    attention_mask: Tensor | None = None,
    **_: Any,
) -> QuantizedBaseModelOutput

load_state_dict

load_state_dict(
    state_dict: dict[str, Tensor],
    strict: bool = True,
    assign: bool = False,
)

remap_legacy_whisper_vq_state_dict

remap_legacy_whisper_vq_state_dict(
    state_dict: dict[str, Tensor],
) -> dict[str, Tensor]

Map CogAudio/GLM-TTS fork parameter names onto HF WhisperEncoder names.

vector_quantize

vector_quantize(
    inputs: Tensor, codebook: Tensor
) -> tuple[Tensor, Tensor, Tensor]

Nearest-neighbour codebook lookup.