Skip to content

vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe

__all__ module-attribute

__all__ = [
    "CompressedTensorsMoEMethod",
    "CompressedTensorsW8A8Fp8MoEMethod",
    "CompressedTensorsW8A8Int8MoEMethod",
    "CompressedTensorsWNA16MarlinMoEMethod",
    "CompressedTensorsWNA16MoEMethod",
    "CompressedTensorsW4A4MoeMethod",
]

logger module-attribute

logger = init_logger(__name__)

CompressedTensorsMoEMethod

Bases: FusedMoEMethodBase

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
class CompressedTensorsMoEMethod(FusedMoEMethodBase):

    def __init_(self, moe: FusedMoEConfig):
        super().__init__(moe)

    @staticmethod
    def get_moe_method(
        quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
        layer: torch.nn.Module
    ) -> "CompressedTensorsMoEMethod":
        # TODO: @dsikka: refactor this to use schemes as other kernels
        # are supported + check if the layer is being ignored.
        # Check if a using "Linear" to select schemes
        if "Linear" in quant_config.target_scheme_map:
            matched_target = "Linear"
        else:
            # May have instead defined the linear layers in the fused model

            fused_layers = [
                "re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"
            ]
            current_scheme = None
            for fused_layer in fused_layers:
                # Check if one of the fused layers are defined in quant_config
                matched_target = find_matched_target(
                    layer_name=fused_layer,
                    module=layer,
                    targets=quant_config.target_scheme_map.keys(),
                    fused_mapping=quant_config.packed_modules_mapping)

                # Only valid if down_proj, gate_proj, and up_proj
                # are mapped to the same quant scheme in the quant_config
                if current_scheme is None:
                    current_scheme = quant_config.target_scheme_map.get(
                        matched_target)
                else:
                    assert current_scheme == quant_config.target_scheme_map.get(
                        matched_target)

        weight_quant = quant_config.target_scheme_map[matched_target].get(
            "weights")
        input_quant = quant_config.target_scheme_map[matched_target].get(
            "input_activations")

        if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
            # group_size=None means channelwise
            group_size = weight_quant.group_size or -1
            # Prefer to use the MarlinMoE kernel when it is supported.
            if not check_moe_marlin_supports_layer(layer, group_size):
                if (weight_quant.strategy in QuantizationStrategy.GROUP and
                        weight_quant.actorder in (ActivationOrdering.GROUP,
                                                  ActivationOrdering.DYNAMIC)):
                    raise ValueError(
                        "WNA16MoE is not supported with actorder=group/dynamic."
                    )
                logger.info_once("Using CompressedTensorsWNA16MoEMethod")
                return CompressedTensorsWNA16MoEMethod(quant_config,
                                                       layer.moe_config)
            else:
                logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
                return CompressedTensorsWNA16MarlinMoEMethod(
                    quant_config, layer.moe_config)
        elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
            return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
        elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
              or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
              or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
            return CompressedTensorsW8A8Fp8MoEMethod(quant_config,
                                                     layer.moe_config)
        elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
            return CompressedTensorsW8A8Int8MoEMethod(quant_config,
                                                      layer.moe_config)
        else:
            raise RuntimeError(
                f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")

__init_

__init_(moe: FusedMoEConfig)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def __init_(self, moe: FusedMoEConfig):
    super().__init__(moe)

get_moe_method staticmethod

get_moe_method(
    quant_config: CompressedTensorsConfig, layer: Module
) -> CompressedTensorsMoEMethod
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@staticmethod
def get_moe_method(
    quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
    layer: torch.nn.Module
) -> "CompressedTensorsMoEMethod":
    # TODO: @dsikka: refactor this to use schemes as other kernels
    # are supported + check if the layer is being ignored.
    # Check if a using "Linear" to select schemes
    if "Linear" in quant_config.target_scheme_map:
        matched_target = "Linear"
    else:
        # May have instead defined the linear layers in the fused model

        fused_layers = [
            "re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"
        ]
        current_scheme = None
        for fused_layer in fused_layers:
            # Check if one of the fused layers are defined in quant_config
            matched_target = find_matched_target(
                layer_name=fused_layer,
                module=layer,
                targets=quant_config.target_scheme_map.keys(),
                fused_mapping=quant_config.packed_modules_mapping)

            # Only valid if down_proj, gate_proj, and up_proj
            # are mapped to the same quant scheme in the quant_config
            if current_scheme is None:
                current_scheme = quant_config.target_scheme_map.get(
                    matched_target)
            else:
                assert current_scheme == quant_config.target_scheme_map.get(
                    matched_target)

    weight_quant = quant_config.target_scheme_map[matched_target].get(
        "weights")
    input_quant = quant_config.target_scheme_map[matched_target].get(
        "input_activations")

    if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
        # group_size=None means channelwise
        group_size = weight_quant.group_size or -1
        # Prefer to use the MarlinMoE kernel when it is supported.
        if not check_moe_marlin_supports_layer(layer, group_size):
            if (weight_quant.strategy in QuantizationStrategy.GROUP and
                    weight_quant.actorder in (ActivationOrdering.GROUP,
                                              ActivationOrdering.DYNAMIC)):
                raise ValueError(
                    "WNA16MoE is not supported with actorder=group/dynamic."
                )
            logger.info_once("Using CompressedTensorsWNA16MoEMethod")
            return CompressedTensorsWNA16MoEMethod(quant_config,
                                                   layer.moe_config)
        else:
            logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
            return CompressedTensorsWNA16MarlinMoEMethod(
                quant_config, layer.moe_config)
    elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
        return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
    elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
          or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
          or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
        return CompressedTensorsW8A8Fp8MoEMethod(quant_config,
                                                 layer.moe_config)
    elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
        return CompressedTensorsW8A8Int8MoEMethod(quant_config,
                                                  layer.moe_config)
    else:
        raise RuntimeError(
            f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")

CompressedTensorsW4A4MoeMethod

Bases: CompressedTensorsMoEMethod

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):

    def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
            detect_nvfp4_moe_support)
        super().__init__(moe)
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
        self.allow_flashinfer = _nvfp4.allow_flashinfer
        self.use_marlin = _nvfp4.use_marlin
        self.group_size = 16
        self.layer = layer

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        layer.num_experts = num_experts
        layer.params_dtype = params_dtype

        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
                requires_grad=False,
                dtype=torch.uint8),
            requires_grad=False)
        layer.register_parameter("w13_weight_packed", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
                dtype=torch.uint8),
            requires_grad=False)
        layer.register_parameter("w2_weight_packed", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # Weight Scales
        w13_weight_scale = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.group_size,
                dtype=torch.float8_e4m3fn),
            requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

        w2_weight_scale = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // self.group_size,
                dtype=torch.float8_e4m3fn),
            requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # Weight Global Scales
        w13_weight_scale_2 = torch.nn.Parameter(torch.empty(
            num_experts, 2, dtype=torch.float32),
                                                requires_grad=False)
        layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
        set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)

        w2_weight_scale_2 = torch.nn.Parameter(torch.empty(
            num_experts, dtype=torch.float32),
                                               requires_grad=False)
        layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
        set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)

        # Input Global Scales
        w13_input_scale = torch.nn.Parameter(torch.empty(num_experts,
                                                         2,
                                                         dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w13_input_global_scale", w13_input_scale)
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
        set_weight_attrs(w13_input_scale, extra_weight_attrs)

        w2_input_scale = torch.nn.Parameter(torch.empty(num_experts,
                                                        dtype=torch.float32),
                                            requires_grad=False)
        layer.register_parameter("w2_input_global_scale", w2_input_scale)
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
        set_weight_attrs(w2_input_scale, extra_weight_attrs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

        # From packed to weight
        layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data,
                                              requires_grad=False)

        layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
                                             requires_grad=False)

        # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
        if self.allow_flashinfer:
            w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
                                        layer.w13_weight_scale.data,
                                        dim=-2)
            layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
            layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)

        if not torch.allclose(layer.w13_weight_global_scale[:, 0],
                              layer.w13_weight_global_scale[:, 1]):
            logger.warning_once(
                "w1_weight_global_scale must match w3_weight_global_scale. "
                "Accuracy may be affected.")

        # Take inverse of global scale saved to disk
        layer.w13_weight_scale_2 = torch.nn.Parameter(
            1 / layer.w13_weight_global_scale[:, 0], requires_grad=False)

        layer.w2_weight_scale_2 = torch.nn.Parameter(
            1 / layer.w2_weight_global_scale.data, requires_grad=False)

        if self.use_marlin:
            prepare_moe_fp4_layer_for_marlin(layer)
            return

        # swizzle weight scales
        layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
            layer.w13_weight_scale),
                                                    requires_grad=False)

        layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
            layer.w2_weight_scale),
                                                   requires_grad=False)

        # w13
        w13_input_global_scale = layer.w13_input_global_scale.max(
            dim=1).values.to(torch.float32)

        layer.g1_alphas = torch.nn.Parameter(
            ((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
            requires_grad=False)

        layer.w13_input_scale_quant = torch.nn.Parameter(
            (w13_input_global_scale), requires_grad=False)

        # w2
        layer.g2_alphas = torch.nn.Parameter(
            ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
                torch.float32),
            requires_grad=False)

        layer.w2_input_scale_quant = torch.nn.Parameter(
            (layer.w2_input_global_scale), requires_grad=False)

    def maybe_make_prepare_finalize(
        self,
        moe: FusedMoEConfig,
    ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if not self.allow_flashinfer:
            return super().maybe_make_prepare_finalize(moe)

        prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
            moe,
            a1_gscale=self.layer.w13_input_scale_quant,
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
        """Return the appropriate GEMM experts implementation."""
        experts = select_nvfp4_gemm_impl(
            moe,
            g1_alphas=self.layer.g1_alphas,
            g2_alphas=self.layer.g2_alphas,
            a1_gscale=self.layer.w13_input_scale_quant,
            a2_gscale=self.layer.w2_input_scale_quant,
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert self.fused_experts is None

        if enable_eplb:
            raise NotImplementedError("EPLB not supported for "
                                      "`CompressedTensorsW4A4MoeMethod` yet.")
        assert activation == "silu", "Only SiLU activation is supported."

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
        )

        if self.use_marlin:
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)

        # FlashInfer fused experts path
        if self.fused_experts is not None:
            assert is_valid_flashinfer_cutlass_fused_moe(
                x, layer.w13_weight, layer.w2_weight), (
                    "Flashinfer CUTLASS Fused MoE not applicable!")

            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

        elif self.allow_flashinfer:
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                flashinfer_cutlass_moe_fp4)

            assert is_valid_flashinfer_cutlass_fused_moe(
                x, layer.w13_weight, layer.w2_weight), (
                    "Flashinfer CUTLASS Fused MoE not applicable!")

            return flashinfer_cutlass_moe_fp4(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                g1_alphas=layer.g1_alphas,
                g2_alphas=layer.g2_alphas,
                a1_gscale=layer.w13_input_scale_quant,
                a2_gscale=layer.w2_input_scale_quant,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

        assert expert_map is None, ("Expert Parallelism / expert_map "
                                    "is currently not supported for "
                                    "CompressedTensorsW4A4MoeMethod.")
        from vllm.model_executor.layers.fused_moe.cutlass_moe import (
            cutlass_moe_fp4)

        # Cutlass moe takes in activations in BF16/Half precision
        # and fp4 quantized weights loaded from the checkpoint
        return cutlass_moe_fp4(
            a=x,
            w1_fp4=layer.w13_weight,
            w2_fp4=layer.w2_weight,
            w1_blockscale=layer.w13_weight_scale,
            w2_blockscale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            m=x.shape[0],
            n=layer.w2_weight.shape[2] * 2,
            k=x.shape[1],
            e=layer.w13_weight.shape[0],
            apply_router_weight_on_input=apply_router_weight_on_input).to(
                x.dtype)

allow_flashinfer instance-attribute

allow_flashinfer = allow_flashinfer

cutlass_nvfp4_supported instance-attribute

cutlass_nvfp4_supported = cutlass_supported

group_size instance-attribute

group_size = 16

layer instance-attribute

layer = layer

use_marlin instance-attribute

use_marlin = use_marlin

__init__

__init__(moe: FusedMoEConfig, layer: Module)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
    from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
        detect_nvfp4_moe_support)
    super().__init__(moe)
    _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
    self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
    self.allow_flashinfer = _nvfp4.allow_flashinfer
    self.use_marlin = _nvfp4.use_marlin
    self.group_size = 16
    self.layer = layer

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert self.fused_experts is None

    if enable_eplb:
        raise NotImplementedError("EPLB not supported for "
                                  "`CompressedTensorsW4A4MoeMethod` yet.")
    assert activation == "silu", "Only SiLU activation is supported."

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype,
    )

    if self.use_marlin:
        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            None,
            None,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            global_scale1=layer.w13_weight_scale_2,
            global_scale2=layer.w2_weight_scale_2,
            quant_type_id=scalar_types.float4_e2m1f.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map)

    # FlashInfer fused experts path
    if self.fused_experts is not None:
        assert is_valid_flashinfer_cutlass_fused_moe(
            x, layer.w13_weight, layer.w2_weight), (
                "Flashinfer CUTLASS Fused MoE not applicable!")

        return self.fused_experts(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,  # TODO(shuw): fix later, now output is high prec
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

    elif self.allow_flashinfer:
        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
            flashinfer_cutlass_moe_fp4)

        assert is_valid_flashinfer_cutlass_fused_moe(
            x, layer.w13_weight, layer.w2_weight), (
                "Flashinfer CUTLASS Fused MoE not applicable!")

        return flashinfer_cutlass_moe_fp4(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,  # TODO(shuw): fix later, now output is high prec
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

    assert expert_map is None, ("Expert Parallelism / expert_map "
                                "is currently not supported for "
                                "CompressedTensorsW4A4MoeMethod.")
    from vllm.model_executor.layers.fused_moe.cutlass_moe import (
        cutlass_moe_fp4)

    # Cutlass moe takes in activations in BF16/Half precision
    # and fp4 quantized weights loaded from the checkpoint
    return cutlass_moe_fp4(
        a=x,
        w1_fp4=layer.w13_weight,
        w2_fp4=layer.w2_weight,
        w1_blockscale=layer.w13_weight_scale,
        w2_blockscale=layer.w2_weight_scale,
        g1_alphas=layer.g1_alphas,
        g2_alphas=layer.g2_alphas,
        a1_gscale=layer.w13_input_scale_quant,
        a2_gscale=layer.w2_input_scale_quant,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        m=x.shape[0],
        n=layer.w2_weight.shape[2] * 2,
        k=x.shape[1],
        e=layer.w13_weight.shape[0],
        apply_router_weight_on_input=apply_router_weight_on_input).to(
            x.dtype)

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    layer.num_experts = num_experts
    layer.params_dtype = params_dtype

    w13_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // 2,
            requires_grad=False,
            dtype=torch.uint8),
        requires_grad=False)
    layer.register_parameter("w13_weight_packed", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition // 2,
            dtype=torch.uint8),
        requires_grad=False)
    layer.register_parameter("w2_weight_packed", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # Weight Scales
    w13_weight_scale = torch.nn.Parameter(
        torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // self.group_size,
            dtype=torch.float8_e4m3fn),
        requires_grad=False)
    layer.register_parameter("w13_weight_scale", w13_weight_scale)
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
    set_weight_attrs(w13_weight_scale, extra_weight_attrs)

    w2_weight_scale = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition // self.group_size,
            dtype=torch.float8_e4m3fn),
        requires_grad=False)
    layer.register_parameter("w2_weight_scale", w2_weight_scale)
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
    set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    # Weight Global Scales
    w13_weight_scale_2 = torch.nn.Parameter(torch.empty(
        num_experts, 2, dtype=torch.float32),
                                            requires_grad=False)
    layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
    set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)

    w2_weight_scale_2 = torch.nn.Parameter(torch.empty(
        num_experts, dtype=torch.float32),
                                           requires_grad=False)
    layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
    set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)

    # Input Global Scales
    w13_input_scale = torch.nn.Parameter(torch.empty(num_experts,
                                                     2,
                                                     dtype=torch.float32),
                                         requires_grad=False)
    layer.register_parameter("w13_input_global_scale", w13_input_scale)
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
    set_weight_attrs(w13_input_scale, extra_weight_attrs)

    w2_input_scale = torch.nn.Parameter(torch.empty(num_experts,
                                                    dtype=torch.float32),
                                        requires_grad=False)
    layer.register_parameter("w2_input_global_scale", w2_input_scale)
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
    set_weight_attrs(w2_input_scale, extra_weight_attrs)

maybe_make_prepare_finalize

maybe_make_prepare_finalize(
    moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def maybe_make_prepare_finalize(
    self,
    moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
    if not self.allow_flashinfer:
        return super().maybe_make_prepare_finalize(moe)

    prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
        moe,
        a1_gscale=self.layer.w13_input_scale_quant,
    )
    logger.debug_once("%s", prepare_finalize.__class__.__name__)
    return prepare_finalize

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

    # From packed to weight
    layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data,
                                          requires_grad=False)

    layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
                                         requires_grad=False)

    # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
    if self.allow_flashinfer:
        w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
                                    layer.w13_weight_scale.data,
                                    dim=-2)
        layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
        layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)

    if not torch.allclose(layer.w13_weight_global_scale[:, 0],
                          layer.w13_weight_global_scale[:, 1]):
        logger.warning_once(
            "w1_weight_global_scale must match w3_weight_global_scale. "
            "Accuracy may be affected.")

    # Take inverse of global scale saved to disk
    layer.w13_weight_scale_2 = torch.nn.Parameter(
        1 / layer.w13_weight_global_scale[:, 0], requires_grad=False)

    layer.w2_weight_scale_2 = torch.nn.Parameter(
        1 / layer.w2_weight_global_scale.data, requires_grad=False)

    if self.use_marlin:
        prepare_moe_fp4_layer_for_marlin(layer)
        return

    # swizzle weight scales
    layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
        layer.w13_weight_scale),
                                                requires_grad=False)

    layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
        layer.w2_weight_scale),
                                               requires_grad=False)

    # w13
    w13_input_global_scale = layer.w13_input_global_scale.max(
        dim=1).values.to(torch.float32)

    layer.g1_alphas = torch.nn.Parameter(
        ((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
        requires_grad=False)

    layer.w13_input_scale_quant = torch.nn.Parameter(
        (w13_input_global_scale), requires_grad=False)

    # w2
    layer.g2_alphas = torch.nn.Parameter(
        ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
            torch.float32),
        requires_grad=False)

    layer.w2_input_scale_quant = torch.nn.Parameter(
        (layer.w2_input_global_scale), requires_grad=False)

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
    layer: Module,
) -> FusedMoEPermuteExpertsUnpermute

Return the appropriate GEMM experts implementation.

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def select_gemm_impl(
    self,
    prepare_finalize: mk.FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
    layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
    """Return the appropriate GEMM experts implementation."""
    experts = select_nvfp4_gemm_impl(
        moe,
        g1_alphas=self.layer.g1_alphas,
        g2_alphas=self.layer.g2_alphas,
        a1_gscale=self.layer.w13_input_scale_quant,
        a2_gscale=self.layer.w2_input_scale_quant,
        allow_flashinfer=self.allow_flashinfer,
    )
    logger.debug_once("Using %s", experts.__class__.__name__)
    return experts

CompressedTensorsW8A8Fp8MoEMethod

Bases: CompressedTensorsMoEMethod

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):

    def __init__(
        self,
        quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
        self.quant_config = quant_config
        self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
            "weights")
        self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
            "input_activations")

        per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
                      and self.input_quant.strategy
                      == QuantizationStrategy.TENSOR)
        per_channel = (
            self.weight_quant.strategy == QuantizationStrategy.CHANNEL
            and self.input_quant.strategy == QuantizationStrategy.TOKEN)
        if not (per_tensor or per_channel):
            raise ValueError(
                "For FP8 Fused MoE layers, we require per tensor "
                "or channelwise, dynamic per token quantization. Found "
                f"{self.weight_quant}, {self.input_quant}")

        self.static_input_scales = not self.input_quant.dynamic
        if self.static_input_scales and per_channel:
            raise ValueError(
                "For FP8 Fused MoE layer, we require either per tensor or "
                "channelwise, dynamic per token quantization.")

        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
            is_rocm_aiter_moe_enabled)

        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

        # cutlass path
        self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
            self.weight_quant, self.input_quant)
        self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
            self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
        self.disable_expert_map = False

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
            # Allocate 2 scales for w1 and w3 respectively.
            # They are combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
            # Add PER-TENSOR quantization for FusedMoE.weight_loader.
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts,
                2 * intermediate_size_per_partition,
                1,
                dtype=torch.float32),
                                                  requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, hidden_size, 1, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
            # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.static_input_scales:
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
            set_weight_attrs(w13_input_scale, extra_weight_attrs)

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
            set_weight_attrs(w2_input_scale, extra_weight_attrs)
        else:
            layer.w13_input_scale = None
            layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Fp8 moe kernels require a single activation scale.
        # We take the max of all the scales in case they differ.
        if self.static_input_scales:
            assert self.input_quant.strategy == QuantizationStrategy.TENSOR
            if (layer.w13_input_scale is None or layer.w2_input_scale is None):
                raise ValueError(
                    "QuantConfig has static quantization, but found "
                    "activation scales are None.")
            if (not all_close_1d(layer.w13_input_scale)
                    or not all_close_1d(layer.w2_input_scale)):
                logger.warning_once(
                    "Found input_scales that are not equal for "
                    "fp8 MoE layer. Using the maximum across experts "
                    "for each layer.")
            layer.w13_input_scale = torch.nn.Parameter(
                layer.w13_input_scale.max(), requires_grad=False)
            layer.w2_input_scale = torch.nn.Parameter(
                layer.w2_input_scale.max(), requires_grad=False)

        if current_platform.is_fp8_fnuz():
            # Normalize the weights and scales
            w13_weight, w13_weight_scale, w13_input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    layer.w13_weight, layer.w13_weight_scale,
                    layer.w13_input_scale)
            w2_weight, w2_weight_scale, w2_input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    layer.w2_weight, layer.w2_weight_scale,
                    layer.w2_input_scale)
            # Reset the parameter
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
                                                        requires_grad=False)
            if w13_input_scale is not None:
                layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
                                                           requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                       requires_grad=False)
            if w2_input_scale is not None:
                layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
                                                          requires_grad=False)

        # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
        # for w13 per expert. Use max then dequant and requant each expert.
        if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
            assert layer.w13_weight_scale is not None
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
            for expert_id in range(layer.local_num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id])
                    layer.w13_weight[expert_id][
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)

        # Property to determine if AITER is used
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa E501
                rocm_aiter_fused_experts, shuffle_weights)

            # reshaping weights is required for aiter moe kernel.
            shuffled_w13, shuffled_w2 = shuffle_weights(
                layer.w13_weight.data, layer.w2_weight.data)

            layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                 requires_grad=False)

            self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
        elif self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
            self.fused_experts_func = None
        else:
            from vllm.model_executor.layers.fused_moe import fused_experts
            self.fused_experts_func = fused_experts

        if self.use_cutlass:
            device = layer.w13_weight.device
            # ab_strides1 and c_strides2 are the same
            self.ab_strides1_c_strides2 = torch.full(
                (layer.local_num_experts, ),
                layer.hidden_size,
                device=device,
                dtype=torch.int64)
            self.ab_strides2 = torch.full(
                (layer.local_num_experts, ),
                layer.intermediate_size_per_partition,
                device=device,
                dtype=torch.int64)
            self.c_strides1 = torch.full(
                (layer.local_num_experts, ),
                2 * layer.intermediate_size_per_partition,
                device=device,
                dtype=torch.int64)

    def select_gemm_impl(
            self, prepare_finalize: FusedMoEPrepareAndFinalize,
            moe: FusedMoEConfig,
            layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
        # cutlass path
        if self.use_cutlass:
            from vllm.model_executor.layers.fused_moe import (
                CutlassBatchedExpertsFp8, CutlassExpertsFp8)

            experts: FusedMoEPermuteExpertsUnpermute

            num_dispatchers = prepare_finalize.num_dispatchers()

            if (prepare_finalize.activation_format ==
                    FusedMoEActivationFormat.BatchedExperts):
                logger.debug("CutlassBatchedExpertsFp8(%s)",
                             self.__class__.__name__)
                experts = CutlassBatchedExpertsFp8(
                    moe.num_local_experts,
                    num_dispatchers,
                    moe.in_dtype,
                    self.input_quant.strategy == QuantizationStrategy.TOKEN,
                    self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
                    ab_strides1=self.ab_strides1_c_strides2,
                    ab_strides2=self.ab_strides2,
                    c_strides1=self.c_strides1,
                    c_strides2=self.ab_strides1_c_strides2,
                )
            else:
                logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
                experts = CutlassExpertsFp8(
                    moe.in_dtype,
                    self.input_quant.strategy == QuantizationStrategy.TOKEN,
                    self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
                    ab_strides1=self.ab_strides1_c_strides2,
                    ab_strides2=self.ab_strides2,
                    c_strides1=self.c_strides1,
                    c_strides2=self.ab_strides1_c_strides2,
                )

            self.disable_expert_map = (num_dispatchers > 1
                                       or not experts.supports_expert_map())

            return experts

        # triton path
        from vllm.model_executor.layers.fused_moe import TritonExperts
        from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
            BatchedTritonExperts)

        assert not self.rocm_aiter_moe_enabled and not self.use_marlin

        logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)

        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
            )
            assert max_num_tokens_per_rank is not None

            return BatchedTritonExperts(
                max_num_tokens=max_num_tokens_per_rank,
                num_dispatchers=prepare_finalize.num_dispatchers(),
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                per_act_token_quant=(
                    self.input_quant.strategy == QuantizationStrategy.TOKEN),
            )
        else:
            return TritonExperts(
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                per_act_token_quant=(
                    self.input_quant.strategy == QuantizationStrategy.TOKEN),
            )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for "
                "`CompressedTensorsW8A8Fp8MoEMethod` yet.")

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
        )

        # cutlass path
        if self.use_cutlass:
            per_act_token = (
                self.input_quant.strategy == QuantizationStrategy.TOKEN)
            per_channel_quant = (
                self.weight_quant.strategy == QuantizationStrategy.CHANNEL)

            # small-batch fallback on SM100
            if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
                from vllm.model_executor.layers.fused_moe import fused_experts
                return fused_experts(
                    hidden_states=x,
                    w1=layer.w13_weight,
                    w2=layer.w2_weight,
                    topk_weights=topk_weights,
                    topk_ids=topk_ids,
                    inplace=True,
                    activation=activation,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                    use_fp8_w8a8=True,
                    per_channel_quant=per_channel_quant,
                    global_num_experts=global_num_experts,
                    expert_map=None if self.disable_expert_map else expert_map,
                    w1_scale=layer.w13_weight_scale,
                    w2_scale=layer.w2_weight_scale,
                    a1_scale=layer.w13_input_scale,
                    a2_scale=layer.w2_input_scale)

            if self.fused_experts is None:
                from vllm.model_executor.layers.fused_moe.cutlass_moe import (
                    cutlass_moe_fp8)
                return cutlass_moe_fp8(
                    x,
                    layer.w13_weight,
                    layer.w2_weight,
                    topk_weights,
                    topk_ids,
                    per_act_token=per_act_token,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=None if self.disable_expert_map else expert_map,
                    w1_scale=layer.w13_weight_scale,
                    w2_scale=layer.w2_weight_scale,
                    ab_strides1=self.ab_strides1_c_strides2,
                    ab_strides2=self.ab_strides2,
                    c_strides1=self.c_strides1,
                    c_strides2=self.ab_strides1_c_strides2,
                    a1_scale=layer.w13_input_scale,
                    a2_scale=layer.w2_input_scale,
                )
            else:
                return self.fused_experts(
                    x,
                    layer.w13_weight,
                    layer.w2_weight,
                    topk_weights,
                    topk_ids,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=None if self.disable_expert_map else expert_map,
                    w1_scale=layer.w13_weight_scale,
                    w2_scale=layer.w2_weight_scale,
                    a1_scale=layer.w13_input_scale,
                    a2_scale=layer.w2_input_scale,
                )

        if self.rocm_aiter_moe_enabled:
            return self.rocm_aiter_fused_experts_func(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
                use_fp8_w8a8=True,
                per_channel_quant=self.weight_quant.strategy ==
                QuantizationStrategy.CHANNEL,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
                expert_map=expert_map)
        if self.use_marlin:
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)

        assert self.fused_experts_func is not None

        return self.fused_experts_func(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
            use_fp8_w8a8=True,
            per_channel_quant=self.weight_quant.strategy ==
            QuantizationStrategy.CHANNEL,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale)

disable_expert_map instance-attribute

disable_expert_map = False

input_quant instance-attribute

input_quant = get('input_activations')

is_fp8_w8a8_sm100 instance-attribute

is_fp8_w8a8_sm100 = _is_fp8_w8a8_sm100(
    weight_quant, input_quant
)

quant_config instance-attribute

quant_config = quant_config

rocm_aiter_moe_enabled instance-attribute

rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

static_input_scales instance-attribute

static_input_scales = not dynamic

use_cutlass instance-attribute

use_cutlass = (
    _is_fp8_w8a8_sm90(weight_quant, input_quant)
    or is_fp8_w8a8_sm100
)

use_marlin instance-attribute

use_marlin = (
    not has_device_capability(89)
    or VLLM_TEST_FORCE_FP8_MARLIN
)

weight_quant instance-attribute

weight_quant = get('weights')

__init__

__init__(
    quant_config: CompressedTensorsConfig,
    moe: FusedMoEConfig,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def __init__(
    self,
    quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
    moe: FusedMoEConfig,
):
    super().__init__(moe)
    self.quant_config = quant_config
    self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
        "weights")
    self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
        "input_activations")

    per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
                  and self.input_quant.strategy
                  == QuantizationStrategy.TENSOR)
    per_channel = (
        self.weight_quant.strategy == QuantizationStrategy.CHANNEL
        and self.input_quant.strategy == QuantizationStrategy.TOKEN)
    if not (per_tensor or per_channel):
        raise ValueError(
            "For FP8 Fused MoE layers, we require per tensor "
            "or channelwise, dynamic per token quantization. Found "
            f"{self.weight_quant}, {self.input_quant}")

    self.static_input_scales = not self.input_quant.dynamic
    if self.static_input_scales and per_channel:
        raise ValueError(
            "For FP8 Fused MoE layer, we require either per tensor or "
            "channelwise, dynamic per token quantization.")

    # For GPUs that lack FP8 hardware support, we can leverage the Marlin
    # kernel for fast weight-only FP8 quantization
    self.use_marlin = (not current_platform.has_device_capability(89)
                       or envs.VLLM_TEST_FORCE_FP8_MARLIN)
    # Disable marlin for rocm
    if current_platform.is_rocm():
        self.use_marlin = False
    from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
        is_rocm_aiter_moe_enabled)

    self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

    # cutlass path
    self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
        self.weight_quant, self.input_quant)
    self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
        self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
    self.disable_expert_map = False

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for "
            "`CompressedTensorsW8A8Fp8MoEMethod` yet.")

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype,
    )

    # cutlass path
    if self.use_cutlass:
        per_act_token = (
            self.input_quant.strategy == QuantizationStrategy.TOKEN)
        per_channel_quant = (
            self.weight_quant.strategy == QuantizationStrategy.CHANNEL)

        # small-batch fallback on SM100
        if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
            from vllm.model_executor.layers.fused_moe import fused_experts
            return fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
                use_fp8_w8a8=True,
                per_channel_quant=per_channel_quant,
                global_num_experts=global_num_experts,
                expert_map=None if self.disable_expert_map else expert_map,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale)

        if self.fused_experts is None:
            from vllm.model_executor.layers.fused_moe.cutlass_moe import (
                cutlass_moe_fp8)
            return cutlass_moe_fp8(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                per_act_token=per_act_token,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=None if self.disable_expert_map else expert_map,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                ab_strides1=self.ab_strides1_c_strides2,
                ab_strides2=self.ab_strides2,
                c_strides1=self.c_strides1,
                c_strides2=self.ab_strides1_c_strides2,
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )
        else:
            return self.fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=None if self.disable_expert_map else expert_map,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )

    if self.rocm_aiter_moe_enabled:
        return self.rocm_aiter_fused_experts_func(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
            use_fp8_w8a8=True,
            per_channel_quant=self.weight_quant.strategy ==
            QuantizationStrategy.CHANNEL,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            expert_map=expert_map)
    if self.use_marlin:
        assert activation == "silu", (
            f"{activation} not supported for Marlin MoE.")
        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            None,
            None,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=scalar_types.float8_e4m3fn.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map)

    assert self.fused_experts_func is not None

    return self.fused_experts_func(
        hidden_states=x,
        w1=layer.w13_weight,
        w2=layer.w2_weight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        activation=activation,
        apply_router_weight_on_input=apply_router_weight_on_input,
        use_fp8_w8a8=True,
        per_channel_quant=self.weight_quant.strategy ==
        QuantizationStrategy.CHANNEL,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        a1_scale=layer.w13_input_scale,
        a2_scale=layer.w2_input_scale)

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    layer.intermediate_size_per_partition = intermediate_size_per_partition
    layer.hidden_size = hidden_size
    layer.num_experts = num_experts
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    params_dtype = torch.float8_e4m3fn

    # WEIGHTS
    w13_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        2 * intermediate_size_per_partition,
        hidden_size,
        dtype=params_dtype),
                                    requires_grad=False)
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        hidden_size,
        intermediate_size_per_partition,
        dtype=params_dtype),
                                   requires_grad=False)
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
        # Allocate 2 scales for w1 and w3 respectively.
        # They are combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts, 2, dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        w2_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        # Add PER-TENSOR quantization for FusedMoE.weight_loader.
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
        w13_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts,
            2 * intermediate_size_per_partition,
            1,
            dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        w2_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts, hidden_size, 1, dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    # INPUT_SCALES
    if self.static_input_scales:
        w13_input_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w13_input_scale", w13_input_scale)
        set_weight_attrs(w13_input_scale, extra_weight_attrs)

        w2_input_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                            requires_grad=False)
        layer.register_parameter("w2_input_scale", w2_input_scale)
        set_weight_attrs(w2_input_scale, extra_weight_attrs)
    else:
        layer.w13_input_scale = None
        layer.w2_input_scale = None

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # Fp8 moe kernels require a single activation scale.
    # We take the max of all the scales in case they differ.
    if self.static_input_scales:
        assert self.input_quant.strategy == QuantizationStrategy.TENSOR
        if (layer.w13_input_scale is None or layer.w2_input_scale is None):
            raise ValueError(
                "QuantConfig has static quantization, but found "
                "activation scales are None.")
        if (not all_close_1d(layer.w13_input_scale)
                or not all_close_1d(layer.w2_input_scale)):
            logger.warning_once(
                "Found input_scales that are not equal for "
                "fp8 MoE layer. Using the maximum across experts "
                "for each layer.")
        layer.w13_input_scale = torch.nn.Parameter(
            layer.w13_input_scale.max(), requires_grad=False)
        layer.w2_input_scale = torch.nn.Parameter(
            layer.w2_input_scale.max(), requires_grad=False)

    if current_platform.is_fp8_fnuz():
        # Normalize the weights and scales
        w13_weight, w13_weight_scale, w13_input_scale = \
            normalize_e4m3fn_to_e4m3fnuz(
                layer.w13_weight, layer.w13_weight_scale,
                layer.w13_input_scale)
        w2_weight, w2_weight_scale, w2_input_scale = \
            normalize_e4m3fn_to_e4m3fnuz(
                layer.w2_weight, layer.w2_weight_scale,
                layer.w2_input_scale)
        # Reset the parameter
        layer.w13_weight = torch.nn.Parameter(w13_weight,
                                              requires_grad=False)
        layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
                                                    requires_grad=False)
        if w13_input_scale is not None:
            layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
                                                       requires_grad=False)
        layer.w2_weight = torch.nn.Parameter(w2_weight,
                                             requires_grad=False)
        layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                   requires_grad=False)
        if w2_input_scale is not None:
            layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
                                                      requires_grad=False)

    # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
    # for w13 per expert. Use max then dequant and requant each expert.
    if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
        assert layer.w13_weight_scale is not None
        shard_size = layer.intermediate_size_per_partition
        max_w13_scales = layer.w13_weight_scale.max(dim=1).values
        for expert_id in range(layer.local_num_experts):
            start = 0
            for shard_id in range(2):
                dq_weight = per_tensor_dequantize(
                    layer.w13_weight[expert_id][start:start +
                                                shard_size, :],
                    layer.w13_weight_scale[expert_id][shard_id])
                layer.w13_weight[expert_id][
                    start:start + shard_size, :], _ = ops.scaled_fp8_quant(
                        dq_weight, max_w13_scales[expert_id])
                start += shard_size
        layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                    requires_grad=False)

    # Property to determine if AITER is used
    if self.rocm_aiter_moe_enabled:
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa E501
            rocm_aiter_fused_experts, shuffle_weights)

        # reshaping weights is required for aiter moe kernel.
        shuffled_w13, shuffled_w2 = shuffle_weights(
            layer.w13_weight.data, layer.w2_weight.data)

        layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                              requires_grad=False)
        layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                             requires_grad=False)

        self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
    elif self.use_marlin:
        prepare_moe_fp8_layer_for_marlin(layer, False)
        # Activations not quantized for marlin.
        del layer.w13_input_scale
        del layer.w2_input_scale
        self.fused_experts_func = None
    else:
        from vllm.model_executor.layers.fused_moe import fused_experts
        self.fused_experts_func = fused_experts

    if self.use_cutlass:
        device = layer.w13_weight.device
        # ab_strides1 and c_strides2 are the same
        self.ab_strides1_c_strides2 = torch.full(
            (layer.local_num_experts, ),
            layer.hidden_size,
            device=device,
            dtype=torch.int64)
        self.ab_strides2 = torch.full(
            (layer.local_num_experts, ),
            layer.intermediate_size_per_partition,
            device=device,
            dtype=torch.int64)
        self.c_strides1 = torch.full(
            (layer.local_num_experts, ),
            2 * layer.intermediate_size_per_partition,
            device=device,
            dtype=torch.int64)

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
    layer: Module,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def select_gemm_impl(
        self, prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
        layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
    # cutlass path
    if self.use_cutlass:
        from vllm.model_executor.layers.fused_moe import (
            CutlassBatchedExpertsFp8, CutlassExpertsFp8)

        experts: FusedMoEPermuteExpertsUnpermute

        num_dispatchers = prepare_finalize.num_dispatchers()

        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            logger.debug("CutlassBatchedExpertsFp8(%s)",
                         self.__class__.__name__)
            experts = CutlassBatchedExpertsFp8(
                moe.num_local_experts,
                num_dispatchers,
                moe.in_dtype,
                self.input_quant.strategy == QuantizationStrategy.TOKEN,
                self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
                ab_strides1=self.ab_strides1_c_strides2,
                ab_strides2=self.ab_strides2,
                c_strides1=self.c_strides1,
                c_strides2=self.ab_strides1_c_strides2,
            )
        else:
            logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
            experts = CutlassExpertsFp8(
                moe.in_dtype,
                self.input_quant.strategy == QuantizationStrategy.TOKEN,
                self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
                ab_strides1=self.ab_strides1_c_strides2,
                ab_strides2=self.ab_strides2,
                c_strides1=self.c_strides1,
                c_strides2=self.ab_strides1_c_strides2,
            )

        self.disable_expert_map = (num_dispatchers > 1
                                   or not experts.supports_expert_map())

        return experts

    # triton path
    from vllm.model_executor.layers.fused_moe import TritonExperts
    from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
        BatchedTritonExperts)

    assert not self.rocm_aiter_moe_enabled and not self.use_marlin

    logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)

    if (prepare_finalize.activation_format ==
            FusedMoEActivationFormat.BatchedExperts):
        max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
        )
        assert max_num_tokens_per_rank is not None

        return BatchedTritonExperts(
            max_num_tokens=max_num_tokens_per_rank,
            num_dispatchers=prepare_finalize.num_dispatchers(),
            use_fp8_w8a8=True,
            block_shape=self.quant_config.weight_block_size,
            per_act_token_quant=(
                self.input_quant.strategy == QuantizationStrategy.TOKEN),
        )
    else:
        return TritonExperts(
            use_fp8_w8a8=True,
            block_shape=self.quant_config.weight_block_size,
            per_act_token_quant=(
                self.input_quant.strategy == QuantizationStrategy.TOKEN),
        )

CompressedTensorsW8A8Int8MoEMethod

Bases: CompressedTensorsMoEMethod

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):

    def __init__(
        self,
        quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
        self.quant_config = quant_config
        self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
            "weights")
        self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
            "input_activations")

        per_channel = (
            self.weight_quant.strategy == QuantizationStrategy.CHANNEL
            and self.input_quant.strategy == QuantizationStrategy.TOKEN)
        if not per_channel:
            raise ValueError(
                "For INT8 Fused MoE layers, we require channelwise, "
                "dynamic per token quantization. Found "
                f"{self.weight_quant}, {self.input_quant}")

        self.static_input_scales = not self.input_quant.dynamic
        if self.static_input_scales:
            raise ValueError(
                "For INT8 Fused MoE layers, we require channelwise, "
                "dynamic per token quantization. Found static input scales.")

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        params_dtype = torch.int8

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
        w13_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts,
            2 * intermediate_size_per_partition,
            1,
            dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                        hidden_size,
                                                        1,
                                                        dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # INPUT_SCALES
        assert not self.static_input_scales
        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        pass

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert self.fused_experts is None

        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for "
                "`CompressedTensorsW8A8Int8MoEMethod` yet.")

        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype)

        return fused_experts(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
            use_int8_w8a8=True,
            per_channel_quant=True,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale)

input_quant instance-attribute

input_quant = get('input_activations')

quant_config instance-attribute

quant_config = quant_config

static_input_scales instance-attribute

static_input_scales = not dynamic

weight_quant instance-attribute

weight_quant = get('weights')

__init__

__init__(
    quant_config: CompressedTensorsConfig,
    moe: FusedMoEConfig,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def __init__(
    self,
    quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
    moe: FusedMoEConfig,
):
    super().__init__(moe)
    self.quant_config = quant_config
    self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
        "weights")
    self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
        "input_activations")

    per_channel = (
        self.weight_quant.strategy == QuantizationStrategy.CHANNEL
        and self.input_quant.strategy == QuantizationStrategy.TOKEN)
    if not per_channel:
        raise ValueError(
            "For INT8 Fused MoE layers, we require channelwise, "
            "dynamic per token quantization. Found "
            f"{self.weight_quant}, {self.input_quant}")

    self.static_input_scales = not self.input_quant.dynamic
    if self.static_input_scales:
        raise ValueError(
            "For INT8 Fused MoE layers, we require channelwise, "
            "dynamic per token quantization. Found static input scales.")

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert self.fused_experts is None

    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for "
            "`CompressedTensorsW8A8Int8MoEMethod` yet.")

    from vllm.model_executor.layers.fused_moe import fused_experts

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype)

    return fused_experts(
        hidden_states=x,
        w1=layer.w13_weight,
        w2=layer.w2_weight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        activation=activation,
        apply_router_weight_on_input=apply_router_weight_on_input,
        use_int8_w8a8=True,
        per_channel_quant=True,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        a1_scale=layer.w13_input_scale,
        a2_scale=layer.w2_input_scale)

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    params_dtype = torch.int8

    # WEIGHTS
    w13_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        2 * intermediate_size_per_partition,
        hidden_size,
        dtype=params_dtype),
                                    requires_grad=False)
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        hidden_size,
        intermediate_size_per_partition,
        dtype=params_dtype),
                                   requires_grad=False)
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
    w13_weight_scale = torch.nn.Parameter(torch.ones(
        num_experts,
        2 * intermediate_size_per_partition,
        1,
        dtype=torch.float32),
                                          requires_grad=False)
    layer.register_parameter("w13_weight_scale", w13_weight_scale)
    w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                    hidden_size,
                                                    1,
                                                    dtype=torch.float32),
                                         requires_grad=False)
    layer.register_parameter("w2_weight_scale", w2_weight_scale)
    # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
    set_weight_attrs(w13_weight_scale, extra_weight_attrs)
    set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    # INPUT_SCALES
    assert not self.static_input_scales
    layer.w13_input_scale = None
    layer.w2_input_scale = None

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    pass

CompressedTensorsWNA16MarlinMoEMethod

Bases: CompressedTensorsMoEMethod

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):

    def __init__(
        self,
        quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
        self.quant_config = quant_config
        # TODO: @dsikka: refactor this to use schemes as other kernels
        # are supported + check if the layer is being ignored.
        config = self.quant_config.target_scheme_map["Linear"].get("weights")
        self.num_bits = config.num_bits
        self.packed_factor = 32 // config.num_bits
        self.strategy = config.strategy
        self.group_size = config.group_size
        self.actorder = config.actorder
        assert config.symmetric, (
            "Only symmetric quantization is supported for MoE")

        if not (self.quant_config.quant_format
                == CompressionFormat.pack_quantized.value
                and self.num_bits in WNA16_SUPPORTED_BITS):
            raise ValueError("For Fused MoE layers, only ",
                             f"{CompressionFormat.pack_quantized.value} ",
                             "is supported for the following bits: ",
                             f"{WNA16_SUPPORTED_BITS}")
        self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        intermediate_size_full = extra_weight_attrs.pop(
            "intermediate_size_full")

        # Will transpose the loaded weight along the
        # intermediate and hidden dim sizes. Will
        # shard for TP along the transposed dims
        extra_weight_attrs.update({
            "is_transposed": True,
            "quant_method": self.strategy
        })
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size // self.packed_factor,
            2 * intermediate_size_per_partition,
            dtype=torch.int32),
                                        requires_grad=False)
        layer.register_parameter("w13_weight_packed", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            intermediate_size_per_partition // self.packed_factor,
            hidden_size,
            dtype=torch.int32),
                                       requires_grad=False)
        layer.register_parameter("w2_weight_packed", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # In the case where we have actorder/g_idx,
        # we do not partition the w2 scales
        load_full_w2 = self.actorder and self.group_size != -1
        w2_scales_size = (intermediate_size_full
                          if load_full_w2 else intermediate_size_per_partition)

        self.is_k_full = (not self.actorder) or (
            intermediate_size_per_partition == intermediate_size_full)

        if self.strategy == "channel":
            num_groups_w2 = num_groups_w13 = 1
            self.group_size = -1
        else:
            num_groups_w2 = w2_scales_size // self.group_size
            num_groups_w13 = hidden_size // self.group_size

        w13_scale = torch.nn.Parameter(torch.ones(
            num_experts,
            num_groups_w13,
            2 * intermediate_size_per_partition,
            dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_scale)
        set_weight_attrs(w13_scale, extra_weight_attrs)

        w2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                 num_groups_w2,
                                                 hidden_size,
                                                 dtype=params_dtype),
                                      requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_scale)
        set_weight_attrs(w2_scale, extra_weight_attrs)
        set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})

        w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_shape", w2_weight_shape)
        set_weight_attrs(w2_weight_shape, extra_weight_attrs)
        w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                              requires_grad=False)

        layer.register_parameter("w13_weight_shape", w13_weight_shape)
        set_weight_attrs(w13_weight_shape, extra_weight_attrs)

        w13_g_idx = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight_g_idx", w13_g_idx)
        set_weight_attrs(w13_g_idx, extra_weight_attrs)

        w2_g_idx = torch.nn.Parameter(
            torch.empty(
                num_experts,
                intermediate_size_per_partition,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight_g_idx", w2_g_idx)
        set_weight_attrs(w2_g_idx, extra_weight_attrs)

        w13_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_g_idx_sort_indices",
                                 w13_g_idx_sort_indices)
        set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)

        w2_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty(
                num_experts,
                intermediate_size_per_partition,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_g_idx_sort_indices",
                                 w2_g_idx_sort_indices)
        set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)

        layer.a13_scale = None
        layer.a2_scale = None
        layer.marlin_state = GPTQMarlinState.REPACK

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        num_experts = layer.w13_weight_g_idx.shape[0]
        device = layer.w13_weight_g_idx.device

        # when running models with grouped act order,
        # resort to g_idx values provided in checkpoint
        if self.actorder == "group":
            w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
            w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
            w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
            w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)

            for e in range(num_experts):
                w13_g_idx_sort_indices[e] = torch.argsort(
                    layer.w13_weight_g_idx[e]).to(torch.int32)
                w2_g_idx_sort_indices[e] = torch.argsort(
                    layer.w2_weight_g_idx[e]).to(torch.int32)
                w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
                    w13_g_idx_sort_indices[e]]
                w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][
                    w2_g_idx_sort_indices[e]]

            replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
            replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
            replace_parameter(layer, "w13_g_idx_sort_indices",
                              w13_g_idx_sort_indices)
            replace_parameter(layer, "w2_g_idx_sort_indices",
                              w2_g_idx_sort_indices)

        else:
            layer.w13_weight_g_idx = torch.nn.Parameter(
                torch.empty((num_experts, 0), dtype=torch.int32,
                            device=device),
                requires_grad=False,
            )
            layer.w2_weight_g_idx = torch.nn.Parameter(
                torch.empty((num_experts, 0), dtype=torch.int32,
                            device=device),
                requires_grad=False,
            )
            layer.w13_g_idx_sort_indices = torch.nn.Parameter(
                torch.empty((num_experts, 0), dtype=torch.int32,
                            device=device),
                requires_grad=False,
            )
            layer.w2_g_idx_sort_indices = torch.nn.Parameter(
                torch.empty((num_experts, 0), dtype=torch.int32,
                            device=device),
                requires_grad=False,
            )

        marlin_w13_qweight = ops.gptq_marlin_moe_repack(
            layer.w13_weight_packed,
            layer.w13_g_idx_sort_indices,
            layer.w13_weight_packed.shape[1] * self.packed_factor,
            layer.w13_weight_packed.shape[2],
            self.num_bits,
        )
        replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
        marlin_w2_qweight = ops.gptq_marlin_moe_repack(
            layer.w2_weight_packed,
            layer.w2_g_idx_sort_indices,
            layer.w2_weight_packed.shape[1] * self.packed_factor,
            layer.w2_weight_packed.shape[2],
            self.num_bits,
        )
        replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
        # Repack scales
        marlin_w13_scales = marlin_moe_permute_scales(
            s=layer.w13_weight_scale,
            size_k=layer.w13_weight_packed.shape[2],
            size_n=layer.w13_weight_scale.shape[2],
            group_size=self.group_size,
        )
        replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
        marlin_w2_scales = marlin_moe_permute_scales(
            s=layer.w2_weight_scale,
            size_k=layer.w2_weight_scale.shape[1] *
            (self.group_size if self.group_size != -1 else self.packed_factor),
            size_n=layer.w2_weight_scale.shape[2],
            group_size=self.group_size,
        )
        replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)

        layer.workspace = marlin_make_workspace_new(device, 4)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert self.fused_experts is None

        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for "
                "`CompressedTensorsWNA16MarlinMoEMethod` yet.")

        assert activation == "silu", (
            f"{activation} not supported for Marlin MoE.")

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype)

        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight_packed,
            layer.w2_weight_packed,
            None,
            None,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=self.quant_type.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            g_idx1=layer.w13_weight_g_idx,
            g_idx2=layer.w2_weight_g_idx,
            sort_indices1=layer.w13_g_idx_sort_indices,
            sort_indices2=layer.w2_g_idx_sort_indices,
            workspace=layer.workspace,
            is_k_full=self.is_k_full)

actorder instance-attribute

actorder = actorder

group_size instance-attribute

group_size = group_size

num_bits instance-attribute

num_bits = num_bits

packed_factor instance-attribute

packed_factor = 32 // num_bits

quant_config instance-attribute

quant_config = quant_config

quant_type instance-attribute

quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]

strategy instance-attribute

strategy = strategy

__init__

__init__(
    quant_config: CompressedTensorsConfig,
    moe: FusedMoEConfig,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def __init__(
    self,
    quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
    moe: FusedMoEConfig,
):
    super().__init__(moe)
    self.quant_config = quant_config
    # TODO: @dsikka: refactor this to use schemes as other kernels
    # are supported + check if the layer is being ignored.
    config = self.quant_config.target_scheme_map["Linear"].get("weights")
    self.num_bits = config.num_bits
    self.packed_factor = 32 // config.num_bits
    self.strategy = config.strategy
    self.group_size = config.group_size
    self.actorder = config.actorder
    assert config.symmetric, (
        "Only symmetric quantization is supported for MoE")

    if not (self.quant_config.quant_format
            == CompressionFormat.pack_quantized.value
            and self.num_bits in WNA16_SUPPORTED_BITS):
        raise ValueError("For Fused MoE layers, only ",
                         f"{CompressionFormat.pack_quantized.value} ",
                         "is supported for the following bits: ",
                         f"{WNA16_SUPPORTED_BITS}")
    self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert self.fused_experts is None

    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for "
            "`CompressedTensorsWNA16MarlinMoEMethod` yet.")

    assert activation == "silu", (
        f"{activation} not supported for Marlin MoE.")

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype)

    return torch.ops.vllm.fused_marlin_moe(
        x,
        layer.w13_weight_packed,
        layer.w2_weight_packed,
        None,
        None,
        layer.w13_weight_scale,
        layer.w2_weight_scale,
        router_logits,
        topk_weights,
        topk_ids,
        quant_type_id=self.quant_type.id,
        apply_router_weight_on_input=apply_router_weight_on_input,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        g_idx1=layer.w13_weight_g_idx,
        g_idx2=layer.w2_weight_g_idx,
        sort_indices1=layer.w13_g_idx_sort_indices,
        sort_indices2=layer.w2_g_idx_sort_indices,
        workspace=layer.workspace,
        is_k_full=self.is_k_full)

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    intermediate_size_full = extra_weight_attrs.pop(
        "intermediate_size_full")

    # Will transpose the loaded weight along the
    # intermediate and hidden dim sizes. Will
    # shard for TP along the transposed dims
    extra_weight_attrs.update({
        "is_transposed": True,
        "quant_method": self.strategy
    })
    w13_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        hidden_size // self.packed_factor,
        2 * intermediate_size_per_partition,
        dtype=torch.int32),
                                    requires_grad=False)
    layer.register_parameter("w13_weight_packed", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        intermediate_size_per_partition // self.packed_factor,
        hidden_size,
        dtype=torch.int32),
                                   requires_grad=False)
    layer.register_parameter("w2_weight_packed", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # In the case where we have actorder/g_idx,
    # we do not partition the w2 scales
    load_full_w2 = self.actorder and self.group_size != -1
    w2_scales_size = (intermediate_size_full
                      if load_full_w2 else intermediate_size_per_partition)

    self.is_k_full = (not self.actorder) or (
        intermediate_size_per_partition == intermediate_size_full)

    if self.strategy == "channel":
        num_groups_w2 = num_groups_w13 = 1
        self.group_size = -1
    else:
        num_groups_w2 = w2_scales_size // self.group_size
        num_groups_w13 = hidden_size // self.group_size

    w13_scale = torch.nn.Parameter(torch.ones(
        num_experts,
        num_groups_w13,
        2 * intermediate_size_per_partition,
        dtype=params_dtype),
                                   requires_grad=False)
    layer.register_parameter("w13_weight_scale", w13_scale)
    set_weight_attrs(w13_scale, extra_weight_attrs)

    w2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                             num_groups_w2,
                                             hidden_size,
                                             dtype=params_dtype),
                                  requires_grad=False)
    layer.register_parameter("w2_weight_scale", w2_scale)
    set_weight_attrs(w2_scale, extra_weight_attrs)
    set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})

    w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                         requires_grad=False)
    layer.register_parameter("w2_weight_shape", w2_weight_shape)
    set_weight_attrs(w2_weight_shape, extra_weight_attrs)
    w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                          requires_grad=False)

    layer.register_parameter("w13_weight_shape", w13_weight_shape)
    set_weight_attrs(w13_weight_shape, extra_weight_attrs)

    w13_g_idx = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight_g_idx", w13_g_idx)
    set_weight_attrs(w13_g_idx, extra_weight_attrs)

    w2_g_idx = torch.nn.Parameter(
        torch.empty(
            num_experts,
            intermediate_size_per_partition,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight_g_idx", w2_g_idx)
    set_weight_attrs(w2_g_idx, extra_weight_attrs)

    w13_g_idx_sort_indices = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_g_idx_sort_indices",
                             w13_g_idx_sort_indices)
    set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)

    w2_g_idx_sort_indices = torch.nn.Parameter(
        torch.empty(
            num_experts,
            intermediate_size_per_partition,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_g_idx_sort_indices",
                             w2_g_idx_sort_indices)
    set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)

    layer.a13_scale = None
    layer.a2_scale = None
    layer.marlin_state = GPTQMarlinState.REPACK

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    num_experts = layer.w13_weight_g_idx.shape[0]
    device = layer.w13_weight_g_idx.device

    # when running models with grouped act order,
    # resort to g_idx values provided in checkpoint
    if self.actorder == "group":
        w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
        w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
        w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
        w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)

        for e in range(num_experts):
            w13_g_idx_sort_indices[e] = torch.argsort(
                layer.w13_weight_g_idx[e]).to(torch.int32)
            w2_g_idx_sort_indices[e] = torch.argsort(
                layer.w2_weight_g_idx[e]).to(torch.int32)
            w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
                w13_g_idx_sort_indices[e]]
            w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][
                w2_g_idx_sort_indices[e]]

        replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
        replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
        replace_parameter(layer, "w13_g_idx_sort_indices",
                          w13_g_idx_sort_indices)
        replace_parameter(layer, "w2_g_idx_sort_indices",
                          w2_g_idx_sort_indices)

    else:
        layer.w13_weight_g_idx = torch.nn.Parameter(
            torch.empty((num_experts, 0), dtype=torch.int32,
                        device=device),
            requires_grad=False,
        )
        layer.w2_weight_g_idx = torch.nn.Parameter(
            torch.empty((num_experts, 0), dtype=torch.int32,
                        device=device),
            requires_grad=False,
        )
        layer.w13_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty((num_experts, 0), dtype=torch.int32,
                        device=device),
            requires_grad=False,
        )
        layer.w2_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty((num_experts, 0), dtype=torch.int32,
                        device=device),
            requires_grad=False,
        )

    marlin_w13_qweight = ops.gptq_marlin_moe_repack(
        layer.w13_weight_packed,
        layer.w13_g_idx_sort_indices,
        layer.w13_weight_packed.shape[1] * self.packed_factor,
        layer.w13_weight_packed.shape[2],
        self.num_bits,
    )
    replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
    marlin_w2_qweight = ops.gptq_marlin_moe_repack(
        layer.w2_weight_packed,
        layer.w2_g_idx_sort_indices,
        layer.w2_weight_packed.shape[1] * self.packed_factor,
        layer.w2_weight_packed.shape[2],
        self.num_bits,
    )
    replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
    # Repack scales
    marlin_w13_scales = marlin_moe_permute_scales(
        s=layer.w13_weight_scale,
        size_k=layer.w13_weight_packed.shape[2],
        size_n=layer.w13_weight_scale.shape[2],
        group_size=self.group_size,
    )
    replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
    marlin_w2_scales = marlin_moe_permute_scales(
        s=layer.w2_weight_scale,
        size_k=layer.w2_weight_scale.shape[1] *
        (self.group_size if self.group_size != -1 else self.packed_factor),
        size_n=layer.w2_weight_scale.shape[2],
        group_size=self.group_size,
    )
    replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)

    layer.workspace = marlin_make_workspace_new(device, 4)

CompressedTensorsWNA16MoEMethod

Bases: CompressedTensorsMoEMethod

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):

    def __init__(
        self,
        quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
        moe: FusedMoEConfig,
    ):
        super().__init__(moe)
        self.quant_config = quant_config
        # TODO: @dsikka: refactor this to use schemes as other kernels
        # are supported + check if the layer is being ignored.
        config = self.quant_config.target_scheme_map["Linear"].get("weights")
        self.num_bits = config.num_bits
        self.packed_factor = 32 // config.num_bits
        self.strategy = config.strategy
        # channelwise is not supported by this kernel
        assert config.strategy == "group"
        self.group_size = config.group_size
        # grouped actorder isn't supported by this kernel
        assert config.actorder != "group"
        assert config.symmetric, (
            "Only symmetric quantization is supported for MoE")

        if not (self.quant_config.quant_format
                == CompressionFormat.pack_quantized.value
                and self.num_bits in WNA16_SUPPORTED_BITS):
            raise ValueError("For Fused MoE layers, only ",
                             f"{CompressionFormat.pack_quantized.value} ",
                             "is supported for the following bits: ",
                             f"{WNA16_SUPPORTED_BITS}")

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        # Will transpose the loaded weight along the
        # intermediate and hidden dim sizes. Will
        # shard for TP along the transposed dims
        extra_weight_attrs.update({
            "is_transposed": True,
            "quant_method": self.strategy
        })
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size // self.packed_factor,
            2 * intermediate_size_per_partition,
            dtype=torch.int32),
                                        requires_grad=False)
        layer.register_parameter("w13_weight_packed", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            intermediate_size_per_partition // self.packed_factor,
            hidden_size,
            dtype=torch.int32),
                                       requires_grad=False)
        layer.register_parameter("w2_weight_packed", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        w2_scales_size = intermediate_size_per_partition

        if self.strategy == "channel":
            num_groups_w2 = num_groups_w13 = 1
            self.group_size = -1
        else:
            num_groups_w2 = w2_scales_size // self.group_size
            num_groups_w13 = hidden_size // self.group_size

        w13_scale = torch.nn.Parameter(torch.ones(
            num_experts,
            num_groups_w13,
            2 * intermediate_size_per_partition,
            dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_scale)
        set_weight_attrs(w13_scale, extra_weight_attrs)

        w2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                 num_groups_w2,
                                                 hidden_size,
                                                 dtype=params_dtype),
                                      requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_scale)
        set_weight_attrs(w2_scale, extra_weight_attrs)
        set_weight_attrs(w2_scale, {"load_full_w2": False})

        w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_shape", w2_weight_shape)
        set_weight_attrs(w2_weight_shape, extra_weight_attrs)
        w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                              requires_grad=False)

        layer.register_parameter("w13_weight_shape", w13_weight_shape)
        set_weight_attrs(w13_weight_shape, extra_weight_attrs)

        w13_g_idx = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight_g_idx", w13_g_idx)
        set_weight_attrs(w13_g_idx, extra_weight_attrs)

        w2_g_idx = torch.nn.Parameter(
            torch.empty(
                num_experts,
                intermediate_size_per_partition,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight_g_idx", w2_g_idx)
        set_weight_attrs(w2_g_idx, extra_weight_attrs)

        w13_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_g_idx_sort_indices",
                                 w13_g_idx_sort_indices)
        set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)

        w2_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty(
                num_experts,
                intermediate_size_per_partition,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_g_idx_sort_indices",
                                 w2_g_idx_sort_indices)
        set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)

        layer.a13_scale = None
        layer.a2_scale = None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Reconfigure packed weights and scales to match moe_wna16 format
        layer.w13_weight_packed = torch.nn.Parameter(
            layer.w13_weight_packed.transpose(1, 2).contiguous().view(
                torch.uint8),
            requires_grad=False)
        layer.w2_weight_packed = torch.nn.Parameter(
            layer.w2_weight_packed.transpose(1,
                                             2).contiguous().view(torch.uint8),
            requires_grad=False)
        layer.w13_weight_scale = torch.nn.Parameter(
            layer.w13_weight_scale.transpose(1, 2).contiguous(),
            requires_grad=False)
        layer.w2_weight_scale = torch.nn.Parameter(
            layer.w2_weight_scale.transpose(1, 2).contiguous(),
            requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert self.fused_experts is None

        if enable_eplb:
            raise NotImplementedError("EPLB not supported for "
                                      "`CompressedTensorsWNA16MoEMethod` yet.")

        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype)

        return fused_experts(
            x,
            layer.w13_weight_packed,
            layer.w2_weight_packed,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            use_int4_w4a16=self.num_bits == 4,
            use_int8_w8a16=self.num_bits == 8,
            global_num_experts=global_num_experts,
            apply_router_weight_on_input=apply_router_weight_on_input,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            w1_zp=None,
            w2_zp=None,
            block_shape=[0, self.group_size])

group_size instance-attribute

group_size = group_size

num_bits instance-attribute

num_bits = num_bits

packed_factor instance-attribute

packed_factor = 32 // num_bits

quant_config instance-attribute

quant_config = quant_config

strategy instance-attribute

strategy = strategy

__init__

__init__(
    quant_config: CompressedTensorsConfig,
    moe: FusedMoEConfig,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def __init__(
    self,
    quant_config: "CompressedTensorsConfig",  # type: ignore # noqa E501
    moe: FusedMoEConfig,
):
    super().__init__(moe)
    self.quant_config = quant_config
    # TODO: @dsikka: refactor this to use schemes as other kernels
    # are supported + check if the layer is being ignored.
    config = self.quant_config.target_scheme_map["Linear"].get("weights")
    self.num_bits = config.num_bits
    self.packed_factor = 32 // config.num_bits
    self.strategy = config.strategy
    # channelwise is not supported by this kernel
    assert config.strategy == "group"
    self.group_size = config.group_size
    # grouped actorder isn't supported by this kernel
    assert config.actorder != "group"
    assert config.symmetric, (
        "Only symmetric quantization is supported for MoE")

    if not (self.quant_config.quant_format
            == CompressionFormat.pack_quantized.value
            and self.num_bits in WNA16_SUPPORTED_BITS):
        raise ValueError("For Fused MoE layers, only ",
                         f"{CompressionFormat.pack_quantized.value} ",
                         "is supported for the following bits: ",
                         f"{WNA16_SUPPORTED_BITS}")

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert self.fused_experts is None

    if enable_eplb:
        raise NotImplementedError("EPLB not supported for "
                                  "`CompressedTensorsWNA16MoEMethod` yet.")

    from vllm.model_executor.layers.fused_moe import fused_experts

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        routed_scaling_factor=routed_scaling_factor,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype)

    return fused_experts(
        x,
        layer.w13_weight_packed,
        layer.w2_weight_packed,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        activation=activation,
        use_int4_w4a16=self.num_bits == 4,
        use_int8_w8a16=self.num_bits == 8,
        global_num_experts=global_num_experts,
        apply_router_weight_on_input=apply_router_weight_on_input,
        expert_map=expert_map,
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        w1_zp=None,
        w2_zp=None,
        block_shape=[0, self.group_size])

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    # Will transpose the loaded weight along the
    # intermediate and hidden dim sizes. Will
    # shard for TP along the transposed dims
    extra_weight_attrs.update({
        "is_transposed": True,
        "quant_method": self.strategy
    })
    w13_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        hidden_size // self.packed_factor,
        2 * intermediate_size_per_partition,
        dtype=torch.int32),
                                    requires_grad=False)
    layer.register_parameter("w13_weight_packed", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        intermediate_size_per_partition // self.packed_factor,
        hidden_size,
        dtype=torch.int32),
                                   requires_grad=False)
    layer.register_parameter("w2_weight_packed", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    w2_scales_size = intermediate_size_per_partition

    if self.strategy == "channel":
        num_groups_w2 = num_groups_w13 = 1
        self.group_size = -1
    else:
        num_groups_w2 = w2_scales_size // self.group_size
        num_groups_w13 = hidden_size // self.group_size

    w13_scale = torch.nn.Parameter(torch.ones(
        num_experts,
        num_groups_w13,
        2 * intermediate_size_per_partition,
        dtype=params_dtype),
                                   requires_grad=False)
    layer.register_parameter("w13_weight_scale", w13_scale)
    set_weight_attrs(w13_scale, extra_weight_attrs)

    w2_scale = torch.nn.Parameter(torch.ones(num_experts,
                                             num_groups_w2,
                                             hidden_size,
                                             dtype=params_dtype),
                                  requires_grad=False)
    layer.register_parameter("w2_weight_scale", w2_scale)
    set_weight_attrs(w2_scale, extra_weight_attrs)
    set_weight_attrs(w2_scale, {"load_full_w2": False})

    w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                         requires_grad=False)
    layer.register_parameter("w2_weight_shape", w2_weight_shape)
    set_weight_attrs(w2_weight_shape, extra_weight_attrs)
    w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
                                          requires_grad=False)

    layer.register_parameter("w13_weight_shape", w13_weight_shape)
    set_weight_attrs(w13_weight_shape, extra_weight_attrs)

    w13_g_idx = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight_g_idx", w13_g_idx)
    set_weight_attrs(w13_g_idx, extra_weight_attrs)

    w2_g_idx = torch.nn.Parameter(
        torch.empty(
            num_experts,
            intermediate_size_per_partition,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight_g_idx", w2_g_idx)
    set_weight_attrs(w2_g_idx, extra_weight_attrs)

    w13_g_idx_sort_indices = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_g_idx_sort_indices",
                             w13_g_idx_sort_indices)
    set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)

    w2_g_idx_sort_indices = torch.nn.Parameter(
        torch.empty(
            num_experts,
            intermediate_size_per_partition,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_g_idx_sort_indices",
                             w2_g_idx_sort_indices)
    set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)

    layer.a13_scale = None
    layer.a2_scale = None

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # Reconfigure packed weights and scales to match moe_wna16 format
    layer.w13_weight_packed = torch.nn.Parameter(
        layer.w13_weight_packed.transpose(1, 2).contiguous().view(
            torch.uint8),
        requires_grad=False)
    layer.w2_weight_packed = torch.nn.Parameter(
        layer.w2_weight_packed.transpose(1,
                                         2).contiguous().view(torch.uint8),
        requires_grad=False)
    layer.w13_weight_scale = torch.nn.Parameter(
        layer.w13_weight_scale.transpose(1, 2).contiguous(),
        requires_grad=False)
    layer.w2_weight_scale = torch.nn.Parameter(
        layer.w2_weight_scale.transpose(1, 2).contiguous(),
        requires_grad=False)

GPTQMarlinState

Bases: Enum

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
class GPTQMarlinState(Enum):
    REPACK = enum.auto()
    READY = enum.auto()

READY class-attribute instance-attribute

READY = auto()

REPACK class-attribute instance-attribute

REPACK = auto()