class DeepEPV2PrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
Prepare/Finalize using DeepEP v2 ElasticBuffer (unified API).
Supports two modes controlled by the `use_cudagraph` constructor arg:
**Decode mode (use_cudagraph=True):**
- do_expand=False, do_cpu_sync=False
- Tokens returned in original order with recv_topk_idx (global IDs)
- Worst-case tensor allocation; padding rows zeroed via
handle.psum_num_recv_tokens_per_scaleup_rank
- Fully cudagraph-capturable
- Expert kernel sorts internally (expert_tokens_meta=None)
**Prefill mode (use_cudagraph=False):**
- do_expand=True, do_cpu_sync=True
- Per-expert-contiguous layout; exact memory allocation
- Saves GPU memory (no worst-case allocation)
- Not cudagraph-capturable (CPU polling), but prefill doesn't
use cudagraphs anyway
- Provides expert_tokens_meta for efficient batched expert kernels
Both modes use async_with_compute_stream=False (synchronous from
caller's perspective). The ElasticBuffer handles comm internally.
"""
@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int, dtype: torch.dtype) -> int:
hidden_size_bytes = hidden_size * dtype.itemsize
xfer_atom_size = 512 # 32 * 16 (size(int4))
if hidden_size_bytes % xfer_atom_size == 0:
return hidden_size
hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
return hidden_size_bytes // dtype.itemsize
def __init__(
self,
buffer: deep_ep.ElasticBuffer,
num_dispatchers: int,
dp_size: int,
rank_expert_offset: int,
num_experts: int,
num_topk: int,
use_fp8_dispatch: bool = False,
use_cudagraph: bool = False,
):
super().__init__()
self.buffer = buffer
self.num_dispatchers_ = num_dispatchers
self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset
self.num_experts = num_experts
self.num_topk = num_topk
self.use_fp8_dispatch = use_fp8_dispatch
self.use_cudagraph = use_cudagraph
# DBO microbatching: one handle slot per micro-batch.
self.handles: list[deep_ep.EPHandle | None] = [None, None]
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64
def _do_dispatch(
self,
tokens: torch.Tensor,
token_scales: torch.Tensor | None,
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor,
num_experts: int,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> Callable:
has_scales = token_scales is not None
token_data = tokens
if has_scales:
token_data = (tokens, token_scales)
# Decode: do_expand=False + do_cpu_sync=False (cudagraph-safe)
# Prefill: do_expand=True + do_cpu_sync=True (memory-efficient)
do_expand = not self.use_cudagraph
do_cpu_sync = not self.use_cudagraph
(
recv_x,
recv_topk_idx,
recv_topk_weights,
handle,
event,
) = self.buffer.dispatch(
x=token_data,
topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights,
num_experts=num_experts,
do_expand=do_expand,
do_cpu_sync=do_cpu_sync,
async_with_compute_stream=False,
)
a2a_idx = dbo_current_ubatch_id()
self.handles[a2a_idx] = handle
return lambda: self._receiver(
event,
has_scales,
recv_x,
recv_topk_idx,
num_experts,
handle.num_recv_tokens_per_expert_list,
recv_topk_weights,
handle.psum_num_recv_tokens_per_scaleup_rank,
a1_scale,
quant_config,
defer_input_quant=defer_input_quant,
)
def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
recv_x: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
recv_topk_idx: torch.Tensor | None,
num_experts: int,
recv_expert_num_tokens: list[int],
recv_topk_weights: torch.Tensor | None,
psum_recv_per_rank: torch.Tensor,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> mk.PrepareResultType:
if event.event is not None:
event.current_stream_wait()
if isinstance(recv_x, tuple):
expert_x, expert_x_scale = recv_x
else:
expert_x, expert_x_scale = recv_x, None
if recv_topk_idx is None:
# do_expand=True (prefill mode): build topk_ids from
# per-expert token counts.
total_tokens = sum(recv_expert_num_tokens)
if total_tokens > 0:
recv_topk_idx = torch.empty(
total_tokens,
dtype=torch.int64,
device=expert_x.device,
)
offset = 0
for i, count in enumerate(recv_expert_num_tokens):
if count > 0:
recv_topk_idx[offset : offset + count].fill_(
i + self.rank_expert_offset
)
offset += count
else:
recv_topk_idx = torch.empty(
0,
dtype=torch.int64,
device=expert_x.device,
)
recv_topk_idx = recv_topk_idx.unsqueeze(1)
else:
# do_expand=False (decode/cudagraph mode): recv_topk_idx has
# LOCAL expert IDs (-1 for non-local and padding rows).
# Convert valid local IDs to global. Rows with -1 are
# skipped by expert kernels (TrtLLM tile-level skipping,
# DeepGemm is_computation_valid), so no need to zero
# hidden states, scales, or weights for padding rows.
valid_mask = recv_topk_idx >= 0
recv_topk_idx = torch.where(
valid_mask,
recv_topk_idx + self.rank_expert_offset,
recv_topk_idx,
)
# Reshape recv_topk_weights to match recv_topk_idx shape [N, 1]
if recv_topk_weights is not None and recv_topk_weights.ndim == 1:
recv_topk_weights = recv_topk_weights.unsqueeze(1)
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
recv_expert_num_tokens,
device=expert_x.device,
)
if not quant_config.is_block_quantized and not defer_input_quant:
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape,
is_scale_swizzled=quant_config.is_scale_swizzled,
)
return (
expert_x,
expert_x_scale,
expert_tokens_meta,
recv_topk_idx,
recv_topk_weights,
)
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.ReceiverType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
if quant_config.is_block_quantized and not defer_input_quant:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
a1_post_scale = None
else:
a1q = a1
a1q_scale = None
a1_post_scale = (
quant_config.a1_gscale
if quant_config.quant_dtype == "nvfp4"
else quant_config.a1_scale
)
return self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config,
defer_input_quant=defer_input_quant,
)
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
defer_input_quant,
)
return receiver()
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
do_async: bool,
) -> Callable | None:
a2a_idx = dbo_current_ubatch_id()
handle = self.handles[a2a_idx]
assert handle is not None
if fused_expert_output.numel() != 0:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
fused_expert_output = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
if fused_expert_output.dtype != torch.bfloat16:
raise ValueError(
f"DeepEP v2 combine requires bfloat16 input, "
f"got {fused_expert_output.dtype}"
)
combined_x, _, event = self.buffer.combine(
x=fused_expert_output,
handle=handle,
topk_weights=None,
async_with_compute_stream=False,
)
output.copy_(combined_x, non_blocking=True)
return None
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
False,
)
return lambda: None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
False,
)