speculators.models.attention
Shared attention utilities for speculator models.
This module contains attention functions and utilities shared across different speculator architectures (EAGLE3, DFlash, etc.) to avoid code duplication.
Functions:
-
flex_attention_forward–Shared flex attention forward implementation.
flex_attention_forward
flex_attention_forward(
module: Module,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask,
scaling: float | None = None,
**_kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]
Shared flex attention forward implementation.
This function is used by both EAGLE3 and DFlash attention mechanisms to avoid code duplication and ensure consistent behavior.
Args: module: The attention module (unused but required for interface compatibility). query: Query tensor of shape (batch, num_heads, seq_len, head_dim). key: Key tensor of shape (batch, num_heads, seq_len, head_dim). value: Value tensor of shape (batch, num_heads, seq_len, head_dim). attention_mask: BlockMask for flex attention. scaling: Optional scaling factor for attention scores. **_kwargs: Additional unused kwargs for interface compatibility.
Returns: Tuple of (attention_output, None) where attention_output has shape (batch, seq_len, num_heads, head_dim) and None represents no attention weights.