Skip to content

vllm.v1.attention.backends.triton_attn

Attention layer with PagedAttention and Triton prefix prefill.

logger module-attribute

logger = init_logger(__name__)

TritonAttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/triton_attn.py
class TritonAttentionBackend(AttentionBackend):

    accept_output_buffer: bool = True

    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        return [32, 64, 96, 128, 160, 192, 224, 256]

    @staticmethod
    def get_name() -> str:
        return "TRITON_ATTN_VLLM_V1"

    @staticmethod
    def get_impl_cls() -> type["TritonAttentionImpl"]:
        return TritonAttentionImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return FlashAttentionMetadata

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)

    @staticmethod
    def use_cascade_attention(*args, **kwargs) -> bool:
        return False

    @staticmethod
    def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
        return TritonAttentionMetadataBuilder

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[TritonAttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
    return TritonAttentionMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[TritonAttentionImpl]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]:
    return TritonAttentionImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")
    return (2, num_blocks, block_size, num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return FlashAttentionMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_name() -> str:
    return "TRITON_ATTN_VLLM_V1"

get_supported_head_sizes staticmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def get_supported_head_sizes() -> list[int]:
    return [32, 64, 96, 128, 160, 192, 224, 256]

use_cascade_attention staticmethod

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/triton_attn.py
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
    return False

TritonAttentionImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/triton_attn.py
class TritonAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[dict[str, Any]] = None,
        logits_soft_cap: Optional[float] = None,
        attn_type: AttentionType = AttentionType.DECODER,
        use_irope: bool = False,
    ) -> None:
        if blocksparse_params is not None:
            raise ValueError(
                "TritonAttention does not support block-sparse attention.")
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap

        self.use_irope = use_irope

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
        if head_size not in support_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by TritonAttention. "
                f"Supported head sizes are: {support_head_sizes}.")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "TritonAttentionImpl")

        self.fp8_dtype = current_platform.fp8_dtype()

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        output: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
            # Profiling run.
            return output

        assert attn_metadata.use_cascade is False

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_queries_per_kv = query.shape[1] // key.shape[1]
        use_prefill_decode_attn = (num_queries_per_kv &
                                   (num_queries_per_kv - 1)) != 0

        num_actual_tokens = attn_metadata.num_actual_tokens

        if use_prefill_decode_attn:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            PagedAttention.write_to_paged_cache(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

        else:
            key_cache, value_cache = kv_cache.unbind(0)
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

        if self.kv_cache_dtype.startswith("fp8"):
            key_cache = key_cache.view(self.fp8_dtype)
            value_cache = value_cache.view(self.fp8_dtype)
            num_tokens, num_heads, head_size = query.shape
            assert layer._q_scale == 1.0, \
                "A non 1.0 q_scale is not currently supported."
            if not current_platform.is_rocm():
                # Skip Q quantization on ROCm, since dequantizing back to
                # f32 in the attention kernel is not supported.
                query, _ = ops.scaled_fp8_quant(
                    query.reshape(
                        (num_tokens, num_heads * head_size)).contiguous(),
                    layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))

        use_local_attn = \
            (self.use_irope and attn_metadata.local_attn_metadata is not None)

        if use_local_attn:
            assert attn_metadata.local_attn_metadata is not None
            local_metadata = attn_metadata.local_attn_metadata
            cu_seqlens_q = local_metadata.local_query_start_loc
            seqused_k = local_metadata.local_seqused_k
            max_seqlen_q = local_metadata.local_max_query_len
            max_seqlen_k = local_metadata.local_max_seq_len
            block_table = local_metadata.local_block_table
        else:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table

        if use_prefill_decode_attn:
            # Compute attention and update output up to `num_actual_tokens`.
            chunked_prefill_paged_decode(query=query[:num_actual_tokens],
                                         key=key[:num_actual_tokens],
                                         value=value[:num_actual_tokens],
                                         output=output[:num_actual_tokens],
                                         kv_cache_dtype=self.kv_cache_dtype,
                                         key_cache=key_cache,
                                         value_cache=value_cache,
                                         block_table=block_table,
                                         query_start_loc=cu_seqlens_q,
                                         seq_lens=seqused_k,
                                         max_seq_len=max_seqlen_k,
                                         max_query_len=max_seqlen_q,
                                         k_scale=layer._k_scale,
                                         v_scale=layer._v_scale,
                                         alibi_slopes=self.alibi_slopes,
                                         sliding_window=self.sliding_window[0],
                                         sm_scale=self.scale)

        else:
            descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

            unified_attention(
                q=query[:num_actual_tokens],
                k=key_cache,
                v=value_cache,
                out=output[:num_actual_tokens],
                cu_seqlens_q=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                seqused_k=seqused_k,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=True,
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                block_table=block_table,
                softcap=self.logits_soft_cap,
                q_descale=None,  # Not supported
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
            )

        return output

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

fp8_dtype instance-attribute

fp8_dtype = fp8_dtype()

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = float(scale)

sliding_window instance-attribute

sliding_window = (-1, -1)

use_irope instance-attribute

use_irope = use_irope

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]] = None,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = DECODER,
    use_irope: bool = False,
) -> None
Source code in vllm/v1/attention/backends/triton_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]] = None,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = AttentionType.DECODER,
    use_irope: bool = False,
) -> None:
    if blocksparse_params is not None:
        raise ValueError(
            "TritonAttention does not support block-sparse attention.")
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    if sliding_window is None:
        self.sliding_window = (-1, -1)
    else:
        self.sliding_window = (sliding_window - 1, 0)
    self.kv_cache_dtype = kv_cache_dtype
    if logits_soft_cap is None:
        # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
        logits_soft_cap = 0
    self.logits_soft_cap = logits_soft_cap

    self.use_irope = use_irope

    assert self.num_heads % self.num_kv_heads == 0
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
    if head_size not in support_head_sizes:
        raise ValueError(
            f"Head size {head_size} is not supported by TritonAttention. "
            f"Supported head sizes are: {support_head_sizes}.")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "TritonAttentionImpl")

    self.fp8_dtype = current_platform.fp8_dtype()

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlashAttentionMetadata,
    output: Optional[Tensor] = None,
) -> Tensor

Forward pass with FlashAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
attn_metadata FlashAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/v1/attention/backends/triton_attn.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: FlashAttentionMetadata,
    output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with FlashAttention.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    assert output is not None, "Output tensor must be provided."

    if attn_metadata is None:
        # Profiling run.
        return output

    assert attn_metadata.use_cascade is False

    # IMPORTANT!
    # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
    # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
    # in this method. For example, `view` and `slice` (or `[:n]`) operations
    # are surprisingly slow even in the case they do not invoke any GPU ops.
    # Minimize the PyTorch ops in this method as much as possible.
    # Whenever making a change in this method, please benchmark the
    # performance to make sure it does not introduce any overhead.

    num_queries_per_kv = query.shape[1] // key.shape[1]
    use_prefill_decode_attn = (num_queries_per_kv &
                               (num_queries_per_kv - 1)) != 0

    num_actual_tokens = attn_metadata.num_actual_tokens

    if use_prefill_decode_attn:
        key_cache, value_cache = PagedAttention.split_kv_cache(
            kv_cache, self.num_kv_heads, self.head_size)

        # Reshape the input keys and values and store them in the cache.
        PagedAttention.write_to_paged_cache(
            key,
            value,
            key_cache,
            value_cache,
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

    else:
        key_cache, value_cache = kv_cache.unbind(0)
        torch.ops._C_cache_ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

    if self.kv_cache_dtype.startswith("fp8"):
        key_cache = key_cache.view(self.fp8_dtype)
        value_cache = value_cache.view(self.fp8_dtype)
        num_tokens, num_heads, head_size = query.shape
        assert layer._q_scale == 1.0, \
            "A non 1.0 q_scale is not currently supported."
        if not current_platform.is_rocm():
            # Skip Q quantization on ROCm, since dequantizing back to
            # f32 in the attention kernel is not supported.
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
        query = query.reshape((num_tokens, num_heads, head_size))

    use_local_attn = \
        (self.use_irope and attn_metadata.local_attn_metadata is not None)

    if use_local_attn:
        assert attn_metadata.local_attn_metadata is not None
        local_metadata = attn_metadata.local_attn_metadata
        cu_seqlens_q = local_metadata.local_query_start_loc
        seqused_k = local_metadata.local_seqused_k
        max_seqlen_q = local_metadata.local_max_query_len
        max_seqlen_k = local_metadata.local_max_seq_len
        block_table = local_metadata.local_block_table
    else:
        cu_seqlens_q = attn_metadata.query_start_loc
        seqused_k = attn_metadata.seq_lens
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_seq_len
        block_table = attn_metadata.block_table

    if use_prefill_decode_attn:
        # Compute attention and update output up to `num_actual_tokens`.
        chunked_prefill_paged_decode(query=query[:num_actual_tokens],
                                     key=key[:num_actual_tokens],
                                     value=value[:num_actual_tokens],
                                     output=output[:num_actual_tokens],
                                     kv_cache_dtype=self.kv_cache_dtype,
                                     key_cache=key_cache,
                                     value_cache=value_cache,
                                     block_table=block_table,
                                     query_start_loc=cu_seqlens_q,
                                     seq_lens=seqused_k,
                                     max_seq_len=max_seqlen_k,
                                     max_query_len=max_seqlen_q,
                                     k_scale=layer._k_scale,
                                     v_scale=layer._v_scale,
                                     alibi_slopes=self.alibi_slopes,
                                     sliding_window=self.sliding_window[0],
                                     sm_scale=self.scale)

    else:
        descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

        unified_attention(
            q=query[:num_actual_tokens],
            k=key_cache,
            v=value_cache,
            out=output[:num_actual_tokens],
            cu_seqlens_q=cu_seqlens_q,
            max_seqlen_q=max_seqlen_q,
            seqused_k=seqused_k,
            max_seqlen_k=max_seqlen_k,
            softmax_scale=self.scale,
            causal=True,
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            block_table=block_table,
            softcap=self.logits_soft_cap,
            q_descale=None,  # Not supported
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
        )

    return output

TritonAttentionMetadataBuilder

Bases: FlashAttentionMetadataBuilder

Source code in vllm/v1/attention/backends/triton_attn.py
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):

    def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
        super().__init__(runner, kv_cache_spec, block_table)
        self.aot_schedule = False

aot_schedule instance-attribute

aot_schedule = False

__init__

__init__(
    runner: GPUModelRunner,
    kv_cache_spec: AttentionSpec,
    block_table: BlockTable,
)
Source code in vllm/v1/attention/backends/triton_attn.py
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
             block_table: BlockTable):
    super().__init__(runner, kv_cache_spec, block_table)
    self.aot_schedule = False