Skip to content

vllm_omni.diffusion.attention.backends.ring.ring_kernels

flash_attn3_func_forward module-attribute

flash_attn3_func_forward = fa3_forward

fa3_forward

fa3_forward(
    q,
    k,
    v,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
    softcap,
    alibi_slopes,
    return_softmax,
)

FA3 forward pass for inference.

FA3 supports Ampere, Ada, and Hopper GPUs. Dropout is ignored since FA3 is inference-only. Uses low-level API (_flash_attn_forward) which always returns softmax_lse, required for Ring Attention's correct accumulation.

flash_attn_forward

flash_attn_forward(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=None,
    alibi_slopes=None,
    return_softmax=False,
)

flash_attn_forward_aiter

flash_attn_forward_aiter(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=None,
    alibi_slopes=None,
    return_softmax=False,
)

flashinfer_attn_forward

flashinfer_attn_forward(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    dropout_p: float = 0.0,
    softmax_scale: float | None = None,
    causal: bool = False,
    window_size: tuple[int, int] = (-1, -1),
    softcap: float | None = None,
    alibi_slopes: Tensor | None = None,
    return_softmax: bool = False,
) -> tuple[Tensor, Tensor]

pytorch_attn_forward

pytorch_attn_forward(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    dropout_p=0.0,
    softmax_scale=None,
    causal=True,
    window_size=(-1, -1),
    softcap=None,
    alibi_slopes=None,
    return_softmax=False,
    op_type="efficient",
)