Skip to content

vllm_omni.diffusion.models.mistral_encoder

Modules:

Name Description
mistral_encoder

TP-aware Mistral model for use as a text encoder in diffusion pipelines.

MistralEncoderModel

Bases: Module

TP-aware Mistral encoder for use as a text encoder in diffusion pipelines.

Accepts a HuggingFace Mistral3Config (or its text_config). Uses vLLM parallel layers for TP but simple SDPA for attention (no PagedAttention).

config instance-attribute

config = text_config

device property

device: device

dtype property

dtype: dtype

head_dim instance-attribute

head_dim = (
    getattr(text_config, "head_dim", None)
    or hidden_size // num_heads
)

hidden_size instance-attribute

hidden_size = hidden_size

intermediate_size instance-attribute

intermediate_size = intermediate_size

language_model instance-attribute

language_model = Module()

max_position_embeddings instance-attribute

max_position_embeddings = getattr(
    text_config, "max_position_embeddings", 131072
)

num_heads instance-attribute

num_heads = num_attention_heads

num_kv_heads instance-attribute

num_kv_heads = getattr(
    text_config, "num_key_value_heads", num_attention_heads
)

num_layers instance-attribute

num_layers = num_hidden_layers

rms_norm_eps instance-attribute

rms_norm_eps = getattr(text_config, 'rms_norm_eps', 1e-05)

rope_theta instance-attribute

rope_theta = getattr(text_config, 'rope_theta', 1000000.0)

vocab_size instance-attribute

vocab_size = vocab_size

forward

forward(
    input_ids: Tensor,
    attention_mask: Tensor | None = None,
    output_hidden_states: bool = False,
    use_cache: bool = False,
    past_key_values: list[tuple[Tensor, Tensor]]
    | None = None,
    **kwargs,
) -> MistralEncoderOutput

generate

generate(
    input_ids: Tensor,
    attention_mask: Tensor | None = None,
    max_new_tokens: int = 512,
    do_sample: bool = True,
    temperature: float = 1.0,
    eos_token_id: int | list[int] | None = None,
    **kwargs,
) -> Tensor

Autoregressive text generation with KV caching.

Accepts the same keyword arguments as the HuggingFace GenerationMixin.generate interface used by the pipeline (pixel_values etc. are accepted and ignored).

Returns the full token sequence including the input prompt.

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

set_processor

set_processor(
    processor,
    system_message_t2i: str | None = None,
    system_message_i2i: str | None = None,
) -> None

upsample_prompt

upsample_prompt(
    prompt: str | list[str],
    images: list | None = None,
    temperature: float = 0.15,
    device: device | None = None,
    max_new_tokens: int = 512,
    max_length: int = 2048,
) -> list[str]