vllm_omni.diffusion.layers.mot.ops.mot_gemm ¶
get_best_mot_config ¶
get_best_mot_config(
M: int, N: int, K: int, dtype_str: str | None = None
) -> tuple[int, dict[str, int]]
get_device_name ¶
get_device_name() -> str
Sanitized GPU device name, matching mot_linear_benchmarks.py output.
get_mot_configs cached ¶
Return {M: tile_config} for a given (K, N) shape, or None.
The return value maps an irregular grid of batch sizes (M) to Triton tile configurations. The caller should pick the entry whose M is closest to the actual batch size.
Config file is selected by device_name + dtype.
get_mot_default_config ¶
get_mot_default_config(
M: int,
N: int,
K: int,
dtype: str | None = None,
block_quant_shape: list[int] | None = None,
) -> dict[str, int]
Conservative fallback config guaranteed to compile on all hardware.
Trades peak performance for universal compatibility (T4 / V100 / A100 / H100, CUDA & ROCm).
invoke_mot_gemm ¶
invoke_mot_gemm(
A: Tensor,
B_text: Tensor,
B_vae: Tensor,
C: Tensor,
bias_text: Tensor | None,
bias_vae: Tensor | None,
text_indices: Tensor,
vae_indices: Tensor,
A_scale: Tensor | None,
B_text_scale: Tensor | None,
B_vae_scale: Tensor | None,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
A_per_channel_quant: bool,
B_per_channel_quant: bool,
config: dict[str, Any] | None = None,
)
mot_unified_gemm_kernel ¶
mot_unified_gemm_kernel(
a_ptr,
b_text_ptr,
b_vae_ptr,
c_ptr,
bias_text_ptr,
bias_vae_ptr,
text_indices_ptr,
vae_indices_ptr,
M_text,
M_vae,
N,
K,
stride_am,
stride_ak,
stride_bk_text,
stride_bn_text,
stride_bk_vae,
stride_bn_vae,
stride_cm,
stride_cn,
scale_a_ptr,
scale_b_text_ptr,
scale_b_vae_ptr,
stride_scale_a,
stride_scale_b,
BLOCK_SIZE_M: constexpr,
BLOCK_SIZE_N: constexpr,
BLOCK_SIZE_K: constexpr,
GROUP_SIZE_M: constexpr,
EVEN_K: constexpr,
EVEN_N: constexpr,
STRIDE_AK_IS_1: constexpr,
STRIDE_BK_IS_1: constexpr,
STRIDE_BN_IS_1: constexpr,
ACCUMULATOR_DTYPE: constexpr,
COMPUTE_DTYPE: constexpr,
OUTPUT_DTYPE: constexpr,
QUANT_TYPE: constexpr,
HAS_BIAS: constexpr = False,
)