vllm_omni.diffusion.attention.parallel ¶
Parallel attention strategies.
This package provides communication / resharding strategies for attention, orthogonal to the attention kernel backend (SDPA/Flash/Sage).
The goal is to keep vllm_omni.diffusion.attention.layer.Attention small and extensible: adding a new parallelism method should not require editing the core Attention module, only adding a new strategy and selecting it in the factory.
Modules:
| Name | Description |
|---|---|
base | |
factory | |
ring | |
ulysses | |
NoParallelAttention ¶
Default strategy: do nothing (single device / no SP).
post_attention ¶
post_attention(
attn_output: Tensor,
ctx: ParallelAttentionContext | None,
) -> Tensor
pre_attention ¶
pre_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attn_metadata: AttentionMetadata | None,
)
ParallelAttentionContext dataclass ¶
ParallelAttentionStrategy ¶
Bases: Protocol
Pluggable strategy for parallel attention communication/resharding.
This is intentionally orthogonal to the attention kernel backend. The kernel backend implements AttentionImpl.forward() for a given device, while the parallel strategy implements how Q/K/V and outputs are sharded / communicated across ranks.
post_attention ¶
post_attention(
attn_output: Tensor,
ctx: ParallelAttentionContext | None,
) -> Tensor
Runs after the attention kernel.
pre_attention ¶
pre_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attn_metadata: AttentionMetadata | None,
) -> tuple[
Tensor,
Tensor,
Tensor,
AttentionMetadata | None,
ParallelAttentionContext | None,
]
Runs before the attention kernel.
Returns possibly transformed Q/K/V and metadata, and an optional context for post_attention.
build_parallel_attention_strategy ¶
build_parallel_attention_strategy(
*, scatter_idx: int, gather_idx: int, use_sync: bool
) -> ParallelAttentionStrategy
Select a parallel attention strategy based on current diffusion config.
Design principle: - Attention kernel backend selection remains in attention/selector.py. - Parallel attention selection is handled here, based on distributed config and initialized process groups.