Skip to content

vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize

DEEPEP_QUANT_BLOCK_SIZE module-attribute

DEEPEP_QUANT_BLOCK_SIZE = 128

DeepEPLLPrepareAndFinalize

Bases: FusedMoEPrepareAndFinalize

Prepare/Finalize using DeepEP low-latency kernels.

Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
    """
    Prepare/Finalize using DeepEP low-latency kernels.
    """

    # DeepEP low-latency kernels are compiled only for certain
    # specific hidden sizes.
    SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]

    def __init__(self,
                 buffer: deep_ep.Buffer,
                 world_size: int,
                 dp_size: int,
                 max_tokens_per_rank: int,
                 quant_dtype: Optional[torch.dtype] = None,
                 block_shape: Optional[list[int]] = None,
                 use_fp8_dispatch: bool = False):
        super().__init__()

        self.buffer = buffer
        self.world_size = world_size
        self.dp_size = dp_size
        self.quant_dtype = quant_dtype
        self.block_shape = block_shape
        self.max_tokens_per_rank = max_tokens_per_rank
        self.use_fp8_dispatch = use_fp8_dispatch
        # The dispatch function returns a handle that the combine function
        # requires. We store the handle here so it is available to the
        # combine function.
        self.handle = None

    def max_num_tokens_per_rank(self) -> Optional[int]:
        return self.max_tokens_per_rank

    def topk_indices_dtype(self) -> Optional[torch.dtype]:
        return torch.int64

    def _do_quant(
            self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
            a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
            a1_dtype: torch.dtype
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:

        block_k = self.block_shape[1] if self.block_shape is not None else None
        if self.use_fp8_dispatch:
            if block_k == DEEPEP_QUANT_BLOCK_SIZE:
                # DeepEP kernels did the quantization for us.
                x, x_scales = x
                return x, x_scales

            # Dequant to get back the tokens in the datatype we dispatched in.
            x_fp8, x_scales = x
            x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)

        assert isinstance(x, torch.Tensor)

        # Check if there is a block_shape / or if we can infer the quantization
        # schemes from the scales.
        per_token_quant = None
        if all([v is None for v in [self.block_shape, a1_scale, a2_scale]
                ]) and self.quant_dtype is not None:
            # Quantization required despite none of the inputs suggesting
            # quantization. Fallback to per_token_dynamic quant.
            per_token_quant = True
        else:
            per_token_quant = ((self.block_shape is not None) or
                               (a1_scale is not None and a1_scale.numel() != 1)
                               or (a2_scale is not None
                                   and a2_scale.numel() != 1))

        num_experts, max_tokens, hidden_dim = x.size()

        # TODO (varun): Optimization - Use a batched version of quant
        x = x.view((-1, hidden_dim))
        x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype,
                                                per_token_quant,
                                                self.block_shape)
        x = x.view((num_experts, -1, hidden_dim))

        if per_token_quant:
            assert x_scales is not None
            x_scales = x_scales.view(num_experts, max_tokens, -1)

        return x, x_scales

    def prepare(
        self,
        a1: torch.Tensor,
        a1_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],
        rank_topk_weights: torch.Tensor,
        rank_topk_ids: torch.Tensor,
        num_experts: int,
        expert_map: Optional[torch.Tensor],
        apply_router_weight_on_input: bool,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
               Optional[torch.Tensor], Optional[torch.Tensor]]:

        hidden_size = a1.size(1)
        assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
            (f"Hidden Size {hidden_size} not in supported list of hidden sizes"
            f"{self.SUPPORTED_HIDDEN_SIZES}")

        if self.use_fp8_dispatch:
            assert hidden_size % 128 == 0, \
            "DeepEP kernels quantize the inputs in blocks of shape 128"

        has_per_token_scales = a1_scale.numel(
        ) != 1 if a1_scale is not None else (
            a2_scale.numel() != 1 if a2_scale is not None else False)
        assert not has_per_token_scales, (
            "low_latency kernels doesn't support dispatching per-token scales")

        if apply_router_weight_on_input:
            topk = rank_topk_ids.size(1)
            # TODO: this only works for topK=1, will need to update for topK>1
            assert topk == 1, (
                "apply_router_weight_on_input is only implemented for topk=1")
            a1 = a1 * rank_topk_weights.to(a1.dtype)

        # Dispatch
        expert_x, expert_num_tokens, self.handle, event, hook = \
                self.buffer.low_latency_dispatch(a1,
                                                rank_topk_ids,
                                                self.max_tokens_per_rank,
                                                num_experts,
                                                use_fp8=self.use_fp8_dispatch,
                                                async_finish=False,
                                                return_recv_hook=False)

        expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
                                                  a1.dtype)

        return (expert_x, expert_x_scale, expert_num_tokens, None, None)

    def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
                 topk_weights: torch.Tensor, topk_ids: torch.Tensor,
                 apply_router_weight_on_input: bool) -> None:

        assert self.handle is not None

        combine_topk_weights = topk_weights
        if apply_router_weight_on_input:
            # weights have already been applied.
            combine_topk_weights = torch.ones_like(topk_weights)

        # TODO (varun) : Enable zero copy mode
        _, event, hook = self.buffer.low_latency_combine(
            fused_expert_output,
            topk_ids,
            combine_topk_weights,
            self.handle,
            async_finish=False,
            zero_copy=False,
            return_recv_hook=False,
            out=output)

SUPPORTED_HIDDEN_SIZES class-attribute instance-attribute

SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]

block_shape instance-attribute

block_shape = block_shape

buffer instance-attribute

buffer = buffer

dp_size instance-attribute

dp_size = dp_size

handle instance-attribute

handle = None

max_tokens_per_rank instance-attribute

max_tokens_per_rank = max_tokens_per_rank

quant_dtype instance-attribute

quant_dtype = quant_dtype

use_fp8_dispatch instance-attribute

use_fp8_dispatch = use_fp8_dispatch

world_size instance-attribute

world_size = world_size

__init__

__init__(
    buffer: Buffer,
    world_size: int,
    dp_size: int,
    max_tokens_per_rank: int,
    quant_dtype: Optional[dtype] = None,
    block_shape: Optional[list[int]] = None,
    use_fp8_dispatch: bool = False,
)
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def __init__(self,
             buffer: deep_ep.Buffer,
             world_size: int,
             dp_size: int,
             max_tokens_per_rank: int,
             quant_dtype: Optional[torch.dtype] = None,
             block_shape: Optional[list[int]] = None,
             use_fp8_dispatch: bool = False):
    super().__init__()

    self.buffer = buffer
    self.world_size = world_size
    self.dp_size = dp_size
    self.quant_dtype = quant_dtype
    self.block_shape = block_shape
    self.max_tokens_per_rank = max_tokens_per_rank
    self.use_fp8_dispatch = use_fp8_dispatch
    # The dispatch function returns a handle that the combine function
    # requires. We store the handle here so it is available to the
    # combine function.
    self.handle = None

_do_quant

_do_quant(
    x: Union[Tensor, tuple[Tensor, Tensor]],
    a1_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    a1_dtype: dtype,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def _do_quant(
        self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
        a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
        a1_dtype: torch.dtype
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:

    block_k = self.block_shape[1] if self.block_shape is not None else None
    if self.use_fp8_dispatch:
        if block_k == DEEPEP_QUANT_BLOCK_SIZE:
            # DeepEP kernels did the quantization for us.
            x, x_scales = x
            return x, x_scales

        # Dequant to get back the tokens in the datatype we dispatched in.
        x_fp8, x_scales = x
        x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)

    assert isinstance(x, torch.Tensor)

    # Check if there is a block_shape / or if we can infer the quantization
    # schemes from the scales.
    per_token_quant = None
    if all([v is None for v in [self.block_shape, a1_scale, a2_scale]
            ]) and self.quant_dtype is not None:
        # Quantization required despite none of the inputs suggesting
        # quantization. Fallback to per_token_dynamic quant.
        per_token_quant = True
    else:
        per_token_quant = ((self.block_shape is not None) or
                           (a1_scale is not None and a1_scale.numel() != 1)
                           or (a2_scale is not None
                               and a2_scale.numel() != 1))

    num_experts, max_tokens, hidden_dim = x.size()

    # TODO (varun): Optimization - Use a batched version of quant
    x = x.view((-1, hidden_dim))
    x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype,
                                            per_token_quant,
                                            self.block_shape)
    x = x.view((num_experts, -1, hidden_dim))

    if per_token_quant:
        assert x_scales is not None
        x_scales = x_scales.view(num_experts, max_tokens, -1)

    return x, x_scales

finalize

finalize(
    output: Tensor,
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
) -> None
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
             topk_weights: torch.Tensor, topk_ids: torch.Tensor,
             apply_router_weight_on_input: bool) -> None:

    assert self.handle is not None

    combine_topk_weights = topk_weights
    if apply_router_weight_on_input:
        # weights have already been applied.
        combine_topk_weights = torch.ones_like(topk_weights)

    # TODO (varun) : Enable zero copy mode
    _, event, hook = self.buffer.low_latency_combine(
        fused_expert_output,
        topk_ids,
        combine_topk_weights,
        self.handle,
        async_finish=False,
        zero_copy=False,
        return_recv_hook=False,
        out=output)

max_num_tokens_per_rank

max_num_tokens_per_rank() -> Optional[int]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def max_num_tokens_per_rank(self) -> Optional[int]:
    return self.max_tokens_per_rank

prepare

prepare(
    a1: Tensor,
    a1_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    rank_topk_weights: Tensor,
    rank_topk_ids: Tensor,
    num_experts: int,
    expert_map: Optional[Tensor],
    apply_router_weight_on_input: bool,
) -> tuple[
    Tensor,
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def prepare(
    self,
    a1: torch.Tensor,
    a1_scale: Optional[torch.Tensor],
    a2_scale: Optional[torch.Tensor],
    rank_topk_weights: torch.Tensor,
    rank_topk_ids: torch.Tensor,
    num_experts: int,
    expert_map: Optional[torch.Tensor],
    apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
           Optional[torch.Tensor], Optional[torch.Tensor]]:

    hidden_size = a1.size(1)
    assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
        (f"Hidden Size {hidden_size} not in supported list of hidden sizes"
        f"{self.SUPPORTED_HIDDEN_SIZES}")

    if self.use_fp8_dispatch:
        assert hidden_size % 128 == 0, \
        "DeepEP kernels quantize the inputs in blocks of shape 128"

    has_per_token_scales = a1_scale.numel(
    ) != 1 if a1_scale is not None else (
        a2_scale.numel() != 1 if a2_scale is not None else False)
    assert not has_per_token_scales, (
        "low_latency kernels doesn't support dispatching per-token scales")

    if apply_router_weight_on_input:
        topk = rank_topk_ids.size(1)
        # TODO: this only works for topK=1, will need to update for topK>1
        assert topk == 1, (
            "apply_router_weight_on_input is only implemented for topk=1")
        a1 = a1 * rank_topk_weights.to(a1.dtype)

    # Dispatch
    expert_x, expert_num_tokens, self.handle, event, hook = \
            self.buffer.low_latency_dispatch(a1,
                                            rank_topk_ids,
                                            self.max_tokens_per_rank,
                                            num_experts,
                                            use_fp8=self.use_fp8_dispatch,
                                            async_finish=False,
                                            return_recv_hook=False)

    expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
                                              a1.dtype)

    return (expert_x, expert_x_scale, expert_num_tokens, None, None)

topk_indices_dtype

topk_indices_dtype() -> Optional[dtype]
Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def topk_indices_dtype(self) -> Optional[torch.dtype]:
    return torch.int64

dequant_fp8

dequant_fp8(
    expert_x_fp8: Tensor, expert_x_scales: Tensor
) -> Tensor

Return dequantized tensor in fp32

Source code in vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
def dequant_fp8(expert_x_fp8: torch.Tensor,
                expert_x_scales: torch.Tensor) -> torch.Tensor:
    """
    Return dequantized tensor in fp32
    """
    # TODO (varun) : Optimize leverage num_tokens_per_expert counts
    assert expert_x_fp8.is_contiguous()
    expert_x_scales = expert_x_scales.contiguous()
    num_experts = expert_x_fp8.size(0)

    expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
        num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE)
    expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
    return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape)