Skip to content

vllm_omni.diffusion.layers.mot.ops.mot_gemm

logger module-attribute

logger = logging.getLogger(__name__)

build_config_filename

build_config_filename(
    device_name: str, dtype_str: str
) -> str

get_best_mot_config

get_best_mot_config(
    M: int, N: int, K: int, dtype_str: str | None = None
) -> tuple[int, dict[str, int]]

get_device_name

get_device_name() -> str

Sanitized GPU device name, matching mot_linear_benchmarks.py output.

get_mot_configs cached

get_mot_configs(
    K: int, N: int, dtype_str: str | None = None
) -> dict[int, dict[str, int]] | None

Return {M: tile_config} for a given (K, N) shape, or None.

The return value maps an irregular grid of batch sizes (M) to Triton tile configurations. The caller should pick the entry whose M is closest to the actual batch size.

Config file is selected by device_name + dtype.

get_mot_default_config

get_mot_default_config(
    M: int,
    N: int,
    K: int,
    dtype: str | None = None,
    block_quant_shape: list[int] | None = None,
) -> dict[str, int]

Conservative fallback config guaranteed to compile on all hardware.

Trades peak performance for universal compatibility (T4 / V100 / A100 / H100, CUDA & ROCm).

invoke_mot_gemm

invoke_mot_gemm(
    A: Tensor,
    B_text: Tensor,
    B_vae: Tensor,
    C: Tensor,
    bias_text: Tensor | None,
    bias_vae: Tensor | None,
    text_indices: Tensor,
    vae_indices: Tensor,
    A_scale: Tensor | None,
    B_text_scale: Tensor | None,
    B_vae_scale: Tensor | None,
    use_fp8_w8a8: bool,
    use_int8_w8a8: bool,
    use_int8_w8a16: bool,
    use_int4_w4a16: bool,
    A_per_channel_quant: bool,
    B_per_channel_quant: bool,
    config: dict[str, Any] | None = None,
)

is_weak_contiguous

is_weak_contiguous(x: Tensor)

mot_unified_gemm_kernel

mot_unified_gemm_kernel(
    a_ptr,
    b_text_ptr,
    b_vae_ptr,
    c_ptr,
    bias_text_ptr,
    bias_vae_ptr,
    text_indices_ptr,
    vae_indices_ptr,
    M_text,
    M_vae,
    N,
    K,
    stride_am,
    stride_ak,
    stride_bk_text,
    stride_bn_text,
    stride_bk_vae,
    stride_bn_vae,
    stride_cm,
    stride_cn,
    scale_a_ptr,
    scale_b_text_ptr,
    scale_b_vae_ptr,
    stride_scale_a,
    stride_scale_b,
    BLOCK_SIZE_M: constexpr,
    BLOCK_SIZE_N: constexpr,
    BLOCK_SIZE_K: constexpr,
    GROUP_SIZE_M: constexpr,
    EVEN_K: constexpr,
    EVEN_N: constexpr,
    STRIDE_AK_IS_1: constexpr,
    STRIDE_BK_IS_1: constexpr,
    STRIDE_BN_IS_1: constexpr,
    ACCUMULATOR_DTYPE: constexpr,
    COMPUTE_DTYPE: constexpr,
    OUTPUT_DTYPE: constexpr,
    QUANT_TYPE: constexpr,
    HAS_BIAS: constexpr = False,
)