Skip to content

vllm.model_executor.models.mimo_v2

Classes:

MiMoV2Model

Bases: Module

Source code in vllm/model_executor/models/mimo_v2.py
@support_torch_compile
class MiMoV2Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config.get_text_config()
        quant_config = vllm_config.quant_config
        eplb_config = vllm_config.parallel_config.eplb_config

        self.config = config
        self.quant_config = quant_config
        self.vocab_size = config.vocab_size
        self.num_redundant_experts = eplb_config.num_redundant_experts

        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=f"{prefix}.embed_tokens",
            )
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: MiMoV2FlashDecoderLayer(
                vllm_config=vllm_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )

        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
        else:
            self.norm = PPMissingLayer()

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.embed_input_ids(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer)
        ):
            hidden_states, residual = layer(positions, hidden_states, residual)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )

        hidden_states, _ = self.norm(hidden_states, residual)

        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return fused_moe_make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=self.num_redundant_experts,
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        expert_params_mapping = self.get_expert_mapping()
        # Pro-format fused qkv_proj arrives as two tensors (weight and
        # weight_scale_inv). Store them per-layer so that they can be
        # sharded together.
        pending_fp8_qkv_proj: dict[str, dict[str, torch.Tensor]] = {}
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                continue
            if "mtp" in name:
                continue

            expert_matched = False
            for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
                if weight_name not in name:
                    continue

                name_rewritten = name.replace(weight_name, param_name)

                if is_pp_missing_parameter(name_rewritten, self):
                    continue

                if (
                    name_rewritten.endswith(".bias") or name_rewritten.endswith("_bias")
                ) and name_rewritten not in params_dict:
                    continue

                if name_rewritten not in params_dict:
                    continue

                param = params_dict[name_rewritten]
                weight_loader = param.weight_loader

                weight_loader(
                    param,
                    loaded_weight,
                    name_rewritten,
                    shard_id=shard_id,
                    expert_id=expert_id,
                )
                loaded_params.add(name_rewritten)
                expert_matched = True
                break

            if expert_matched:
                continue
            # Support fused qkv_proj checkpoint (Pro format)
            if self._try_load_fp8_qkv_proj(
                name,
                loaded_weight,
                pending_fp8_qkv_proj,
                params_dict,
                loaded_params,
                tp_rank,
                tp_size,
            ):
                continue
            stacked_matched = False
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name_rewritten = name.replace(weight_name, param_name)

                if (
                    name_rewritten.endswith(".bias")
                    and name_rewritten not in params_dict
                ):
                    continue

                if is_pp_missing_parameter(name_rewritten, self):
                    continue

                if name_rewritten not in params_dict:
                    continue

                param = params_dict[name_rewritten]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(name_rewritten)

                stacked_matched = True
                break

            if stacked_matched:
                continue

            if name.endswith(".bias") and name not in params_dict:
                continue

            orig_name = name
            mapped_name = maybe_remap_kv_scale_name(name, params_dict)
            name = mapped_name if mapped_name is not None else orig_name

            if name not in params_dict:
                continue

            param = params_dict[name]

            if "attention_sink_bias" in name:
                total_heads = loaded_weight.shape[0]
                heads_per_rank = total_heads // tp_size
                head_start = tp_rank * heads_per_rank
                narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)

                param.data.copy_(narrow_weight)
                loaded_params.add(name)
            else:
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)

        return loaded_params

    def _try_load_fp8_qkv_proj(
        self,
        name: str,
        tensor: torch.Tensor,
        fp8_qkv_proj_dict: dict[str, dict[str, torch.Tensor]],
        params_dict: dict[str, torch.nn.Parameter],
        loaded_params: set[str],
        tp_rank: int,
        tp_size: int,
    ) -> bool:
        """
        The fused fp8 QKV projection weights and scale are stored separately.
        Special care must be taken while sharding these tensors across TP ranks.
        See _shard_fp8_qkv_proj for more details.

        Returns:
            True if ``tensor`` was an fp8 qkv_proj weight/scale and was consumed
            (caller should skip it); False otherwise, so the caller falls
            through to its normal loading path.
        """
        is_weight = (
            name.endswith("qkv_proj.weight") and tensor.dtype == torch.float8_e4m3fn
        )
        is_scale = name.endswith("qkv_proj.weight_scale_inv")
        if not is_weight and not is_scale:
            # Weight is not in FP8 format. Ignore.
            return False

        if is_pp_missing_parameter(name, self):
            # This qkv_proj is for a layer not on this PP rank.
            return True

        prefix, qkv_kind = name.rsplit(".", 1)
        entry = fp8_qkv_proj_dict.setdefault(prefix, {})
        entry[qkv_kind] = tensor
        if "weight" not in entry or "weight_scale_inv" not in entry:
            # Still waiting for the other param.
            return True
        del fp8_qkv_proj_dict[prefix]

        # Get self_attn module, which is a parent of qkv_proj.
        attn = self.get_submodule(prefix.rsplit(".", 1)[0])

        # Shard the qkv_proj per-rank.
        w_rank, s_rank = _shard_fp8_qkv_proj(
            entry["weight"],
            entry["weight_scale_inv"],
            num_heads=attn.total_num_heads,
            num_kv_heads=attn.total_num_kv_heads,
            head_dim=attn.head_dim,
            v_head_dim=attn.v_head_dim,
            tp_rank=tp_rank,
            tp_size=tp_size,
        )
        sharded = {"weight": w_rank, "weight_scale_inv": s_rank}
        for kind, tensor in sharded.items():
            param_name = f"{prefix}.{kind}"
            param = params_dict[param_name]
            if tensor.shape[0] > param.shape[0]:
                tensor = tensor[: param.shape[0]]
            default_weight_loader(param, tensor)
            loaded_params.add(param_name)
        return True

_try_load_fp8_qkv_proj(name, tensor, fp8_qkv_proj_dict, params_dict, loaded_params, tp_rank, tp_size)

The fused fp8 QKV projection weights and scale are stored separately. Special care must be taken while sharding these tensors across TP ranks. See _shard_fp8_qkv_proj for more details.

Returns:

  • bool

    True if tensor was an fp8 qkv_proj weight/scale and was consumed

  • bool

    (caller should skip it); False otherwise, so the caller falls

  • bool

    through to its normal loading path.

Source code in vllm/model_executor/models/mimo_v2.py
def _try_load_fp8_qkv_proj(
    self,
    name: str,
    tensor: torch.Tensor,
    fp8_qkv_proj_dict: dict[str, dict[str, torch.Tensor]],
    params_dict: dict[str, torch.nn.Parameter],
    loaded_params: set[str],
    tp_rank: int,
    tp_size: int,
) -> bool:
    """
    The fused fp8 QKV projection weights and scale are stored separately.
    Special care must be taken while sharding these tensors across TP ranks.
    See _shard_fp8_qkv_proj for more details.

    Returns:
        True if ``tensor`` was an fp8 qkv_proj weight/scale and was consumed
        (caller should skip it); False otherwise, so the caller falls
        through to its normal loading path.
    """
    is_weight = (
        name.endswith("qkv_proj.weight") and tensor.dtype == torch.float8_e4m3fn
    )
    is_scale = name.endswith("qkv_proj.weight_scale_inv")
    if not is_weight and not is_scale:
        # Weight is not in FP8 format. Ignore.
        return False

    if is_pp_missing_parameter(name, self):
        # This qkv_proj is for a layer not on this PP rank.
        return True

    prefix, qkv_kind = name.rsplit(".", 1)
    entry = fp8_qkv_proj_dict.setdefault(prefix, {})
    entry[qkv_kind] = tensor
    if "weight" not in entry or "weight_scale_inv" not in entry:
        # Still waiting for the other param.
        return True
    del fp8_qkv_proj_dict[prefix]

    # Get self_attn module, which is a parent of qkv_proj.
    attn = self.get_submodule(prefix.rsplit(".", 1)[0])

    # Shard the qkv_proj per-rank.
    w_rank, s_rank = _shard_fp8_qkv_proj(
        entry["weight"],
        entry["weight_scale_inv"],
        num_heads=attn.total_num_heads,
        num_kv_heads=attn.total_num_kv_heads,
        head_dim=attn.head_dim,
        v_head_dim=attn.v_head_dim,
        tp_rank=tp_rank,
        tp_size=tp_size,
    )
    sharded = {"weight": w_rank, "weight_scale_inv": s_rank}
    for kind, tensor in sharded.items():
        param_name = f"{prefix}.{kind}"
        param = params_dict[param_name]
        if tensor.shape[0] > param.shape[0]:
            tensor = tensor[: param.shape[0]]
        default_weight_loader(param, tensor)
        loaded_params.add(param_name)
    return True

_shard_fp8_qkv_proj(w_full, s_full, num_heads, num_kv_heads, head_dim, v_head_dim, tp_rank, tp_size, block=128)

Shard the fp8 qkv_proj weights for tp_rank.

The checkpoint stores the fused QKV as num_kv_heads contiguous groups (one per KV head; n below), each ordered [Q | K | V]:

[Q_1 | K_1 | V_1 | Q_2 | K_2 | V_2 | ... | Q_n | K_n | V_n]

Per group, Q has (num_heads / num_kv_heads) * head_dim rows, K has head_dim rows, and V has v_head_dim rows.

Each TP rank owns g = num_kv_heads / tp_size of these groups, and the forward expects them de-interleaved into a single Q, K, and V block:

[Q_1 | Q_2 | ... | Q_g | K_1 | K_2 | ... | K_g | V_1 | V_2 | ... | V_g]

When g == 1 the rank's slice is already [Q | K | V], so a plain chunk suffices. When g > 1 we cannot reach the de-interleaved layout by re-permuting the fp8 block scales: each scale covers a 128-row block, and since K is 192 rows (1.5 blocks) a block straddles the K/V boundary, so no whole-block permutation produces it. Instead we dequantize this rank's groups to float (dropping the block constraint), reorder the rows into the layout above (Q, K, and V then each span a whole number of blocks), and re-quantize to fp8.

Source code in vllm/model_executor/models/mimo_v2.py
def _shard_fp8_qkv_proj(
    w_full: torch.Tensor,
    s_full: torch.Tensor,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
    v_head_dim: int,
    tp_rank: int,
    tp_size: int,
    block: int = 128,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Shard the fp8 qkv_proj weights for ``tp_rank``.

    The checkpoint stores the fused QKV as ``num_kv_heads`` contiguous groups
    (one per KV head; ``n`` below), each ordered ``[Q | K | V]``:

        [Q_1 | K_1 | V_1 | Q_2 | K_2 | V_2 | ... | Q_n | K_n | V_n]

    Per group, Q has ``(num_heads / num_kv_heads) * head_dim`` rows, K has
    ``head_dim`` rows, and V has ``v_head_dim`` rows.

    Each TP rank owns ``g = num_kv_heads / tp_size`` of these groups, and the
    forward expects them de-interleaved into a single Q, K, and V block:

        [Q_1 | Q_2 | ... | Q_g | K_1 | K_2 | ... | K_g | V_1 | V_2 | ... | V_g]

    When ``g == 1`` the rank's slice is already ``[Q | K | V]``, so a plain
    chunk suffices. When ``g > 1`` we cannot reach the de-interleaved layout by
    re-permuting the fp8 block scales: each scale covers a 128-row block, and
    since K is 192 rows (1.5 blocks) a block straddles the K/V boundary, so no
    whole-block permutation produces it. Instead we dequantize this rank's
    groups to float (dropping the block constraint), reorder the rows into the
    layout above (Q, K, and V then each span a whole number of blocks), and
    re-quantize to fp8.
    """
    assert tp_size <= num_kv_heads and num_kv_heads % tp_size == 0, (
        "TP size must evenly split the number of KV heads."
    )

    kv_heads_per_rank = num_kv_heads // tp_size
    if kv_heads_per_rank == 1:
        # One KV head per rank. The weights and scale can be trivially sharded
        # without re-quantization.
        w = w_full.chunk(tp_size, dim=0)[tp_rank]
        s = s_full.chunk(tp_size, dim=0)[tp_rank]
        return w, s

    q_rows_per_group = (num_heads // num_kv_heads) * head_dim
    k_rows_per_group = head_dim
    v_rows_per_group = v_head_dim
    rows_per_group = q_rows_per_group + k_rows_per_group + v_rows_per_group
    scale_rows_per_group = s_full.shape[0] // num_kv_heads
    qs, ks, vs = [], [], []
    for g_idx in range(tp_rank * kv_heads_per_rank, (tp_rank + 1) * kv_heads_per_rank):
        row_start = g_idx * rows_per_group
        scale_row_start = g_idx * scale_rows_per_group
        # Dequantize this group's weights.
        w_g = w_full[row_start : row_start + rows_per_group].to(torch.float32)
        s_g = s_full[scale_row_start : scale_row_start + scale_rows_per_group].to(
            torch.float32
        )
        s_g_expanded = s_g.repeat_interleave(block, dim=0).repeat_interleave(
            block, dim=1
        )[:rows_per_group]
        w_g_dequant = w_g * s_g_expanded
        # Track the dequantized q, k, and v weights separately.
        qs.append(w_g_dequant[:q_rows_per_group])
        ks.append(w_g_dequant[q_rows_per_group : q_rows_per_group + k_rows_per_group])
        vs.append(w_g_dequant[q_rows_per_group + k_rows_per_group :])

    # Combine the q, k, and v weights into the following layout:
    # [Q_1, Q_2, .., Q_g, K_1, K_2, ..., K_g, V_1, V_2, ..., V_g]
    grouped = torch.cat([torch.cat(qs), torch.cat(ks), torch.cat(vs)], dim=0)
    # Quantize back to fp8.
    return scaled_quantize(
        grouped, GroupShape(block, block), w_full.dtype, compute_dtype=torch.float32
    )