Skip to content

vllm.models.minimax_m3

MiniMax M3 model — hardware-isolated entry point.

The implementation lives under nvidia/ and amd/; this module picks the right one for the current platform and re-exports the public classes used by the model registry. (Mirrors vllm.models.deepseek_v4.)

Modules:

Classes:

MiniMaxM3MTP

Bases: Module

Source code in vllm/models/minimax_m3/amd/mtp.py
class MiniMaxM3MTP(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        assert vllm_config.speculative_config is not None
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
        self.quant_config = vllm_config.quant_config
        self.model = MiniMaxM3MultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        return self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor | None:
        current_step_idx = spec_step_idx % self.model.num_mtp_layers
        mtp_layer = self.model.layers[str(current_step_idx)]
        return self.logits_processor(
            self.lm_head, mtp_layer.final_layernorm(hidden_states)
        )

    def _get_mtp_layer_idx_from_weight_name(self, name: str) -> int | None:
        """Return the MTP layer index in *.mtp.layers.{idx}.*, else None."""
        match = re.search(r"\.mtp\.layers\.(\d+)\.", name)
        return int(match.group(1)) if match else None

    def _map_checkpoint_name(self, name: str) -> str | None:
        """Map a full checkpoint key to this MTP module's parameter name.

        The MTP module only owns the *.mtp.layers.* weights plus the token
        embedding and LM head, which the checkpoint shares with the main model.
        Everything else belongs to other modules and is ignored here by returning
        None.
        """
        # In the bundled checkpoint, the MTP weights are prefixed with
        # "language_model". The standalone MTP checkpoint has no such prefix.
        # Strip it if present.
        name = name.removeprefix("language_model.")

        if name == "model.embed_tokens.weight":
            return "model.embed_tokens.weight"
        if name == "lm_head.weight":
            return "lm_head.weight"
        if "model.mtp.layers" in name:
            if "weight_scale_inv" in name:
                # The checkpoint stores block scales as "weight_scale_inv".
                # The ModelOpt MXFP8 layers expose them as "weight_scale".
                name = name.replace("weight_scale_inv", "weight_scale")
            # Strip "mtp" from prefix.
            return name.replace(".mtp.", ".")
        return None

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        # Map q/k/v projections to qkv_proj, and gate/up projections to gate_up_proj.
        stacked_params_mapping: list[tuple[str, str, int | str]] = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".qkv_proj", ".index_q_proj", "index_q"),
            (".qkv_proj", ".index_k_proj", "index_k"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]

        # Map expert weights w1/w2/w3 to gate/down/up.
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = fused_moe_make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
        )

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        loaded_mtp_layers: set[int] = set()
        for name, loaded_weight in weights:
            mtp_layer = self._get_mtp_layer_idx_from_weight_name(name)
            mapped_name = self._map_checkpoint_name(name)
            if mapped_name is None:
                # This weight does not belong to the MTP module, so skip it.
                continue
            name = mapped_name

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue

                # Routed experts (w1/w2/w3) are handled below. Don't let the
                # stacked mapping rewrite them.
                if ("block_sparse_moe.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for (
                    param_name,
                    weight_name,
                    expert_id,
                    expert_shard_id,
                ) in expert_params_mapping:
                    if weight_name not in name:
                        continue

                    name = name.replace(weight_name, param_name)
                    if name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=expert_shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    remapped_name = maybe_remap_kv_scale_name(name, params_dict)
                    if remapped_name is None or remapped_name not in params_dict:
                        continue
                    name = remapped_name
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

            loaded_params.add(name)
            if mtp_layer is not None:
                loaded_mtp_layers.add(mtp_layer)

        # Validate that weights were loaded for each MTP layer.
        for layer_idx in range(self.model.num_mtp_layers):
            if layer_idx not in loaded_mtp_layers:
                raise ValueError(
                    f"Failed to load MTP layer {layer_idx} weights from checkpoint."
                )

        return loaded_params

_get_mtp_layer_idx_from_weight_name(name)

Return the MTP layer index in .mtp.layers.{idx}., else None.

Source code in vllm/models/minimax_m3/amd/mtp.py
def _get_mtp_layer_idx_from_weight_name(self, name: str) -> int | None:
    """Return the MTP layer index in *.mtp.layers.{idx}.*, else None."""
    match = re.search(r"\.mtp\.layers\.(\d+)\.", name)
    return int(match.group(1)) if match else None

_map_checkpoint_name(name)

Map a full checkpoint key to this MTP module's parameter name.

The MTP module only owns the .mtp.layers. weights plus the token embedding and LM head, which the checkpoint shares with the main model. Everything else belongs to other modules and is ignored here by returning None.

Source code in vllm/models/minimax_m3/amd/mtp.py
def _map_checkpoint_name(self, name: str) -> str | None:
    """Map a full checkpoint key to this MTP module's parameter name.

    The MTP module only owns the *.mtp.layers.* weights plus the token
    embedding and LM head, which the checkpoint shares with the main model.
    Everything else belongs to other modules and is ignored here by returning
    None.
    """
    # In the bundled checkpoint, the MTP weights are prefixed with
    # "language_model". The standalone MTP checkpoint has no such prefix.
    # Strip it if present.
    name = name.removeprefix("language_model.")

    if name == "model.embed_tokens.weight":
        return "model.embed_tokens.weight"
    if name == "lm_head.weight":
        return "lm_head.weight"
    if "model.mtp.layers" in name:
        if "weight_scale_inv" in name:
            # The checkpoint stores block scales as "weight_scale_inv".
            # The ModelOpt MXFP8 layers expose them as "weight_scale".
            name = name.replace("weight_scale_inv", "weight_scale")
        # Strip "mtp" from prefix.
        return name.replace(".mtp.", ".")
    return None

MiniMaxM3SparseForCausalLM

Bases: Module, SupportsEagle3

MiniMax M3 (sparse/dense backbone) for causal language modeling.

Source code in vllm/models/minimax_m3/amd/model.py
class MiniMaxM3SparseForCausalLM(nn.Module, SupportsEagle3):
    """MiniMax M3 (sparse/dense backbone) for causal language modeling."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_text_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = MiniMaxM3Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        self.logits_processor = LogitsProcessor(config.vocab_size)

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        return self.model(input_ids, positions, inputs_embeds)

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
        return self.logits_processor(self.lm_head, hidden_states)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

MiniMaxM3SparseForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsEagle3

Top-level (VL) entry point for MiniMax M3.

Owns the shared MiniMax-M3 vision tower on ROCm and delegates text generation to the AMD language-model path.

Source code in vllm/models/minimax_m3/amd/model.py
@MULTIMODAL_REGISTRY.register_processor(
    MiniMaxM3VLMultiModalProcessor,
    info=MiniMaxM3VLProcessingInfo,
    dummy_inputs=MiniMaxM3VLDummyInputsBuilder,
)
class MiniMaxM3SparseForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsEagle3
):
    """Top-level (VL) entry point for MiniMax M3.

    Owns the shared MiniMax-M3 vision tower on ROCm and delegates text
    generation to the AMD language-model path.
    """

    # The vision tower runs replicated per rank under ``--mm-encoder-tp-mode
    # data``; ``run_dp_sharded_mrope_vision_model`` shards the work across
    # ranks (see ``_process_image_input`` / ``_process_video_input``).
    supports_encoder_tp_data = True

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "multi_modal_projector.": "vision_tower.multi_modal_projector.",
            "patch_merge_mlp.": "vision_tower.patch_merge_mlp.",
        },
        orig_to_new_substr={
            ".mlp.fc1.": ".fc1.",
            ".mlp.fc2.": ".fc2.",
        },
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality == "image":
            return MiniMaxM3VLProcessingInfo.IMAGE_TOKEN
        if modality == "video":
            return MiniMaxM3VLProcessingInfo.VIDEO_TOKEN
        raise ValueError(f"Unsupported modality: {modality!r}")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.quant_config = vllm_config.quant_config
        self.multimodal_config = vllm_config.model_config.multimodal_config
        assert self.multimodal_config is not None
        self.use_data_parallel = self.multimodal_config.mm_encoder_tp_mode == "data"

        text_hidden_size = getattr(config.text_config, "hidden_size", None)
        assert text_hidden_size is not None, "text_config.hidden_size is required"
        projector_hidden_size = getattr(config, "projector_hidden_size", None)

        with self._mark_tower_model(vllm_config, {"image", "video"}):
            vision_config = config.vision_config
            self.vision_tower = MiniMaxVLVisionModel(
                config=PretrainedConfig.from_dict(vision_config),
                text_hidden_size=text_hidden_size,
                projector_hidden_size=projector_hidden_size,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["MiniMaxM3SparseForCausalLM"],
        )

    def _parse_and_validate_image_input(self, **kwargs: object) -> dict | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)
        if pixel_values is None:
            return None
        return {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}

    def _parse_and_validate_video_input(self, **kwargs: object) -> dict | None:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
        if pixel_values_videos is None:
            return None
        return {
            "pixel_values_videos": pixel_values_videos,
            "video_grid_thw": video_grid_thw,
        }

    def _process_image_input(self, image_input: dict) -> tuple[torch.Tensor, ...]:
        pixel_values: torch.Tensor = image_input["pixel_values"].type(
            self.vision_tower.dtype
        )
        grid_thw: torch.Tensor = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        if self.use_data_parallel:
            # Already returns a per-item tuple of embeddings.
            return run_dp_sharded_mrope_vision_model(
                self.vision_tower,
                pixel_values,
                grid_thw.tolist(),
                rope_type="rope_3d",
            )

        image_embeds = self.vision_tower(
            pixel_values=pixel_values,
            grid_thw=grid_thw.tolist(),
        )

        # Split the concatenated output into one tensor per image item.
        merge_size = self.vision_tower.spatial_merge_size
        sizes = (grid_thw.prod(-1) // (merge_size * merge_size)).tolist()
        return image_embeds.split(sizes)

    def _process_video_input(self, video_input: dict) -> tuple[torch.Tensor, ...]:
        pixel_values: torch.Tensor = video_input["pixel_values_videos"].type(
            self.vision_tower.dtype
        )
        grid_thw: torch.Tensor = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        if self.use_data_parallel:
            # Already returns a per-item tuple of embeddings.
            return run_dp_sharded_mrope_vision_model(
                self.vision_tower,
                pixel_values,
                grid_thw.tolist(),
                rope_type="rope_3d",
            )

        video_embeds = self.vision_tower(
            pixel_values=pixel_values,
            grid_thw=grid_thw.tolist(),
        )

        # Split the concatenated output into one tensor per video item.
        merge_size = self.vision_tower.spatial_merge_size
        sizes = (grid_thw.prod(-1) // (merge_size * merge_size)).tolist()
        return video_embeds.split(sizes)

    def _parse_and_validate_multimodal_inputs(
        self, **kwargs: object
    ) -> dict[str, dict]:
        mm_input_by_modality: dict[str, dict] = {}
        for input_key in kwargs:
            if input_key == "pixel_values" and "image" not in mm_input_by_modality:
                image_input = self._parse_and_validate_image_input(**kwargs)
                if image_input is not None:
                    mm_input_by_modality["image"] = image_input
            if (
                input_key == "pixel_values_videos"
                and "video" not in mm_input_by_modality
            ):
                video_input = self._parse_and_validate_video_input(**kwargs)
                if video_input is not None:
                    mm_input_by_modality["video"] = video_input
        return mm_input_by_modality

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not mm_input_by_modality:
            return []

        multimodal_embeddings: list[torch.Tensor] = []
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                image_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings.extend(image_embeddings)
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                multimodal_embeddings.extend(video_embeddings)

        return tuple(multimodal_embeddings)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        return self.language_model(input_ids, positions, inputs_embeds)

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.language_model.get_expert_mapping()

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)