Skip to content

vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe

Native MXFP8 (1x32 block, E8M0 scale) MoE for AMD CDNA4 (gfx950) via Triton tl.dot_scaled (hardware microscaling matmul).

The expert GEMMs consume the FP8 E4M3 weights and their E8M0 block scales directly (no dequant-to-BF16), and activations are MXFP8-quantized per token. On CDNA4 dot_scaled maps to the native MX matrix-core ops; on other archs Triton upcasts to BF16 (so this stays correct, just not faster) — but the oracle only selects this path on gfx950 and routes everything else to the BF16 Mxfp8EmulationTritonExperts fallback.

Structure mirrors vLLM's fused_moe_kernel: tokens are sorted by expert (moe_align_block_size); each program computes a [BLOCK_M, BLOCK_N] tile for one expert, accumulating over K with dot_scaled. SwiGLU-OAI activation and the top-k weighted reduction run in PyTorch between/after the two GEMMs.

Classes:

Mxfp8NativeTritonExperts

Bases: Mxfp8TritonExpertsBase

Native MXFP8 MoE (CDNA4 dot_scaled) on gfx950.

Source code in vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py
class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase):
    """Native MXFP8 MoE (CDNA4 ``dot_scaled``) on gfx950."""

    @property
    def quant_dtype(self) -> torch.dtype | str | None:
        return self.quant_config.quant_dtype

    @property
    def block_shape(self) -> list[int] | None:
        return self.quant_config.block_shape

    @property
    def expects_unquantized_inputs(self) -> bool:
        # Activations are MXFP8-quantized inside ``fused_moe_mxfp8_native``.
        return True

    @staticmethod
    def _supports_current_device() -> bool:
        return current_platform.is_rocm() and current_platform.supports_mx()

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,
        a2_scale: torch.Tensor | None,
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        apply_router_weight_on_input: bool,
    ):
        alpha = self.quant_config.gemm1_alpha
        alpha = 1.702 if alpha is None else float(alpha)
        beta = self.quant_config.gemm1_beta
        beta = 1.0 if beta is None else float(beta)
        limit = self.quant_config.gemm1_clamp_limit
        limit = None if limit is None else float(limit)
        out = fused_moe_mxfp8_native(
            hidden_states,
            w1,
            self.w1_scale_val,
            w2,
            self.w2_scale_val,
            topk_weights,
            topk_ids,
            alpha=alpha,
            beta=beta,
            limit=limit,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
        )
        output.copy_(out)