Skip to content

speculators.models

Modules:

Classes:

DFlashDraftModel

DFlashDraftModel(config: DFlashSpeculatorConfig)

Bases: DraftVocabMixin, SpeculatorModel

Methods:

Attributes:

  • target_layer_ids (list[int]) –

    Target layer IDs for auxiliary hidden states.

Source code in speculators/models/dflash/core.py
def __init__(
    self,
    config: DFlashSpeculatorConfig,
) -> None:
    # Forcibly override config settings
    if config.transformer_layer_config._attn_implementation is None:  # noqa: SLF001
        config.transformer_layer_config._attn_implementation = (  # noqa: SLF001
            "simple_flex_attention"
        )
    self._attn_impl = config.transformer_layer_config._attn_implementation  # noqa: SLF001
    self._create_mask_fn = (
        create_block_mask
        if self._attn_impl == "simple_flex_attention"
        else create_mask
    )
    super().__init__(config=config)
    self._init_vocab(config)

    tl_config = config.transformer_layer_config

    # Number of draft layers is encoded in transformer_layer_config
    num_draft_layers = tl_config.num_hidden_layers
    self.layers = nn.ModuleList(
        [
            Qwen3DFlashDecoderLayer(config.transformer_layer_config, layer_idx)  # type: ignore[arg-type]
            for layer_idx in range(num_draft_layers)
        ]
    )
    self.sliding_window = tl_config.sliding_window
    self.sliding_window_indices = [
        i
        for i, layer_type in enumerate(tl_config.layer_types)
        if layer_type == "sliding_attention"
    ]
    self.uses_sliding_window_attn = bool(self.sliding_window_indices)
    self.uses_full_attn = bool(num_draft_layers - len(self.sliding_window_indices))
    self.sliding_window_non_causal = config.sliding_window_non_causal

    self.norm = Qwen3RMSNorm(
        config.transformer_layer_config.hidden_size,
        eps=config.transformer_layer_config.rms_norm_eps,  # type: ignore[arg-type]
    )
    self.rotary_emb = Qwen3RotaryEmbedding(config.transformer_layer_config)  # type: ignore[arg-type]

    self.fc = nn.Linear(
        len(self.target_layer_ids) * config.transformer_layer_config.hidden_size,
        config.transformer_layer_config.hidden_size,
        bias=False,
    )
    self.hidden_norm = Qwen3RMSNorm(
        config.transformer_layer_config.hidden_size,
        eps=config.transformer_layer_config.rms_norm_eps,  # type: ignore[arg-type]
    )
    self.verifier_norm = Qwen3RMSNorm(
        config.transformer_layer_config.hidden_size,
        eps=config.transformer_layer_config.rms_norm_eps,  # type: ignore[arg-type]
    )
    self.verifier_norm.weight.requires_grad = False
    self.block_size = config.block_size
    self.post_init()

target_layer_ids property

target_layer_ids: list[int]

Target layer IDs for auxiliary hidden states.

from_training_args classmethod

from_training_args(
    verifier_config: PretrainedConfig,
    t2d: Tensor | None = None,
    d2t: Tensor | None = None,
    **kwargs,
) -> DFlashDraftModel

Create DFlash model from training arguments.

Args: verifier_config: Verifier model configuration. This should be a config with num_hidden_layers set to the number of DRAFT layers (created by create_transformer_layer_config in train.py). t2d: Target-to-draft vocabulary mapping tensor (optional) d2t: Draft-to-target vocabulary mapping tensor (optional) **kwargs: Training arguments with DFlash-specific params - draft_vocab_size: Size of draft vocabulary - block_size: Block size for draft predictions (default: 8) - max_anchors: Max anchor positions during training (default: 256) - verifier_name_or_path: Path to verifier model

Returns: Initialized DFlashDraftModel

Note: The number of draft layers is encoded in verifier_config.num_hidden_layers, following the same pattern as EAGLE3.

Source code in speculators/models/dflash/core.py
@classmethod
def from_training_args(
    cls,
    verifier_config: "PretrainedConfig",
    t2d: torch.Tensor | None = None,
    d2t: torch.Tensor | None = None,
    **kwargs,
) -> "DFlashDraftModel":
    """Create DFlash model from training arguments.

    Args:
        verifier_config: Verifier model configuration. This should be a config
            with num_hidden_layers set to the number of DRAFT layers (created
            by create_transformer_layer_config in train.py).
        t2d: Target-to-draft vocabulary mapping tensor (optional)
        d2t: Draft-to-target vocabulary mapping tensor (optional)
        **kwargs: Training arguments with DFlash-specific params
            - draft_vocab_size: Size of draft vocabulary
            - block_size: Block size for draft predictions (default: 8)
            - max_anchors: Max anchor positions during training (default: 256)
            - verifier_name_or_path: Path to verifier model

    Returns:
        Initialized DFlashDraftModel

    Note:
        The number of draft layers is encoded in verifier_config.num_hidden_layers,
        following the same pattern as EAGLE3.
    """
    from speculators.config import (  # noqa: PLC0415
        SpeculatorsConfig,
        VerifierConfig,
    )
    from speculators.proposals.greedy import (  # noqa: PLC0415
        GreedyTokenProposalConfig,
    )

    # Resolve target layer IDs if not provided
    target_layer_ids = resolve_target_layer_ids(
        kwargs.get("target_layer_ids"), kwargs["verifier_name_or_path"]
    )

    verifier_config._attn_implementation = kwargs.get(  # noqa: SLF001
        "draft_attn_impl", "simple_flex_attention"
    )

    config = DFlashSpeculatorConfig(
        transformer_layer_config=verifier_config,
        draft_vocab_size=kwargs["draft_vocab_size"],
        block_size=kwargs.get("block_size", 8),
        max_anchors=kwargs.get("max_anchors", 3072),
        aux_hidden_state_layer_ids=target_layer_ids,
        mask_token_id=kwargs.get("mask_token_id"),
        sliding_window_non_causal=kwargs.get("sliding_window_non_causal", False),
        speculators_config=SpeculatorsConfig(
            algorithm="dflash",
            proposal_methods=[
                GreedyTokenProposalConfig(
                    # DFlash first position is anchor position, not used during gen
                    speculative_tokens=kwargs.get("block_size", 8) - 1,
                )
            ],
            default_proposal_method="greedy",
            verifier=VerifierConfig.from_pretrained(
                kwargs["verifier_name_or_path"]
            ),
        ),
    )

    model = cls(config=config)
    model.load_vocab_mappings(t2d, d2t)
    model.load_verifier_weights()
    return model

get_trainer_kwargs staticmethod

get_trainer_kwargs(**kwargs) -> tuple[dict, dict]

Get training and validation kwargs for DFlash.

Args: **kwargs: Training arguments

Returns: Tuple of (train_call_kwargs, val_call_kwargs)

Source code in speculators/models/dflash/core.py
@staticmethod
def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]:
    """Get training and validation kwargs for DFlash.

    Args:
        **kwargs: Training arguments

    Returns:
        Tuple of (train_call_kwargs, val_call_kwargs)
    """
    loss_fn = resolve_loss_fn(kwargs["loss_fn"])
    gamma = kwargs.get("dflash_decay_gamma", 4.0)
    return {"loss_fn": loss_fn, "gamma": gamma}, {
        "loss_fn": loss_fn,
        "gamma": gamma,
    }

DFlashSpeculatorConfig

DFlashSpeculatorConfig(**kwargs)

Bases: SpeculatorModelConfig

Configuration for DFlash speculator with vocabulary mapping.

DFlash features vocabulary mapping between draft (64K) and target (128K) vocabularies, enabling cross-tokenizer speculation.

Parameters:

  • transformer_layer_config

    Configuration for the transformer decoder layer

  • draft_vocab_size

    Size of draft model vocabulary for speculation

Methods:

Attributes:

Source code in speculators/config.py
def __init__(self, **kwargs):
    # initialize the Pydantic arguments first to set all valid fields
    PydanticClassRegistryMixin.__init__(self, **kwargs)

    # reset kwargs handled by Pydantic so PretrainedConfig doesn't override
    for field in self.__class__.model_fields:
        kwargs[field] = getattr(self, field)

    # strip ClassVars so PretrainedConfig.__post_init__ doesn't try to
    # setattr them (pydantic blocks setattr on ClassVar names)
    class_vars = self.__class__.__class_vars__
    for cv in class_vars:
        kwargs.pop(cv, None)

    # initialize the Hugging Face PretrainedConfig arguments for the model
    PretrainedConfig.__init__(self, **kwargs)

    # ensure we always update the transformers version
    self.transformers_version = version("transformers")

target_vocab_size property

target_vocab_size: int

Get target vocabulary size from transformer config.

serialize_transformer_config

serialize_transformer_config(
    value: PretrainedConfig,
) -> dict

Serialize transformer config to dict.

Source code in speculators/models/dflash/config.py
@field_serializer("transformer_layer_config")
def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
    """Serialize transformer config to dict."""
    return value.to_diff_dict()

validate_transformer_config classmethod

validate_transformer_config(value: Any) -> PretrainedConfig

Validate and convert transformer config.

Source code in speculators/models/dflash/config.py
@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
    """Validate and convert transformer config."""
    if isinstance(value, dict):
        config_class: type[PretrainedConfig] = Qwen3Config
        if "model_type" in value:
            config_class = AutoConfig.for_model(
                model_type=value["model_type"]
            ).__class__
        return config_class(**value)
    return value

Eagle3DraftModel

Eagle3DraftModel(config: Eagle3SpeculatorConfig)

Bases: DraftVocabMixin, SpeculatorModel

Methods:

Attributes:

  • target_layer_ids (list[int]) –

    Target layer IDs for auxiliary hidden states.

Source code in speculators/models/eagle3/core.py
def __init__(self, config: Eagle3SpeculatorConfig):
    # Forcibly override config settings
    if config.transformer_layer_config._attn_implementation is None:  # noqa: SLF001
        config.transformer_layer_config._attn_implementation = (  # noqa: SLF001
            "simple_flex_attention"
        )
    self._attn_impl = config.transformer_layer_config._attn_implementation  # noqa: SLF001
    self._create_mask_fn = (
        create_block_mask
        if self._attn_impl == "simple_flex_attention"
        else create_mask
    )
    super().__init__(config=config)
    self._init_vocab(config)

    tl_config = self.config.transformer_layer_config
    self._model_definitions = model_classes[tl_config.model_type]

    # Eagle3-specific: embed_tokens grad depends on config
    self.embed_tokens.weight.requires_grad = self.config.embed_requires_grad

    # FC LAYER
    self.fc = torch.nn.Linear(3 * self.hidden_size, self.hidden_size, bias=False)

    # DECODER LAYERS
    num_layers = tl_config.num_hidden_layers
    fl_class = self._model_definitions.first_layer_class
    dl_class = self._model_definitions.decoder_layer_class
    layers = [
        fl_class(  # first layer
            tl_config,
            layer_idx=0,
            norm_before_residual=self.config.norm_before_residual,
        )
    ]
    layers.extend(  # remaining layers
        [dl_class(tl_config, layer_idx) for layer_idx in range(1, num_layers)]
    )
    self.layers = torch.nn.ModuleList(layers)

    # ROTARY EMBEDDINGS
    # Create a modified config for the rotary embedding to use 2x the hidden size
    modified_tl_config = copy.copy(config.transformer_layer_config)
    modified_tl_config.hidden_size *= 2
    self.rotary_emb = self._model_definitions.rotary_emb_class(modified_tl_config)

    # LAYER NORMS
    norm_class = self._model_definitions.norm_class
    self.norm = norm_class(
        self.hidden_size, eps=config.transformer_layer_config.rms_norm_eps
    )
    self.verifier_norm = norm_class(self.hidden_size, eps=tl_config.rms_norm_eps)
    self.verifier_norm.weight.requires_grad = False

    # Normalize draft path input (gpt-oss only)
    if config.norm_before_fc:
        self.input_norm = self._model_definitions.norm_class(
            3 * self.hidden_size,
            eps=config.transformer_layer_config.rms_norm_eps,
        )
    else:
        self.input_norm = None

    self.post_init()

target_layer_ids property

target_layer_ids: list[int]

Target layer IDs for auxiliary hidden states.

from_training_args classmethod

from_training_args(
    verifier_config: PretrainedConfig,
    t2d: Tensor | None = None,
    d2t: Tensor | None = None,
    **kwargs,
) -> Eagle3DraftModel

Create Eagle3 model from training arguments.

Args: verifier_config: Verifier model configuration **kwargs: Training arguments with Eagle3-specific params - num_layers: Number of decoder layers - norm_before_residual: Whether to normalize before residual connection - t2d: Target-to-draft vocabulary mapping tensor - d2t: Draft-to-target vocabulary mapping tensor - ttt_steps: Number of TTT steps - verifier_name_or_path: Path to verifier model

Returns: Initialized Eagle3DraftModel

Source code in speculators/models/eagle3/core.py
@classmethod
def from_training_args(
    cls,
    verifier_config: PretrainedConfig,
    t2d: torch.Tensor | None = None,
    d2t: torch.Tensor | None = None,
    **kwargs,
) -> "Eagle3DraftModel":
    """Create Eagle3 model from training arguments.

    Args:
        verifier_config: Verifier model configuration
        **kwargs: Training arguments with Eagle3-specific params
            - num_layers: Number of decoder layers
            - norm_before_residual: Whether to normalize before residual connection
            - t2d: Target-to-draft vocabulary mapping tensor
            - d2t: Draft-to-target vocabulary mapping tensor
            - ttt_steps: Number of TTT steps
            - verifier_name_or_path: Path to verifier model

    Returns:
        Initialized Eagle3DraftModel
    """
    # Resolve target layer IDs if not provided
    target_layer_ids = resolve_target_layer_ids(
        kwargs.get("target_layer_ids"), kwargs["verifier_name_or_path"]
    )

    verifier_config._attn_implementation = kwargs.get(  # noqa: SLF001
        "draft_attn_impl", "simple_flex_attention"
    )

    config = Eagle3SpeculatorConfig(
        transformer_layer_config=verifier_config,
        draft_vocab_size=kwargs["draft_vocab_size"],
        norm_before_residual=kwargs["norm_before_residual"],
        norm_before_fc=kwargs.get("norm_before_fc", False),
        embed_requires_grad=kwargs.get("embed_requires_grad", False),
        eagle_aux_hidden_state_layer_ids=target_layer_ids,
        speculators_config=SpeculatorsConfig(
            algorithm="eagle3",
            proposal_methods=[
                GreedyTokenProposalConfig(
                    speculative_tokens=kwargs["ttt_steps"],
                )
            ],
            default_proposal_method="greedy",
            verifier=VerifierConfig.from_pretrained(
                kwargs["verifier_name_or_path"]
            ),
        ),
    )
    model = cls(config=config)
    model.load_vocab_mappings(t2d, d2t)
    model.load_verifier_weights()
    return model

get_trainer_kwargs staticmethod

get_trainer_kwargs(**kwargs) -> tuple[dict, dict]

Get training and validation kwargs for Eagle3.

Args: **kwargs: Training arguments

Returns: Tuple of (train_call_kwargs, val_call_kwargs)

Source code in speculators/models/eagle3/core.py
@staticmethod
def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]:
    """Get training and validation kwargs for Eagle3.

    Args:
        **kwargs: Training arguments

    Returns:
        Tuple of (train_call_kwargs, val_call_kwargs)
    """
    loss_fn = resolve_loss_fn(kwargs["loss_fn"])
    train_kwargs = {
        "use_off_policy_tokens": kwargs["use_off_policy_tokens"],
        "ttt_steps": kwargs["ttt_steps"],
        "ttt_step_loss_decay": kwargs["ttt_step_loss_decay"],
        "loss_fn": loss_fn,
    }
    val_kwargs = {
        "use_off_policy_tokens": False,
        "ttt_steps": kwargs["ttt_steps"],
        "ttt_step_loss_decay": kwargs["ttt_step_loss_decay"],
        "loss_fn": loss_fn,
    }
    return train_kwargs, val_kwargs

Eagle3SpeculatorConfig

Eagle3SpeculatorConfig(**kwargs)

Bases: SpeculatorModelConfig

Configuration for EAGLE-3 speculator with vocabulary mapping.

EAGLE-3 features vocabulary mapping between draft (32K) and target (128K) vocabularies, enabling cross-tokenizer speculation.

Parameters:

  • transformer_layer_config

    Configuration for the transformer decoder layer

  • draft_vocab_size

    Size of draft model vocabulary for speculation

  • norm_before_residual

    Apply hidden_norm before storing residual

Methods:

Attributes:

Source code in speculators/config.py
def __init__(self, **kwargs):
    # initialize the Pydantic arguments first to set all valid fields
    PydanticClassRegistryMixin.__init__(self, **kwargs)

    # reset kwargs handled by Pydantic so PretrainedConfig doesn't override
    for field in self.__class__.model_fields:
        kwargs[field] = getattr(self, field)

    # strip ClassVars so PretrainedConfig.__post_init__ doesn't try to
    # setattr them (pydantic blocks setattr on ClassVar names)
    class_vars = self.__class__.__class_vars__
    for cv in class_vars:
        kwargs.pop(cv, None)

    # initialize the Hugging Face PretrainedConfig arguments for the model
    PretrainedConfig.__init__(self, **kwargs)

    # ensure we always update the transformers version
    self.transformers_version = version("transformers")

target_vocab_size property

target_vocab_size: int

Get target vocabulary size from transformer config.

serialize_transformer_config

serialize_transformer_config(
    value: PretrainedConfig,
) -> dict

Serialize transformer config to dict.

Source code in speculators/models/eagle3/config.py
@field_serializer("transformer_layer_config")
def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
    """Serialize transformer config to dict."""
    return value.to_diff_dict()

validate_transformer_config classmethod

validate_transformer_config(value: Any) -> PretrainedConfig

Validate and convert transformer config.

Source code in speculators/models/eagle3/config.py
@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
    """Validate and convert transformer config."""
    if isinstance(value, dict):
        config_class: type[PretrainedConfig] = Qwen3Config
        if "model_type" in value:
            config_class = AutoConfig.for_model(
                model_type=value["model_type"]
            ).__class__
        return config_class(**value)
    return value

MTPDraftModel

MTPDraftModel(config: MTPSpeculatorConfig)

Bases: DraftVocabMixin, SpeculatorModel

MTP speculator model for multi-token prediction.

Predicts multiple future tokens (default: 3) per forward pass using a single layer with weighted multi-step loss for training.

embed_tokens and lm_head are managed by DraftVocabMixin — initialized to NaN, populated via load_verifier_weights() (called automatically by from_pretrained), and excluded from saved checkpoints. verifier_lm_head is created by DraftVocabMixin but not used in the MTP forward pass.

Methods:

  • forward

    Forward pass for MTP multi-token prediction (teacher-forced).

  • get_trainer_kwargs

    Get training and validation kwargs for MTP.

  • load_verifier_weights

    Re-set NaN sentinel before loading — meta-device init may clear

Attributes:

  • layers (ModuleList) –

    Expose mtp_layers for FSDP wrapping compatibility.

  • target_layer_ids (list[int]) –

    MTP only uses the last hidden layer (verifier_last_hidden_states).

Source code in speculators/models/mtp/core.py
def __init__(self, config: MTPSpeculatorConfig) -> None:
    if config.transformer_layer_config._attn_implementation is None:  # noqa: SLF001
        config.transformer_layer_config._attn_implementation = "eager"  # noqa: SLF001
    super().__init__(config=config)
    self._init_vocab(config)
    if self.use_draft_vocab:
        raise NotImplementedError(
            "Vocab reduction is not supported for MTP speculators"
        )

    tc = config.transformer_layer_config
    self._model_definitions = mtp_model_classes[resolve_model_type(tc.model_type)]
    self.mtp_layers = nn.ModuleList(
        [self._model_definitions.first_layer_class(tc, layer_idx=0)]
    )
    self.rotary_emb = self._model_definitions.rotary_emb_class(tc)

    self.post_init()

layers property

layers: ModuleList

Expose mtp_layers for FSDP wrapping compatibility.

target_layer_ids property

target_layer_ids: list[int]

MTP only uses the last hidden layer (verifier_last_hidden_states).

forward

forward(
    input_ids: Tensor,
    hidden_states: Tensor,
    attention_mask: Tensor | None = None,
    position_ids: Tensor | None = None,
    loss_mask: Tensor | None = None,
    step_weights: list[float] | None = None,
    return_dict: bool = True,
    **kwargs: Any,
) -> tuple

Forward pass for MTP multi-token prediction (teacher-forced).

At step k, uses ground-truth input_ids[t+k+1] as the embedding input and the MTP output from step k-1 (or verifier hidden states for step 0) as the hidden state input. Hidden states are passed recursively: each step's MTP output feeds the next step.

Targets are derived from input_ids via per-step offset slicing -- no separate label tensor is needed. Use loss_mask to exclude positions (e.g. prompt tokens) from the loss.

Parameters:

  • input_ids

    (Tensor) –

    Token IDs [batch, seq_len]. Serves as both the embedding source and the prediction target (offset by step+2).

  • hidden_states

    (Tensor) –

    Hidden states from verifier [batch, seq_len, hidden_size]

  • attention_mask

    (Tensor | None, default: None ) –

    Optional attention mask [batch, seq_len]

  • position_ids

    (Tensor | None, default: None ) –

    Optional position IDs [batch, seq_len]

  • loss_mask

    (Tensor | None, default: None ) –

    Optional binary mask [batch, seq_len]; 1=compute loss, 0=ignore.

  • step_weights

    (list[float] | None, default: None ) –

    Per-step loss weights (None = uniform). Training only.

  • return_dict

    (bool, default: True ) –

    Unused, kept for interface compatibility.

  • kwargs

    (Any, default: {} ) –

    Absorbs unexpected batch keys (lengths, verifier_last_hidden_states)

Returns:

  • tuple

    Tuple of (logits_list, loss, metrics)

Source code in speculators/models/mtp/core.py
@conditional_torch_compile
def forward(
    self,
    input_ids: torch.Tensor,
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    position_ids: torch.Tensor | None = None,
    loss_mask: torch.Tensor | None = None,
    step_weights: list[float] | None = None,
    return_dict: bool = True,  # noqa: ARG002
    **kwargs: Any,  # noqa: ARG002
) -> tuple:
    """Forward pass for MTP multi-token prediction (teacher-forced).

    At step k, uses ground-truth input_ids[t+k+1] as the embedding input and
    the MTP output from step k-1 (or verifier hidden states for step 0) as the
    hidden state input. Hidden states are passed recursively: each step's MTP
    output feeds the next step.

    Targets are derived from input_ids via per-step offset slicing -- no
    separate label tensor is needed. Use loss_mask to exclude positions
    (e.g. prompt tokens) from the loss.

    :param input_ids: Token IDs [batch, seq_len]. Serves as both the
        embedding source and the prediction target (offset by step+2).
    :param hidden_states: Hidden states from verifier [batch, seq_len, hidden_size]
    :param attention_mask: Optional attention mask [batch, seq_len]
    :param position_ids: Optional position IDs [batch, seq_len]
    :param loss_mask: Optional binary mask [batch, seq_len]; 1=compute loss,
        0=ignore.
    :param step_weights: Per-step loss weights (None = uniform). Training only.
    :param return_dict: Unused, kept for interface compatibility.
    :param kwargs: Absorbs unexpected batch keys
        (lengths, verifier_last_hidden_states)
    :return: Tuple of (logits_list, loss, metrics)
    """
    input_ids = input_ids.long()
    device = input_ids.device
    batch_size, seq_len = input_ids.shape
    num_steps = self.config.num_speculative_steps

    if step_weights is not None and len(step_weights) != num_steps:
        raise ValueError(
            f"step_weights has {len(step_weights)} entries but "
            f"num_speculative_steps={num_steps}; expected exactly "
            f"{num_steps} weights."
        )

    if position_ids is None:
        position_ids = (
            torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
        )

    all_logits: list[torch.Tensor] = []
    total_loss = torch.tensor(0.0, device=device)
    metrics: dict[str, float | torch.Tensor] = {}

    # Uniform valid_len keeps tensor shapes identical across loop
    # iterations, which torch.compile requires for stable codegen.
    # Cap steps so short sequences still produce partial results.
    effective_steps = min(num_steps, max(0, seq_len - 2))
    valid_len = seq_len - effective_steps - 1
    if valid_len <= 0 or effective_steps == 0:
        metrics["loss_sum"] = total_loss.detach().clone()
        metrics["loss_total"] = torch.tensor(1.0, device=device)
        return (all_logits, total_loss, metrics)

    step_pos_ids = position_ids[:, :valid_len]
    causal_mask = create_causal_mask(
        config=self.config.transformer_layer_config,
        inputs_embeds=hidden_states[:, :valid_len],
        attention_mask=attention_mask,
        past_key_values=None,
        position_ids=step_pos_ids,
    )

    current_hidden = hidden_states
    for step in range(effective_steps):
        step_hidden = current_hidden[:, :valid_len]
        step_embeds = self.embed_tokens(
            input_ids[:, step + 1 : step + 1 + valid_len]
        )
        step_pos_emb = self.rotary_emb(step_hidden, step_pos_ids)

        mtp_output = self.mtp_layers[0](
            hidden_states=step_hidden,
            token_embeddings=step_embeds,
            attention_mask=causal_mask,
            position_ids=step_pos_ids,
            position_embeddings=step_pos_emb,
        )

        logits = self.lm_head(mtp_output)
        all_logits.append(logits)

        step_targets = input_ids[:, step + 2 : step + 2 + valid_len]
        if loss_mask is not None:
            step_mask = loss_mask[:, step + 2 : step + 2 + valid_len]
            step_targets = step_targets.clone()
            step_targets[step_mask == 0] = _IGNORE_INDEX
        weight = step_weights[step] if step_weights is not None else 1.0
        unreduced = nn.functional.cross_entropy(
            logits.permute(0, 2, 1),
            step_targets,
            ignore_index=_IGNORE_INDEX,
            reduction="none",
        )
        valid_count = (step_targets != _IGNORE_INDEX).sum()
        step_loss = weight * unreduced.sum() / valid_count.clamp(min=1)
        total_loss = total_loss + step_loss
        metrics[f"loss_step_{step}"] = step_loss.detach().clone()

        current_hidden = mtp_output

    metrics["loss_sum"] = total_loss.detach().clone()
    metrics["loss_total"] = torch.tensor(1.0, device=device)

    return (all_logits, total_loss, metrics)

get_trainer_kwargs staticmethod

get_trainer_kwargs(**kwargs) -> tuple[dict, dict]

Get training and validation kwargs for MTP.

Step weights are computed from step_weight_beta and num_speculative_steps using the normalized exponential-decay formula from FastMTP (arXiv:2509.18362), Equation 2.

Pass step_weights to override the computed weights.

Source code in speculators/models/mtp/core.py
@staticmethod
def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]:
    """Get training and validation kwargs for MTP.

    Step weights are computed from ``step_weight_beta`` and
    ``num_speculative_steps`` using the normalized exponential-decay
    formula from FastMTP (arXiv:2509.18362), Equation 2.

    Pass ``step_weights`` to override the computed weights.
    """
    step_weights = kwargs.get("step_weights")
    if step_weights is None:
        if "num_speculative_steps" not in kwargs:
            raise ValueError(
                "num_speculative_steps must be set from the model config "
                "before calling get_trainer_kwargs"
            )
        step_weights = compute_step_weights(
            beta=kwargs.get("step_weight_beta", 0.6),
            num_steps=kwargs["num_speculative_steps"],
        )
    train_kwargs: dict[str, Any] = {"step_weights": step_weights}
    val_kwargs = train_kwargs.copy()

    return train_kwargs, val_kwargs

load_verifier_weights

load_verifier_weights() -> None

Re-set NaN sentinel before loading — meta-device init may clear it. Deletes verifier_lm_head after loading since MTP does not use it.

Source code in speculators/models/mtp/core.py
def load_verifier_weights(self) -> None:
    """Re-set NaN sentinel before loading — meta-device init may clear
    it. Deletes verifier_lm_head after loading since MTP does not use it.
    """
    with torch.no_grad():
        self.embed_tokens.weight.fill_(torch.nan)
        self.lm_head.weight.fill_(torch.nan)
    super().load_verifier_weights()
    del self.verifier_lm_head

MTPSpeculatorConfig

MTPSpeculatorConfig(**kwargs)

Bases: SpeculatorModelConfig

Configuration for MTP (Multi-Token Prediction) speculator.

Architecture: a single MTP layer with attention and MLP, combining verifier hidden states with token embeddings via an explicit input projection. embed_tokens and lm_head share the verifier's full vocabulary.

Parameters:

  • transformer_layer_config

    Configuration for the underlying transformer architecture (e.g., Qwen2Config). All architecture dimensions are derived from this config.

  • num_nextn_predict_layers

    Number of MTP prediction heads in the checkpoint. vLLM reads this field directly to instantiate the correct number of MTP head instances. Currently only 1 is supported.

Source code in speculators/config.py
def __init__(self, **kwargs):
    # initialize the Pydantic arguments first to set all valid fields
    PydanticClassRegistryMixin.__init__(self, **kwargs)

    # reset kwargs handled by Pydantic so PretrainedConfig doesn't override
    for field in self.__class__.model_fields:
        kwargs[field] = getattr(self, field)

    # strip ClassVars so PretrainedConfig.__post_init__ doesn't try to
    # setattr them (pydantic blocks setattr on ClassVar names)
    class_vars = self.__class__.__class_vars__
    for cv in class_vars:
        kwargs.pop(cv, None)

    # initialize the Hugging Face PretrainedConfig arguments for the model
    PretrainedConfig.__init__(self, **kwargs)

    # ensure we always update the transformers version
    self.transformers_version = version("transformers")

PEagleDraftModel

PEagleDraftModel(config: PEagleSpeculatorConfig)

Bases: Eagle3DraftModel

P-EAGLE (Parallel EAGLE) draft model for speculative decoding.

P-EAGLE extends EAGLE-3 with parallel multi-token prediction using Conditional-On-Distribution (COD) sampling for memory-efficient training.

Methods:

  • forward

    Forward pass for P-EAGLE model training with parallel group prediction.

  • from_training_args

    Create P-EAGLE model from training arguments.

  • get_trainer_kwargs

    Get training and validation kwargs for P-EAGLE.

Source code in speculators/models/peagle/core.py
def __init__(
    self,
    config: PEagleSpeculatorConfig,
):
    super().__init__(config=config)

    self.num_depths = config.num_depths
    self.down_sample_ratio = config.down_sample_ratio
    self.down_sample_ratio_min = config.down_sample_ratio_min
    self.mask_token_id = config.mask_token_id

    # Learnable mask_hidden parameter for padding unsampled positions
    self.mask_hidden = torch.nn.Parameter(torch.randn(1, 1, 3 * self.hidden_size))

forward

forward(
    hidden_states: Tensor,
    input_ids: Tensor,
    document_ids: Tensor,
    position_ids: Tensor | None = None,
    loss_mask: Tensor | None = None,
    verifier_last_hidden_states: Tensor | None = None,
    loss_fn=kl_div_loss,
    **kwargs,
)

Forward pass for P-EAGLE model training with parallel group prediction.

Args: hidden_states: Verifier hidden states [batch, seq_len, 3*hidden_size] input_ids: Input token IDs [batch, seq_len] document_ids: Document IDs [1, seq_len], maps positions to doc index, pad -1 position_ids: Position IDs [batch, seq_len] (optional) loss_mask: Loss mask for which tokens to compute loss on [batch, seq_len] verifier_last_hidden_states: Verifier final hidden states for targets [batch, seq_len, hidden_size]

Returns: Tuple of (draft_tokens, loss, metrics)

Source code in speculators/models/peagle/core.py
@conditional_torch_compile
def forward(
    self,
    hidden_states: torch.Tensor,
    input_ids: torch.Tensor,
    document_ids: torch.Tensor,
    position_ids: torch.Tensor | None = None,
    loss_mask: torch.Tensor | None = None,
    verifier_last_hidden_states: torch.Tensor | None = None,
    loss_fn=kl_div_loss,
    **kwargs,
):
    """
    Forward pass for P-EAGLE model training with parallel group prediction.

    Args:
        hidden_states: Verifier hidden states [batch, seq_len, 3*hidden_size]
        input_ids: Input token IDs [batch, seq_len]
        document_ids: Document IDs [1, seq_len], maps positions to doc index, pad -1
        position_ids: Position IDs [batch, seq_len] (optional)
        loss_mask: Loss mask for which tokens to compute loss on
            [batch, seq_len]
        verifier_last_hidden_states: Verifier final hidden states for
            targets [batch, seq_len, hidden_size]

    Returns:
        Tuple of (draft_tokens, loss, metrics)
    """
    if verifier_last_hidden_states is None:
        raise ValueError("verifier_last_hidden_states required for training")

    device = hidden_states.device
    seq_length = input_ids.shape[1]

    if loss_mask is None:
        loss_mask = torch.ones_like(input_ids, dtype=torch.float32)

    # Generate COD sampling indices
    anchor_pos, depth = generate_cod_sample_indices(
        seq_length=seq_length,
        loss_mask=loss_mask,
        num_depths=self.num_depths,
        down_sample_ratio=self.down_sample_ratio,
        down_sample_ratio_min=self.down_sample_ratio_min,
    )
    total_sampled = anchor_pos.shape[0]

    orig_positions = anchor_pos + depth
    is_depth_0 = depth == 0  # [total_sampled]

    # Build sampled input_ids: real tokens for depth 0, mask for others
    sampled_ids = torch.where(
        is_depth_0,
        input_ids[0, orig_positions],
        torch.tensor(self.mask_token_id, dtype=input_ids.dtype, device=device),
    ).unsqueeze(0)  # [1, total_sampled]
    inputs_embeds = self.embed_tokens(sampled_ids).to(
        hidden_states.dtype
    )  # [1, total_sampled, hidden_size]

    # Build sampled hidden states: real for depth 0, mask_hidden for others
    mask_hidden = self.mask_hidden.to(device=device, dtype=hidden_states.dtype)
    sampled_hidden = torch.where(
        is_depth_0.unsqueeze(-1),
        hidden_states[0, orig_positions],
        mask_hidden.squeeze(0).expand(orig_positions.shape[0], -1),
    ).unsqueeze(0)  # [1, total_sampled, 3*hidden_size]

    # Project concatenated hidden states (3*hidden_size) -> hidden_size
    sampled_hidden = self.fc(sampled_hidden)  # [1, total_sampled, hidden_size]

    layer_input = torch.cat(
        [inputs_embeds, sampled_hidden], dim=-1
    )  # [1, total_sampled, 2*hidden_size]

    position_ids = orig_positions.unsqueeze(0)  # [1, total_sampled]

    position_embeddings = self.rotary_emb(layer_input, position_ids)

    mask_mod = create_peagle_mask_mod(
        anchor_pos=anchor_pos,
        depth=depth,
        document_ids=document_ids.squeeze(0).to(device),
    )

    attention_mask = self._create_mask_fn(
        mask_mod,
        B=None,
        H=None,
        Q_LEN=total_sampled,
        KV_LEN=total_sampled,
        device=device,
    )

    hidden_states = layer_input
    for layer in self.layers:
        hidden_states = layer(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            position_embeddings=position_embeddings,
            **kwargs,
        )

    logits = self.lm_head(
        self.norm(hidden_states)
    )  # [1, total_sampled, vocab_size]

    with torch.no_grad():
        targets = self.verifier_lm_head(
            self.verifier_norm(verifier_last_hidden_states)
        )

    targets = targets[:, orig_positions, :]  # [1, total_sampled, vocab_size]

    loss, metrics = compute_metrics(
        logits=logits,
        targets=targets,
        loss_mask=loss_mask,
        anchor_pos=anchor_pos,
        depth=depth,
        num_depths=self.num_depths,
        loss_fn=loss_fn,
    )

    return None, loss, metrics

from_training_args classmethod

from_training_args(
    verifier_config: PretrainedConfig,
    t2d: Tensor | None = None,
    d2t: Tensor | None = None,
    **kwargs,
) -> PEagleDraftModel

Create P-EAGLE model from training arguments.

Args: verifier_config: Verifier model configuration **kwargs: Training arguments with P-EAGLE-specific params - draft_vocab_size: Size of draft vocabulary - norm_before_residual: Whether to normalize before residual - num_depths: Number of parallel groups (default 8) - down_sample_ratio: COD sampling ratio (default 0.7) - down_sample_ratio_min: Minimum sampling ratio (default 0.2) - mask_token_id: Mask token ID - t2d: Target-to-draft vocabulary mapping - d2t: Draft-to-target vocabulary mapping - verifier_name_or_path: Path to verifier model

Returns: Initialized PEagleDraftModel

Source code in speculators/models/peagle/core.py
@classmethod
def from_training_args(
    cls,
    verifier_config: PretrainedConfig,
    t2d: torch.Tensor | None = None,
    d2t: torch.Tensor | None = None,
    **kwargs,
) -> "PEagleDraftModel":
    """
    Create P-EAGLE model from training arguments.

    Args:
        verifier_config: Verifier model configuration
        **kwargs: Training arguments with P-EAGLE-specific params
            - draft_vocab_size: Size of draft vocabulary
            - norm_before_residual: Whether to normalize before residual
            - num_depths: Number of parallel groups (default 8)
            - down_sample_ratio: COD sampling ratio (default 0.7)
            - down_sample_ratio_min: Minimum sampling ratio (default 0.2)
            - mask_token_id: Mask token ID
            - t2d: Target-to-draft vocabulary mapping
            - d2t: Draft-to-target vocabulary mapping
            - verifier_name_or_path: Path to verifier model

    Returns:
        Initialized PEagleDraftModel
    """
    # Resolve target layer IDs if not provided
    target_layer_ids = resolve_target_layer_ids(
        kwargs.get("target_layer_ids"), kwargs["verifier_name_or_path"]
    )

    verifier_config._attn_implementation = kwargs.get(  # noqa: SLF001
        "draft_attn_impl", "simple_flex_attention"
    )

    config = PEagleSpeculatorConfig(
        transformer_layer_config=verifier_config,
        draft_vocab_size=kwargs["draft_vocab_size"],
        norm_before_residual=kwargs.get("norm_before_residual", False),
        eagle_aux_hidden_state_layer_ids=target_layer_ids,
        num_depths=kwargs.get("num_depths", 8),
        down_sample_ratio=kwargs.get("down_sample_ratio", 0.7),
        down_sample_ratio_min=kwargs.get("down_sample_ratio_min", 0.2),
        mask_token_id=kwargs.get("mask_token_id"),
        speculators_config=SpeculatorsConfig(
            algorithm="peagle",
            proposal_methods=[
                GreedyTokenProposalConfig(
                    speculative_tokens=kwargs.get("num_depths", 8),
                )
            ],
            default_proposal_method="greedy",
            verifier=VerifierConfig.from_pretrained(
                kwargs["verifier_name_or_path"]
            ),
        ),
    )

    model = cls(config=config)
    model.load_vocab_mappings(t2d, d2t)
    model.load_verifier_weights()
    return model

get_trainer_kwargs staticmethod

get_trainer_kwargs(**kwargs) -> tuple[dict, dict]

Get training and validation kwargs for P-EAGLE.

Args: **kwargs: Training arguments

Returns: Tuple of (train_call_kwargs, val_call_kwargs)

Source code in speculators/models/peagle/core.py
@staticmethod
def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]:
    """
    Get training and validation kwargs for P-EAGLE.

    Args:
        **kwargs: Training arguments

    Returns:
        Tuple of (train_call_kwargs, val_call_kwargs)
    """
    loss_fn = resolve_loss_fn(kwargs["loss_fn"])
    return {"loss_fn": loss_fn}, {"loss_fn": loss_fn}

PEagleSpeculatorConfig

PEagleSpeculatorConfig(**kwargs)

Bases: Eagle3SpeculatorConfig

Configuration for P-EAGLE (Parallel EAGLE) speculator.

P-EAGLE extends EAGLE-3 with parallel multi-token prediction using Conditional Drop Token (COD) sampling for memory-efficient training.

Parameters:

  • num_depths

    Number of parallel prediction groups (typically 8)

  • down_sample_ratio

    Geometric decay ratio for COD sampling (r in [0,1])

  • down_sample_ratio_min

    Minimum retention ratio floor

  • mask_token_id

    Token ID used for masking

Source code in speculators/config.py
def __init__(self, **kwargs):
    # initialize the Pydantic arguments first to set all valid fields
    PydanticClassRegistryMixin.__init__(self, **kwargs)

    # reset kwargs handled by Pydantic so PretrainedConfig doesn't override
    for field in self.__class__.model_fields:
        kwargs[field] = getattr(self, field)

    # strip ClassVars so PretrainedConfig.__post_init__ doesn't try to
    # setattr them (pydantic blocks setattr on ClassVar names)
    class_vars = self.__class__.__class_vars__
    for cv in class_vars:
        kwargs.pop(cv, None)

    # initialize the Hugging Face PretrainedConfig arguments for the model
    PretrainedConfig.__init__(self, **kwargs)

    # ensure we always update the transformers version
    self.transformers_version = version("transformers")