Skip to content

vllm_omni.diffusion.attention.backends.ring.ring_selector

AttnType

Bases: Enum

AITER class-attribute instance-attribute

AITER = 'aiter'

FA class-attribute instance-attribute

FA = 'fa'

FA3 class-attribute instance-attribute

FA3 = 'fa3'

FLASHINFER class-attribute instance-attribute

FLASHINFER = 'flashinfer'

SAGE_AUTO class-attribute instance-attribute

SAGE_AUTO = 'sage_auto'

SAGE_FP16 class-attribute instance-attribute

SAGE_FP16 = 'sage_fp16'

SAGE_FP16_TRITON class-attribute instance-attribute

SAGE_FP16_TRITON = 'sage_fp16_triton'

SAGE_FP8 class-attribute instance-attribute

SAGE_FP8 = 'sage_fp8'

SAGE_FP8_SM90 class-attribute instance-attribute

SAGE_FP8_SM90 = 'sage_fp8_sm90'

SPARSE_SAGE class-attribute instance-attribute

SPARSE_SAGE = 'sparse_sage'

TORCH class-attribute instance-attribute

TORCH = 'torch'

from_string classmethod

from_string(s: str)

select_flash_attn_impl

select_flash_attn_impl(
    impl_type: AttnType,
    stage: str = "fwd-only",
    attn_processor: Module | None = None,
) -> Callable[..., tuple[Tensor, Tensor | None]]

Select attention implementation for forward pass (inference only).

Parameters:

Name Type Description Default
impl_type AttnType

The attention implementation type.

required
stage str

Must be "fwd-only" (backward not supported for inference).

'fwd-only'
attn_processor Module | None

Optional custom attention processor.

None

Returns:

Type Description
Callable[..., tuple[Tensor, Tensor | None]]

Callable[..., tuple[torch.Tensor, torch.Tensor | None]]: The attention forward function for the specified implementation.