Skip to content

vllm_omni.diffusion.attention.backends.abstract

T module-attribute

T = TypeVar('T', bound=AttentionMetadata)

AttentionBackend

Bases: ABC

Abstract class for diffusion attention backends.

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = False

get_builder_cls abstractmethod staticmethod

get_builder_cls()

get_impl_cls abstractmethod staticmethod

get_impl_cls() -> type[AttentionImpl]

get_metadata_cls abstractmethod staticmethod

get_metadata_cls() -> type[AttentionMetadata]

get_name abstractmethod staticmethod

get_name() -> str

get_supported_head_sizes abstractmethod staticmethod

get_supported_head_sizes() -> list[int]

Get the list of supported head sizes for this backend.

supports_attention_mask classmethod

supports_attention_mask() -> bool

supports_head_size classmethod

supports_head_size(head_size: int) -> bool

AttentionImpl

Bases: ABC, Generic[T]

forward

forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: T | None = None,
) -> Tensor

Dispatch to platform-specific forward implementation.

forward_cuda

forward_cuda(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: T | None = None,
) -> Tensor

forward_hip

forward_hip(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: T | None = None,
) -> Tensor

forward_musa

forward_musa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: T | None = None,
) -> Tensor

forward_npu

forward_npu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: T | None = None,
) -> Tensor

forward_xpu

forward_xpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: T | None = None,
) -> Tensor

supports_kv_cache_dtype classmethod

supports_kv_cache_dtype(
    kv_cache_dtype: str | None, platform_key: str
) -> bool

AttentionMetadata dataclass

attn_mask class-attribute instance-attribute

attn_mask: Tensor | None = None

extra class-attribute instance-attribute

extra: dict[str, Any] = field(default_factory=dict)

full_attn_spans class-attribute instance-attribute

full_attn_spans: list[list[tuple[int, int]]] | None = None

joint_attn_mask class-attribute instance-attribute

joint_attn_mask: Tensor | None = None

joint_key class-attribute instance-attribute

joint_key: Tensor | None = None

joint_query class-attribute instance-attribute

joint_query: Tensor | None = None

joint_strategy class-attribute instance-attribute

joint_strategy: str = 'front'

joint_value class-attribute instance-attribute

joint_value: Tensor | None = None