Skip to content

vllm.v1.core.single_type_kv_cache_manager

Classes:

Functions:

ChunkedLocalAttentionManager

Bases: SingleTypeKVCacheManager

Methods:

Source code in vllm/v1/core/single_type_kv_cache_manager.py
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
    def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None:
        super().__init__(kv_cache_spec, **kwargs)
        self.attention_chunk_size = kv_cache_spec.attention_chunk_size

    @classmethod
    def find_longest_cache_hit(
        cls,
        block_hashes: BlockHashList,
        max_length: int,
        kv_cache_group_ids: list[int],
        block_pool: BlockPool,
        kv_cache_spec: KVCacheSpec,
        drop_eagle_block: bool,
        alignment_tokens: int,
        dcp_world_size: int = 1,
        pcp_world_size: int = 1,
    ) -> tuple[list[KVCacheBlock], ...]:
        """
        For chunked local attention, we need to find the longest cache hit
        prefix of the blocks that is not longer than `max_length`. The prefix
        should be a common prefix hit for all the kv cache groups in
        `kv_cache_group_ids`. If no cache hit is found, return an empty list.
        note we mark as computed if the whole block is outside of the local
        window, and set the block as null. Examples:

        1. Attention chunk size of 8, block size of 4, max length of 15
        for next token at 15th (zero-indexed), 8th - 14th tokens are in
        the window(needs lookup), 0th - 7th are not in the window,
        so they are already marked as computed. We check the complete
        block3 (8th - 11th tokens), Assume block 3 is hit, we will return
        [null, null, block 3], otherwise, we return [null, null]

        2. Attention chunk size of 8, block size of 4, max length of 16
        for next token at 16th (zero-indexed), 0th - 15th tokens are not
        in the window, so they are already marked as computed.
        we return 4 blocks[null, null, null, null]

        Args:
            block_hashes: The block hashes of the request.
            max_length: The maximum length of the cache hit prefix.
            kv_cache_group_ids: The ids of the kv cache groups.
            block_pool: The block pool.
            kv_cache_spec: The kv cache spec.
            drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP.
            dcp_world_size: The world size of decode context parallelism.
            pcp_world_size: The world size of prefill context parallelism.
            alignment_tokens: The returned cache hit length (in tokens) should
                be a multiple of this value (in tokens).

        Returns:
            A list of cached blocks
        """
        assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
            "ChunkedLocalAttentionManager can only be used for "
            "chunked local attention groups"
        )
        assert drop_eagle_block is False, (
            "Hybrid KV cache is not supported for " + "eagle + chunked local attention."
        )
        assert dcp_world_size == 1, "DCP not support chunked local attn now."
        assert pcp_world_size == 1, "PCP not support chunked local attn now."
        assert kv_cache_spec.block_size == alignment_tokens, (
            "KV cache groups with different block sizes are not compatible with "
            "chunked local attention now"
        )
        max_num_blocks = max_length // kv_cache_spec.block_size
        if max_length > 0:
            local_attention_start_idx = (
                max_length
                // kv_cache_spec.attention_chunk_size
                * kv_cache_spec.attention_chunk_size
            )
        else:
            local_attention_start_idx = 0
        # we marked blocks out of window as computed
        # with null blocks, and blocks inside window based on cache lookup
        # result [null] [null] ... [null] [hit block 1 (1st block contain
        # last window)] [hit block 2] ... [hit block x]
        local_attention_start_block_idx = (
            local_attention_start_idx // kv_cache_spec.block_size
        )
        computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
            [block_pool.null_block] * local_attention_start_block_idx
            for _ in range(len(kv_cache_group_ids))
        )
        for i in range(local_attention_start_block_idx, max_num_blocks):
            block_hash = block_hashes[i]
            if cached_block := block_pool.get_cached_block(
                block_hash, kv_cache_group_ids
            ):
                for computed, cached in zip(computed_blocks, cached_block):
                    computed.append(cached)
            else:
                break
        return computed_blocks

    def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
        """
        Get the number of tokens that will be skipped for attention computation.

        For chunked local attention, this corresponds to the tokens that are on
        the left side of the current chunk.

        Example 1:
        chunk size = 8, num_computed_tokens = 13
        Tokens:  [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
                 | ----- computed ---------------|
                                                  ^^ next token to be computed
                                   |----------------| <-- attention window for
                                                          next token
                 |--- skipped -----|
        Output: get_num_skipped_tokens(13) == 8

        Example 2:
        chunk size = 8, num_computed_tokens = 8
        Tokens:  [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
                 | --- computed ---|
                                     ^ next token to be computed
                                   |--| <-- attention window for next token
                 | --- skipped ----|
        Output: get_num_skipped_tokens(8) == 8

        Example 3:
        chunk size = 8, num_computed_tokens = 7
        Tokens:  [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
                 |---computed---|
                                 ^ next token to be computed
                 |-----------------| <-- attention window for next token
                 no token should be skipped.
        Output: get_num_skipped_tokens(7) == 0

        Args:
            num_computed_tokens: The number of tokens that have been computed.

        Returns:
            The number of tokens that will be skipped for attention computation.
        """
        num_skipped_tokens = (
            num_computed_tokens // self.attention_chunk_size
        ) * self.attention_chunk_size
        return num_skipped_tokens

    def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
        """
        cascade attention is not supported by chunked local attention.
        """
        return 0

find_longest_cache_hit(block_hashes, max_length, kv_cache_group_ids, block_pool, kv_cache_spec, drop_eagle_block, alignment_tokens, dcp_world_size=1, pcp_world_size=1) classmethod

For chunked local attention, we need to find the longest cache hit prefix of the blocks that is not longer than max_length. The prefix should be a common prefix hit for all the kv cache groups in kv_cache_group_ids. If no cache hit is found, return an empty list. note we mark as computed if the whole block is outside of the local window, and set the block as null. Examples:

  1. Attention chunk size of 8, block size of 4, max length of 15 for next token at 15th (zero-indexed), 8th - 14th tokens are in the window(needs lookup), 0th - 7th are not in the window, so they are already marked as computed. We check the complete block3 (8th - 11th tokens), Assume block 3 is hit, we will return [null, null, block 3], otherwise, we return [null, null]

  2. Attention chunk size of 8, block size of 4, max length of 16 for next token at 16th (zero-indexed), 0th - 15th tokens are not in the window, so they are already marked as computed. we return 4 blocks[null, null, null, null]

Parameters:

  • block_hashes

    (BlockHashList) –

    The block hashes of the request.

  • max_length

    (int) –

    The maximum length of the cache hit prefix.

  • kv_cache_group_ids

    (list[int]) –

    The ids of the kv cache groups.

  • block_pool

    (BlockPool) –

    The block pool.

  • kv_cache_spec

    (KVCacheSpec) –

    The kv cache spec.

  • drop_eagle_block

    (bool) –

    Whether to drop the last matched block for EAGLE/MTP.

  • dcp_world_size

    (int, default: 1 ) –

    The world size of decode context parallelism.

  • pcp_world_size

    (int, default: 1 ) –

    The world size of prefill context parallelism.

  • alignment_tokens

    (int) –

    The returned cache hit length (in tokens) should be a multiple of this value (in tokens).

Returns:

Source code in vllm/v1/core/single_type_kv_cache_manager.py
@classmethod
def find_longest_cache_hit(
    cls,
    block_hashes: BlockHashList,
    max_length: int,
    kv_cache_group_ids: list[int],
    block_pool: BlockPool,
    kv_cache_spec: KVCacheSpec,
    drop_eagle_block: bool,
    alignment_tokens: int,
    dcp_world_size: int = 1,
    pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
    """
    For chunked local attention, we need to find the longest cache hit
    prefix of the blocks that is not longer than `max_length`. The prefix
    should be a common prefix hit for all the kv cache groups in
    `kv_cache_group_ids`. If no cache hit is found, return an empty list.
    note we mark as computed if the whole block is outside of the local
    window, and set the block as null. Examples:

    1. Attention chunk size of 8, block size of 4, max length of 15
    for next token at 15th (zero-indexed), 8th - 14th tokens are in
    the window(needs lookup), 0th - 7th are not in the window,
    so they are already marked as computed. We check the complete
    block3 (8th - 11th tokens), Assume block 3 is hit, we will return
    [null, null, block 3], otherwise, we return [null, null]

    2. Attention chunk size of 8, block size of 4, max length of 16
    for next token at 16th (zero-indexed), 0th - 15th tokens are not
    in the window, so they are already marked as computed.
    we return 4 blocks[null, null, null, null]

    Args:
        block_hashes: The block hashes of the request.
        max_length: The maximum length of the cache hit prefix.
        kv_cache_group_ids: The ids of the kv cache groups.
        block_pool: The block pool.
        kv_cache_spec: The kv cache spec.
        drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP.
        dcp_world_size: The world size of decode context parallelism.
        pcp_world_size: The world size of prefill context parallelism.
        alignment_tokens: The returned cache hit length (in tokens) should
            be a multiple of this value (in tokens).

    Returns:
        A list of cached blocks
    """
    assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
        "ChunkedLocalAttentionManager can only be used for "
        "chunked local attention groups"
    )
    assert drop_eagle_block is False, (
        "Hybrid KV cache is not supported for " + "eagle + chunked local attention."
    )
    assert dcp_world_size == 1, "DCP not support chunked local attn now."
    assert pcp_world_size == 1, "PCP not support chunked local attn now."
    assert kv_cache_spec.block_size == alignment_tokens, (
        "KV cache groups with different block sizes are not compatible with "
        "chunked local attention now"
    )
    max_num_blocks = max_length // kv_cache_spec.block_size
    if max_length > 0:
        local_attention_start_idx = (
            max_length
            // kv_cache_spec.attention_chunk_size
            * kv_cache_spec.attention_chunk_size
        )
    else:
        local_attention_start_idx = 0
    # we marked blocks out of window as computed
    # with null blocks, and blocks inside window based on cache lookup
    # result [null] [null] ... [null] [hit block 1 (1st block contain
    # last window)] [hit block 2] ... [hit block x]
    local_attention_start_block_idx = (
        local_attention_start_idx // kv_cache_spec.block_size
    )
    computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
        [block_pool.null_block] * local_attention_start_block_idx
        for _ in range(len(kv_cache_group_ids))
    )
    for i in range(local_attention_start_block_idx, max_num_blocks):
        block_hash = block_hashes[i]
        if cached_block := block_pool.get_cached_block(
            block_hash, kv_cache_group_ids
        ):
            for computed, cached in zip(computed_blocks, cached_block):
                computed.append(cached)
        else:
            break
    return computed_blocks

get_num_common_prefix_blocks(running_request_id)

cascade attention is not supported by chunked local attention.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
    """
    cascade attention is not supported by chunked local attention.
    """
    return 0

get_num_skipped_tokens(num_computed_tokens)

Get the number of tokens that will be skipped for attention computation.

For chunked local attention, this corresponds to the tokens that are on the left side of the current chunk.

Example 1: chunk size = 8, num_computed_tokens = 13 Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... | ----- computed ---------------| ^^ next token to be computed |----------------| <-- attention window for next token |--- skipped -----| Output: get_num_skipped_tokens(13) == 8

Example 2: chunk size = 8, num_computed_tokens = 8 Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... | --- computed ---| ^ next token to be computed |--| <-- attention window for next token | --- skipped ----| Output: get_num_skipped_tokens(8) == 8

Example 3: chunk size = 8, num_computed_tokens = 7 Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ... |---computed---| ^ next token to be computed |-----------------| <-- attention window for next token no token should be skipped. Output: get_num_skipped_tokens(7) == 0

Parameters:

  • num_computed_tokens

    (int) –

    The number of tokens that have been computed.

Returns:

  • int

    The number of tokens that will be skipped for attention computation.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
    """
    Get the number of tokens that will be skipped for attention computation.

    For chunked local attention, this corresponds to the tokens that are on
    the left side of the current chunk.

    Example 1:
    chunk size = 8, num_computed_tokens = 13
    Tokens:  [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
             | ----- computed ---------------|
                                              ^^ next token to be computed
                               |----------------| <-- attention window for
                                                      next token
             |--- skipped -----|
    Output: get_num_skipped_tokens(13) == 8

    Example 2:
    chunk size = 8, num_computed_tokens = 8
    Tokens:  [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
             | --- computed ---|
                                 ^ next token to be computed
                               |--| <-- attention window for next token
             | --- skipped ----|
    Output: get_num_skipped_tokens(8) == 8

    Example 3:
    chunk size = 8, num_computed_tokens = 7
    Tokens:  [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
             |---computed---|
                             ^ next token to be computed
             |-----------------| <-- attention window for next token
             no token should be skipped.
    Output: get_num_skipped_tokens(7) == 0

    Args:
        num_computed_tokens: The number of tokens that have been computed.

    Returns:
        The number of tokens that will be skipped for attention computation.
    """
    num_skipped_tokens = (
        num_computed_tokens // self.attention_chunk_size
    ) * self.attention_chunk_size
    return num_skipped_tokens

CrossAttentionManager

Bases: SingleTypeKVCacheManager

Manager for cross-attention KV cache in encoder-decoder models.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
class CrossAttentionManager(SingleTypeKVCacheManager):
    """Manager for cross-attention KV cache in encoder-decoder models."""

    def allocate_new_computed_blocks(
        self,
        request_id: str,
        new_computed_blocks: Sequence[KVCacheBlock],
        num_local_computed_tokens: int,
        num_external_computed_tokens: int,
    ) -> None:
        # We do not cache blocks for cross-attention to be shared between
        # requests, so  `new_computed_blocks` should always be empty.
        assert len(new_computed_blocks) == 0

    def cache_blocks(
        self,
        request: Request,
        num_tokens: int,
        retention_interval: int | None = None,
    ) -> None:
        # We do not cache blocks for cross-attention to be shared between
        # requests, so this method is not relevant.
        raise ValueError("Should not be called as prefix caching is disabled.")

    def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
        # Cross-attention blocks contain request-specific encoder states
        # and are not shared between different requests
        return 0

    @classmethod
    def find_longest_cache_hit(
        cls,
        block_hashes: BlockHashList,
        max_length: int,
        kv_cache_group_ids: list[int],
        block_pool: BlockPool,
        kv_cache_spec: KVCacheSpec,
        drop_eagle_block: bool,
        alignment_tokens: int,
        dcp_world_size: int = 1,
        pcp_world_size: int = 1,
    ) -> tuple[list[KVCacheBlock], ...]:
        assert isinstance(kv_cache_spec, CrossAttentionSpec), (
            "CrossAttentionManager can only be used for cross-attention groups"
        )
        # Cross-attention does not benefit from prefix caching since:
        # 1. Encoder states are unique per request (different audio/image
        #    inputs)
        # 2. Encoder states are computed once per request, not incrementally
        # 3. No reusable prefix exists between different multimodal inputs
        # Return empty blocks to indicate no cache hits
        raise NotImplementedError("CrossAttentionManager does not support caching")

MambaManager

Bases: SingleTypeKVCacheManager

Methods:

Source code in vllm/v1/core/single_type_kv_cache_manager.py
class MambaManager(SingleTypeKVCacheManager):
    def __init__(
        self, kv_cache_spec: MambaSpec, block_pool: BlockPool, **kwargs
    ) -> None:
        super().__init__(kv_cache_spec, block_pool, **kwargs)
        self.cached_blocks_this_step: set[BlockHashWithGroupId] = set()
        self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode
        self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks
        if self.mamba_cache_mode == "align":
            # Mapping from request ID to the index of the block
            # allocated in the previous step
            self.last_state_block_idx: dict[str, int] = {}
            # The set of the requests that have been allocated blocks
            self._allocated_block_reqs: set[str] = set()

    @classmethod
    def find_longest_cache_hit(
        cls,
        block_hashes: BlockHashList,
        max_length: int,
        kv_cache_group_ids: list[int],
        block_pool: BlockPool,
        kv_cache_spec: KVCacheSpec,
        drop_eagle_block: bool,
        alignment_tokens: int,
        dcp_world_size: int = 1,
        pcp_world_size: int = 1,
    ) -> tuple[list[KVCacheBlock], ...]:
        assert isinstance(kv_cache_spec, MambaSpec), (
            "MambaManager can only be used for mamba groups"
        )
        assert dcp_world_size == 1, "DCP not support mamba now."
        assert pcp_world_size == 1, "PCP not support mamba now."
        computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
            [] for _ in range(len(kv_cache_group_ids))
        )

        block_size = kv_cache_spec.block_size
        max_num_blocks = max_length // block_size
        # Search from right to left and early stop when a match is found.
        for i in range(max_num_blocks - 1, -1, -1):
            if cached_block := block_pool.get_cached_block(
                block_hashes[i], kv_cache_group_ids
            ):
                # When enable Mamba prefix caching, `block_size` will be aligned
                # across full attention layers and Mamba layers to ensure the
                # prefix hit length aligned at block
                if (
                    block_size != alignment_tokens  # Faster for common case.
                    and (i + 1) * block_size % alignment_tokens != 0
                ):
                    continue
                for computed, cached in zip(computed_blocks, cached_block):
                    # the hit length logic later assumes:
                    #  hit_length = len(hit_blocks_other_attn[0])
                    #               * self.other_block_size
                    # so we insert dummy blocks at the beginning:
                    computed.extend([block_pool.null_block] * i)
                    computed.append(cached)
                break  # we just need the last match - early stopping

        return computed_blocks

    def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
        assert isinstance(self.kv_cache_spec, MambaSpec)

        # NOTE (tdoublep) with async scheduling, the num_computed_tokens can contain
        # draft tokens from the previous step that may or may not be rejected later.
        # This can make us think we are further ahead in the sequence than we actually
        # are, so let's assume that all tokens are rejected so we don't free blocks
        # that we might actually need.
        num_computed_tokens = max(0, num_computed_tokens - self.num_speculative_blocks)

        super().remove_skipped_blocks(request_id, num_computed_tokens)
        if self.mamba_cache_mode == "align":
            # `last_state_block_idx` refers to the block index allocated two steps ago.
            # The block allocated in the previous step is used to copy Mamba states
            # into the block allocated in the current step; the earlier block is
            # no longer needed and should be freed here.
            last_state_block_idx = self.last_state_block_idx.get(request_id)
            # Blocks allocated during prefill may be non-contiguous. Use
            # `last_state_block_idx` to free the appropriate block and replace it
            # with a null block.
            if (
                last_state_block_idx is not None
                and last_state_block_idx
                < cdiv(num_computed_tokens, self.block_size) - 1
            ):
                blocks = self.req_to_blocks[request_id]
                if blocks[last_state_block_idx] != self._null_block:
                    self.block_pool.free_blocks([blocks[last_state_block_idx]])
                    blocks[last_state_block_idx] = self._null_block

    def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
        """
        cascade attention is not supported by mamba
        """
        return 0

    def get_num_blocks_to_allocate(
        self,
        request_id: str,
        num_tokens: int,
        new_computed_blocks: Sequence[KVCacheBlock],
        total_computed_tokens: int,
        num_tokens_main_model: int,
        apply_admission_cap: bool = False,
    ) -> int:
        assert isinstance(self.kv_cache_spec, MambaSpec)
        if (
            len(new_computed_blocks) > 0
            and new_computed_blocks[-1].block_hash in self.cached_blocks_this_step
        ):
            # Mamba can't rely on blocks generated by other requests in the current step
            # To put it in the next step, we return num_gpu_blocks + 1 so
            # that kv_cache_manager will think there is no enough blocks to allocate now
            # and don't schedule it in the current step.
            return self.block_pool.num_gpu_blocks + 1
        if self.mamba_cache_mode != "align":
            # Allocate extra `num_speculative_blocks` blocks for
            # speculative decoding (MTP/EAGLE) with linear attention.
            if self.num_speculative_blocks > 0:
                num_tokens += (
                    self.kv_cache_spec.block_size * self.num_speculative_blocks
                )
            return super().get_num_blocks_to_allocate(
                request_id,
                num_tokens,
                new_computed_blocks,
                total_computed_tokens,
                num_tokens_main_model,
                apply_admission_cap=apply_admission_cap,
            )
        else:
            # We don't allocate blocks for lookahead tokens in align mode, because if
            # x * block_size tokens are scheduled, num_tokens is
            # x * block_size + num_lookahead_tokens and breaks the alignment.
            # We can ignore lookahead tokens because current draft models don't have
            # mamba layers.
            num_tokens = num_tokens_main_model

            # NOTE(tdouble): this is an over-estimate of how many blocks we need because
            # num_tokens can include draft tokens that will later be rejected.
            num_required_blocks = (
                cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
            )
            num_new_blocks = (
                num_required_blocks
                - len(new_computed_blocks)
                - len(self.req_to_blocks[request_id])
            )
            if num_new_blocks > 0:
                if request_id in self._allocated_block_reqs:
                    # Old request. Needs at most 1 more blocks as we can reuse the
                    # speculative blocks in previous step.
                    num_new_blocks = 1
                else:
                    # First prefill. Allocate 1 block for running state and the
                    # speculative blocks.
                    num_new_blocks = 1 + self.num_speculative_blocks

            num_evictable_computed_blocks = self._get_num_evictable_blocks(
                new_computed_blocks
            )
            return num_new_blocks + num_evictable_computed_blocks

    def allocate_new_blocks(
        self, request_id: str, num_tokens: int, num_tokens_main_model: int
    ) -> list[KVCacheBlock]:
        assert isinstance(self.kv_cache_spec, MambaSpec)
        if self.mamba_cache_mode != "align":
            # Allocate extra `num_speculative_blocks` blocks for
            # speculative decoding (MTP/EAGLE) with linear attention.
            if self.num_speculative_blocks > 0:
                num_tokens += self.block_size * self.num_speculative_blocks
            return super().allocate_new_blocks(
                request_id, num_tokens, num_tokens_main_model
            )
        else:
            # We don't allocate blocks for lookahead tokens in align mode, because if
            # x * block_size tokens are scheduled, num_tokens is
            # x * block_size + num_lookahead_tokens and breaks the alignment.
            # We can ignore lookahead tokens because current draft models don't have
            # mamba layers.
            num_tokens = num_tokens_main_model
            req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id]
            # NOTE(tdouble): this is an over-estimate of how many blocks we need because
            # num_tokens can include draft tokens that will later be rejected.
            num_required_blocks = (
                cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
            )
            # `num_required_blocks` might be less than `len(req_blocks)` if blocks are
            # over-allocated at last round.
            if num_required_blocks <= len(req_blocks):
                return []
            else:
                prev_block_len = len(req_blocks)
                blocks_allocated = request_id in self._allocated_block_reqs
                # Record the last state block
                if blocks_allocated:
                    # We always save the running state at the last
                    # (1 + num_speculative_blocks) block
                    self.last_state_block_idx[request_id] = (
                        prev_block_len - 1 - self.num_speculative_blocks
                    )
                elif prev_block_len > 0:
                    # When a new request hits the prefix cache, the last block
                    # saves the hit state.
                    self.last_state_block_idx[request_id] = prev_block_len - 1

                num_skipped_blocks = (
                    num_required_blocks - self.num_speculative_blocks - 1
                )
                # null blocks
                if prev_block_len < num_skipped_blocks:
                    req_blocks.extend(
                        [
                            self._null_block
                            for _ in range(prev_block_len, num_skipped_blocks)
                        ]
                    )

                if blocks_allocated:
                    # reuse previous speculative blocks in this step
                    for block_idx in range(
                        prev_block_len - self.num_speculative_blocks, prev_block_len
                    ):
                        if block_idx < num_skipped_blocks:
                            req_blocks.append(req_blocks[block_idx])
                            req_blocks[block_idx] = self._null_block
                        else:
                            break
                num_new_blocks = num_required_blocks - len(req_blocks)
                if blocks_allocated:
                    assert num_new_blocks <= 1
                else:
                    assert num_new_blocks <= self.num_speculative_blocks + 1
                new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
                req_blocks.extend(new_blocks)
                self._allocated_block_reqs.add(request_id)
                return req_blocks[prev_block_len:]

    def free(self, request_id: str) -> None:
        if self.mamba_cache_mode == "align":
            self._allocated_block_reqs.discard(request_id)
            self.last_state_block_idx.pop(request_id, None)
        super().free(request_id)

    def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
        """
        Get the number of tokens whose mamba state are not needed anymore. Mamba only
        need to keep the state of the last computed token, so we return
        num_computed_tokens - 1.
        """
        return num_computed_tokens - 1

    def cache_blocks(
        self,
        request: Request,
        num_tokens: int,
        retention_interval: int | None = None,
    ) -> None:
        num_cached_blocks_before = self.num_cached_block.get(request.request_id, 0)
        super().cache_blocks(request, num_tokens, retention_interval=retention_interval)
        num_cached_blocks_after = self.num_cached_block.get(request.request_id, 0)
        if num_cached_blocks_after > num_cached_blocks_before:
            for block in self.req_to_blocks[request.request_id][
                num_cached_blocks_before:num_cached_blocks_after
            ]:
                if block.is_null:
                    continue
                assert block.block_hash is not None
                self.cached_blocks_this_step.add(block.block_hash)

    def new_step_starts(self) -> None:
        self.cached_blocks_this_step.clear()

get_num_common_prefix_blocks(running_request_id)

cascade attention is not supported by mamba

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
    """
    cascade attention is not supported by mamba
    """
    return 0

get_num_skipped_tokens(num_computed_tokens)

Get the number of tokens whose mamba state are not needed anymore. Mamba only need to keep the state of the last computed token, so we return num_computed_tokens - 1.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
    """
    Get the number of tokens whose mamba state are not needed anymore. Mamba only
    need to keep the state of the last computed token, so we return
    num_computed_tokens - 1.
    """
    return num_computed_tokens - 1

SingleTypeKVCacheManager

Bases: ABC

An abstract base class for a manager that handle the kv cache management logic of one specific type of attention layer.

Methods:

Source code in vllm/v1/core/single_type_kv_cache_manager.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
class SingleTypeKVCacheManager(ABC):
    """
    An abstract base class for a manager that handle the kv cache management
    logic of one specific type of attention layer.
    """

    def __init__(
        self,
        kv_cache_spec: KVCacheSpec,
        block_pool: BlockPool,
        enable_caching: bool,
        kv_cache_group_id: int,
        scheduler_block_size: int,
        dcp_world_size: int = 1,
        pcp_world_size: int = 1,
        max_admission_blocks_per_request: int | None = None,
    ) -> None:
        """
        Initializes the SingleTypeKVCacheManager.
        Args:
            kv_cache_spec: The kv_cache_spec for this manager.
            block_pool: The block pool.
            kv_cache_group_id: The id of the kv cache group of this manager.
            scheduler_block_size: The scheduling granularity (LCM of all group
                block sizes); a multiple of this manager's ``block_size``.
            max_admission_blocks_per_request: Recycling-aware per-request
                block cap used by `get_num_blocks_to_allocate`. Only set for
                spec types that recycle blocks across chunks (SWA,
                chunked-local); `None` (the default) means no cap, which is
                correct for full-attention-style specs that hold every
                block until the request finishes.
        """
        self.scheduler_block_size = scheduler_block_size
        # The block size for this manager; used for actual block allocation.
        self.block_size = kv_cache_spec.block_size
        self.dcp_world_size = dcp_world_size
        self.pcp_world_size = pcp_world_size
        if dcp_world_size * pcp_world_size > 1:
            self.block_size *= dcp_world_size * pcp_world_size
        self.kv_cache_spec = kv_cache_spec
        self.block_pool = block_pool
        self.enable_caching = enable_caching
        self._max_admission_blocks_per_request = max_admission_blocks_per_request
        self.new_block_ids: list[int] = []

        # Mapping from request ID to blocks to track the blocks allocated
        # for each request, so that we can free the blocks when the request
        # is finished.
        self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)

        # {req_id: The number of cached blocks for this given request}
        # This is used to track the number of cached blocks for each request.
        # This is only used to track the RUNNING requests, we do not track the
        # data for preempted ones.
        self.num_cached_block: dict[str, int] = {}

        self.kv_cache_group_id = kv_cache_group_id
        self._null_block = block_pool.null_block

        # Whether this group's prefix-cache hits drop the EAGLE/MTP lookahead
        # block. Only consulted by managers whose hit logic is sparse within an
        # aligned segment (SWA). Initialized lazily by the coordinator after
        # determining the attention groups.
        self.use_eagle = False

    @classmethod
    def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]):
        return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks)

    def get_num_blocks_to_allocate(
        self,
        request_id: str,
        num_tokens: int,
        new_computed_blocks: Sequence[KVCacheBlock],
        total_computed_tokens: int,
        num_tokens_main_model: int,
        apply_admission_cap: bool = False,
    ) -> int:
        """
        Get the number of blocks needed to be allocated for the request.

        Args:
            request_id: The request ID.
            num_tokens: The total number of tokens that need a slot (including
                tokens that are already allocated).
            new_computed_blocks: The new computed blocks just hitting the
                prefix caching.
            total_computed_tokens: Include both local and external computed
                tokens.
            num_tokens_main_model: The number of tokens for the main model (aka target
                model in spec decode). w/o spec decode, it is num_tokens;
                with spec decode, it is num_tokens - num_lookahead_tokens.
            apply_admission_cap: If True, clamp by `num_required_blocks` by
                `_max_admission_blocks_per_request`for recycling-aware specs
                (SWA, chunked-local).

        Returns:
            The number of blocks to allocate.
        """

        num_required_blocks = cdiv(num_tokens, self.block_size)
        if apply_admission_cap and self._max_admission_blocks_per_request is not None:
            # Recycling-aware specs (SWA, chunked-local) cap the per-request
            # reservation here so admission matches the startup pool sizer
            # (`SlidingWindowSpec.max_admission_blocks_per_request` / its
            # chunked-local counterpart). `remove_skipped_blocks` runs from
            # `allocate_slots` before each chunk's `get_num_blocks_to_allocate`,
            # so per-request peak real-held blocks <= this cap, which keeps
            # `sum(reservations) <= pool` <=> `sum(peak_real_held) <= pool`.
            # Drift between the two would re-introduce the deadlock from
            # issue #39734 or, worse, mid-prefill OOM.
            num_required_blocks = min(
                num_required_blocks, self._max_admission_blocks_per_request
            )
        num_req_blocks = len(self.req_to_blocks.get(request_id, ()))

        if request_id in self.num_cached_block:
            # Fast-path: a running request won't have any new prefix-cache hits.
            assert len(new_computed_blocks) == 0
            # NOTE: With speculative decoding, request's blocks may be allocated
            # for draft tokens which are later rejected. In this case,
            # num_required_blocks may be smaller than num_req_blocks.
            return max(num_required_blocks - num_req_blocks, 0)

        num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
        num_local_computed_blocks = len(new_computed_blocks) + num_req_blocks
        # Number of whole blocks that are skipped by the attention window.
        # If nothing is skipped, this is 0.
        num_skipped_blocks = num_skipped_tokens // self.block_size
        # We need blocks for the non-skipped suffix. If there are still
        # local-computed blocks inside the window, they contribute to the
        # required capacity; otherwise, skipped blocks dominate.
        num_new_blocks = max(
            num_required_blocks - max(num_skipped_blocks, num_local_computed_blocks),
            0,
        )

        # Among the `new_computed_blocks`, the first `num_skipped_blocks` worth
        # of blocks are skipped; `num_req_blocks` of those may already be in
        # `req_to_blocks`, so only skip the remainder from `new_computed_blocks`.
        num_skipped_new_computed_blocks = max(0, num_skipped_blocks - num_req_blocks)

        # If a computed block is an eviction candidate (in the free queue and
        # ref_cnt == 0), it will be removed from the free queue when touched by
        # the allocated request, so we must count it in the free-capacity check.
        num_evictable_blocks = self._get_num_evictable_blocks(
            new_computed_blocks[num_skipped_new_computed_blocks:]
        )
        return num_new_blocks + num_evictable_blocks

    def allocate_new_computed_blocks(
        self,
        request_id: str,
        new_computed_blocks: Sequence[KVCacheBlock],
        num_local_computed_tokens: int,
        num_external_computed_tokens: int,
    ) -> None:
        """
        Add the new computed blocks to the request. This involves three steps:
        1. Touch the computed blocks to make sure they won't be evicted.
        1.5. (Optional) For sliding window, skip blocks are padded with null blocks.
        2. Add the remaining computed blocks.
        3. (Optional) For KV connectors, allocate new blocks for external computed
            tokens (if any).

        Args:
            request_id: The request ID.
            new_computed_blocks: The new computed blocks just hitting the
                prefix cache.
            num_local_computed_tokens: The number of local computed tokens.
            num_external_computed_tokens: The number of external computed tokens.
        """

        if request_id in self.num_cached_block:
            # Fast-path: a running request won't have any new prefix-cache hits.
            # It should not have any new computed blocks.
            assert len(new_computed_blocks) == 0
            return

        # A new request.
        req_blocks = self.req_to_blocks[request_id]
        assert len(req_blocks) == 0
        num_total_computed_tokens = (
            num_local_computed_tokens + num_external_computed_tokens
        )
        num_skipped_tokens = self.get_num_skipped_tokens(num_total_computed_tokens)
        num_skipped_blocks = num_skipped_tokens // self.block_size
        if num_skipped_blocks > 0:
            # It is possible that all new computed blocks are skipped when
            # num_skipped_blocks > len(new_computed_blocks).
            new_computed_blocks = new_computed_blocks[num_skipped_blocks:]
            # Some external computed tokens may be skipped too.
            num_external_computed_tokens = min(
                num_total_computed_tokens - num_skipped_tokens,
                num_external_computed_tokens,
            )

        # Touch the computed blocks to make sure they won't be evicted.
        if self.enable_caching:
            self.block_pool.touch(new_computed_blocks)
        else:
            assert not any(new_computed_blocks), (
                "Computed blocks should be empty when prefix caching is disabled"
            )

        # Skip blocks are padded with null blocks.
        req_blocks.extend([self._null_block] * num_skipped_blocks)
        # Add the remaining computed blocks.
        req_blocks.extend(new_computed_blocks)
        # All cached hits (including skipped nulls) are already cached; mark
        # them so cache_blocks() will not try to re-cache blocks that already
        # have a block_hash set.
        self.num_cached_block[request_id] = len(req_blocks)

        if num_external_computed_tokens > 0:
            # Allocate new blocks for external computed tokens.
            allocated_blocks = self.block_pool.get_new_blocks(
                cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
            )
            req_blocks.extend(allocated_blocks)
            if type(self.kv_cache_spec) in (
                FullAttentionSpec,
                TQFullAttentionSpec,
                MLAAttentionSpec,
            ):
                self.new_block_ids.extend(b.block_id for b in allocated_blocks)

    def allocate_new_blocks(
        self, request_id: str, num_tokens: int, num_tokens_main_model: int
    ) -> list[KVCacheBlock]:
        """
        Allocate new blocks for the request to give it at least `num_tokens`
        token slots.

        Args:
            request_id: The request ID.
            num_tokens: The total number of tokens that need a slot (including
                tokens that are already allocated).
            num_tokens_main_model: The number of tokens for the main model (aka target
                model in spec decode). w/o spec decode, it is num_tokens;
                with spec decode, it is num_tokens - num_lookahead_tokens.
        Returns:
            The new allocated blocks.
        """
        req_blocks = self.req_to_blocks[request_id]
        num_required_blocks = cdiv(num_tokens, self.block_size)
        num_new_blocks = num_required_blocks - len(req_blocks)
        if num_new_blocks <= 0:
            return []
        else:
            new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
            req_blocks.extend(new_blocks)
            if type(self.kv_cache_spec) in (
                FullAttentionSpec,
                TQFullAttentionSpec,
                MLAAttentionSpec,
            ):
                self.new_block_ids.extend(b.block_id for b in new_blocks)
            return new_blocks

    def take_new_block_ids(self) -> list[int]:
        """Drain and return block IDs allocated since the last call."""
        ids = self.new_block_ids
        self.new_block_ids = []
        return ids

    def cache_blocks(
        self,
        request: Request,
        num_tokens: int,
        retention_interval: int | None = None,
    ) -> None:
        """
        Cache the blocks for the request.

        Args:
            request: The request.
            num_tokens: The total number of tokens that need to be cached
                (including tokens that are already cached).
            retention_interval: Sparse local-checkpoint granularity. ``None``
                keeps dense checkpointing; ``0`` keeps only the latest replay
                boundary; a positive multiple of ``scheduler_block_size`` keeps
                a tail once per that-sized segment. Only SWA acts on it.
        """
        num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
        num_full_blocks = num_tokens // self.block_size

        if num_cached_blocks >= num_full_blocks:
            return

        block_mask = self.reachable_block_mask(
            start_block=num_cached_blocks,
            end_block=num_full_blocks,
            alignment_tokens=self.scheduler_block_size,
            kv_cache_spec=self.kv_cache_spec,
            use_eagle=self.use_eagle,
            retention_interval=retention_interval,
            num_prompt_tokens=request.num_prompt_tokens,
        )
        self.block_pool.cache_full_blocks(
            request=request,
            blocks=self.req_to_blocks[request.request_id],
            num_cached_blocks=num_cached_blocks,
            num_full_blocks=num_full_blocks,
            block_size=self.block_size,
            kv_cache_group_id=self.kv_cache_group_id,
            block_mask=block_mask,
        )

        self.num_cached_block[request.request_id] = num_full_blocks

    @classmethod
    def reachable_block_mask(
        cls,
        start_block: int,
        end_block: int,
        alignment_tokens: int | None,
        kv_cache_spec: KVCacheSpec,
        use_eagle: bool,
        retention_interval: int | None = None,
        num_prompt_tokens: int | None = None,
    ) -> list[bool] | None:
        """Per-block mask for ``cache_full_blocks``. ``None`` means cache
        every (non-null) block — the default for full attention.

        Subclasses with sparse hit semantics (SWA) override this to skip
        blocks that can never serve a hit at any alignment-aligned prefix
        length.
        """
        return None

    def free(self, request_id: str) -> None:
        """
        Free the blocks for the request.

        Args:
            request_id: The request ID.
        """
        # Default to [] in case a request is freed (aborted) before alloc.
        req_blocks = self.req_to_blocks.pop(request_id, [])

        # Free blocks in reverse order so that the tail blocks are
        # freed first.
        ordered_blocks = reversed(req_blocks)

        self.block_pool.free_blocks(ordered_blocks)
        self.num_cached_block.pop(request_id, None)

    @abstractmethod
    def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
        """
        Get the number of common prefix blocks for all requests with allocated
        KV cache.

        Args:
            running_request_id: The request ID.

        Returns:
            The number of common prefix blocks for all requests with allocated
            KV cache.
        """

        raise NotImplementedError

    @classmethod
    @abstractmethod
    def find_longest_cache_hit(
        cls,
        block_hashes: BlockHashList,
        max_length: int,
        kv_cache_group_ids: list[int],
        block_pool: BlockPool,
        kv_cache_spec: KVCacheSpec,
        drop_eagle_block: bool,
        alignment_tokens: int,
        dcp_world_size: int = 1,
        pcp_world_size: int = 1,
    ) -> tuple[list[KVCacheBlock], ...]:
        """
        Get the longest cache hit prefix of the blocks that is not longer than
        `max_length`. The prefix should be a common prefix hit for all the
        kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
        return an empty list.
        If eagle is enabled, drop the last matched block to force recompute the
        last block to get the required hidden states for eagle drafting head.
        Need to be customized for each attention type.

        Args:
            block_hashes: The block hashes of the request.
            max_length: The maximum length of the cache hit prefix.
            kv_cache_group_ids: The ids of the kv cache groups.
            block_pool: The block pool.
            kv_cache_spec: The kv cache spec.
            drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP.
                Always False for non-EAGLE/MTP groups, but can be False for EAGLE/MTP
                groups too if the last block is already dropped (e.g., in a
                convergence loop in `find_longest_cache_hit`).
            alignment_tokens: The returned cache hit length (in tokens) should
                be a multiple of this value (in tokens). By default, it should
                be set to the block_size.
            dcp_world_size: The world size of decode context parallelism.
            pcp_world_size: The world size of prefill context parallelism.

        Returns:
            A list of cached blocks with skipped blocks replaced by null block
            for each kv cache group in `kv_cache_group_ids`.
            Return a list of length `len(kv_cache_group_ids)`, where the i-th
            element is a list of cached blocks for the i-th kv cache group
            in `kv_cache_group_ids`.
            For example, sliding window manager should return a list like
            ([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
            and sliding window 8 and len(kv_cache_group_ids) = 1.
        """

        raise NotImplementedError

    def remove_skipped_blocks(
        self, request_id: str, total_computed_tokens: int
    ) -> None:
        """
        Remove and free the blocks that are no longer needed for attention computation.
        The removed blocks should be replaced by null_block.

        This function depends on `get_num_skipped_tokens`, which need to be implemented
        differently for each attention type.

        Args:
            request_id: The request ID.
            total_computed_tokens: The total number of computed tokens, including
                local computed tokens and external computed tokens.
        """
        # Remove the blocks that will be skipped during attention computation.
        num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
        if num_skipped_tokens <= 0:
            # This indicates that ALL tokens are inside attention window.
            # Thus we do not need to free any blocks outside attention window.
            # A typical case is full attention that we never free any token
            # before the request is finished.
            return
        blocks = self.req_to_blocks[request_id]
        num_skipped_blocks = num_skipped_tokens // self.block_size
        # `num_skipped_tokens` may include tokens that haven't been allocated yet
        # (e.g., when the attention window moves into the external computed tokens
        # range), so we must cap to the number of blocks that currently exist for
        # this request.
        num_skipped_blocks = min(num_skipped_blocks, len(blocks))

        # Reuse skipped local blocks in order:
        #   scratch blocks: no prefix-cache value, reuse first.
        #   cached blocks: reusable prefix-cache value, reuse last.
        removed_cached_blocks: list[KVCacheBlock] = []
        removed_uncached_blocks: list[KVCacheBlock] = []
        # Because the block starts from index 0, the num_skipped_block-th block
        # corresponds to index num_skipped_blocks - 1.
        for i in range(num_skipped_blocks - 1, -1, -1):
            if blocks[i] == self._null_block:
                # If the block is already a null block, the blocks before it
                # should also have been set to null blocks by the previous calls
                # to this function.
                break
            if blocks[i].block_hash is None:
                removed_uncached_blocks.append(blocks[i])
            else:
                removed_cached_blocks.append(blocks[i])
            blocks[i] = self._null_block
        # `prepend=True` makes uncached scratch blocks the next allocation
        # candidates, while cached blocks stay behind them as best-effort
        # prefix-cache entries.
        self.block_pool.free_blocks(removed_cached_blocks)
        self.block_pool.free_blocks(removed_uncached_blocks, prepend=True)

    def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
        """
        Get the number of tokens that will be skipped for attention computation.

        Args:
            num_computed_tokens: The number of tokens that have been computed.

        Returns:
            The number of tokens that will be skipped for attention computation.
        """
        # The default behavior is to not skip any tokens.
        return 0

    def new_step_starts(self) -> None:
        # do nothing by default
        return None

__init__(kv_cache_spec, block_pool, enable_caching, kv_cache_group_id, scheduler_block_size, dcp_world_size=1, pcp_world_size=1, max_admission_blocks_per_request=None)

Initializes the SingleTypeKVCacheManager. Args: kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. scheduler_block_size: The scheduling granularity (LCM of all group block sizes); a multiple of this manager's block_size. max_admission_blocks_per_request: Recycling-aware per-request block cap used by get_num_blocks_to_allocate. Only set for spec types that recycle blocks across chunks (SWA, chunked-local); None (the default) means no cap, which is correct for full-attention-style specs that hold every block until the request finishes.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def __init__(
    self,
    kv_cache_spec: KVCacheSpec,
    block_pool: BlockPool,
    enable_caching: bool,
    kv_cache_group_id: int,
    scheduler_block_size: int,
    dcp_world_size: int = 1,
    pcp_world_size: int = 1,
    max_admission_blocks_per_request: int | None = None,
) -> None:
    """
    Initializes the SingleTypeKVCacheManager.
    Args:
        kv_cache_spec: The kv_cache_spec for this manager.
        block_pool: The block pool.
        kv_cache_group_id: The id of the kv cache group of this manager.
        scheduler_block_size: The scheduling granularity (LCM of all group
            block sizes); a multiple of this manager's ``block_size``.
        max_admission_blocks_per_request: Recycling-aware per-request
            block cap used by `get_num_blocks_to_allocate`. Only set for
            spec types that recycle blocks across chunks (SWA,
            chunked-local); `None` (the default) means no cap, which is
            correct for full-attention-style specs that hold every
            block until the request finishes.
    """
    self.scheduler_block_size = scheduler_block_size
    # The block size for this manager; used for actual block allocation.
    self.block_size = kv_cache_spec.block_size
    self.dcp_world_size = dcp_world_size
    self.pcp_world_size = pcp_world_size
    if dcp_world_size * pcp_world_size > 1:
        self.block_size *= dcp_world_size * pcp_world_size
    self.kv_cache_spec = kv_cache_spec
    self.block_pool = block_pool
    self.enable_caching = enable_caching
    self._max_admission_blocks_per_request = max_admission_blocks_per_request
    self.new_block_ids: list[int] = []

    # Mapping from request ID to blocks to track the blocks allocated
    # for each request, so that we can free the blocks when the request
    # is finished.
    self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)

    # {req_id: The number of cached blocks for this given request}
    # This is used to track the number of cached blocks for each request.
    # This is only used to track the RUNNING requests, we do not track the
    # data for preempted ones.
    self.num_cached_block: dict[str, int] = {}

    self.kv_cache_group_id = kv_cache_group_id
    self._null_block = block_pool.null_block

    # Whether this group's prefix-cache hits drop the EAGLE/MTP lookahead
    # block. Only consulted by managers whose hit logic is sparse within an
    # aligned segment (SWA). Initialized lazily by the coordinator after
    # determining the attention groups.
    self.use_eagle = False

allocate_new_blocks(request_id, num_tokens, num_tokens_main_model)

Allocate new blocks for the request to give it at least num_tokens token slots.

Parameters:

  • request_id

    (str) –

    The request ID.

  • num_tokens

    (int) –

    The total number of tokens that need a slot (including tokens that are already allocated).

  • num_tokens_main_model

    (int) –

    The number of tokens for the main model (aka target model in spec decode). w/o spec decode, it is num_tokens; with spec decode, it is num_tokens - num_lookahead_tokens.

Returns: The new allocated blocks.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def allocate_new_blocks(
    self, request_id: str, num_tokens: int, num_tokens_main_model: int
) -> list[KVCacheBlock]:
    """
    Allocate new blocks for the request to give it at least `num_tokens`
    token slots.

    Args:
        request_id: The request ID.
        num_tokens: The total number of tokens that need a slot (including
            tokens that are already allocated).
        num_tokens_main_model: The number of tokens for the main model (aka target
            model in spec decode). w/o spec decode, it is num_tokens;
            with spec decode, it is num_tokens - num_lookahead_tokens.
    Returns:
        The new allocated blocks.
    """
    req_blocks = self.req_to_blocks[request_id]
    num_required_blocks = cdiv(num_tokens, self.block_size)
    num_new_blocks = num_required_blocks - len(req_blocks)
    if num_new_blocks <= 0:
        return []
    else:
        new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
        req_blocks.extend(new_blocks)
        if type(self.kv_cache_spec) in (
            FullAttentionSpec,
            TQFullAttentionSpec,
            MLAAttentionSpec,
        ):
            self.new_block_ids.extend(b.block_id for b in new_blocks)
        return new_blocks

allocate_new_computed_blocks(request_id, new_computed_blocks, num_local_computed_tokens, num_external_computed_tokens)

Add the new computed blocks to the request. This involves three steps: 1. Touch the computed blocks to make sure they won't be evicted. 1.5. (Optional) For sliding window, skip blocks are padded with null blocks. 2. Add the remaining computed blocks. 3. (Optional) For KV connectors, allocate new blocks for external computed tokens (if any).

Parameters:

  • request_id

    (str) –

    The request ID.

  • new_computed_blocks

    (Sequence[KVCacheBlock]) –

    The new computed blocks just hitting the prefix cache.

  • num_local_computed_tokens

    (int) –

    The number of local computed tokens.

  • num_external_computed_tokens

    (int) –

    The number of external computed tokens.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def allocate_new_computed_blocks(
    self,
    request_id: str,
    new_computed_blocks: Sequence[KVCacheBlock],
    num_local_computed_tokens: int,
    num_external_computed_tokens: int,
) -> None:
    """
    Add the new computed blocks to the request. This involves three steps:
    1. Touch the computed blocks to make sure they won't be evicted.
    1.5. (Optional) For sliding window, skip blocks are padded with null blocks.
    2. Add the remaining computed blocks.
    3. (Optional) For KV connectors, allocate new blocks for external computed
        tokens (if any).

    Args:
        request_id: The request ID.
        new_computed_blocks: The new computed blocks just hitting the
            prefix cache.
        num_local_computed_tokens: The number of local computed tokens.
        num_external_computed_tokens: The number of external computed tokens.
    """

    if request_id in self.num_cached_block:
        # Fast-path: a running request won't have any new prefix-cache hits.
        # It should not have any new computed blocks.
        assert len(new_computed_blocks) == 0
        return

    # A new request.
    req_blocks = self.req_to_blocks[request_id]
    assert len(req_blocks) == 0
    num_total_computed_tokens = (
        num_local_computed_tokens + num_external_computed_tokens
    )
    num_skipped_tokens = self.get_num_skipped_tokens(num_total_computed_tokens)
    num_skipped_blocks = num_skipped_tokens // self.block_size
    if num_skipped_blocks > 0:
        # It is possible that all new computed blocks are skipped when
        # num_skipped_blocks > len(new_computed_blocks).
        new_computed_blocks = new_computed_blocks[num_skipped_blocks:]
        # Some external computed tokens may be skipped too.
        num_external_computed_tokens = min(
            num_total_computed_tokens - num_skipped_tokens,
            num_external_computed_tokens,
        )

    # Touch the computed blocks to make sure they won't be evicted.
    if self.enable_caching:
        self.block_pool.touch(new_computed_blocks)
    else:
        assert not any(new_computed_blocks), (
            "Computed blocks should be empty when prefix caching is disabled"
        )

    # Skip blocks are padded with null blocks.
    req_blocks.extend([self._null_block] * num_skipped_blocks)
    # Add the remaining computed blocks.
    req_blocks.extend(new_computed_blocks)
    # All cached hits (including skipped nulls) are already cached; mark
    # them so cache_blocks() will not try to re-cache blocks that already
    # have a block_hash set.
    self.num_cached_block[request_id] = len(req_blocks)

    if num_external_computed_tokens > 0:
        # Allocate new blocks for external computed tokens.
        allocated_blocks = self.block_pool.get_new_blocks(
            cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
        )
        req_blocks.extend(allocated_blocks)
        if type(self.kv_cache_spec) in (
            FullAttentionSpec,
            TQFullAttentionSpec,
            MLAAttentionSpec,
        ):
            self.new_block_ids.extend(b.block_id for b in allocated_blocks)

cache_blocks(request, num_tokens, retention_interval=None)

Cache the blocks for the request.

Parameters:

  • request

    (Request) –

    The request.

  • num_tokens

    (int) –

    The total number of tokens that need to be cached (including tokens that are already cached).

  • retention_interval

    (int | None, default: None ) –

    Sparse local-checkpoint granularity. None keeps dense checkpointing; 0 keeps only the latest replay boundary; a positive multiple of scheduler_block_size keeps a tail once per that-sized segment. Only SWA acts on it.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def cache_blocks(
    self,
    request: Request,
    num_tokens: int,
    retention_interval: int | None = None,
) -> None:
    """
    Cache the blocks for the request.

    Args:
        request: The request.
        num_tokens: The total number of tokens that need to be cached
            (including tokens that are already cached).
        retention_interval: Sparse local-checkpoint granularity. ``None``
            keeps dense checkpointing; ``0`` keeps only the latest replay
            boundary; a positive multiple of ``scheduler_block_size`` keeps
            a tail once per that-sized segment. Only SWA acts on it.
    """
    num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
    num_full_blocks = num_tokens // self.block_size

    if num_cached_blocks >= num_full_blocks:
        return

    block_mask = self.reachable_block_mask(
        start_block=num_cached_blocks,
        end_block=num_full_blocks,
        alignment_tokens=self.scheduler_block_size,
        kv_cache_spec=self.kv_cache_spec,
        use_eagle=self.use_eagle,
        retention_interval=retention_interval,
        num_prompt_tokens=request.num_prompt_tokens,
    )
    self.block_pool.cache_full_blocks(
        request=request,
        blocks=self.req_to_blocks[request.request_id],
        num_cached_blocks=num_cached_blocks,
        num_full_blocks=num_full_blocks,
        block_size=self.block_size,
        kv_cache_group_id=self.kv_cache_group_id,
        block_mask=block_mask,
    )

    self.num_cached_block[request.request_id] = num_full_blocks

find_longest_cache_hit(block_hashes, max_length, kv_cache_group_ids, block_pool, kv_cache_spec, drop_eagle_block, alignment_tokens, dcp_world_size=1, pcp_world_size=1) abstractmethod classmethod

Get the longest cache hit prefix of the blocks that is not longer than max_length. The prefix should be a common prefix hit for all the kv cache groups in kv_cache_group_ids. If no cache hit is found, return an empty list. If eagle is enabled, drop the last matched block to force recompute the last block to get the required hidden states for eagle drafting head. Need to be customized for each attention type.

Parameters:

  • block_hashes

    (BlockHashList) –

    The block hashes of the request.

  • max_length

    (int) –

    The maximum length of the cache hit prefix.

  • kv_cache_group_ids

    (list[int]) –

    The ids of the kv cache groups.

  • block_pool

    (BlockPool) –

    The block pool.

  • kv_cache_spec

    (KVCacheSpec) –

    The kv cache spec.

  • drop_eagle_block

    (bool) –

    Whether to drop the last matched block for EAGLE/MTP. Always False for non-EAGLE/MTP groups, but can be False for EAGLE/MTP groups too if the last block is already dropped (e.g., in a convergence loop in find_longest_cache_hit).

  • alignment_tokens

    (int) –

    The returned cache hit length (in tokens) should be a multiple of this value (in tokens). By default, it should be set to the block_size.

  • dcp_world_size

    (int, default: 1 ) –

    The world size of decode context parallelism.

  • pcp_world_size

    (int, default: 1 ) –

    The world size of prefill context parallelism.

Returns:

Source code in vllm/v1/core/single_type_kv_cache_manager.py
@classmethod
@abstractmethod
def find_longest_cache_hit(
    cls,
    block_hashes: BlockHashList,
    max_length: int,
    kv_cache_group_ids: list[int],
    block_pool: BlockPool,
    kv_cache_spec: KVCacheSpec,
    drop_eagle_block: bool,
    alignment_tokens: int,
    dcp_world_size: int = 1,
    pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
    """
    Get the longest cache hit prefix of the blocks that is not longer than
    `max_length`. The prefix should be a common prefix hit for all the
    kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
    return an empty list.
    If eagle is enabled, drop the last matched block to force recompute the
    last block to get the required hidden states for eagle drafting head.
    Need to be customized for each attention type.

    Args:
        block_hashes: The block hashes of the request.
        max_length: The maximum length of the cache hit prefix.
        kv_cache_group_ids: The ids of the kv cache groups.
        block_pool: The block pool.
        kv_cache_spec: The kv cache spec.
        drop_eagle_block: Whether to drop the last matched block for EAGLE/MTP.
            Always False for non-EAGLE/MTP groups, but can be False for EAGLE/MTP
            groups too if the last block is already dropped (e.g., in a
            convergence loop in `find_longest_cache_hit`).
        alignment_tokens: The returned cache hit length (in tokens) should
            be a multiple of this value (in tokens). By default, it should
            be set to the block_size.
        dcp_world_size: The world size of decode context parallelism.
        pcp_world_size: The world size of prefill context parallelism.

    Returns:
        A list of cached blocks with skipped blocks replaced by null block
        for each kv cache group in `kv_cache_group_ids`.
        Return a list of length `len(kv_cache_group_ids)`, where the i-th
        element is a list of cached blocks for the i-th kv cache group
        in `kv_cache_group_ids`.
        For example, sliding window manager should return a list like
        ([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
        and sliding window 8 and len(kv_cache_group_ids) = 1.
    """

    raise NotImplementedError

free(request_id)

Free the blocks for the request.

Parameters:

  • request_id

    (str) –

    The request ID.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def free(self, request_id: str) -> None:
    """
    Free the blocks for the request.

    Args:
        request_id: The request ID.
    """
    # Default to [] in case a request is freed (aborted) before alloc.
    req_blocks = self.req_to_blocks.pop(request_id, [])

    # Free blocks in reverse order so that the tail blocks are
    # freed first.
    ordered_blocks = reversed(req_blocks)

    self.block_pool.free_blocks(ordered_blocks)
    self.num_cached_block.pop(request_id, None)

get_num_blocks_to_allocate(request_id, num_tokens, new_computed_blocks, total_computed_tokens, num_tokens_main_model, apply_admission_cap=False)

Get the number of blocks needed to be allocated for the request.

Parameters:

  • request_id

    (str) –

    The request ID.

  • num_tokens

    (int) –

    The total number of tokens that need a slot (including tokens that are already allocated).

  • new_computed_blocks

    (Sequence[KVCacheBlock]) –

    The new computed blocks just hitting the prefix caching.

  • total_computed_tokens

    (int) –

    Include both local and external computed tokens.

  • num_tokens_main_model

    (int) –

    The number of tokens for the main model (aka target model in spec decode). w/o spec decode, it is num_tokens; with spec decode, it is num_tokens - num_lookahead_tokens.

  • apply_admission_cap

    (bool, default: False ) –

    If True, clamp by num_required_blocks by _max_admission_blocks_per_requestfor recycling-aware specs (SWA, chunked-local).

Returns:

  • int

    The number of blocks to allocate.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_num_blocks_to_allocate(
    self,
    request_id: str,
    num_tokens: int,
    new_computed_blocks: Sequence[KVCacheBlock],
    total_computed_tokens: int,
    num_tokens_main_model: int,
    apply_admission_cap: bool = False,
) -> int:
    """
    Get the number of blocks needed to be allocated for the request.

    Args:
        request_id: The request ID.
        num_tokens: The total number of tokens that need a slot (including
            tokens that are already allocated).
        new_computed_blocks: The new computed blocks just hitting the
            prefix caching.
        total_computed_tokens: Include both local and external computed
            tokens.
        num_tokens_main_model: The number of tokens for the main model (aka target
            model in spec decode). w/o spec decode, it is num_tokens;
            with spec decode, it is num_tokens - num_lookahead_tokens.
        apply_admission_cap: If True, clamp by `num_required_blocks` by
            `_max_admission_blocks_per_request`for recycling-aware specs
            (SWA, chunked-local).

    Returns:
        The number of blocks to allocate.
    """

    num_required_blocks = cdiv(num_tokens, self.block_size)
    if apply_admission_cap and self._max_admission_blocks_per_request is not None:
        # Recycling-aware specs (SWA, chunked-local) cap the per-request
        # reservation here so admission matches the startup pool sizer
        # (`SlidingWindowSpec.max_admission_blocks_per_request` / its
        # chunked-local counterpart). `remove_skipped_blocks` runs from
        # `allocate_slots` before each chunk's `get_num_blocks_to_allocate`,
        # so per-request peak real-held blocks <= this cap, which keeps
        # `sum(reservations) <= pool` <=> `sum(peak_real_held) <= pool`.
        # Drift between the two would re-introduce the deadlock from
        # issue #39734 or, worse, mid-prefill OOM.
        num_required_blocks = min(
            num_required_blocks, self._max_admission_blocks_per_request
        )
    num_req_blocks = len(self.req_to_blocks.get(request_id, ()))

    if request_id in self.num_cached_block:
        # Fast-path: a running request won't have any new prefix-cache hits.
        assert len(new_computed_blocks) == 0
        # NOTE: With speculative decoding, request's blocks may be allocated
        # for draft tokens which are later rejected. In this case,
        # num_required_blocks may be smaller than num_req_blocks.
        return max(num_required_blocks - num_req_blocks, 0)

    num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
    num_local_computed_blocks = len(new_computed_blocks) + num_req_blocks
    # Number of whole blocks that are skipped by the attention window.
    # If nothing is skipped, this is 0.
    num_skipped_blocks = num_skipped_tokens // self.block_size
    # We need blocks for the non-skipped suffix. If there are still
    # local-computed blocks inside the window, they contribute to the
    # required capacity; otherwise, skipped blocks dominate.
    num_new_blocks = max(
        num_required_blocks - max(num_skipped_blocks, num_local_computed_blocks),
        0,
    )

    # Among the `new_computed_blocks`, the first `num_skipped_blocks` worth
    # of blocks are skipped; `num_req_blocks` of those may already be in
    # `req_to_blocks`, so only skip the remainder from `new_computed_blocks`.
    num_skipped_new_computed_blocks = max(0, num_skipped_blocks - num_req_blocks)

    # If a computed block is an eviction candidate (in the free queue and
    # ref_cnt == 0), it will be removed from the free queue when touched by
    # the allocated request, so we must count it in the free-capacity check.
    num_evictable_blocks = self._get_num_evictable_blocks(
        new_computed_blocks[num_skipped_new_computed_blocks:]
    )
    return num_new_blocks + num_evictable_blocks

get_num_common_prefix_blocks(running_request_id) abstractmethod

Get the number of common prefix blocks for all requests with allocated KV cache.

Parameters:

  • running_request_id

    (str) –

    The request ID.

Returns:

  • int

    The number of common prefix blocks for all requests with allocated

  • int

    KV cache.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
@abstractmethod
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
    """
    Get the number of common prefix blocks for all requests with allocated
    KV cache.

    Args:
        running_request_id: The request ID.

    Returns:
        The number of common prefix blocks for all requests with allocated
        KV cache.
    """

    raise NotImplementedError

get_num_skipped_tokens(num_computed_tokens)

Get the number of tokens that will be skipped for attention computation.

Parameters:

  • num_computed_tokens

    (int) –

    The number of tokens that have been computed.

Returns:

  • int

    The number of tokens that will be skipped for attention computation.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
    """
    Get the number of tokens that will be skipped for attention computation.

    Args:
        num_computed_tokens: The number of tokens that have been computed.

    Returns:
        The number of tokens that will be skipped for attention computation.
    """
    # The default behavior is to not skip any tokens.
    return 0

reachable_block_mask(start_block, end_block, alignment_tokens, kv_cache_spec, use_eagle, retention_interval=None, num_prompt_tokens=None) classmethod

Per-block mask for cache_full_blocks. None means cache every (non-null) block — the default for full attention.

Subclasses with sparse hit semantics (SWA) override this to skip blocks that can never serve a hit at any alignment-aligned prefix length.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
@classmethod
def reachable_block_mask(
    cls,
    start_block: int,
    end_block: int,
    alignment_tokens: int | None,
    kv_cache_spec: KVCacheSpec,
    use_eagle: bool,
    retention_interval: int | None = None,
    num_prompt_tokens: int | None = None,
) -> list[bool] | None:
    """Per-block mask for ``cache_full_blocks``. ``None`` means cache
    every (non-null) block — the default for full attention.

    Subclasses with sparse hit semantics (SWA) override this to skip
    blocks that can never serve a hit at any alignment-aligned prefix
    length.
    """
    return None

remove_skipped_blocks(request_id, total_computed_tokens)

Remove and free the blocks that are no longer needed for attention computation. The removed blocks should be replaced by null_block.

This function depends on get_num_skipped_tokens, which need to be implemented differently for each attention type.

Parameters:

  • request_id

    (str) –

    The request ID.

  • total_computed_tokens

    (int) –

    The total number of computed tokens, including local computed tokens and external computed tokens.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def remove_skipped_blocks(
    self, request_id: str, total_computed_tokens: int
) -> None:
    """
    Remove and free the blocks that are no longer needed for attention computation.
    The removed blocks should be replaced by null_block.

    This function depends on `get_num_skipped_tokens`, which need to be implemented
    differently for each attention type.

    Args:
        request_id: The request ID.
        total_computed_tokens: The total number of computed tokens, including
            local computed tokens and external computed tokens.
    """
    # Remove the blocks that will be skipped during attention computation.
    num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
    if num_skipped_tokens <= 0:
        # This indicates that ALL tokens are inside attention window.
        # Thus we do not need to free any blocks outside attention window.
        # A typical case is full attention that we never free any token
        # before the request is finished.
        return
    blocks = self.req_to_blocks[request_id]
    num_skipped_blocks = num_skipped_tokens // self.block_size
    # `num_skipped_tokens` may include tokens that haven't been allocated yet
    # (e.g., when the attention window moves into the external computed tokens
    # range), so we must cap to the number of blocks that currently exist for
    # this request.
    num_skipped_blocks = min(num_skipped_blocks, len(blocks))

    # Reuse skipped local blocks in order:
    #   scratch blocks: no prefix-cache value, reuse first.
    #   cached blocks: reusable prefix-cache value, reuse last.
    removed_cached_blocks: list[KVCacheBlock] = []
    removed_uncached_blocks: list[KVCacheBlock] = []
    # Because the block starts from index 0, the num_skipped_block-th block
    # corresponds to index num_skipped_blocks - 1.
    for i in range(num_skipped_blocks - 1, -1, -1):
        if blocks[i] == self._null_block:
            # If the block is already a null block, the blocks before it
            # should also have been set to null blocks by the previous calls
            # to this function.
            break
        if blocks[i].block_hash is None:
            removed_uncached_blocks.append(blocks[i])
        else:
            removed_cached_blocks.append(blocks[i])
        blocks[i] = self._null_block
    # `prepend=True` makes uncached scratch blocks the next allocation
    # candidates, while cached blocks stay behind them as best-effort
    # prefix-cache entries.
    self.block_pool.free_blocks(removed_cached_blocks)
    self.block_pool.free_blocks(removed_uncached_blocks, prepend=True)

take_new_block_ids()

Drain and return block IDs allocated since the last call.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def take_new_block_ids(self) -> list[int]:
    """Drain and return block IDs allocated since the last call."""
    ids = self.new_block_ids
    self.new_block_ids = []
    return ids

SlidingWindowManager

Bases: SingleTypeKVCacheManager

Methods:

Source code in vllm/v1/core/single_type_kv_cache_manager.py
class SlidingWindowManager(SingleTypeKVCacheManager):
    def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None:
        super().__init__(kv_cache_spec, **kwargs)
        self.sliding_window = kv_cache_spec.sliding_window

    @classmethod
    def _contiguous_blocks_for_hit(
        cls, window_size: int, block_size: int, use_eagle: bool
    ) -> int:
        blocks = cdiv(window_size - 1, block_size)
        if use_eagle:
            # Need to drop the last matched block if eagle is enabled. For
            # sliding window layer, we achieve this by increasing the number of
            # contiguous blocks needed for prefix cache hit by one and dropping
            # the last matched block.
            blocks += 1
        return blocks

    @classmethod
    def find_longest_cache_hit(
        cls,
        block_hashes: BlockHashList,
        max_length: int,
        kv_cache_group_ids: list[int],
        block_pool: BlockPool,
        kv_cache_spec: KVCacheSpec,
        drop_eagle_block: bool,
        alignment_tokens: int,
        dcp_world_size: int = 1,
        pcp_world_size: int = 1,
    ) -> tuple[list[KVCacheBlock], ...]:
        assert isinstance(kv_cache_spec, SlidingWindowSpec), (
            "SlidingWindowManager can only be used for sliding window groups"
        )
        assert dcp_world_size == 1, "DCP not support sliding window attn now."
        assert pcp_world_size == 1, "PCP not support sliding window attn now."

        # The number of contiguous blocks needed for a prefix cache hit.
        sliding_window_contiguous_blocks = cls._contiguous_blocks_for_hit(
            kv_cache_spec.sliding_window, kv_cache_spec.block_size, drop_eagle_block
        )

        # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
        # optimize the time complexity from O(max_num_blocks) to
        # O(max_num_blocks / sliding_window_contiguous_blocks +
        # sliding_window_contiguous_blocks),
        # which is good for low cache hit rate scenarios.
        max_num_blocks = max_length // kv_cache_spec.block_size
        computed_blocks = tuple(
            [block_pool.null_block] * max_num_blocks
            for _ in range(len(kv_cache_group_ids))
        )
        block_size = kv_cache_spec.block_size
        num_contiguous_blocks = 0
        match_found = False
        # Search from right to left and early stop when a match is found.
        for i in range(max_num_blocks - 1, -1, -1):
            if cached_block := block_pool.get_cached_block(
                block_hashes[i], kv_cache_group_ids
            ):
                # Skip prefix matching check if the block is not aligned with
                # `alignment_tokens`.
                if num_contiguous_blocks == 0 and block_size != alignment_tokens:
                    post_pop_blocks = i if drop_eagle_block else i + 1
                    if (post_pop_blocks * block_size) % alignment_tokens != 0:
                        continue
                # Add the cached block to the computed blocks.
                for computed, cached in zip(computed_blocks, cached_block):
                    computed[i] = cached
                num_contiguous_blocks += 1
                if num_contiguous_blocks >= sliding_window_contiguous_blocks:
                    # Trim the trailing blocks.
                    # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
                    # when sliding_window_contiguous_blocks=2.
                    for computed in computed_blocks:
                        del computed[i + num_contiguous_blocks :]
                    match_found = True
                    break
            else:
                num_contiguous_blocks = 0
        if not match_found:
            # The first `num_contiguous_blocks` is a cache hit even if
            # `num_contiguous_blocks < sliding_window_contiguous_blocks`.
            for computed in computed_blocks:
                del computed[num_contiguous_blocks:]
            while (
                block_size != alignment_tokens  # Faster for common case.
                and len(computed_blocks[0]) * block_size % alignment_tokens != 0
            ):
                for computed in computed_blocks:
                    computed.pop()
        if drop_eagle_block and computed_blocks[0]:
            for computed in computed_blocks:
                computed.pop()
            # Re-align after eagle pop: the pop may break the alignment
            # when block_size != alignment_tokens (hybrid models with
            # different page sizes, e.g. Gemma4).
            while (
                block_size != alignment_tokens
                and len(computed_blocks[0]) * block_size % alignment_tokens != 0
            ):
                for computed in computed_blocks:
                    computed.pop()
        return computed_blocks

    @classmethod
    def reachable_block_mask(
        cls,
        start_block: int,
        end_block: int,
        alignment_tokens: int | None,
        kv_cache_spec: KVCacheSpec,
        use_eagle: bool,
        retention_interval: int | None = None,
        num_prompt_tokens: int | None = None,
    ) -> list[bool] | None:
        assert isinstance(kv_cache_spec, SlidingWindowSpec)
        if alignment_tokens is None:
            # Fast path: when the coordinator imposes no alignment constraint.
            return None
        assert alignment_tokens % kv_cache_spec.block_size == 0

        block_size = kv_cache_spec.block_size
        # Contiguous blocks a hit needs at a boundary (incl. the EAGLE peek).
        need = cls._contiguous_blocks_for_hit(
            window_size=kv_cache_spec.sliding_window,
            block_size=block_size,
            use_eagle=use_eagle,
        )
        # The matched run's right edge sits on the aligned boundary block when
        # EAGLE peeks one block past it (shift=1), otherwise on the last block
        # before the boundary (shift=0).
        shift = 1 if use_eagle else 0

        mask = [False] * (end_block - start_block)

        # (1) Segment-boundary tails. ``retention_interval``:
        #   None -> dense (a tail at every ``alignment_tokens`` boundary);
        #   0    -> no dense tails (only the replay boundary below);
        #   >0   -> a tail once per ``retention_interval``-sized segment.
        segment_tokens = (
            alignment_tokens
            if retention_interval is None
            else (None if retention_interval == 0 else retention_interval)
        )
        if segment_tokens is not None:
            per_segment = segment_tokens // block_size
            if need >= per_segment:
                # Every block is reachable; cache them all.
                return None
            for i in range(start_block, end_block):
                if i >= shift and (i - shift) % per_segment >= per_segment - need:
                    mask[i - start_block] = True

        # (2) Replay-boundary tail. ``get_computed_blocks`` caps hits at
        # ``num_prompt - 1`` (to recompute the last token's logits), so an exact
        # prompt replay can only land on the latest *fine*-aligned boundary.
        # Sparse retention would otherwise skip it, so keep its tail explicitly.
        if retention_interval is not None and num_prompt_tokens is not None:
            latest = (num_prompt_tokens - 1) // alignment_tokens * alignment_tokens
            prompt_end_block = latest // block_size + shift
            for i in range(
                max(start_block, prompt_end_block - need),
                min(end_block, prompt_end_block),
            ):
                mask[i - start_block] = True

        return mask

    def free(self, request_id: str) -> None:
        # similar to remove_skipped_blocks(), prepend the uncached blocks
        # and append the cached blocks to the free queue
        req_blocks = self.req_to_blocks.pop(request_id, [])
        if req_blocks:
            cached_blocks: list[KVCacheBlock] = []
            uncached_blocks: list[KVCacheBlock] = []
            for block in reversed(req_blocks):
                if block.block_hash is None:
                    uncached_blocks.append(block)
                else:
                    cached_blocks.append(block)
            self.block_pool.free_blocks(cached_blocks)
            self.block_pool.free_blocks(uncached_blocks, prepend=True)
        self.num_cached_block.pop(request_id, None)

    def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
        """
        Get the number of tokens that will be skipped for attention computation.

        For sliding window, this corresponds to the tokens that are prior to
        the current sliding window.

        Example:
        sliding_window=4, num_computed_tokens=7

        Tokens:   [ 0  1  2  3  4  5  6  7 ]
                  | ---- computed -----|
                                         ^ next token to be computed
                               |-----------| sliding window for next token
                  |--skipped---|

        The current window contains tokens 4~7. Tokens 0~3 will be skipped for
        attention computation since they are outside the sliding window.
        Thus, get_num_skipped_tokens(7) == 4.

        Args:
            num_computed_tokens: The number of tokens that have been computed.

        Returns:
            The number of tokens that will be skipped for attention computation.
        """
        return max(0, num_computed_tokens - self.sliding_window + 1)

    def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
        """
        NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
        So it's not correct to count ref_cnt like FullAttentionManager. Return
        0 here for correctness. Need to support cascade attention + sliding
        window in the future.
        """
        return 0

get_num_common_prefix_blocks(running_request_id)

NOTE(Chen): The prefix blocks are null blocks for sliding window layers. So it's not correct to count ref_cnt like FullAttentionManager. Return 0 here for correctness. Need to support cascade attention + sliding window in the future.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
    """
    NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
    So it's not correct to count ref_cnt like FullAttentionManager. Return
    0 here for correctness. Need to support cascade attention + sliding
    window in the future.
    """
    return 0

get_num_skipped_tokens(num_computed_tokens)

Get the number of tokens that will be skipped for attention computation.

For sliding window, this corresponds to the tokens that are prior to the current sliding window.

Example: sliding_window=4, num_computed_tokens=7

[ 0 1 2 3 4 5 6 7 ]

| ---- computed -----| ^ next token to be computed |-----------| sliding window for next token |--skipped---|

The current window contains tokens 4~7. Tokens 0~3 will be skipped for attention computation since they are outside the sliding window. Thus, get_num_skipped_tokens(7) == 4.

Parameters:

  • num_computed_tokens

    (int) –

    The number of tokens that have been computed.

Returns:

  • int

    The number of tokens that will be skipped for attention computation.

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
    """
    Get the number of tokens that will be skipped for attention computation.

    For sliding window, this corresponds to the tokens that are prior to
    the current sliding window.

    Example:
    sliding_window=4, num_computed_tokens=7

    Tokens:   [ 0  1  2  3  4  5  6  7 ]
              | ---- computed -----|
                                     ^ next token to be computed
                           |-----------| sliding window for next token
              |--skipped---|

    The current window contains tokens 4~7. Tokens 0~3 will be skipped for
    attention computation since they are outside the sliding window.
    Thus, get_num_skipped_tokens(7) == 4.

    Args:
        num_computed_tokens: The number of tokens that have been computed.

    Returns:
        The number of tokens that will be skipped for attention computation.
    """
    return max(0, num_computed_tokens - self.sliding_window + 1)

get_manager_for_kv_cache_spec(kv_cache_spec, max_num_batched_tokens, max_model_len, **kwargs)

Get the appropriate manager for a given KVCacheSpec.

Uses the KVCacheSpecRegistry to look up the manager class, supporting both built-in and custom specs registered via @register_kv_cache_spec and KVCacheSpecRegistry.register.

Parameters:

  • kv_cache_spec

    (KVCacheSpec) –

    The KVCacheSpec instance

  • max_num_batched_tokens

    (int) –

    The maximum number of tokens in a batch

  • max_model_len

    (int) –

    The maximum context length the model could serve

Returns: An instance of the appropriate SingleTypeKVCacheManager subclass

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def get_manager_for_kv_cache_spec(
    kv_cache_spec: KVCacheSpec,
    max_num_batched_tokens: int,
    max_model_len: int,
    **kwargs,
) -> SingleTypeKVCacheManager:
    """
    Get the appropriate manager for a given KVCacheSpec.

    Uses the KVCacheSpecRegistry to look up the manager class, supporting
    both built-in and custom specs registered via @register_kv_cache_spec
    and KVCacheSpecRegistry.register.

    Args:
        kv_cache_spec: The KVCacheSpec instance
        max_num_batched_tokens: The maximum number of tokens in a batch
        max_model_len: The maximum context length the model could serve
    Returns:
        An instance of the appropriate SingleTypeKVCacheManager subclass
    """
    manager_class = KVCacheSpecRegistry.get_manager_class(kv_cache_spec)
    assert manager_class is not None, (
        f"No manager registered for KVCacheSpec {type(kv_cache_spec)}"
    )
    # SlidingWindow / ChunkedLocalAttention managers recycle blocks across
    # chunks; the runtime admission cap must match the recycling-aware bound
    # the startup pool sizer uses (single source of truth: the spec method).
    if isinstance(kv_cache_spec, (SlidingWindowSpec, ChunkedLocalAttentionSpec)):
        kwargs["max_admission_blocks_per_request"] = (
            kv_cache_spec.max_admission_blocks_per_request(
                max_num_batched_tokens=max_num_batched_tokens,
                max_model_len=max_model_len,
            )
        )
    manager = manager_class(kv_cache_spec, **kwargs)
    return manager

register_all_kvcache_specs(vllm_config)

Built-in spec registration

Source code in vllm/v1/core/single_type_kv_cache_manager.py
def register_all_kvcache_specs(vllm_config):
    """Built-in spec registration"""
    KVCacheSpecRegistry.register(
        FullAttentionSpec,
        FullAttentionManager,
        uniform_type_base_spec=FullAttentionSpec,
    )

    KVCacheSpecRegistry.register(
        SlidingWindowSpec,
        SlidingWindowManager,
        uniform_type_base_spec=SlidingWindowSpec,
    )
    KVCacheSpecRegistry.register(
        SlidingWindowMLASpec,
        SlidingWindowManager,
        uniform_type_base_spec=SlidingWindowMLASpec,
    )

    KVCacheSpecRegistry.register(
        MambaSpec, MambaManager, uniform_type_base_spec=MambaSpec
    )
    KVCacheSpecRegistry.register(
        ChunkedLocalAttentionSpec,
        ChunkedLocalAttentionManager,
        uniform_type_base_spec=ChunkedLocalAttentionSpec,
    )
    KVCacheSpecRegistry.register(
        CrossAttentionSpec,
        CrossAttentionManager,
        uniform_type_base_spec=CrossAttentionSpec,
    )

    # FullAttentionSpec subclasses — grouped with FullAttentionSpec
    KVCacheSpecRegistry.register(
        TQFullAttentionSpec,
        FullAttentionManager,
        uniform_type_base_spec=FullAttentionSpec,
    )
    KVCacheSpecRegistry.register(
        MLAAttentionSpec, FullAttentionManager, uniform_type_base_spec=FullAttentionSpec
    )
    # NOTE(Mengqing): HiddenStateCacheSpec won't take part in
    # grouping, thus the uniform_type_base_spec is just a
    # placeholder.
    KVCacheSpecRegistry.register(
        HiddenStateCacheSpec,
        FullAttentionManager,
        uniform_type_base_spec=FullAttentionSpec,
    )
    KVCacheSpecRegistry.register(
        SinkFullAttentionSpec,
        SinkFullAttentionManager,
        uniform_type_base_spec=FullAttentionSpec,
    )

    from vllm.platforms import current_platform

    current_platform.register_custom_kv_cache_specs(vllm_config)