Skip to content

vllm_omni.diffusion.attention.layer

logger module-attribute

logger = init_logger(__name__)

Attention

Bases: Module

attention instance-attribute

attention = attn_impl_cls(
    num_heads=num_heads,
    head_size=head_size,
    softmax_scale=softmax_scale,
    causal=causal,
    num_kv_heads=num_kv_heads,
    qkv_layout=qkv_layout,
    backend_kwargs=backend_kwargs,
)

attn_backend instance-attribute

attn_backend = attn_backend_cls

attn_impl_cls instance-attribute

attn_impl_cls = get_impl_cls()

backend_pref instance-attribute

backend_pref = None

causal instance-attribute

causal = causal

gather_idx instance-attribute

gather_idx = gather_idx

layer_idx instance-attribute

layer_idx: int | None = _try_extract_layer_index(prefix)

parallel_strategy instance-attribute

parallel_strategy = build_parallel_attention_strategy(
    scatter_idx=scatter_idx,
    gather_idx=gather_idx,
    use_sync=use_sync,
)

qkv_layout instance-attribute

qkv_layout = qkv_layout

ring_pg instance-attribute

ring_pg = ring_group

ring_runner instance-attribute

ring_runner = RingParallelAttention(
    sp_group, attn_backend_pref=backend_pref
)

role instance-attribute

role = role

role_category instance-attribute

role_category = role_category

scatter_idx instance-attribute

scatter_idx = scatter_idx

sdpa_fallback instance-attribute

sdpa_fallback = get_impl_cls()(
    num_heads=num_heads,
    head_size=head_size,
    softmax_scale=softmax_scale,
    causal=causal,
    num_kv_heads=num_kv_heads,
    qkv_layout=qkv_layout,
)

skip_sequence_parallel instance-attribute

skip_sequence_parallel = skip_sequence_parallel

softmax_scale instance-attribute

softmax_scale = softmax_scale

use_ring instance-attribute

use_ring = False

use_sync instance-attribute

use_sync = use_sync

forward

forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: AttentionMetadata | None = None,
) -> Tensor