Skip to content

llmcompressor.modeling.gemma4

Classes:

Gemma4TextExpertsList

Gemma4TextExpertsList(
    config: Gemma4TextConfig, original: Gemma4TextExperts
)

Bases: ModuleList

Unpacks 3D expert parameter tensors into individual Gemma4TextMLP modules so that each expert's weights are nn.Linear and can be targeted by quantization with targets="Linear".

Source code in src/llmcompressor/modeling/gemma4.py
def __init__(self, config: Gemma4TextConfig, original: Gemma4TextExperts):
    from transformers.models.gemma4.modeling_gemma4 import Gemma4TextMLP

    self.num_experts = config.num_experts
    intermediate_size = config.moe_intermediate_size

    with skip_weights_initialize():
        super().__init__(
            [Gemma4TextMLP(config, layer_idx=0) for _ in range(self.num_experts)]
        )

    gate_up_data = original.gate_up_proj.data  # [num_experts, 2*inter, hidden]
    down_data = original.down_proj.data  # [num_experts, hidden, inter]

    for i in range(self.num_experts):
        gate_up = gate_up_data[i]  # [2*intermediate, hidden]
        down = down_data[i]  # [hidden, intermediate]

        # gate_up_proj stores [gate; up] stacked along dim 0
        # nn.Linear weight is [out_features, in_features]
        self[i].gate_proj.weight.data = (
            gate_up[:intermediate_size, :].clone().contiguous()
        )
        self[i].up_proj.weight.data = (
            gate_up[intermediate_size:, :].clone().contiguous()
        )
        self[i].down_proj.weight.data = down.clone().contiguous()

SequentialGemma4TextExperts

SequentialGemma4TextExperts(
    original: Gemma4TextExperts,
    config: Gemma4Config,
    calibrate_all_experts: bool = True,
)

Bases: MoECalibrationModule

Calibration version of Gemma4TextExperts that unpacks experts.

This module unpacks the packed expert weights (3D -> 2D) for calibration and stays in unpacked form (permanent) for vLLM compatibility.

Source code in src/llmcompressor/modeling/gemma4.py
def __init__(
    self,
    original: Gemma4TextExperts,
    config: Gemma4Config,
    calibrate_all_experts: bool = True,
):
    super().__init__()
    self.num_experts = original.num_experts
    self.hidden_dim = original.hidden_dim
    self.intermediate_dim = original.intermediate_dim
    self.calibrate_all_experts = calibrate_all_experts

    # Unpack the 3D expert weights into individual MLP modules
    # Register experts directly as numbered children to avoid double nesting
    # (HF has layers[i].experts, so we want layers[i].experts.0,
    # not layers[i].experts.experts.0)
    expert_list = Gemma4TextExpertsList(config.text_config, original)
    for i, expert in enumerate(expert_list):
        self.add_module(str(i), expert)