Skip to content

vllm.models.minimax_m3.amd.ops

AMD/ROCm fused Triton ops for MiniMax-M3.

These replace per-element PyTorch fallbacks (FlashInfer / fused HIP kernels are unavailable on ROCm) with single-pass Triton kernels to cut launch overhead and intermediate-tensor traffic during decode.

Modules:

  • gemma_rmsnorm

    Fused Gemma-style RMSNorm for AMD ROCm via Triton.

  • swiglu_oai

    Fused SwiGLU-OAI activation (split layout) for AMD ROCm via Triton.

Functions:

swiglu_oai_quantize_mxfp8(gate_up, alpha, beta, limit, block_m=64)

SwiGLU-OAI on split-layout [M, 2I] fused with MXFP8 activation-quant.

Returns (act_q [M, I] float8_e4m3fn, act_scale [M, I//32] uint8 E8M0), identical to mxfp8_e4m3_quantize(swiglu_oai_split(gate_up)) but in a single Triton pass (no bf16 intermediate). Used between the two GEMMs of the native MXFP8 MoE. Numerically equivalent to the unfused chain (bit-exact on measured MoE shapes); marginally more accurate (fp32 act, no bf16 round-trip).

Source code in vllm/models/minimax_m3/amd/ops/swiglu_oai.py
def swiglu_oai_quantize_mxfp8(
    gate_up: torch.Tensor,
    alpha: float,
    beta: float,
    limit: float | None,
    block_m: int = 64,
) -> tuple[torch.Tensor, torch.Tensor]:
    """SwiGLU-OAI on split-layout ``[M, 2I]`` fused with MXFP8 activation-quant.

    Returns ``(act_q [M, I] float8_e4m3fn, act_scale [M, I//32] uint8 E8M0)``,
    identical to ``mxfp8_e4m3_quantize(swiglu_oai_split(gate_up))`` but in a
    single Triton pass (no bf16 intermediate). Used between the two GEMMs of the
    native MXFP8 MoE. Numerically equivalent to the unfused chain (bit-exact on
    measured MoE shapes); marginally more accurate (fp32 act, no bf16 round-trip).
    """
    from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
        MXFP8_BLOCK_SIZE,
        MXFP8_SCALE_DTYPE,
        MXFP8_VALUE_DTYPE,
    )

    two_i = gate_up.shape[-1]
    n_inter = two_i // 2
    assert n_inter % MXFP8_BLOCK_SIZE == 0, (
        f"fused swiglu+quant needs I % {MXFP8_BLOCK_SIZE} == 0, got I={n_inter}"
    )
    g1 = gate_up.reshape(-1, two_i).contiguous()
    M = g1.shape[0]
    aq = torch.empty((M, n_inter), dtype=MXFP8_VALUE_DTYPE, device=g1.device)
    asc = torch.empty(
        (M, n_inter // MXFP8_BLOCK_SIZE), dtype=MXFP8_SCALE_DTYPE, device=g1.device
    )
    grid = (triton.cdiv(M, block_m), n_inter // MXFP8_BLOCK_SIZE)
    _swiglu_oai_quant_kernel[grid](
        g1,
        aq,
        asc,
        M,
        n_inter,
        g1.stride(0),
        g1.stride(1),
        aq.stride(0),
        aq.stride(1),
        asc.stride(0),
        asc.stride(1),
        float(alpha),
        float(beta),
        0.0 if limit is None else float(limit),
        HAS_LIMIT=limit is not None,
        BLOCK_M=block_m,
        num_warps=4,
    )
    return aq, asc

swiglu_oai_split(gate_up, alpha, beta, limit, out_dtype=None)

SwiGLU-OAI on a split-layout [*, 2I] tensor -> [*, I].

Source code in vllm/models/minimax_m3/amd/ops/swiglu_oai.py
def swiglu_oai_split(
    gate_up: torch.Tensor,
    alpha: float,
    beta: float,
    limit: float | None,
    out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
    """SwiGLU-OAI on a split-layout ``[*, 2I]`` tensor -> ``[*, I]``."""
    orig_shape = gate_up.shape
    two_i = orig_shape[-1]
    n_inter = two_i // 2
    x2 = gate_up.reshape(-1, two_i)
    m = x2.shape[0]
    dt = out_dtype if out_dtype is not None else gate_up.dtype
    out = torch.empty((m, n_inter), dtype=dt, device=gate_up.device)
    # Tile tuned on gfx950. The SwiGLU intermediate is sharded across tensor
    # parallel ranks (per-rank n_inter = I / tp: dense I=12288, MoE I=3072), and
    # a 512-wide tile (4 warps, ~2 elems/lane) only helps once the per-rank slice
    # is large enough to be bandwidth-bound — at TP=1 prefill that is ~1.25-1.35x
    # faster than 256. For small sharded slices (high TP) the kernel is launch-
    # bound (~12us) and a wide tile can slightly regress, so fall back to 256.
    # Decode is launch-bound at every TP. num_warps=8 underfills this tile, so it
    # is pinned to 4.
    block_i = 512 if n_inter >= 2048 else 256
    grid = (m, triton.cdiv(n_inter, block_i))
    _swiglu_oai_kernel[grid](
        x2,
        out,
        n_inter,
        x2.stride(0),
        x2.stride(1),
        out.stride(0),
        out.stride(1),
        float(alpha),
        float(beta),
        0.0 if limit is None else float(limit),
        HAS_LIMIT=limit is not None,
        BLOCK_I=block_i,
        num_warps=4,
    )
    return out.reshape(*orig_shape[:-1], n_inter)