Native MXFP8 (1x32 block, E8M0 scale) MoE for AMD CDNA4 (gfx950) via Triton tl.dot_scaled (hardware microscaling matmul).
The expert GEMMs consume the FP8 E4M3 weights and their E8M0 block scales directly (no dequant-to-BF16), and activations are MXFP8-quantized per token. On CDNA4 dot_scaled maps to the native MX matrix-core ops; on other archs Triton upcasts to BF16 (so this stays correct, just not faster) — but the oracle only selects this path on gfx950 and routes everything else to the BF16 Mxfp8EmulationTritonExperts fallback.
Structure mirrors vLLM's fused_moe_kernel: tokens are sorted by expert (moe_align_block_size); each program computes a [BLOCK_M, BLOCK_N] tile for one expert, accumulating over K with dot_scaled. SwiGLU-OAI activation and the top-k weighted reduction run in PyTorch between/after the two GEMMs.
Classes:
Mxfp8NativeTritonExperts
Bases: Mxfp8TritonExpertsBase
Native MXFP8 MoE (CDNA4 dot_scaled) on gfx950.
Source code in vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py
| class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase):
"""Native MXFP8 MoE (CDNA4 ``dot_scaled``) on gfx950."""
@property
def quant_dtype(self) -> torch.dtype | str | None:
return self.quant_config.quant_dtype
@property
def block_shape(self) -> list[int] | None:
return self.quant_config.block_shape
@property
def expects_unquantized_inputs(self) -> bool:
# Activations are MXFP8-quantized inside ``fused_moe_mxfp8_native``.
return True
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_rocm() and current_platform.supports_mx()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
alpha = self.quant_config.gemm1_alpha
alpha = 1.702 if alpha is None else float(alpha)
beta = self.quant_config.gemm1_beta
beta = 1.0 if beta is None else float(beta)
limit = self.quant_config.gemm1_clamp_limit
limit = None if limit is None else float(limit)
out = fused_moe_mxfp8_native(
hidden_states,
w1,
self.w1_scale_val,
w2,
self.w2_scale_val,
topk_weights,
topk_ids,
alpha=alpha,
beta=beta,
limit=limit,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
output.copy_(out)
|