vllm_omni.diffusion.attention.parallel.base ¶
NoParallelAttention ¶
Default strategy: do nothing (single device / no SP).
post_attention ¶
post_attention(
attn_output: Tensor,
ctx: ParallelAttentionContext | None,
) -> Tensor
pre_attention ¶
pre_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attn_metadata: AttentionMetadata | None,
)
ParallelAttentionContext dataclass ¶
ParallelAttentionStrategy ¶
Bases: Protocol
Pluggable strategy for parallel attention communication/resharding.
This is intentionally orthogonal to the attention kernel backend. The kernel backend implements AttentionImpl.forward() for a given device, while the parallel strategy implements how Q/K/V and outputs are sharded / communicated across ranks.
post_attention ¶
post_attention(
attn_output: Tensor,
ctx: ParallelAttentionContext | None,
) -> Tensor
Runs after the attention kernel.
pre_attention ¶
pre_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attn_metadata: AttentionMetadata | None,
) -> tuple[
Tensor,
Tensor,
Tensor,
AttentionMetadata | None,
ParallelAttentionContext | None,
]
Runs before the attention kernel.
Returns possibly transformed Q/K/V and metadata, and an optional context for post_attention.