Skip to content

vllm.v1.worker.gpu.model_states.interface

Classes:

ModelSpecificAttnMetadata

Base class for model-specific attention metadata.

Source code in vllm/v1/worker/gpu/model_states/interface.py
class ModelSpecificAttnMetadata:
    """Base class for model-specific attention metadata."""

    def get_extra_common_attn_kwargs(
        self,
        kv_cache_group_id: int,
        num_reqs: int,
    ) -> dict[str, Any]:
        return {}

    def get_extra_attn_kwargs(
        self,
        attn_metadata_builder: Any,
        num_reqs: int,
    ) -> dict[str, Any]:
        return {}

ModelState

Bases: ABC

Methods:

Attributes:

Source code in vllm/v1/worker/gpu/model_states/interface.py
class ModelState(ABC):
    def __init__(
        self,
        vllm_config: VllmConfig,
        model: nn.Module,
        encoder_cache: EncoderCache | None,
        device: torch.device,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.scheduler_config = vllm_config.scheduler_config
        self.model = model
        self.device = device

        self.max_model_len = self.model_config.max_model_len
        self.max_num_reqs = self.scheduler_config.max_num_seqs
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
        self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
        self.dtype = self.model_config.dtype

        self.supports_mm_inputs = encoder_cache is not None
        if encoder_cache is not None:
            self.encoder_cache = encoder_cache
            self.encoder_runner = EncoderRunner(
                model=self.model,
                max_num_tokens=self.max_num_tokens,
                hidden_size=self.inputs_embeds_size,
                encoder_cache=encoder_cache,
                dtype=self.dtype,
                device=self.device,
            )

    def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
        from vllm.model_executor.models.interfaces import (
            supports_realtime,
            supports_transcription,
        )
        from vllm.model_executor.models.interfaces_base import is_text_generation_model

        supported_tasks = list[GenerationTask]()
        if is_text_generation_model(self.model):
            supported_tasks.append("generate")
        if supports_transcription(self.model):
            if self.model.supports_transcription_only:
                return ("transcription",)
            supported_tasks.append("transcription")
        if supports_realtime(self.model):
            supported_tasks.append("realtime")
        return tuple(supported_tasks)

    def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
        return None

    def remove_request(self, req_id: str) -> None:
        return None

    def apply_staged_writes(self) -> None:
        return None

    def preprocess_state(
        self,
        input_batch: InputBatch,
        block_tables: tuple[torch.Tensor, ...],
        kv_cache_config: KVCacheConfig,
        num_computed_tokens: torch.Tensor,
    ) -> None:
        """Hook run on real batches before the forward pass (after block tables
        are gathered). Used by mamba "align" prefix caching to pre-copy state
        across block boundaries. No-op by default."""
        return None

    def postprocess_state(
        self,
        idx_mapping: torch.Tensor,
        num_sampled: torch.Tensor,
        num_computed_tokens: torch.Tensor | None = None,
    ) -> None:
        return None

    @abstractmethod
    def get_mm_embeddings(
        self,
        scheduled_encoder_inputs: dict[str, list[int]],
        input_batch: InputBatch,
        req_states: RequestState,
    ) -> torch.Tensor | None:
        raise NotImplementedError

    def dummy_inputs_embeds(self, num_tokens: int) -> torch.Tensor | None:
        """Pre-allocated inputs_embeds buffer for dummy runs (contents unused)."""
        return None

    def gather_mm_embeddings(
        self, input_batch: InputBatch, draft_lookahead: int = 0
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        """Gather cached multimodal embeddings."""
        return self.encoder_runner.gather_mm_embeddings(
            input_batch.req_ids,
            input_batch.num_tokens,
            input_batch.num_scheduled_tokens,
            input_batch.query_start_loc_np,
            input_batch.prefill_len_np,
            input_batch.num_computed_tokens_np,
            draft_lookahead=draft_lookahead,
        )

    @abstractmethod
    def prepare_inputs(
        self, input_batch: InputBatch, req_states: RequestState
    ) -> dict[str, Any]:
        raise NotImplementedError

    @abstractmethod
    def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
        raise NotImplementedError

    @abstractmethod
    def prepare_attn(
        self,
        input_batch: InputBatch,
        cudagraph_mode: CUDAGraphMode,
        block_tables: tuple[torch.Tensor, ...],
        slot_mappings: torch.Tensor,
        attn_groups: list[list[AttentionGroup]],
        kv_cache_config: KVCacheConfig,
        for_capture: bool = False,
    ) -> dict[str, Any]:
        raise NotImplementedError

    def custom_sampler(self, sampler: Any) -> tuple[Any, Any] | None:
        """Wrap or replace the default sampler.

        Called after model loading with the already-constructed base
        ``Sampler``.  Return ``None`` to keep the defaults, or
        ``(sampler, rejection_sampler | None)`` to override.
        """
        return None

    num_new_sampled_tokens_per_step: int = 1
    """New tokens sampled on each decode step 
    (excluding accepted draft tokens, a.k.a num bonus tokens)."""

num_new_sampled_tokens_per_step = 1 class-attribute instance-attribute

New tokens sampled on each decode step (excluding accepted draft tokens, a.k.a num bonus tokens).

custom_sampler(sampler)

Wrap or replace the default sampler.

Called after model loading with the already-constructed base Sampler. Return None to keep the defaults, or (sampler, rejection_sampler | None) to override.

Source code in vllm/v1/worker/gpu/model_states/interface.py
def custom_sampler(self, sampler: Any) -> tuple[Any, Any] | None:
    """Wrap or replace the default sampler.

    Called after model loading with the already-constructed base
    ``Sampler``.  Return ``None`` to keep the defaults, or
    ``(sampler, rejection_sampler | None)`` to override.
    """
    return None

dummy_inputs_embeds(num_tokens)

Pre-allocated inputs_embeds buffer for dummy runs (contents unused).

Source code in vllm/v1/worker/gpu/model_states/interface.py
def dummy_inputs_embeds(self, num_tokens: int) -> torch.Tensor | None:
    """Pre-allocated inputs_embeds buffer for dummy runs (contents unused)."""
    return None

gather_mm_embeddings(input_batch, draft_lookahead=0)

Gather cached multimodal embeddings.

Source code in vllm/v1/worker/gpu/model_states/interface.py
def gather_mm_embeddings(
    self, input_batch: InputBatch, draft_lookahead: int = 0
) -> tuple[list[torch.Tensor], torch.Tensor]:
    """Gather cached multimodal embeddings."""
    return self.encoder_runner.gather_mm_embeddings(
        input_batch.req_ids,
        input_batch.num_tokens,
        input_batch.num_scheduled_tokens,
        input_batch.query_start_loc_np,
        input_batch.prefill_len_np,
        input_batch.num_computed_tokens_np,
        draft_lookahead=draft_lookahead,
    )

preprocess_state(input_batch, block_tables, kv_cache_config, num_computed_tokens)

Hook run on real batches before the forward pass (after block tables are gathered). Used by mamba "align" prefix caching to pre-copy state across block boundaries. No-op by default.

Source code in vllm/v1/worker/gpu/model_states/interface.py
def preprocess_state(
    self,
    input_batch: InputBatch,
    block_tables: tuple[torch.Tensor, ...],
    kv_cache_config: KVCacheConfig,
    num_computed_tokens: torch.Tensor,
) -> None:
    """Hook run on real batches before the forward pass (after block tables
    are gathered). Used by mamba "align" prefix caching to pre-copy state
    across block boundaries. No-op by default."""
    return None