class DeepGemmExperts(mk.FusedMoEExpertsModular):
"""DeepGemm-based fused MoE expert implementation."""
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(moe_config=moe_config, quant_config=quant_config)
# MXFP8: FP8 e4m3 values + UE8M0 1x32 block scales (Blackwell). Reuses
# the same grouped GEMM (aliased to fp8_fp4) with recipe (1, 32).
self.mxfp8 = quant_config.block_shape == [1, 32]
if self.mxfp8:
assert quant_config.quant_dtype == "mxfp8"
else:
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
# Gated-activation params: silu == swigluoai with alpha=1, beta=0.
# FP8 (silu) configs leave these None, reproducing plain silu.
self.gemm1_alpha = (
quant_config.gemm1_alpha if quant_config.gemm1_alpha is not None else 1.0
)
self.gemm1_beta = (
quant_config.gemm1_beta if quant_config.gemm1_beta is not None else 0.0
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
return is_deep_gemm_supported()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
return True
# MXFP8 1x32 uses the fp8_fp4 grouped GEMM with recipe (1, 32) — only
# available on Blackwell (SM100).
if (weight_key, activation_key) == (kMxfp8Static, kMxfp8Dynamic):
return current_platform.is_device_capability_family(100)
return False
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
# silu/swigluoai go through the fused alpha/beta kernel; swiglustep
# uses the unfused activation path. The fused kernel reads packed w13
# (gate = first half, up = second half), so it implements the
# *uninterleaved* SwiGLU-OAI variant.
return activation in [
MoEActivation.SILU,
MoEActivation.SWIGLUSTEP,
MoEActivation.SWIGLUOAI_UNINTERLEAVE,
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None
# Use the contiguous-layout M alignment (matches apply()); block_shape[0]
# is the quant block (1 for MXFP8) and would under-size the workspace.
block_m = get_mk_alignment_for_contiguous_layout()[0]
M_sum, align_used = compute_aligned_M_and_alignment(
M, topk, local_num_experts, block_m, expert_tokens_meta
)
assert M_sum % align_used == 0
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M_sum, max(activation_out_dim, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.block_shape is not None
block_k = self.block_shape[1]
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
# silu and swigluoai are both expressible by the fused gated kernel via
# (alpha, beta): silu uses alpha=1, beta=0; swigluoai uses config values.
# The fused kernel reads packed w13, hence SWIGLUOAI_UNINTERLEAVE.
fused_gated = activation in (
MoEActivation.SILU,
MoEActivation.SWIGLUOAI_UNINTERLEAVE,
)
# 1. DeepGemm UE8M0: fused gate+mul+clamp+quant+pack
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
if fused_gated:
return fused_silu_mul_fp8_quant_packed(
input=input,
output_q=output,
group_size=block_k,
clamp_limit=self.gemm1_clamp_limit,
alpha=self.gemm1_alpha,
beta=self.gemm1_beta,
)
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
act_out,
block_k,
out_q=output,
)
return a2q, a2q_scale
# 2. Hopper / non‑E8M0: prefer the fused gate+mul+quant kernel
if fused_gated:
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input,
output=output,
use_ue8m0=use_ue8m0,
clamp_limit=self.gemm1_clamp_limit,
group_size=block_k,
alpha=self.gemm1_alpha,
beta=self.gemm1_beta,
)
# 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output
)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is not None
assert a2_scale is None
assert self.block_shape is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, K = w1.size()
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
assert w2.size(1) == K
M_sum, _ = compute_aligned_M_and_alignment(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=get_mk_alignment_for_contiguous_layout()[0],
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
)
a1q, a1q_scale, expert_ids, inv_perm, align_used = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
expert_tokens_meta=expert_tokens_meta,
aq_out=a1q_perm,
# MXFP8 uses a 32-element activation-scale group (block_shape[1]);
# FP8-block keeps the default (128) alignment.
block_size=self.block_shape[1] if self.mxfp8 else None,
)
assert a1q.size(0) == M_sum
# MXFP8 (1x32) drives the fp8_fp4-aliased grouped GEMM with recipe
# (1, 32); the FP8 block path keeps the default (128) recipe.
gemm_kwargs = (
{"recipe_a": (1, self.block_shape[1]), "recipe_b": (1, self.block_shape[1])}
if self.mxfp8
else {}
)
# Cap DG's BLOCK_M heuristic at the workspace's per-expert alignment;
# otherwise the scheduler can pick the wrong expert id from m_indices
# under cudagraph replay.
with mk_alignment_scope(align_used):
mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_gemm_nt_contiguous(
(a1q, a1q_scale),
(w1, self.w1_scale),
mm1_out,
expert_ids,
**gemm_kwargs,
)
activation_out_dim = self.adjust_N_for_activation(N, activation)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_gemm_nt_contiguous(
(a2q, a2q_scale),
(w2, self.w2_scale),
mm2_out,
expert_ids,
**gemm_kwargs,
)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=output,
)