Skip to content

vllm.v1.attention.backends.flash_attn

Attention layer with FlashAttention.

logger module-attribute

logger = init_logger(__name__)

FlashAttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/flash_attn.py
class FlashAttentionBackend(AttentionBackend):

    accept_output_buffer: bool = True

    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        return [32, 64, 96, 128, 160, 192, 224, 256]

    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN_VLLM_V1"

    @staticmethod
    def get_impl_cls() -> type["FlashAttentionImpl"]:
        return FlashAttentionImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return FlashAttentionMetadata

    @staticmethod
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
        return FlashAttentionMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[FlashAttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/flash_attn.py
@staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
    return FlashAttentionMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[FlashAttentionImpl]
Source code in vllm/v1/attention/backends/flash_attn.py
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
    return FlashAttentionImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flash_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")
    return (2, num_blocks, block_size, num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/flash_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return FlashAttentionMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/flash_attn.py
@staticmethod
def get_name() -> str:
    return "FLASH_ATTN_VLLM_V1"

get_supported_head_sizes staticmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/flash_attn.py
@staticmethod
def get_supported_head_sizes() -> list[int]:
    return [32, 64, 96, 128, 160, 192, 224, 256]

FlashAttentionImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/flash_attn.py
class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[dict[str, Any]] = None,
        logits_soft_cap: Optional[float] = None,
        attn_type: AttentionType = AttentionType.DECODER,
        use_irope: bool = False,
    ) -> None:
        if blocksparse_params is not None:
            raise ValueError(
                "FlashAttention does not support block-sparse attention.")
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
        if head_size not in support_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by FlashAttention. "
                f"Supported head sizes are: {support_head_sizes}. "
                "Set VLLM_USE_V1=0 to use another attention backend.")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashAttentionImpl")
        self.use_irope = use_irope
        self.vllm_flash_attn_version = get_flash_attn_version()
        if is_quantized_kv_cache(self.kv_cache_dtype) \
            and not flash_attn_supports_fp8():
            raise NotImplementedError(
                "FlashAttention does not support fp8 kv-cache on this device.")

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        output: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
        """
        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
            # Profiling run.
            return output

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens
        # Reshape the input keys and values and store them in the cache.
        # NOTE(woosuk): Here, key and value are padded while slot_mapping is
        # not padded. However, we don't need to do key[:num_actual_tokens] and
        # value[:num_actual_tokens] because the reshape_and_cache_flash op uses
        # the slot_mapping's shape to determine the number of actual tokens.
        key_cache, value_cache = kv_cache.unbind(0)
        torch.ops._C_cache_ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

        if self.kv_cache_dtype.startswith("fp8"):
            key_cache = key_cache.view(torch.float8_e4m3fn)
            value_cache = value_cache.view(torch.float8_e4m3fn)
            num_tokens, num_heads, head_size = query.shape
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))

        # Compute attention and update output up to `num_actual_tokens`.
        use_local_attn = \
            (self.use_irope and attn_metadata.local_attn_metadata is not None)

        if not attn_metadata.use_cascade or use_local_attn:
            if use_local_attn:
                assert attn_metadata.local_attn_metadata is not None
                local_metadata = attn_metadata.local_attn_metadata
                cu_seqlens_q = local_metadata.local_query_start_loc
                seqused_k = local_metadata.local_seqused_k
                max_seqlen_q = local_metadata.local_max_query_len
                max_seqlen_k = local_metadata.local_max_seq_len
                block_table = local_metadata.local_block_table
                scheduler_metadata = local_metadata.local_scheduler_metadata
            else:
                cu_seqlens_q = attn_metadata.query_start_loc
                seqused_k = attn_metadata.seq_lens
                max_seqlen_q = attn_metadata.max_query_len
                max_seqlen_k = attn_metadata.max_seq_len
                block_table = attn_metadata.block_table
                scheduler_metadata = attn_metadata.scheduler_metadata

            descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

            flash_attn_varlen_func(
                q=query[:num_actual_tokens],
                k=key_cache,
                v=value_cache,
                out=output[:num_actual_tokens],
                cu_seqlens_q=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                seqused_k=seqused_k,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=True,
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                block_table=block_table,
                softcap=self.logits_soft_cap,
                scheduler_metadata=scheduler_metadata,
                fa_version=self.vllm_flash_attn_version,
                q_descale=layer._q_scale.expand(descale_shape),
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
            )
            return output

        assert not use_local_attn, (
            "Cascade attention does not support local attention.")
        # Cascade attention (rare case).
        cascade_attention(
            output[:num_actual_tokens],
            query[:num_actual_tokens],
            key_cache,
            value_cache,
            cu_query_lens=attn_metadata.query_start_loc,
            max_query_len=attn_metadata.max_query_len,
            cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
            prefix_kv_lens=attn_metadata.prefix_kv_lens,
            suffix_kv_lens=attn_metadata.suffix_kv_lens,
            max_kv_len=attn_metadata.max_seq_len,
            softmax_scale=self.scale,
            alibi_slopes=self.alibi_slopes,
            sliding_window=self.sliding_window,
            logits_soft_cap=self.logits_soft_cap,
            block_table=attn_metadata.block_table,
            common_prefix_len=attn_metadata.common_prefix_len,
            fa_version=self.vllm_flash_attn_version,
            prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
            suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
            q_descale=layer._q_scale,
            k_descale=layer._k_scale,
            v_descale=layer._v_scale,
        )
        return output

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = float(scale)

sliding_window instance-attribute

sliding_window = (-1, -1)

use_irope instance-attribute

use_irope = use_irope

vllm_flash_attn_version instance-attribute

vllm_flash_attn_version = get_flash_attn_version()

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]] = None,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = DECODER,
    use_irope: bool = False,
) -> None
Source code in vllm/v1/attention/backends/flash_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]] = None,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = AttentionType.DECODER,
    use_irope: bool = False,
) -> None:
    if blocksparse_params is not None:
        raise ValueError(
            "FlashAttention does not support block-sparse attention.")
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    if sliding_window is None:
        self.sliding_window = (-1, -1)
    else:
        self.sliding_window = (sliding_window - 1, 0)
    self.kv_cache_dtype = kv_cache_dtype
    if logits_soft_cap is None:
        # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
        logits_soft_cap = 0
    self.logits_soft_cap = logits_soft_cap

    assert self.num_heads % self.num_kv_heads == 0
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
    if head_size not in support_head_sizes:
        raise ValueError(
            f"Head size {head_size} is not supported by FlashAttention. "
            f"Supported head sizes are: {support_head_sizes}. "
            "Set VLLM_USE_V1=0 to use another attention backend.")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "FlashAttentionImpl")
    self.use_irope = use_irope
    self.vllm_flash_attn_version = get_flash_attn_version()
    if is_quantized_kv_cache(self.kv_cache_dtype) \
        and not flash_attn_supports_fp8():
        raise NotImplementedError(
            "FlashAttention does not support fp8 kv-cache on this device.")

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlashAttentionMetadata,
    output: Optional[Tensor] = None,
) -> Tensor

Forward pass with FlashAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
attn_metadata FlashAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values

Source code in vllm/v1/attention/backends/flash_attn.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: FlashAttentionMetadata,
    output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with FlashAttention.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    NOTE: FP8 quantization, flash-attn expect the size of
          {q,k,v}_descale to be (num_sequences, num_kv_heads).
          We use torch's .expand() to avoid duplicating values
    """
    assert output is not None, "Output tensor must be provided."

    if attn_metadata is None:
        # Profiling run.
        return output

    # IMPORTANT!
    # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
    # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
    # in this method. For example, `view` and `slice` (or `[:n]`) operations
    # are surprisingly slow even in the case they do not invoke any GPU ops.
    # Minimize the PyTorch ops in this method as much as possible.
    # Whenever making a change in this method, please benchmark the
    # performance to make sure it does not introduce any overhead.

    num_actual_tokens = attn_metadata.num_actual_tokens
    # Reshape the input keys and values and store them in the cache.
    # NOTE(woosuk): Here, key and value are padded while slot_mapping is
    # not padded. However, we don't need to do key[:num_actual_tokens] and
    # value[:num_actual_tokens] because the reshape_and_cache_flash op uses
    # the slot_mapping's shape to determine the number of actual tokens.
    key_cache, value_cache = kv_cache.unbind(0)
    torch.ops._C_cache_ops.reshape_and_cache_flash(
        key,
        value,
        key_cache,
        value_cache,
        attn_metadata.slot_mapping,
        self.kv_cache_dtype,
        layer._k_scale,
        layer._v_scale,
    )

    if self.kv_cache_dtype.startswith("fp8"):
        key_cache = key_cache.view(torch.float8_e4m3fn)
        value_cache = value_cache.view(torch.float8_e4m3fn)
        num_tokens, num_heads, head_size = query.shape
        query, _ = ops.scaled_fp8_quant(
            query.reshape(
                (num_tokens, num_heads * head_size)).contiguous(),
            layer._q_scale)
        query = query.reshape((num_tokens, num_heads, head_size))

    # Compute attention and update output up to `num_actual_tokens`.
    use_local_attn = \
        (self.use_irope and attn_metadata.local_attn_metadata is not None)

    if not attn_metadata.use_cascade or use_local_attn:
        if use_local_attn:
            assert attn_metadata.local_attn_metadata is not None
            local_metadata = attn_metadata.local_attn_metadata
            cu_seqlens_q = local_metadata.local_query_start_loc
            seqused_k = local_metadata.local_seqused_k
            max_seqlen_q = local_metadata.local_max_query_len
            max_seqlen_k = local_metadata.local_max_seq_len
            block_table = local_metadata.local_block_table
            scheduler_metadata = local_metadata.local_scheduler_metadata
        else:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata

        descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

        flash_attn_varlen_func(
            q=query[:num_actual_tokens],
            k=key_cache,
            v=value_cache,
            out=output[:num_actual_tokens],
            cu_seqlens_q=cu_seqlens_q,
            max_seqlen_q=max_seqlen_q,
            seqused_k=seqused_k,
            max_seqlen_k=max_seqlen_k,
            softmax_scale=self.scale,
            causal=True,
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            block_table=block_table,
            softcap=self.logits_soft_cap,
            scheduler_metadata=scheduler_metadata,
            fa_version=self.vllm_flash_attn_version,
            q_descale=layer._q_scale.expand(descale_shape),
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
        )
        return output

    assert not use_local_attn, (
        "Cascade attention does not support local attention.")
    # Cascade attention (rare case).
    cascade_attention(
        output[:num_actual_tokens],
        query[:num_actual_tokens],
        key_cache,
        value_cache,
        cu_query_lens=attn_metadata.query_start_loc,
        max_query_len=attn_metadata.max_query_len,
        cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
        prefix_kv_lens=attn_metadata.prefix_kv_lens,
        suffix_kv_lens=attn_metadata.suffix_kv_lens,
        max_kv_len=attn_metadata.max_seq_len,
        softmax_scale=self.scale,
        alibi_slopes=self.alibi_slopes,
        sliding_window=self.sliding_window,
        logits_soft_cap=self.logits_soft_cap,
        block_table=attn_metadata.block_table,
        common_prefix_len=attn_metadata.common_prefix_len,
        fa_version=self.vllm_flash_attn_version,
        prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
        suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
        q_descale=layer._q_scale,
        k_descale=layer._k_scale,
        v_descale=layer._v_scale,
    )
    return output

FlashAttentionMetadata dataclass

Source code in vllm/v1/attention/backends/flash_attn.py
@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

    num_actual_tokens: int  # Number of tokens excluding padding.
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table: torch.Tensor
    slot_mapping: torch.Tensor

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]

    # Optional aot scheduling
    scheduler_metadata: Optional[torch.Tensor] = None
    prefix_scheduler_metadata: Optional[torch.Tensor] = None

    # for local attention
    @dataclass
    class LocalAttentionMetadata:
        local_query_start_loc: torch.Tensor
        local_seqused_k: torch.Tensor
        local_block_table: torch.Tensor
        local_max_query_len: int
        local_max_seq_len: int
        local_scheduler_metadata: Optional[torch.Tensor]

    local_attn_metadata: Optional[LocalAttentionMetadata] = None

block_table instance-attribute

block_table: Tensor

common_prefix_len instance-attribute

common_prefix_len: int

cu_prefix_query_lens instance-attribute

cu_prefix_query_lens: Optional[Tensor]

local_attn_metadata class-attribute instance-attribute

local_attn_metadata: Optional[LocalAttentionMetadata] = None

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

prefix_kv_lens instance-attribute

prefix_kv_lens: Optional[Tensor]

prefix_scheduler_metadata class-attribute instance-attribute

prefix_scheduler_metadata: Optional[Tensor] = None

query_start_loc instance-attribute

query_start_loc: Tensor

scheduler_metadata class-attribute instance-attribute

scheduler_metadata: Optional[Tensor] = None

seq_lens instance-attribute

seq_lens: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

suffix_kv_lens instance-attribute

suffix_kv_lens: Optional[Tensor]

use_cascade instance-attribute

use_cascade: bool

LocalAttentionMetadata dataclass

Source code in vllm/v1/attention/backends/flash_attn.py
@dataclass
class LocalAttentionMetadata:
    local_query_start_loc: torch.Tensor
    local_seqused_k: torch.Tensor
    local_block_table: torch.Tensor
    local_max_query_len: int
    local_max_seq_len: int
    local_scheduler_metadata: Optional[torch.Tensor]

local_block_table instance-attribute

local_block_table: Tensor

local_max_query_len instance-attribute

local_max_query_len: int

local_max_seq_len instance-attribute

local_max_seq_len: int

local_query_start_loc instance-attribute

local_query_start_loc: Tensor

local_scheduler_metadata instance-attribute

local_scheduler_metadata: Optional[Tensor]

local_seqused_k instance-attribute

local_seqused_k: Tensor

__init__

__init__(
    local_query_start_loc: Tensor,
    local_seqused_k: Tensor,
    local_block_table: Tensor,
    local_max_query_len: int,
    local_max_seq_len: int,
    local_scheduler_metadata: Optional[Tensor],
) -> None

__init__

__init__(
    num_actual_tokens: int,
    max_query_len: int,
    query_start_loc: Tensor,
    max_seq_len: int,
    seq_lens: Tensor,
    block_table: Tensor,
    slot_mapping: Tensor,
    use_cascade: bool,
    common_prefix_len: int,
    cu_prefix_query_lens: Optional[Tensor],
    prefix_kv_lens: Optional[Tensor],
    suffix_kv_lens: Optional[Tensor],
    scheduler_metadata: Optional[Tensor] = None,
    prefix_scheduler_metadata: Optional[Tensor] = None,
    local_attn_metadata: Optional[
        LocalAttentionMetadata
    ] = None,
) -> None

FlashAttentionMetadataBuilder

Source code in vllm/v1/attention/backends/flash_attn.py
class FlashAttentionMetadataBuilder:

    def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
        model_config = runner.model_config
        compilation_config = runner.vllm_config.compilation_config

        self.runner = runner
        self.num_heads_q = model_config.get_num_attention_heads(
            runner.parallel_config)
        self.num_heads_kv = model_config.get_num_kv_heads(
            runner.parallel_config)
        self.headdim = model_config.get_head_size()
        self.block_size = kv_cache_spec.block_size
        self.kv_cache_spec = kv_cache_spec
        self.block_table = block_table

        if get_flash_attn_version() == 3:
            self.aot_schedule = not compilation_config.full_cuda_graph
            if not self.aot_schedule:
                logger.warning(
                    "AOT Schedule is disabled when using full_cuda_graph")
        else:
            self.aot_schedule = False

        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
        self.aot_sliding_window: Optional[tuple[int, int]] = None

    def reorder_batch(self, input_batch: "InputBatch",
                      scheduler_output: "SchedulerOutput") -> bool:
        return False

    def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata):
        max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
        block_table = self.block_table
        block_table_tensor = block_table.get_device_tensor()[:num_reqs]

        block_table.slot_mapping[:num_actual_tokens].copy_(
            block_table.slot_mapping_cpu[:num_actual_tokens],
            non_blocking=True)
        # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
        # mode.
        block_table.slot_mapping[num_actual_tokens:].fill_(-1)

        slot_mapping = block_table.slot_mapping[:num_actual_tokens]

        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
            if self.aot_schedule:
                sliding_window_configs = _get_sliding_window_configs(
                    self.runner.vllm_config)
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False

        def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
                     max_seq_len, causal):
            if self.aot_schedule:
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
                    cache_seqlens=seqlens,
                    num_heads_q=self.num_heads_q,
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
                    page_size=self.block_size,
                    cu_seqlens_q=cu_query_lens,
                    causal=causal,
                    window_size=self.aot_sliding_window,
                )
            return None

        # for local attention
        local_attn_metadata = None
        if self.runner.attention_chunk_size is not None:
            seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
                virt_block_table_tensor = make_local_attention_virtual_batches(
                    self.runner.attention_chunk_size,
                    self.runner.query_start_loc_np[:num_reqs + 1],
                    self.runner.seq_lens_np[:num_reqs],
                    block_table_tensor,
                    self.block_size,
                )
            local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
                self.runner.device, non_blocking=True)
            local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
                self.runner.device, non_blocking=True)
            local_max_query_len = seqlens_q_local_np.max()
            local_max_seq_len = virt_k_seqlens_np.max()
            local_scheduler_metadata = schedule(
                batch_size=local_query_start_loc.shape[0] - 1,
                cu_query_lens=local_query_start_loc,
                max_query_len=local_max_query_len,
                seqlens=local_seqused_k,
                max_seq_len=local_max_seq_len,
                causal=True)

            local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
                local_query_start_loc=local_query_start_loc,
                local_seqused_k=local_seqused_k,
                local_block_table=virt_block_table_tensor,
                local_max_query_len=local_max_query_len,
                local_max_seq_len=local_max_seq_len,
                local_scheduler_metadata=local_scheduler_metadata,
            )

        use_cascade = common_prefix_len > 0

        if use_cascade:
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
                                                device=self.runner.device)
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
                                          device=self.runner.device)
            suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
                              common_prefix_len)
            suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
                self.runner.device)
            prefix_scheduler_metadata = schedule(
                batch_size=1,
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
                causal=False)
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
                                          max_query_len=max_query_len,
                                          seqlens=suffix_kv_lens,
                                          max_seq_len=max_seq_len -
                                          common_prefix_len,
                                          causal=True)
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None
            prefix_scheduler_metadata = None
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
                                          max_query_len=max_query_len,
                                          seqlens=seq_lens,
                                          max_seq_len=max_seq_len,
                                          causal=True)

        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table=block_table_tensor,
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
            scheduler_metadata=scheduler_metadata,
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
            local_attn_metadata=local_attn_metadata,
            prefix_scheduler_metadata=prefix_scheduler_metadata,
        )
        return attn_metadata

    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

aot_schedule instance-attribute

aot_schedule = not full_cuda_graph

aot_sliding_window instance-attribute

aot_sliding_window: Optional[tuple[int, int]] = None

block_size instance-attribute

block_size = block_size

block_table instance-attribute

block_table = block_table

headdim instance-attribute

headdim = get_head_size()

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

num_heads_kv instance-attribute

num_heads_kv = get_num_kv_heads(parallel_config)

num_heads_q instance-attribute

num_heads_q = get_num_attention_heads(parallel_config)

runner instance-attribute

runner = runner

__init__

__init__(
    runner: GPUModelRunner,
    kv_cache_spec: AttentionSpec,
    block_table: BlockTable,
)
Source code in vllm/v1/attention/backends/flash_attn.py
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
             block_table: BlockTable):
    model_config = runner.model_config
    compilation_config = runner.vllm_config.compilation_config

    self.runner = runner
    self.num_heads_q = model_config.get_num_attention_heads(
        runner.parallel_config)
    self.num_heads_kv = model_config.get_num_kv_heads(
        runner.parallel_config)
    self.headdim = model_config.get_head_size()
    self.block_size = kv_cache_spec.block_size
    self.kv_cache_spec = kv_cache_spec
    self.block_table = block_table

    if get_flash_attn_version() == 3:
        self.aot_schedule = not compilation_config.full_cuda_graph
        if not self.aot_schedule:
            logger.warning(
                "AOT Schedule is disabled when using full_cuda_graph")
    else:
        self.aot_schedule = False

    # Sliding window size to be used with the AOT scheduler will be
    # populated on first build() call.
    self.aot_sliding_window: Optional[tuple[int, int]] = None

build

build(
    num_reqs: int,
    num_actual_tokens: int,
    max_query_len: int,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
)
Source code in vllm/v1/attention/backends/flash_attn.py
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata):
    max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
    query_start_loc = common_attn_metadata.query_start_loc
    seq_lens = common_attn_metadata.seq_lens
    block_table = self.block_table
    block_table_tensor = block_table.get_device_tensor()[:num_reqs]

    block_table.slot_mapping[:num_actual_tokens].copy_(
        block_table.slot_mapping_cpu[:num_actual_tokens],
        non_blocking=True)
    # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
    # mode.
    block_table.slot_mapping[num_actual_tokens:].fill_(-1)

    slot_mapping = block_table.slot_mapping[:num_actual_tokens]

    if self.aot_sliding_window is None:
        self.aot_sliding_window = (-1, -1)
        # For the AOT scheduler we need the sliding window value to be
        # constant for all layers to. We have to populate this on the first
        # build() call so the layers are constructed (cannot populate)
        # in __init__.
        if self.aot_schedule:
            sliding_window_configs = _get_sliding_window_configs(
                self.runner.vllm_config)
            if len(sliding_window_configs) == 1:
                sliding_window_config = sliding_window_configs.pop()
                if sliding_window_config is not None:
                    self.aot_sliding_window = sliding_window_config
            elif len(sliding_window_configs) > 1:
                self.aot_schedule = False

    def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
                 max_seq_len, causal):
        if self.aot_schedule:
            return get_scheduler_metadata(
                batch_size=batch_size,
                max_seqlen_q=max_query_len,
                max_seqlen_k=max_seq_len,
                cache_seqlens=seqlens,
                num_heads_q=self.num_heads_q,
                num_heads_kv=self.num_heads_kv,
                headdim=self.headdim,
                page_size=self.block_size,
                cu_seqlens_q=cu_query_lens,
                causal=causal,
                window_size=self.aot_sliding_window,
            )
        return None

    # for local attention
    local_attn_metadata = None
    if self.runner.attention_chunk_size is not None:
        seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
            virt_block_table_tensor = make_local_attention_virtual_batches(
                self.runner.attention_chunk_size,
                self.runner.query_start_loc_np[:num_reqs + 1],
                self.runner.seq_lens_np[:num_reqs],
                block_table_tensor,
                self.block_size,
            )
        local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
            self.runner.device, non_blocking=True)
        local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
            self.runner.device, non_blocking=True)
        local_max_query_len = seqlens_q_local_np.max()
        local_max_seq_len = virt_k_seqlens_np.max()
        local_scheduler_metadata = schedule(
            batch_size=local_query_start_loc.shape[0] - 1,
            cu_query_lens=local_query_start_loc,
            max_query_len=local_max_query_len,
            seqlens=local_seqused_k,
            max_seq_len=local_max_seq_len,
            causal=True)

        local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
            local_query_start_loc=local_query_start_loc,
            local_seqused_k=local_seqused_k,
            local_block_table=virt_block_table_tensor,
            local_max_query_len=local_max_query_len,
            local_max_seq_len=local_max_seq_len,
            local_scheduler_metadata=local_scheduler_metadata,
        )

    use_cascade = common_prefix_len > 0

    if use_cascade:
        cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                            dtype=torch.int32,
                                            device=self.runner.device)
        prefix_kv_lens = torch.tensor([common_prefix_len],
                                      dtype=torch.int32,
                                      device=self.runner.device)
        suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
                          common_prefix_len)
        suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
            self.runner.device)
        prefix_scheduler_metadata = schedule(
            batch_size=1,
            cu_query_lens=cu_prefix_query_lens,
            max_query_len=num_actual_tokens,
            seqlens=prefix_kv_lens,
            max_seq_len=common_prefix_len,
            causal=False)
        scheduler_metadata = schedule(batch_size=num_reqs,
                                      cu_query_lens=query_start_loc,
                                      max_query_len=max_query_len,
                                      seqlens=suffix_kv_lens,
                                      max_seq_len=max_seq_len -
                                      common_prefix_len,
                                      causal=True)
    else:
        cu_prefix_query_lens = None
        prefix_kv_lens = None
        suffix_kv_lens = None
        prefix_scheduler_metadata = None
        scheduler_metadata = schedule(batch_size=num_reqs,
                                      cu_query_lens=query_start_loc,
                                      max_query_len=max_query_len,
                                      seqlens=seq_lens,
                                      max_seq_len=max_seq_len,
                                      causal=True)

    attn_metadata = FlashAttentionMetadata(
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
        query_start_loc=query_start_loc,
        max_seq_len=max_seq_len,
        seq_lens=seq_lens,
        block_table=block_table_tensor,
        slot_mapping=slot_mapping,
        use_cascade=use_cascade,
        common_prefix_len=common_prefix_len,
        scheduler_metadata=scheduler_metadata,
        cu_prefix_query_lens=cu_prefix_query_lens,
        prefix_kv_lens=prefix_kv_lens,
        suffix_kv_lens=suffix_kv_lens,
        local_attn_metadata=local_attn_metadata,
        prefix_scheduler_metadata=prefix_scheduler_metadata,
    )
    return attn_metadata

reorder_batch

reorder_batch(
    input_batch: InputBatch,
    scheduler_output: SchedulerOutput,
) -> bool
Source code in vllm/v1/attention/backends/flash_attn.py
def reorder_batch(self, input_batch: "InputBatch",
                  scheduler_output: "SchedulerOutput") -> bool:
    return False

use_cascade_attention

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/flash_attn.py
def use_cascade_attention(self, *args, **kwargs) -> bool:
    return use_cascade_attention(*args, **kwargs)

_get_sliding_window_configs

_get_sliding_window_configs(
    vllm_config: VllmConfig,
) -> set[Optional[tuple[int, int]]]

Get the set of all sliding window configs used in the model.

Source code in vllm/v1/attention/backends/flash_attn.py
def _get_sliding_window_configs(
        vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[Optional[tuple[int, int]]] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs

cascade_attention

cascade_attention(
    output: Tensor,
    query: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    cu_query_lens: Tensor,
    max_query_len: int,
    cu_prefix_query_lens: Tensor,
    prefix_kv_lens: Tensor,
    suffix_kv_lens: Tensor,
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[Tensor],
    sliding_window: tuple[int, int],
    logits_soft_cap: float,
    block_table: Tensor,
    common_prefix_len: int,
    fa_version: int,
    prefix_scheduler_metadata: Optional[Tensor] = None,
    suffix_scheduler_metadata: Optional[Tensor] = None,
    q_descale: Optional[Tensor] = None,
    k_descale: Optional[Tensor] = None,
    v_descale: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/v1/attention/backends/flash_attn.py
def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[torch.Tensor],
    sliding_window: tuple[int, int],
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
    fa_version: int,
    prefix_scheduler_metadata: Optional[torch.Tensor] = None,
    suffix_scheduler_metadata: Optional[torch.Tensor] = None,
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
        "Cascade attention does not support sliding window.")

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])

    # Process shared prefix.
    prefix_output, prefix_lse = flash_attn_varlen_func(
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_prefix_query_lens,
        seqused_k=prefix_kv_lens,
        max_seqlen_q=num_tokens,
        max_seqlen_k=common_prefix_len,
        softmax_scale=softmax_scale,
        causal=False,
        window_size=sliding_window,
        block_table=block_table[:1],
        softcap=logits_soft_cap,
        return_softmax_lse=True,
        scheduler_metadata=prefix_scheduler_metadata,
        fa_version=fa_version,
        q_descale=q_descale.expand(descale_shape)
        if q_descale is not None else None,
        k_descale=k_descale.expand(descale_shape)
        if k_descale is not None else None,
        v_descale=v_descale.expand(descale_shape)
        if v_descale is not None else None,
    )

    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

    # Process suffix per query.
    suffix_output, suffix_lse = flash_attn_varlen_func(
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_query_lens,
        seqused_k=suffix_kv_lens,
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len - common_prefix_len,
        softmax_scale=softmax_scale,
        causal=True,
        window_size=sliding_window,
        block_table=block_table[:, num_common_kv_blocks:],
        softcap=logits_soft_cap,
        return_softmax_lse=True,
        scheduler_metadata=suffix_scheduler_metadata,
        fa_version=fa_version,
        q_descale=q_descale.expand(descale_shape)
        if q_descale is not None else None,
        k_descale=k_descale.expand(descale_shape)
        if k_descale is not None else None,
        v_descale=v_descale.expand(descale_shape)
        if v_descale is not None else None,
    )

    # Merge prefix and suffix outputs, and store the result in output.
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
                      suffix_lse)

make_local_attention_virtual_batches

make_local_attention_virtual_batches(
    attn_chunk_size: int,
    query_start_loc_np: ndarray,
    seq_lens_np: ndarray,
    block_table: Tensor,
    block_size: int = 0,
) -> tuple[ndarray, ndarray, ndarray, Tensor]
Source code in vllm/v1/attention/backends/flash_attn.py
def make_local_attention_virtual_batches(
    attn_chunk_size: int,
    query_start_loc_np: np.ndarray,
    seq_lens_np: np.ndarray,
    block_table: torch.Tensor,
    block_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
    actual_batch_size = seq_lens_np.shape[0]

    # Handle if we are starting in the middle of a local attention block,
    #  we assume q_seqlens > 0 (for all elements), for each batch idx we compute
    #  the number of tokens that are not in the first local attention block and
    #  then we can simply use a cdiv for the rest.
    # For example if we have:
    #   attn_chunk_size = 4
    #   q_seqlens = [4, 10, 5]
    #   k_seqlens = [6, 17, 9]
    # Then we would get:
    #   new_tokens_in_first_block = [2, 1, 4]
    #   local_blocks = [2, 4, 2]
    q_tokens_in_first_block = np.minimum(
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
        q_seqlens).astype(np.int32)
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
                            attn_chunk_size)

    # Once we know the number of local blocks we can compute the request spans
    #  for each batch idx, we can figure out the number of "virtual" requests we
    #  have to make,
    # For the above example we would get:
    #   seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
    #
    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
    #   (TODO: max a utility to share this code with _prepare_inputs)
    # arange step 1. [2, 4, 2] -> [2, 6, 8]
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
    # Then we can compute the seqlens_q_local, handling the fact that the
    #  first and last blocks could be partial
    seqlens_q_local = \
        np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
    # set the first block since this may be a partial block
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    # set the remaining blocks
    seqlens_q_local[arange > 0] = np.minimum(
        seqlens_q_local - attn_chunk_size * (arange - 1),
        attn_chunk_size)[arange > 0]

    # convert from q_seqlens to cu_seqlens_q
    cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
        .astype(np.int32)

    # compute the seqlens_k_local,
    #  basically a full local attention block for all but the last block in each
    #  batch
    # For our example this will be:
    #   seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
    seqlens_k_local = np.full(cu_num_blocks[-1],
                              attn_chunk_size,
                              dtype=np.int32)
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block

    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
        (rarange * attn_chunk_size + \
            np.repeat(tokens_in_last_block, local_blocks))
    # For the example the local attention blocks start at:
    #                           _b0_  _____b1_____  _b2_
    #   k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
    block_starts = k_seqstarts_absolute // block_size
    assert attn_chunk_size % block_size == 0, \
        f"attn_chunk_size {attn_chunk_size} is not " \
        f"divisible by block_size {block_size}"
    pages_per_local_batch = attn_chunk_size // block_size

    # Create a block_table for the local attention blocks
    # For out example if we have a block-table like (assuming block_size=2):
    #   block_table = [
    #     [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],  < batch 0
    #     [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  < batch 1
    #     [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],  < batch 2
    #   ]
    # Then for the local batches we would want a block-table like
    #   block_table_local = [
    #     [  0,  1 ], < local-batch 0, (batch 0, starting from k[0])
    #     [  2,  3 ], < local-batch 1, (batch 0, starting from k[4])
    #     [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
    #     [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
    #     [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
    #     [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
    #     [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
    #     [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
    #   ]
    block_indices= np.broadcast_to(
        np.arange(pages_per_local_batch, dtype=np.int32),
        (virtual_batches, pages_per_local_batch)) \
            + np.expand_dims(block_starts, axis=1)
    block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
    batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
                              local_blocks * pages_per_local_batch)
    block_table_local = block_table[batch_indices, block_indices]\
        .view(virtual_batches, -1)

    return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
        block_table_local

use_cascade_attention

use_cascade_attention(
    common_prefix_len: int,
    query_lens: ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
    num_sms: int,
) -> bool

Decide whether to use cascade attention.

This function 1) checks whether cascade attention is supported with the given configuration, and 2) heuristically decides whether using cascade attention can improve performance.

Source code in vllm/v1/attention/backends/flash_attn.py
def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
    num_sms: int,
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
    if use_alibi or use_sliding_window:
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
    use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
                          and not use_alibi and np.all(query_lens == 1))
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

    flash_decoding_ctas = (num_reqs * num_kv_heads *
                           cdiv(num_queries_per_kv, q_tile_size))
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time