Skip to content

vllm_omni.diffusion.attention.backends.ring_flash_attn

RingFlashAttnFunc

Bases: Function

Ring Flash Attention autograd function (inference only, no backward).

forward staticmethod

forward(
    ctx,
    q,
    k,
    v,
    dropout_p,
    softmax_scale,
    causal,
    window_size,
    softcap,
    alibi_slopes,
    deterministic,
    return_softmax,
    group,
    attn_type,
    attn_processor,
    joint_tensor_key=None,
    joint_tensor_value=None,
    joint_strategy="front",
)

ring_flash_attn_forward

ring_flash_attn_forward(
    process_group,
    q: Tensor,
    k: Tensor,
    v: Tensor,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = FA,
    attn_processor=None,
    joint_tensor_key=None,
    joint_tensor_value=None,
    joint_strategy="front",
)

ring_flash_attn_func

ring_flash_attn_func(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = FA,
    attn_processor=None,
    joint_tensor_key=None,
    joint_tensor_value=None,
    joint_strategy="front",
) -> Tensor | tuple[Tensor, Tensor, None]

Ring Attention forward pass using Flash Attention backend.

Implements Ring Attention with sequence parallelism using a ring-based P2P communication pattern. The sequence dimension is sharded across devices, and Key/Value blocks are circulated through the ring to accumulate attention results.

Parameters:

Name Type Description Default
q Tensor

Query tensor of shape (batch, seq_len, num_heads, head_dim). Sequence dimension is sharded across the ring group.

required
k Tensor

Key tensor of shape (batch, seq_len, num_heads, head_dim). Sequence dimension is sharded across the ring group.

required
v Tensor

Value tensor of shape (batch, seq_len, num_heads, head_dim). Sequence dimension is sharded across the ring group.

required
dropout_p float

Dropout probability. Defaults to 0.0.

0.0
softmax_scale float | None

Scaling factor for softmax. If None, computed as head_dim^(-0.5).

None
causal bool

Whether to apply causal masking. Defaults to False.

False
window_size tuple[int, int]

Sliding window size for attention. (-1, -1) means no windowing.

(-1, -1)
softcap float

Soft capping value for attention logits. Defaults to 0.0.

0.0
alibi_slopes Tensor | None

ALiBi slopes for positional bias. Not supported.

None
deterministic bool

Whether to use deterministic algorithms. Defaults to False.

False
return_attn_probs bool

If True, returns (out, softmax_lse, None). Defaults to False.

False
group ProcessGroup | None

Process group for ring communication. Defaults to None.

None
attn_type AttnType

Flash Attention implementation type (AttnType.FA, AttnType.FA3, etc.).

FA
attn_processor Callable | None

Custom attention processor for sparse attention. Defaults to None.

None
joint_tensor_key Tensor | None

Additional key tensor for joint attention (e.g., text + image). Concatenated only at step=0. Defaults to None.

None
joint_tensor_value Tensor | None

Additional value tensor for joint attention (e.g., text + image). Concatenated only at step=0. Defaults to None.

None
joint_strategy str

Concatenation strategy ("front" or "back"). Defaults to "front".

'front'

Returns:

Type Description
Tensor | tuple[Tensor, Tensor, None]

Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, None]]: - If return_attn_probs is False: Output tensor (batch, seq_len, num_heads, head_dim). - If return_attn_probs is True: A tuple (out, softmax_lse, None).

ring_flash_attn_kvpacked_func

ring_flash_attn_kvpacked_func(
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = FA,
)

ring_flash_attn_qkvpacked_func

ring_flash_attn_qkvpacked_func(
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = FA,
)