Skip to content

vllm_gaudi.ops.hpu_fused_moe

_MOE_COMPILE module-attribute

_MOE_COMPILE = getenv('HPU_FUSED_MOE', '1') == '1'

_orig_default_moe_runner_forward module-attribute

_orig_default_moe_runner_forward = forward

_orig_fused_moe_init module-attribute

_orig_fused_moe_init = __init__

HPUUnquantizedFusedMoEMethod

Bases: UnquantizedFusedMoEMethod

MoE method without quantization.

Source code in vllm_gaudi/ops/hpu_fused_moe.py
@UnquantizedFusedMoEMethod.register_oot
class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
    """MoE method without quantization."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_dispatch_fn = get_config().use_dispatch_fn
        torch.hpu.synchronize()
        vllm_config = get_current_vllm_config()
        self.model_type = None
        self.is_mxfp4 = False
        if (vllm_config is not None and vllm_config.model_config is not None
                and vllm_config.model_config.hf_config is not None):
            self.model_type = vllm_config.model_config.hf_config.model_type
            if (hasattr(vllm_config.model_config.hf_config, "quantization_config")
                    and vllm_config.model_config.hf_config.quantization_config is not None):
                self.is_mxfp4 = vllm_config.model_config.hf_config.quantization_config.get("quant_method") == "mxfp4"

    def _select_monolithic(self) -> Callable:
        """Overriding base method"""
        return self.apply_monolithic

    @property
    def is_monolithic(self) -> bool:
        return True

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        super().process_weights_after_loading(layer)
        # custom handling for HPU
        num_experts = layer.local_num_experts
        ep_shift = layer.ep_rank * num_experts
        has_bias = hasattr(layer, "w13_bias") and hasattr(layer, "w2_bias")

        experts_min, experts_max = ep_shift, num_experts + ep_shift - 1

        if layer.moe_config.dp_size > 1 and self.use_dispatch_fn:
            dispatch_fn = partial(dispatch_hidden_states, is_sequence_parallel=layer.moe_config.is_sequence_parallel)
        else:
            dispatch_fn = None

        bias = has_bias if has_bias is True else None

        is_bf16 = getattr(layer, "w13_weight", None) is not None and layer.w13_weight.dtype == torch.bfloat16

        model_config = None
        if getattr(layer, "vllm_config", None) is not None:
            model_config = getattr(layer.vllm_config, "model_config", None)

        is_unquantized = (model_config is None) or (not has_quant_config(model_config))

        cache_weight_lists = bool(is_bf16 and is_unquantized)

        # Pass cache flag into moe_op (requires ops.py __init__ signature update)
        layer.moe_op = VllmMixtureOfExpertsOp(layer.global_num_experts, num_experts, experts_min, experts_max, bias,
                                              dispatch_fn)

        for expert_id in range(layer.local_num_experts):
            layer.moe_op.w13_list[expert_id].set_weight(layer.w13_weight.data[expert_id])
            layer.moe_op.w2_list[expert_id].set_weight(layer.w2_weight.data[expert_id])
            if has_bias:
                layer.moe_op.w13_list[expert_id].set_bias(layer.w13_bias.data[expert_id])
                layer.moe_op.w2_list[expert_id].set_bias(layer.w2_bias.data[expert_id])

        # Build cache once AFTER weights/bias are set (BF16 + unquantized only)
        if cache_weight_lists and hasattr(layer.moe_op, "_cache_weight_lists"):
            layer.moe_op._cache_weight_lists()

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):

        if self.model_type in ["gpt_oss"] and self.is_mxfp4:
            from vllm.utils.math_utils import round_up

            # Fused gate_up_proj (column parallel)
            w13_weight = torch.nn.Parameter(
                torch.zeros(num_experts,
                            2 * round_up(intermediate_size_per_partition, 32),
                            hidden_size,
                            dtype=params_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight", w13_weight)
            set_weight_attrs(w13_weight, extra_weight_attrs)

            w13_bias = torch.nn.Parameter(
                torch.zeros(num_experts, 2 * round_up(intermediate_size_per_partition, 32), dtype=params_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, extra_weight_attrs)

            # down_proj (row parallel)
            w2_weight = torch.nn.Parameter(
                torch.zeros(num_experts, hidden_size, round_up(intermediate_size_per_partition, 32),
                            dtype=params_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w2_weight", w2_weight)
            set_weight_attrs(w2_weight, extra_weight_attrs)

            w2_bias = torch.nn.Parameter(torch.zeros(num_experts, hidden_size, dtype=params_dtype), requires_grad=False)
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, extra_weight_attrs)
        else:
            super().create_weights(layer, num_experts, hidden_size, intermediate_size_per_partition, params_dtype,
                                   **extra_weight_attrs)

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        **kwargs,
    ):
        input_shape = x.shape
        x = x.view(-1, x.shape[-1])
        if layer.use_grouped_topk or getattr(layer, "custom_routing_function", None) is not None:
            topk_weights, topk_ids = layer.router.select_experts(hidden_states=x, router_logits=router_logits)
        else:
            import torch.nn.functional as F

            if self.model_type == "gpt_oss":
                topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1)
                topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32)
            else:
                topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
                topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1)
                topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
            topk_weights = topk_weights.to(x.dtype)

        if not layer.use_grouped_topk:
            topk_ids = topk_ids.to(torch.int64)
            topk_weights = topk_weights.to(x.dtype)

        if layer.moe_config.dp_size > 1:
            dp_metadata = get_hpu_dp_metadata()
            if not (has_quant_config(layer.vllm_config.model_config) and self.use_dispatch_fn):
                hidden_states_across_dp = dp_metadata.hidden_states_across_dp if dp_metadata is not None else None
                x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel)

            topk_ids_across_dp = dp_metadata.topk_ids_across_dp if dp_metadata is not None else None
            topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel)

            topk_weights_across_dp = dp_metadata.topk_weights_across_dp if dp_metadata is not None else None
            topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel)

        topk_ids = topk_ids.view(-1, topk_ids.shape[-1])
        topk_weights = topk_weights.view(-1, topk_weights.shape[-1])

        output = layer.moe_op(
            x,
            topk_ids,
            topk_weights,
            permuted_weights=True,
            activation=_normalize_moe_activation(layer.activation),
        )
        if layer.moe_config.dp_size > 1:
            return output.view(*(output.size(0), *input_shape[1:]))
        else:
            return output.view(*input_shape)

    def forward_oot(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        **kwargs,
    ):
        input_shape = x.shape
        x = x.view(-1, x.shape[-1])
        if layer.use_grouped_topk or getattr(layer, "custom_routing_function", None) is not None:
            topk_weights, topk_ids = layer.router.select_experts(hidden_states=x, router_logits=router_logits)
        else:
            import torch.nn.functional as F

            if self.model_type is not None and self.model_type in ["gpt_oss"]:
                topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1)
                topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32)
            else:
                topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
                topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1)
                topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
            topk_weights = topk_weights.to(x.dtype)

        if not layer.use_grouped_topk:
            topk_ids = topk_ids.to(torch.int64)
            topk_weights = topk_weights.to(x.dtype)

        if layer.moe_config.dp_size > 1:
            dp_metadata = get_hpu_dp_metadata()
            if not (has_quant_config(layer.vllm_config.model_config) and self.use_dispatch_fn):
                hidden_states_across_dp = dp_metadata.hidden_states_across_dp if dp_metadata is not None else None
                x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel)

            topk_ids_across_dp = dp_metadata.topk_ids_across_dp if dp_metadata is not None else None
            topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel)

            topk_weights_across_dp = dp_metadata.topk_weights_across_dp if dp_metadata is not None else None
            topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel)

        topk_ids = topk_ids.view(-1, topk_ids.shape[-1])
        topk_weights = topk_weights.view(-1, topk_weights.shape[-1])

        if self.model_type in ["gpt_oss"]:
            return layer.moe_op(
                x,
                topk_ids.to(torch.int64),
                topk_weights.to(x.dtype),
                permuted_weights=True,
                activation=_normalize_moe_activation(layer.activation),
            ).view(*input_shape)

        output = layer.moe_op(
            x,
            topk_ids,
            topk_weights,
            permuted_weights=True,
            activation=_normalize_moe_activation(layer.activation),
        )
        if layer.moe_config.dp_size > 1:
            return output.view(*(output.size(0), *input_shape[1:]))
        else:
            return output.view(*input_shape)

is_monolithic property

is_monolithic: bool

is_mxfp4 instance-attribute

is_mxfp4 = False

model_type instance-attribute

model_type = None

use_dispatch_fn instance-attribute

use_dispatch_fn = use_dispatch_fn

__init__

__init__(*args, **kwargs)
Source code in vllm_gaudi/ops/hpu_fused_moe.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.use_dispatch_fn = get_config().use_dispatch_fn
    torch.hpu.synchronize()
    vllm_config = get_current_vllm_config()
    self.model_type = None
    self.is_mxfp4 = False
    if (vllm_config is not None and vllm_config.model_config is not None
            and vllm_config.model_config.hf_config is not None):
        self.model_type = vllm_config.model_config.hf_config.model_type
        if (hasattr(vllm_config.model_config.hf_config, "quantization_config")
                and vllm_config.model_config.hf_config.quantization_config is not None):
            self.is_mxfp4 = vllm_config.model_config.hf_config.quantization_config.get("quant_method") == "mxfp4"

_select_monolithic

_select_monolithic() -> Callable

Overriding base method

Source code in vllm_gaudi/ops/hpu_fused_moe.py
def _select_monolithic(self) -> Callable:
    """Overriding base method"""
    return self.apply_monolithic

apply_monolithic

apply_monolithic(
    layer: FusedMoE,
    x: Tensor,
    router_logits: Tensor,
    **kwargs,
)
Source code in vllm_gaudi/ops/hpu_fused_moe.py
def apply_monolithic(
    self,
    layer: FusedMoE,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    **kwargs,
):
    input_shape = x.shape
    x = x.view(-1, x.shape[-1])
    if layer.use_grouped_topk or getattr(layer, "custom_routing_function", None) is not None:
        topk_weights, topk_ids = layer.router.select_experts(hidden_states=x, router_logits=router_logits)
    else:
        import torch.nn.functional as F

        if self.model_type == "gpt_oss":
            topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1)
            topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32)
        else:
            topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
            topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1)
            topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
        topk_weights = topk_weights.to(x.dtype)

    if not layer.use_grouped_topk:
        topk_ids = topk_ids.to(torch.int64)
        topk_weights = topk_weights.to(x.dtype)

    if layer.moe_config.dp_size > 1:
        dp_metadata = get_hpu_dp_metadata()
        if not (has_quant_config(layer.vllm_config.model_config) and self.use_dispatch_fn):
            hidden_states_across_dp = dp_metadata.hidden_states_across_dp if dp_metadata is not None else None
            x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel)

        topk_ids_across_dp = dp_metadata.topk_ids_across_dp if dp_metadata is not None else None
        topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel)

        topk_weights_across_dp = dp_metadata.topk_weights_across_dp if dp_metadata is not None else None
        topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel)

    topk_ids = topk_ids.view(-1, topk_ids.shape[-1])
    topk_weights = topk_weights.view(-1, topk_weights.shape[-1])

    output = layer.moe_op(
        x,
        topk_ids,
        topk_weights,
        permuted_weights=True,
        activation=_normalize_moe_activation(layer.activation),
    )
    if layer.moe_config.dp_size > 1:
        return output.view(*(output.size(0), *input_shape[1:]))
    else:
        return output.view(*input_shape)

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm_gaudi/ops/hpu_fused_moe.py
def create_weights(
    self,
    layer: torch.nn.Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):

    if self.model_type in ["gpt_oss"] and self.is_mxfp4:
        from vllm.utils.math_utils import round_up

        # Fused gate_up_proj (column parallel)
        w13_weight = torch.nn.Parameter(
            torch.zeros(num_experts,
                        2 * round_up(intermediate_size_per_partition, 32),
                        hidden_size,
                        dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w13_bias = torch.nn.Parameter(
            torch.zeros(num_experts, 2 * round_up(intermediate_size_per_partition, 32), dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
        w2_weight = torch.nn.Parameter(
            torch.zeros(num_experts, hidden_size, round_up(intermediate_size_per_partition, 32),
                        dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        w2_bias = torch.nn.Parameter(torch.zeros(num_experts, hidden_size, dtype=params_dtype), requires_grad=False)
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)
    else:
        super().create_weights(layer, num_experts, hidden_size, intermediate_size_per_partition, params_dtype,
                               **extra_weight_attrs)

forward_oot

forward_oot(
    layer: FusedMoE,
    x: Tensor,
    router_logits: Tensor,
    **kwargs,
)
Source code in vllm_gaudi/ops/hpu_fused_moe.py
def forward_oot(
    self,
    layer: FusedMoE,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    **kwargs,
):
    input_shape = x.shape
    x = x.view(-1, x.shape[-1])
    if layer.use_grouped_topk or getattr(layer, "custom_routing_function", None) is not None:
        topk_weights, topk_ids = layer.router.select_experts(hidden_states=x, router_logits=router_logits)
    else:
        import torch.nn.functional as F

        if self.model_type is not None and self.model_type in ["gpt_oss"]:
            topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1)
            topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32)
        else:
            topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
            topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1)
            topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
        topk_weights = topk_weights.to(x.dtype)

    if not layer.use_grouped_topk:
        topk_ids = topk_ids.to(torch.int64)
        topk_weights = topk_weights.to(x.dtype)

    if layer.moe_config.dp_size > 1:
        dp_metadata = get_hpu_dp_metadata()
        if not (has_quant_config(layer.vllm_config.model_config) and self.use_dispatch_fn):
            hidden_states_across_dp = dp_metadata.hidden_states_across_dp if dp_metadata is not None else None
            x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel)

        topk_ids_across_dp = dp_metadata.topk_ids_across_dp if dp_metadata is not None else None
        topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel)

        topk_weights_across_dp = dp_metadata.topk_weights_across_dp if dp_metadata is not None else None
        topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel)

    topk_ids = topk_ids.view(-1, topk_ids.shape[-1])
    topk_weights = topk_weights.view(-1, topk_weights.shape[-1])

    if self.model_type in ["gpt_oss"]:
        return layer.moe_op(
            x,
            topk_ids.to(torch.int64),
            topk_weights.to(x.dtype),
            permuted_weights=True,
            activation=_normalize_moe_activation(layer.activation),
        ).view(*input_shape)

    output = layer.moe_op(
        x,
        topk_ids,
        topk_weights,
        permuted_weights=True,
        activation=_normalize_moe_activation(layer.activation),
    )
    if layer.moe_config.dp_size > 1:
        return output.view(*(output.size(0), *input_shape[1:]))
    else:
        return output.view(*input_shape)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm_gaudi/ops/hpu_fused_moe.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    super().process_weights_after_loading(layer)
    # custom handling for HPU
    num_experts = layer.local_num_experts
    ep_shift = layer.ep_rank * num_experts
    has_bias = hasattr(layer, "w13_bias") and hasattr(layer, "w2_bias")

    experts_min, experts_max = ep_shift, num_experts + ep_shift - 1

    if layer.moe_config.dp_size > 1 and self.use_dispatch_fn:
        dispatch_fn = partial(dispatch_hidden_states, is_sequence_parallel=layer.moe_config.is_sequence_parallel)
    else:
        dispatch_fn = None

    bias = has_bias if has_bias is True else None

    is_bf16 = getattr(layer, "w13_weight", None) is not None and layer.w13_weight.dtype == torch.bfloat16

    model_config = None
    if getattr(layer, "vllm_config", None) is not None:
        model_config = getattr(layer.vllm_config, "model_config", None)

    is_unquantized = (model_config is None) or (not has_quant_config(model_config))

    cache_weight_lists = bool(is_bf16 and is_unquantized)

    # Pass cache flag into moe_op (requires ops.py __init__ signature update)
    layer.moe_op = VllmMixtureOfExpertsOp(layer.global_num_experts, num_experts, experts_min, experts_max, bias,
                                          dispatch_fn)

    for expert_id in range(layer.local_num_experts):
        layer.moe_op.w13_list[expert_id].set_weight(layer.w13_weight.data[expert_id])
        layer.moe_op.w2_list[expert_id].set_weight(layer.w2_weight.data[expert_id])
        if has_bias:
            layer.moe_op.w13_list[expert_id].set_bias(layer.w13_bias.data[expert_id])
            layer.moe_op.w2_list[expert_id].set_bias(layer.w2_bias.data[expert_id])

    # Build cache once AFTER weights/bias are set (BF16 + unquantized only)
    if cache_weight_lists and hasattr(layer.moe_op, "_cache_weight_lists"):
        layer.moe_op._cache_weight_lists()

_hpu_fused_moe_init

_hpu_fused_moe_init(self, *args, **kwargs)
Source code in vllm_gaudi/ops/hpu_fused_moe.py
def _hpu_fused_moe_init(self, *args, **kwargs):
    _orig_fused_moe_init(self, *args, **kwargs)
    if hasattr(self, "runner"):
        object.__setattr__(self.runner, "_hpu_layer_ref", self)
        if self.runner.gate is not None:
            object.__setattr__(self.runner, "_hpu_gate_ref", self.runner.gate)

_normalize_moe_activation

_normalize_moe_activation(activation)
Source code in vllm_gaudi/ops/hpu_fused_moe.py
def _normalize_moe_activation(activation):
    return activation.value if isinstance(activation, Enum) else activation

_patched_default_moe_runner_forward

_patched_default_moe_runner_forward(self, *args, **kwargs)
Source code in vllm_gaudi/ops/hpu_fused_moe.py
def _patched_default_moe_runner_forward(self, *args, **kwargs):
    if _MOE_COMPILE:
        return patched_fused_moe_forward(self, *args, **kwargs)
    return _orig_default_moe_runner_forward(self, *args, **kwargs)

create_fused_moe_router

create_fused_moe_router(
    top_k: int,
    global_num_experts: int,
    renormalize: bool = True,
    indices_type_getter: Callable[[], dtype | None]
    | None = None,
    use_grouped_topk: bool = False,
    num_expert_group: int | None = None,
    topk_group: int | None = None,
    scoring_func: str = "softmax",
    num_fused_shared_experts: int = 0,
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Tensor | None = None,
    custom_routing_function: Callable | None = None,
    enable_eplb: bool = False,
    eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
    zero_expert_type: str | None = None,
    num_logical_experts: int | None = None,
    hash_indices_table: Tensor | None = None,
) -> FusedMoERouter

Factory function to create the appropriate FusedMoERouter subclass based on the provided parameters.

The selection logic follows this priority order: 1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set 2. ZeroExpertRouter - if zero_expert_type is not None 3. GroupedTopKRouter - if use_grouped_topk is True 4. CustomRoutingRouter - if custom_routing_function is not None 5. FusedTopKBiasRouter - if e_score_correction_bias is not None 6. FusedTopKRouter - default fallback

Common arguments

top_k: Number of experts to select per token global_num_experts: Total number of experts in the model renormalize: Whether to renormalize the routing weights indices_type_getter: Function to get the desired indices dtype

Grouped topk arguments

use_grouped_topk: Whether to use grouped top-k routing num_expert_group: Number of expert groups (for grouped routing) topk_group: Top-k within each group (for grouped routing) scoring_func: Scoring function to use ("softmax" or "sigmoid") num_fused_shared_experts: Number of fused shared experts (for ROCm AITER)

Grouped topk and fused topk bias arguments

routed_scaling_factor: Scaling factor for routed weights e_score_correction_bias: Optional bias correction for expert scores

Custom routing arguments

custom_routing_function: Optional custom routing function

EPLB arguments

enable_eplb: Whether EPLB is enabled eplb_state: EPLB (Expert Parallelism Load Balancing) state

Zero expert arguments

zero_expert_type: Type of zero expert (e.g. identity). If not None, creates a ZeroExpertRouter. num_logical_experts: Number of real (non-zero) experts. Required when zero_expert_type is not None.

Hash Indices Table

hash_indices_table: Used to map input_ids to experts, needed for Deepseek V4

Returns:

Type Description
FusedMoERouter

An instance of the appropriate FusedMoERouter subclass

Source code in vllm_gaudi/ops/hpu_fused_moe.py
def create_fused_moe_router(
    # common parameters
    top_k: int,
    global_num_experts: int,
    renormalize: bool = True,
    indices_type_getter: Callable[[], torch.dtype | None] | None = None,
    # grouped topk parameters
    use_grouped_topk: bool = False,
    num_expert_group: int | None = None,
    topk_group: int | None = None,
    scoring_func: str = "softmax",
    num_fused_shared_experts: int = 0,
    # grouped topk + fused topk bias parameters
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: torch.Tensor | None = None,
    # custom routing parameters
    custom_routing_function: Callable | None = None,
    # eplb parameters
    enable_eplb: bool = False,
    eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
    # zero expert parameters
    zero_expert_type: str | None = None,
    num_logical_experts: int | None = None,
    hash_indices_table: torch.Tensor | None = None,
) -> FusedMoERouter:
    """
    Factory function to create the appropriate FusedMoERouter subclass based on
    the provided parameters.

    The selection logic follows this priority order:
    1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
    2. ZeroExpertRouter - if zero_expert_type is not None
    3. GroupedTopKRouter - if use_grouped_topk is True
    4. CustomRoutingRouter - if custom_routing_function is not None
    5. FusedTopKBiasRouter - if e_score_correction_bias is not None
    6. FusedTopKRouter - default fallback

    Common arguments:
        top_k: Number of experts to select per token
        global_num_experts: Total number of experts in the model
        renormalize: Whether to renormalize the routing weights
        indices_type_getter: Function to get the desired indices dtype

    Grouped topk arguments:
        use_grouped_topk: Whether to use grouped top-k routing
        num_expert_group: Number of expert groups (for grouped routing)
        topk_group: Top-k within each group (for grouped routing)
        scoring_func: Scoring function to use ("softmax" or "sigmoid")
        num_fused_shared_experts: Number of fused shared experts (for ROCm AITER)

    Grouped topk and fused topk bias arguments:
        routed_scaling_factor: Scaling factor for routed weights
        e_score_correction_bias: Optional bias correction for expert scores

    Custom routing arguments:
        custom_routing_function: Optional custom routing function

    EPLB arguments:
        enable_eplb: Whether EPLB is enabled
        eplb_state: EPLB (Expert Parallelism Load Balancing) state

    Zero expert arguments:
        zero_expert_type: Type of zero expert (e.g. identity). If not None,
            creates a ZeroExpertRouter.
        num_logical_experts: Number of real (non-zero) experts. Required when
            zero_expert_type is not None.

    Hash Indices Table:
        hash_indices_table: Used to map input_ids to experts, needed for
            Deepseek V4

    Returns:
        An instance of the appropriate FusedMoERouter subclass
    """

    routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
    if routing_strategy != "":
        return RoutingSimulatorRouter(
            top_k=top_k,
            global_num_experts=global_num_experts,
            eplb_state=eplb_state,
            enable_eplb=enable_eplb,
            indices_type_getter=indices_type_getter,
        )

    if zero_expert_type is not None:
        assert num_logical_experts is not None, "num_logical_experts is required when zero_expert_type is set"
        assert e_score_correction_bias is not None, "e_score_correction_bias is required when zero_expert_type is set"
        return ZeroExpertRouter(
            top_k=top_k,
            global_num_experts=global_num_experts,
            eplb_state=eplb_state,
            e_score_correction_bias=e_score_correction_bias,
            num_logical_experts=num_logical_experts,
            zero_expert_type=zero_expert_type,
            scoring_func=scoring_func,
            renormalize=renormalize,
            routed_scaling_factor=routed_scaling_factor,
            enable_eplb=enable_eplb,
            indices_type_getter=indices_type_getter,
        )

    if use_grouped_topk:
        assert custom_routing_function is None
        if num_expert_group is None or topk_group is None:
            raise ValueError("num_expert_group and topk_group must be provided when use_grouped_topk is True")
        grouped_topk_router = GroupedTopKRouter(
            top_k=top_k,
            global_num_experts=global_num_experts,
            eplb_state=eplb_state,
            num_expert_group=num_expert_group,
            topk_group=topk_group,
            renormalize=renormalize,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
            e_score_correction_bias=e_score_correction_bias,
            num_fused_shared_experts=num_fused_shared_experts,
            enable_eplb=enable_eplb,
            indices_type_getter=indices_type_getter,
        )
        return grouped_topk_router

    if custom_routing_function is not None:
        return CustomRoutingRouter(
            top_k=top_k,
            global_num_experts=global_num_experts,
            eplb_state=eplb_state,
            custom_routing_function=custom_routing_function,
            renormalize=renormalize,
            enable_eplb=enable_eplb,
            indices_type_getter=indices_type_getter,
        )

    assert scoring_func in ["sigmoid", "softmax", "sqrtsoftplus"]

    if e_score_correction_bias is not None or hash_indices_table is not None:
        return FusedTopKBiasRouter(
            top_k=top_k,
            global_num_experts=global_num_experts,
            eplb_state=eplb_state,
            e_score_correction_bias=e_score_correction_bias,
            scoring_func=scoring_func,
            renormalize=renormalize,
            routed_scaling_factor=routed_scaling_factor,
            enable_eplb=enable_eplb,
            indices_type_getter=indices_type_getter,
            hash_indices_table=hash_indices_table,
        )

    return FusedTopKRouter(
        top_k=top_k,
        global_num_experts=global_num_experts,
        eplb_state=eplb_state,
        renormalize=renormalize,
        scoring_func=scoring_func,
        enable_eplb=enable_eplb,
        indices_type_getter=indices_type_getter,
    )

get_compressed_expert_map

get_compressed_expert_map(expert_map: Tensor) -> str

Compresses the expert map by removing any -1 entries.

This implementation uses a standard Python loop, which is compatible with graph compilation modes that do not support dynamic shapes resulting from operations like torch.where.

Parameters:

Name Type Description Default
expert_map Tensor

A tensor of shape (global_num_experts,) mapping a global expert index to its local index. Contains -1 for experts that are not assigned to the current rank.

required

Returns:

Name Type Description
str str

A string mapping from local to global index,

str

ordered by global index. (e.g., "0->5, 1->12, 2->23")

Source code in vllm_gaudi/ops/hpu_fused_moe.py
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
    """
    Compresses the expert map by removing any -1 entries.

    This implementation uses a standard Python loop, which is compatible with
    graph compilation modes that do not support dynamic shapes resulting from
    operations like `torch.where`.

    Args:
        expert_map (torch.Tensor): A tensor of shape (global_num_experts,)
            mapping a global expert index to its local index. Contains -1 for
            experts that are not assigned to the current rank.

    Returns:
        str: A string mapping from local to global index,
        ordered by global index.
            (e.g., "0->5, 1->12, 2->23")
    """
    mappings = []
    # A standard loop over a tensor with a known shape is statically analyzable.
    # `enumerate` provides the global_index (the position in the tensor) and
    # `local_index_tensor` (the value at that position).
    for global_index, local_index_tensor in enumerate(expert_map):
        local_index = local_index_tensor.item()
        # We only build strings for valid experts (those not marked as -1).
        if local_index != -1:
            mappings.append(f"{local_index}->{global_index}")

    return ", ".join(mappings)

patched_fused_moe_forward

patched_fused_moe_forward(
    self,
    hidden_states: Tensor,
    router_logits: Tensor,
    input_ids: Tensor | None = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]

Patched forward that avoids graph breaks from ForwardContext lookups and dynamo per-layer string guards.

Instead of calling forward_dispatch (which uses get_layer_from_name, ensure_moe_quant_config_init, and _sequence_parallel_context — all of which access ForwardContext and cause torch.compile graph breaks), we use a layer reference stashed on the runner at FusedMoE.init time (self._hpu_layer_ref) and bypass _forward_impl for dp_size==1, calling _apply_quant_method + _maybe_combine directly. This also bypasses self.layer_name (a per-layer string) so dynamo no longer emits per-layer string guards that trigger recompilation.

The post-forward reduction sequence mirrors upstream MoERunner.forward (vllm/model_executor/layers/fused_moe/runner/ moe_runner.py) so we stay in sync with the new shared/fused output combination logic introduced by upstream PR #35949.

Source code in vllm_gaudi/ops/hpu_fused_moe.py
def patched_fused_moe_forward(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    input_ids: torch.Tensor | None = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    """Patched forward that avoids graph breaks from ForwardContext lookups
    and dynamo per-layer string guards.

    Instead of calling forward_dispatch (which uses get_layer_from_name,
    ensure_moe_quant_config_init, and _sequence_parallel_context — all of
    which access ForwardContext and cause torch.compile graph breaks), we
    use a layer reference stashed on the runner at FusedMoE.__init__ time
    (self._hpu_layer_ref) and bypass _forward_impl for dp_size==1,
    calling _apply_quant_method + _maybe_combine directly. This also
    bypasses self.layer_name (a per-layer string) so dynamo no longer
    emits per-layer string guards that trigger recompilation.

    The post-forward reduction sequence mirrors upstream
    MoERunner.forward (vllm/model_executor/layers/fused_moe/runner/
    moe_runner.py) so we stay in sync with the new shared/fused
    output combination logic introduced by upstream PR #35949.
    """
    hidden_states, shared_experts_input = self.apply_routed_input_transform(hidden_states)
    hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(shared_experts_input, hidden_states)

    if self.moe_config.dp_size == 1:
        # Bypass _forward_impl entirely for dp_size==1 to eliminate
        # graph breaks from _sequence_parallel_context() (which calls
        # get_forward_context()), skip the no-op _maybe_dispatch(), and
        # avoid double gate / stream-sync calls that _forward_impl
        # would redundantly repeat.
        if self.moe_config.pcp_size > 1:
            raise RuntimeError("dp_size==1 fast path does not support pcp_size > 1")
        layer = self._hpu_layer_ref
        layer.ensure_moe_quant_config_init()
        self._maybe_sync_shared_experts_stream(shared_experts_input)
        gate = self.gate or getattr(self, "_hpu_gate_ref", None)
        if gate is not None:
            router_logits, _ = gate(hidden_states)
        shared_output, fused_hidden = self._apply_quant_method(
            layer=layer,
            hidden_states=hidden_states,
            router_logits=router_logits,
            shared_experts_input=shared_experts_input,
            input_ids=input_ids,
        )
        result = self._maybe_combine(shared_output, fused_hidden)
    else:
        result = self._forward_entry(hidden_states, router_logits, shared_experts_input, input_ids,
                                     self._encode_layer_name(), self._trtllm_mxfp4_unpadded_dim())

    # Mirror upstream MoERunner.forward post-_forward_entry pipeline.
    if isinstance(result, tuple):
        shared_output, fused_output = result
    else:
        shared_output, fused_output = None, result

    shared_output = self._maybe_reduce_shared_expert_output(shared_output)
    shared_output, fused_output = self._maybe_apply_routed_scale_to_output(shared_output, fused_output)
    fused_output = self.apply_routed_output_transform(fused_output)

    combined = (shared_output + fused_output) if shared_output is not None else fused_output

    combined = self._maybe_reduce_final_output(combined, og_hidden_dims)
    return self._maybe_add_zero_expert_output(combined)