class AiterFlashAttnPrefillBackend(MLAPrefillBackend):
"""AITER FlashAttention backend for MLA prefill"""
@staticmethod
def get_name() -> str:
return "ROCM_AITER_FA"
@classmethod
def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
if not current_platform.is_rocm():
return False
from vllm.platforms.rocm import on_mi3xx
return on_mi3xx()
@classmethod
def is_available(cls) -> bool:
from vllm._aiter_ops import rocm_aiter_ops
return rocm_aiter_ops.is_enabled()
def __init__(
self,
num_heads: int,
scale: float,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
vllm_config: "VllmConfig",
) -> None:
super().__init__(
num_heads=num_heads,
scale=scale,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
vllm_config=vllm_config,
)
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def run_prefill_new_tokens(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
return_softmax_lse: bool,
out: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert output_scale is None, (
"AiterFlashAttnPrefillBackend does not support fused quantized output."
)
result = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=self._prefill_metadata.query_start_loc,
cu_seqlens_k=self._prefill_metadata.query_start_loc,
max_seqlen_q=self._prefill_metadata.max_query_len,
max_seqlen_k=self._prefill_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
return_lse=return_softmax_lse,
out=out,
)
# aiter returns the bare output tensor when return_lse is False, and
# (out, softmax_lse) when it is True.
if return_softmax_lse:
return result[0], result[1]
return result
def run_prefill_context_chunk(
self,
chunk_idx: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
assert self._prefill_metadata.chunked_context is not None
chunked = self._prefill_metadata.chunked_context
out, lse = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=self._prefill_metadata.query_start_loc,
cu_seqlens_k=chunked.cu_seq_lens[chunk_idx],
max_seqlen_q=self._prefill_metadata.max_query_len,
max_seqlen_k=chunked.max_seq_lens[chunk_idx],
softmax_scale=self.scale,
causal=False,
return_lse=True,
)
return out, lse