vllm_omni.diffusion.attention.backends.ring_flash_attn ¶
RingFlashAttnFunc ¶
Bases: Function
Ring Flash Attention autograd function (inference only, no backward).
forward staticmethod ¶
forward(
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
group,
attn_type,
attn_processor,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
)
ring_flash_attn_forward ¶
ring_flash_attn_forward(
process_group,
q: Tensor,
k: Tensor,
v: Tensor,
softmax_scale,
dropout_p=0,
causal=True,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
attn_type: AttnType = FA,
attn_processor=None,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
)
ring_flash_attn_func ¶
ring_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
attn_type: AttnType = FA,
attn_processor=None,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
) -> Tensor | tuple[Tensor, Tensor, None]
Ring Attention forward pass using Flash Attention backend.
Implements Ring Attention with sequence parallelism using a ring-based P2P communication pattern. The sequence dimension is sharded across devices, and Key/Value blocks are circulated through the ring to accumulate attention results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q | Tensor | Query tensor of shape (batch, seq_len, num_heads, head_dim). Sequence dimension is sharded across the ring group. | required |
k | Tensor | Key tensor of shape (batch, seq_len, num_heads, head_dim). Sequence dimension is sharded across the ring group. | required |
v | Tensor | Value tensor of shape (batch, seq_len, num_heads, head_dim). Sequence dimension is sharded across the ring group. | required |
dropout_p | float | Dropout probability. Defaults to 0.0. | 0.0 |
softmax_scale | float | None | Scaling factor for softmax. If None, computed as head_dim^(-0.5). | None |
causal | bool | Whether to apply causal masking. Defaults to False. | False |
window_size | tuple[int, int] | Sliding window size for attention. (-1, -1) means no windowing. | (-1, -1) |
softcap | float | Soft capping value for attention logits. Defaults to 0.0. | 0.0 |
alibi_slopes | Tensor | None | ALiBi slopes for positional bias. Not supported. | None |
deterministic | bool | Whether to use deterministic algorithms. Defaults to False. | False |
return_attn_probs | bool | If True, returns (out, softmax_lse, None). Defaults to False. | False |
group | ProcessGroup | None | Process group for ring communication. Defaults to None. | None |
attn_type | AttnType | Flash Attention implementation type (AttnType.FA, AttnType.FA3, etc.). | FA |
attn_processor | Callable | None | Custom attention processor for sparse attention. Defaults to None. | None |
joint_tensor_key | Tensor | None | Additional key tensor for joint attention (e.g., text + image). Concatenated only at step=0. Defaults to None. | None |
joint_tensor_value | Tensor | None | Additional value tensor for joint attention (e.g., text + image). Concatenated only at step=0. Defaults to None. | None |
joint_strategy | str | Concatenation strategy ("front" or "back"). Defaults to "front". | 'front' |
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, Tensor, None] | Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, None]]: - If return_attn_probs is False: Output tensor (batch, seq_len, num_heads, head_dim). - If return_attn_probs is True: A tuple (out, softmax_lse, None). |