Skip to content

speculators.convert.mtp.converter

MTP checkpoint converter.

Extracts only the MTP layer weights from a checkpoint with native MTP layers and saves a Speculators checkpoint that loads with MTPDraftModel.from_pretrained(path).

Only the mtp.* subtree is extracted from the (potentially sharded) safetensors file; the rest of the model is never loaded. The embed_tokens and lm_head are loaded from the verifier at runtime via load_verifier_weights().

Classes:

  • MTPConverter

    Extract the MTP head from a checkpoint with native MTP layers.

MTPConverter

Extract the MTP head from a checkpoint with native MTP layers.

Reads only the MTP layer, embed_tokens, and lm_head from the source checkpoint. Sharded safetensors files are handled transparently via the weight index -- the main transformer stack is never loaded.

Methods:

convert_to_state_dict

convert_to_state_dict(
    input_path: str | Path,
    cache_dir: str | Path | None = None,
) -> dict[str, torch.Tensor]

Extract native MTP weights and return them as a state dict.

Performs the full pipeline (download/locate → verify → extract → remap → fuse MoE experts) without writing anything to disk.

Source code in speculators/convert/mtp/converter.py
def convert_to_state_dict(
    self,
    input_path: str | Path,
    cache_dir: str | Path | None = None,
) -> dict[str, torch.Tensor]:
    """Extract native MTP weights and return them as a state dict.

    Performs the full pipeline (download/locate → verify → extract →
    remap → fuse MoE experts) without writing anything to disk.
    """
    logger.info(f"Extracting native MTP weights from {input_path}")

    local_path = ensure_checkpoint_is_local(input_path, cache_dir)
    all_keys = list_checkpoint_keys(local_path)
    self._verify_mtp_format(all_keys)

    weights = self._extract_weights(local_path, all_keys)
    weights = self._fuse_moe_experts(weights)
    logger.info(f"Extracted {len(weights)} MTP weight tensors")
    return weights