Skip to content

speculators.models.attention

Shared attention utilities for speculator models.

This module contains attention functions and utilities shared across different speculator architectures (EAGLE3, DFlash, etc.) to avoid code duplication.

Functions:

flex_attention_forward

flex_attention_forward(
    module: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask,
    scaling: float | None = None,
    **_kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]

Shared flex attention forward implementation.

This function is used by both EAGLE3 and DFlash attention mechanisms to avoid code duplication and ensure consistent behavior.

Args: module: The attention module (unused but required for interface compatibility). query: Query tensor of shape (batch, num_heads, seq_len, head_dim). key: Key tensor of shape (batch, num_heads, seq_len, head_dim). value: Value tensor of shape (batch, num_heads, seq_len, head_dim). attention_mask: BlockMask for flex attention. scaling: Optional scaling factor for attention scores. **_kwargs: Additional unused kwargs for interface compatibility.

Returns: Tuple of (attention_output, None) where attention_output has shape (batch, seq_len, num_heads, head_dim) and None represents no attention weights.

Source code in speculators/models/attention.py
def flex_attention_forward(
    module: torch.nn.Module,  # noqa: ARG001
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask,
    scaling: float | None = None,
    **_kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """Shared flex attention forward implementation.

    This function is used by both EAGLE3 and DFlash attention mechanisms to avoid
    code duplication and ensure consistent behavior.

    Args:
        module: The attention module (unused but required for interface compatibility).
        query: Query tensor of shape (batch, num_heads, seq_len, head_dim).
        key: Key tensor of shape (batch, num_heads, seq_len, head_dim).
        value: Value tensor of shape (batch, num_heads, seq_len, head_dim).
        attention_mask: BlockMask for flex attention.
        scaling: Optional scaling factor for attention scores.
        **_kwargs: Additional unused kwargs for interface compatibility.

    Returns:
        Tuple of (attention_output, None) where attention_output has shape
        (batch, seq_len, num_heads, head_dim) and None represents no attention weights.
    """
    num_query_heads = query.shape[1]
    num_key_value_heads = key.shape[1]
    enable_gqa = num_query_heads != num_key_value_heads

    query = query.contiguous()
    key = key.contiguous()
    value = value.contiguous()

    flex_attention_output = flex_attention(
        query,
        key,
        value,
        score_mod=None,
        block_mask=attention_mask,
        enable_gqa=enable_gqa,
        scale=scaling,
    )
    attention_output: torch.Tensor = flex_attention_output
    attention_output = attention_output.transpose(1, 2).contiguous()
    return attention_output, None