vllm_omni.diffusion.attention.layer ¶
Attention ¶
Bases: Module
attention instance-attribute ¶
attention = attn_impl_cls(
num_heads=num_heads,
head_size=head_size,
softmax_scale=softmax_scale,
causal=causal,
num_kv_heads=num_kv_heads,
qkv_layout=qkv_layout,
backend_kwargs=backend_kwargs,
)
parallel_strategy instance-attribute ¶
parallel_strategy = build_parallel_attention_strategy(
scatter_idx=scatter_idx,
gather_idx=gather_idx,
use_sync=use_sync,
)
ring_runner instance-attribute ¶
ring_runner = RingParallelAttention(
sp_group, attn_backend_pref=backend_pref
)
sdpa_fallback instance-attribute ¶
sdpa_fallback = get_impl_cls()(
num_heads=num_heads,
head_size=head_size,
softmax_scale=softmax_scale,
causal=causal,
num_kv_heads=num_kv_heads,
qkv_layout=qkv_layout,
)
forward ¶
forward(
query: Tensor,
key: Tensor,
value: Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> Tensor