Skip to content

vllm_omni.diffusion.attention.parallel

Parallel attention strategies.

This package provides communication / resharding strategies for attention, orthogonal to the attention kernel backend (SDPA/Flash/Sage).

The goal is to keep vllm_omni.diffusion.attention.layer.Attention small and extensible: adding a new parallelism method should not require editing the core Attention module, only adding a new strategy and selecting it in the factory.

Modules:

Name Description
base
factory
ring
ulysses

NoParallelAttention

Default strategy: do nothing (single device / no SP).

enabled property

enabled: bool

name property

name: str

post_attention

post_attention(
    attn_output: Tensor,
    ctx: ParallelAttentionContext | None,
) -> Tensor

pre_attention

pre_attention(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: AttentionMetadata | None,
)

ParallelAttentionContext dataclass

Opaque per-forward context returned by a parallel strategy.

Strategies may stash whatever they need here to finish post-processing after the attention kernel runs (e.g. reverse resharding, slicing metadata, etc.).

name instance-attribute

name: str

ParallelAttentionStrategy

Bases: Protocol

Pluggable strategy for parallel attention communication/resharding.

This is intentionally orthogonal to the attention kernel backend. The kernel backend implements AttentionImpl.forward() for a given device, while the parallel strategy implements how Q/K/V and outputs are sharded / communicated across ranks.

enabled property

enabled: bool

name property

name: str

post_attention

post_attention(
    attn_output: Tensor,
    ctx: ParallelAttentionContext | None,
) -> Tensor

Runs after the attention kernel.

pre_attention

pre_attention(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: AttentionMetadata | None,
) -> tuple[
    Tensor,
    Tensor,
    Tensor,
    AttentionMetadata | None,
    ParallelAttentionContext | None,
]

Runs before the attention kernel.

Returns possibly transformed Q/K/V and metadata, and an optional context for post_attention.

build_parallel_attention_strategy

build_parallel_attention_strategy(
    *, scatter_idx: int, gather_idx: int, use_sync: bool
) -> ParallelAttentionStrategy

Select a parallel attention strategy based on current diffusion config.

Design principle: - Attention kernel backend selection remains in attention/selector.py. - Parallel attention selection is handled here, based on distributed config and initialized process groups.