Skip to content

vllm.models.minimax_m3.amd.mtp

MiniMax M3 MTP (multi-token prediction) draft model -- ROCm/AMD variant.

Byte-identical to nvidia/mtp.py except this file lives under amd/ so its from .model import ... resolves to the self-contained AMD model (native Gemma RMSNorm, native MXFP8 MoE, Triton sparse attention). The MTP logic is platform-agnostic. (Mirrors vllm.models.deepseek_v4.amd.mtp.)

TODO(future, separate diff): since this is byte-identical to nvidia/mtp.py, both copies could be consolidated into a single common/mtp.py that dispatches its model import (..amd.model vs ..nvidia.model) via current_platform.is_rocm() -- the same dispatch minimax_m3/__init__.py uses. This was prototyped and VERIFIED working (MiniMaxM3MTP resolves through common.mtp to the AMD decoder layer / RMSNorm on ROCm), but it deletes the upstream nvidia/mtp.py and touches the NVIDIA load path, so it is deferred to a dedicated refactor diff to keep this AMD-enablement change NVIDIA-untouched.

Classes:

MiniMaxM3MTP

Bases: Module

Source code in vllm/models/minimax_m3/amd/mtp.py
class MiniMaxM3MTP(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        assert vllm_config.speculative_config is not None
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
        self.quant_config = vllm_config.quant_config
        self.model = MiniMaxM3MultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        return self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor | None:
        current_step_idx = spec_step_idx % self.model.num_mtp_layers
        mtp_layer = self.model.layers[str(current_step_idx)]
        return self.logits_processor(
            self.lm_head, mtp_layer.final_layernorm(hidden_states)
        )

    def _get_mtp_layer_idx_from_weight_name(self, name: str) -> int | None:
        """Return the MTP layer index in *.mtp.layers.{idx}.*, else None."""
        match = re.search(r"\.mtp\.layers\.(\d+)\.", name)
        return int(match.group(1)) if match else None

    def _map_checkpoint_name(self, name: str) -> str | None:
        """Map a full checkpoint key to this MTP module's parameter name.

        The MTP module only owns the *.mtp.layers.* weights plus the token
        embedding and LM head, which the checkpoint shares with the main model.
        Everything else belongs to other modules and is ignored here by returning
        None.
        """
        # In the bundled checkpoint, the MTP weights are prefixed with
        # "language_model". The standalone MTP checkpoint has no such prefix.
        # Strip it if present.
        name = name.removeprefix("language_model.")

        if name == "model.embed_tokens.weight":
            return "model.embed_tokens.weight"
        if name == "lm_head.weight":
            return "lm_head.weight"
        if "model.mtp.layers" in name:
            if "weight_scale_inv" in name:
                # The checkpoint stores block scales as "weight_scale_inv".
                # The ModelOpt MXFP8 layers expose them as "weight_scale".
                name = name.replace("weight_scale_inv", "weight_scale")
            # Strip "mtp" from prefix.
            return name.replace(".mtp.", ".")
        return None

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        # Map q/k/v projections to qkv_proj, and gate/up projections to gate_up_proj.
        stacked_params_mapping: list[tuple[str, str, int | str]] = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".qkv_proj", ".index_q_proj", "index_q"),
            (".qkv_proj", ".index_k_proj", "index_k"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]

        # Map expert weights w1/w2/w3 to gate/down/up.
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = fused_moe_make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
        )

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        loaded_mtp_layers: set[int] = set()
        for name, loaded_weight in weights:
            mtp_layer = self._get_mtp_layer_idx_from_weight_name(name)
            mapped_name = self._map_checkpoint_name(name)
            if mapped_name is None:
                # This weight does not belong to the MTP module, so skip it.
                continue
            name = mapped_name

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                # Routed experts (w1/w2/w3) are handled below. Don't let the
                # stacked mapping rewrite them.
                if ("block_sparse_moe.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for (
                    param_name,
                    weight_name,
                    expert_id,
                    expert_shard_id,
                ) in expert_params_mapping:
                    if weight_name not in name:
                        continue

                    name = name.replace(weight_name, param_name)
                    if name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=expert_shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    remapped_name = maybe_remap_kv_scale_name(name, params_dict)
                    if remapped_name is None or remapped_name not in params_dict:
                        continue
                    name = remapped_name
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

            loaded_params.add(name)
            if mtp_layer is not None:
                loaded_mtp_layers.add(mtp_layer)

        # Validate that weights were loaded for each MTP layer.
        for layer_idx in range(self.model.num_mtp_layers):
            if layer_idx not in loaded_mtp_layers:
                raise ValueError(
                    f"Failed to load MTP layer {layer_idx} weights from checkpoint."
                )

        return loaded_params

_get_mtp_layer_idx_from_weight_name(name)

Return the MTP layer index in .mtp.layers.{idx}., else None.

Source code in vllm/models/minimax_m3/amd/mtp.py
def _get_mtp_layer_idx_from_weight_name(self, name: str) -> int | None:
    """Return the MTP layer index in *.mtp.layers.{idx}.*, else None."""
    match = re.search(r"\.mtp\.layers\.(\d+)\.", name)
    return int(match.group(1)) if match else None

_map_checkpoint_name(name)

Map a full checkpoint key to this MTP module's parameter name.

The MTP module only owns the .mtp.layers. weights plus the token embedding and LM head, which the checkpoint shares with the main model. Everything else belongs to other modules and is ignored here by returning None.

Source code in vllm/models/minimax_m3/amd/mtp.py
def _map_checkpoint_name(self, name: str) -> str | None:
    """Map a full checkpoint key to this MTP module's parameter name.

    The MTP module only owns the *.mtp.layers.* weights plus the token
    embedding and LM head, which the checkpoint shares with the main model.
    Everything else belongs to other modules and is ignored here by returning
    None.
    """
    # In the bundled checkpoint, the MTP weights are prefixed with
    # "language_model". The standalone MTP checkpoint has no such prefix.
    # Strip it if present.
    name = name.removeprefix("language_model.")

    if name == "model.embed_tokens.weight":
        return "model.embed_tokens.weight"
    if name == "lm_head.weight":
        return "lm_head.weight"
    if "model.mtp.layers" in name:
        if "weight_scale_inv" in name:
            # The checkpoint stores block scales as "weight_scale_inv".
            # The ModelOpt MXFP8 layers expose them as "weight_scale".
            name = name.replace("weight_scale_inv", "weight_scale")
        # Strip "mtp" from prefix.
        return name.replace(".mtp.", ".")
    return None