Skip to content

vllm_omni.platforms.rocm

Modules:

Name Description
patch
platform

RocmOmniPlatform

Bases: OmniPlatform, RocmPlatform

ROCm/AMD GPU implementation of OmniPlatform.

Inherits all ROCm-specific implementations from vLLM's RocmPlatform, and adds Omni-specific interfaces from OmniPlatform.

NOTE: AR Attention Backend Overriding Logic:

Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm. However, the compatibility of ROCM_ATTN with Omni is not guaranteed. Therefore, we still use TRITON_ATTN as the default attention backend, when the selected_backend is not specified.

So the behaviour of the attention backend overriding logic currently lives in extract_stage_metadata in vllm_omni/engine/stage_init_utils.py

if current_omni_platform.is_rocm():
    print(f"engine_args: {str(engine_args)}")
    if engine_args.get("attention_backend") is None:
        from vllm._aiter_ops import rocm_aiter_ops

        if rocm_aiter_ops.is_enabled():
            engine_args["attention_backend"] = "ROCM_AITER_FA"
        # Before vLLM v0.19.0, the default attention backend is TRITON_ATTN for ROCm.
        # Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm.
        # However, the compatibility of ROCM_ATTN with Omni is not guaranteed.
        # Therefore, we still use TRITON_ATTN as the default attention backend,
        # when the selected_backend is not specified.
        engine_args["attention_backend"] = "TRITON_ATTN"

get_default_ir_op_priority classmethod

get_default_ir_op_priority(
    vllm_config: VllmConfig,
) -> IrOpPriorityConfig

Copied from vllm/platforms/rocm/platform.py v0.20.0 with force using vllm_c kernels

get_default_stage_config_path classmethod

get_default_stage_config_path() -> str

get_device_count classmethod

get_device_count() -> int

get_device_version classmethod

get_device_version() -> str | None

get_diffusion_attn_backend_cls classmethod

get_diffusion_attn_backend_cls(
    selected_backend: str | None, head_size: int
) -> str

get_free_memory classmethod

get_free_memory(device: device | None = None) -> int

get_omni_ar_worker_cls classmethod

get_omni_ar_worker_cls() -> str

get_omni_generation_worker_cls classmethod

get_omni_generation_worker_cls() -> str

get_torch_device classmethod

get_torch_device(local_rank: int | None = None) -> device

has_flash_attn_package classmethod

has_flash_attn_package() -> bool

set_device_control_env_var classmethod

set_device_control_env_var(
    devices: str | int | None,
) -> None

supports_torch_inductor classmethod

supports_torch_inductor() -> bool

synchronize classmethod

synchronize() -> None

unset_device_control_env_var classmethod

unset_device_control_env_var() -> None