Skip to content

speculators

Speculators: A Unified Library for Speculative Decoding Algorithms for LLMs

Speculators provides a standardized framework for creating, representing, and storing speculative decoding algorithms for large language model (LLM) inference. It enables developers to implement and productize various speculative decoding approaches with a consistent interface, making them ready for integration with LLM inference servers like vLLM.

Speculative decoding is a technique that can significantly improve LLM inference performance by predicting multiple tokens with a smaller, speculative model and then verifying the predictions with the original, larger model. This approach tradesoff extra computation for reduced latency, making it suitable for real-time applications on deployments that are not compute-constrained.

The library offers a modular architecture with components for: - Standardized interfaces for working with speculative decoding algorithms that build on top of Transformers pathways for simple integration. - Centralized definition, configuration, and validation of speculative decoding algorithms.

Modules:

  • config

    Configuration classes for Speculators library.

  • convert

    Checkpoint conversion utilities for Speculators.

  • data_generation

    Data generation utilities for EAGLE-style speculative decoding training.

  • model

    Base model classes for the Speculators library.

  • models
  • proposals
  • train
  • utils

Classes:

Functions:

  • reload_schemas

    Automatically populates the registry for all PydanticClassRegistryMixin subclasses

Eagle3DraftModel

Eagle3DraftModel(config: Eagle3SpeculatorConfig)

Bases: DraftVocabMixin, SpeculatorModel

Methods:

Source code in speculators/models/eagle3/core.py
def __init__(self, config: Eagle3SpeculatorConfig):
    # Forcibly override config settings
    impl = "simple_flex_attention"
    config.transformer_layer_config._attn_implementation = impl  # noqa: SLF001
    super().__init__(config=config)
    self._init_vocab(config)

    tl_config = self.config.transformer_layer_config
    self._model_definitions = model_classes[tl_config.model_type]

    # Eagle3-specific: embed_tokens grad depends on config
    self.embed_tokens.weight.requires_grad = self.config.embed_requires_grad

    # FC LAYER
    self.fc = torch.nn.Linear(3 * self.hidden_size, self.hidden_size, bias=False)

    # DECODER LAYERS
    num_layers = tl_config.num_hidden_layers
    fl_class = self._model_definitions.first_layer_class
    dl_class = self._model_definitions.decoder_layer_class
    layers = [
        fl_class(  # first layer
            tl_config,
            layer_idx=0,
            norm_before_residual=self.config.norm_before_residual,
        )
    ]
    layers.extend(  # remaining layers
        [dl_class(tl_config, layer_idx) for layer_idx in range(1, num_layers)]
    )
    self.layers = torch.nn.ModuleList(layers)

    # ROTARY EMBEDDINGS
    # Create a modified config for the rotary embedding to use 2x the hidden size
    modified_tl_config = copy.copy(config.transformer_layer_config)
    modified_tl_config.hidden_size *= 2
    self.rotary_emb = self._model_definitions.rotary_emb_class(modified_tl_config)

    # LAYER NORMS
    norm_class = self._model_definitions.norm_class
    self.norm = norm_class(
        self.hidden_size, eps=config.transformer_layer_config.rms_norm_eps
    )
    self.verifier_norm = norm_class(self.hidden_size, eps=tl_config.rms_norm_eps)
    self.verifier_norm.weight.requires_grad = False

    # Normalize draft path input (gpt-oss only)
    if config.norm_before_fc:
        self.input_norm = self._model_definitions.norm_class(
            3 * self.hidden_size,
            eps=config.transformer_layer_config.rms_norm_eps,
        )
    else:
        self.input_norm = None

    self.post_init()

from_training_args classmethod

from_training_args(
    verifier_config: PretrainedConfig,
    t2d: Tensor | None = None,
    d2t: Tensor | None = None,
    **kwargs,
) -> Eagle3DraftModel

Create Eagle3 model from training arguments.

Args: verifier_config: Verifier model configuration **kwargs: Training arguments with Eagle3-specific params - num_layers: Number of decoder layers - norm_before_residual: Whether to normalize before residual connection - t2d: Target-to-draft vocabulary mapping tensor - d2t: Draft-to-target vocabulary mapping tensor - ttt_steps: Number of TTT steps - verifier_name_or_path: Path to verifier model

Returns: Initialized Eagle3DraftModel

Source code in speculators/models/eagle3/core.py
@classmethod
def from_training_args(
    cls,
    verifier_config: PretrainedConfig,
    t2d: torch.Tensor | None = None,
    d2t: torch.Tensor | None = None,
    **kwargs,
) -> "Eagle3DraftModel":
    """Create Eagle3 model from training arguments.

    Args:
        verifier_config: Verifier model configuration
        **kwargs: Training arguments with Eagle3-specific params
            - num_layers: Number of decoder layers
            - norm_before_residual: Whether to normalize before residual connection
            - t2d: Target-to-draft vocabulary mapping tensor
            - d2t: Draft-to-target vocabulary mapping tensor
            - ttt_steps: Number of TTT steps
            - verifier_name_or_path: Path to verifier model

    Returns:
        Initialized Eagle3DraftModel
    """
    target_layer_ids = resolve_target_layer_ids(
        kwargs.get("target_layer_ids"),
        kwargs["verifier_name_or_path"],
    )

    config = Eagle3SpeculatorConfig(
        transformer_layer_config=verifier_config,
        draft_vocab_size=kwargs["draft_vocab_size"],
        norm_before_residual=kwargs["norm_before_residual"],
        norm_before_fc=kwargs.get("norm_before_fc", False),
        embed_requires_grad=kwargs.get("embed_requires_grad", False),
        eagle_aux_hidden_state_layer_ids=target_layer_ids,
        speculators_config=SpeculatorsConfig(
            algorithm="eagle3",
            proposal_methods=[
                GreedyTokenProposalConfig(
                    speculative_tokens=kwargs["ttt_steps"],
                )
            ],
            default_proposal_method="greedy",
            verifier=VerifierConfig.from_config(
                verifier_config, name_or_path=kwargs["verifier_name_or_path"]
            ),
        ),
    )
    model = cls(config=config)
    model.load_vocab_mappings(t2d, d2t)
    model.load_verifier_weights()
    return model

get_trainer_kwargs staticmethod

get_trainer_kwargs(**kwargs) -> tuple[dict, dict]

Get training and validation kwargs for Eagle3.

Args: **kwargs: Training arguments

Returns: Tuple of (train_call_kwargs, val_call_kwargs)

Source code in speculators/models/eagle3/core.py
@staticmethod
def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]:
    """Get training and validation kwargs for Eagle3.

    Args:
        **kwargs: Training arguments

    Returns:
        Tuple of (train_call_kwargs, val_call_kwargs)
    """
    loss_fn = resolve_loss_fn(kwargs["loss_fn"])
    train_kwargs = {
        "use_off_policy_tokens": kwargs["use_off_policy_tokens"],
        "ttt_steps": kwargs["ttt_steps"],
        "ttt_step_loss_decay": kwargs["ttt_step_loss_decay"],
        "loss_fn": loss_fn,
    }
    val_kwargs = {
        "use_off_policy_tokens": False,
        "ttt_steps": kwargs["ttt_steps"],
        "ttt_step_loss_decay": kwargs["ttt_step_loss_decay"],
        "loss_fn": loss_fn,
    }
    return train_kwargs, val_kwargs

Eagle3SpeculatorConfig

Eagle3SpeculatorConfig(**kwargs)

Bases: SpeculatorModelConfig

Configuration for EAGLE-3 speculator with vocabulary mapping.

EAGLE-3 features vocabulary mapping between draft (32K) and target (128K) vocabularies, enabling cross-tokenizer speculation.

Parameters:

  • transformer_layer_config

    Configuration for the transformer decoder layer

  • draft_vocab_size

    Size of draft model vocabulary for speculation

  • norm_before_residual

    Apply hidden_norm before storing residual

Methods:

Attributes:

Source code in speculators/config.py
def __init__(self, **kwargs):
    # initialize the Pydantic arguments first to set all valid fields
    PydanticClassRegistryMixin.__init__(self, **kwargs)

    # reset kwargs handled by Pydantic so PretrainedConfig doesn't override
    for field in self.__class__.model_fields:
        kwargs[field] = getattr(self, field)

    # strip ClassVars so PretrainedConfig.__post_init__ doesn't try to
    # setattr them (pydantic blocks setattr on ClassVar names)
    class_vars = self.__class__.__class_vars__
    for cv in class_vars:
        kwargs.pop(cv, None)

    # initialize the Hugging Face PretrainedConfig arguments for the model
    PretrainedConfig.__init__(self, **kwargs)

    # ensure we always update the transformers version
    self.transformers_version = version("transformers")

target_vocab_size property

target_vocab_size: int

Get target vocabulary size from transformer config.

serialize_transformer_config

serialize_transformer_config(
    value: PretrainedConfig,
) -> dict

Serialize transformer config to dict.

Source code in speculators/models/eagle3/config.py
@field_serializer("transformer_layer_config")
def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
    """Serialize transformer config to dict."""
    return value.to_diff_dict()

validate_transformer_config classmethod

validate_transformer_config(value: Any) -> PretrainedConfig

Validate and convert transformer config.

Source code in speculators/models/eagle3/config.py
@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
    """Validate and convert transformer config."""
    if isinstance(value, dict):
        config_class: type[PretrainedConfig] = LlamaConfig
        if "model_type" in value:
            config_class = AutoConfig.for_model(
                model_type=value["model_type"]
            ).__class__
        return config_class(**value)
    return value

SpeculatorModel

SpeculatorModel(config: SpeculatorModelConfig, **kwargs)

Bases: ClassRegistryMixin, PreTrainedModel

Abstract base class for all speculator models in the Speculators library.

This class provides the foundation for implementing speculative decoding models that can generate candidate tokens to be verified by a base verifier model. It combines the functionality of Hugging Face's PreTrainedModel and GenerationMixin with automatic model registration and discovery capabilities. All concrete speculator model implementations must inherit from this class, register with SpeculatorModel.register(NAME), and implement the abstract forward method.

Example:

# Load a speculator model with automatic class resolution
model = SpeculatorModel.from_pretrained("path/to/speculator")

Initialize a SpeculatorModel instance.

Parameters:

  • config

    (SpeculatorModelConfig) –

    The configuration for the speculator model. Must be a SpeculatorModelConfig instance containing model hyperparameters and speculative decoding settings.

  • kwargs

    Additional keyword arguments passed to the parent PreTrainedModel constructor.

Methods:

Source code in speculators/model.py
def __init__(self, config: SpeculatorModelConfig, **kwargs):
    """
    Initialize a SpeculatorModel instance.

    :param config: The configuration for the speculator model. Must be a
        SpeculatorModelConfig instance containing model hyperparameters and
        speculative decoding settings.
    :param kwargs: Additional keyword arguments passed to the parent
        PreTrainedModel constructor.
    """
    if not config:
        raise ValueError(
            "Config must be provided to initialize a SpeculatorModel. "
            "Use SpeculatorModelConfig to create a valid configuration."
        )

    if not isinstance(config, SpeculatorModelConfig):
        raise TypeError(
            f"Expected config to be an instance of SpeculatorModelConfig, "
            f"got {type(config)} {config}."
        )

    config.tie_word_embeddings = False
    super().__init__(config, **kwargs)
    self.config: SpeculatorModelConfig = config

from_pretrained classmethod

from_pretrained(
    pretrained_model_name_or_path: str | PathLike | None,
    *model_args,
    config: PretrainedConfig | str | PathLike | None = None,
    cache_dir: str | PathLike | None = None,
    ignore_mismatched_sizes: bool = False,
    force_download: bool = False,
    local_files_only: bool = False,
    token: str | bool | None = None,
    revision: str = "main",
    use_safetensors: bool | None = None,
    weights_only: bool = True,
    t2d: Tensor | None = None,
    d2t: Tensor | None = None,
    **kwargs,
) -> SpeculatorModel

Load a pretrained speculator model from the Hugging Face Hub or local directory.

This method automatically resolves the correct speculator model class based on the configuration type and loads the model with the appropriate weights. If called on the base SpeculatorModel class, it will automatically determine and instantiate the correct subclass based on the model configuration.

Example:

# Load with automatic class resolution
model = SpeculatorModel.from_pretrained("RedHatAI/speculator-llama-7b")

# Load from local directory
model = SpeculatorModel.from_pretrained("./my_speculator")

# Load with custom config
config = SpeculatorModelConfig.from_pretrained("RedHatAI/eagle-llama-7b")
model = SpeculatorModel.from_pretrained(
    None, config=config, state_dict=state_dict
)

Parameters:

  • pretrained_model_name_or_path

    (str | PathLike | None) –

    The model identifier on Hugging Face Hub, or path to a local directory containing the model files. Can be None if config is provided as a path.

  • model_args

    Additional positional arguments passed to the model constructor.

  • config

    (PretrainedConfig | str | PathLike | None, default: None ) –

    Optional configuration for the model. Can be a SpeculatorModelConfig instance, a path to a config file, or None to load from model directory.

  • cache_dir

    (str | PathLike | None, default: None ) –

    Directory to cache downloaded files. If None, uses default transformers cache directory.

  • ignore_mismatched_sizes

    (bool, default: False ) –

    Whether to ignore size mismatches when loading pretrained weights. Useful for loading models with different architectures.

  • force_download

    (bool, default: False ) –

    Whether to force re-download of model files even if they exist in cache.

  • local_files_only

    (bool, default: False ) –

    Whether to avoid downloading files and only use local cached files. Raises an error if files are not found locally.

  • token

    (str | bool | None, default: None ) –

    Optional authentication token for accessing private models on Hugging Face Hub. Can be a string token or True to use saved token.

  • revision

    (str, default: 'main' ) –

    The specific model revision to load (branch name, tag, or commit hash). Defaults to "main".

  • use_safetensors

    (bool | None, default: None ) –

    Whether to use safetensors format for loading weights. If None, automatically detects the available format.

  • weights_only

    (bool, default: True ) –

    Whether to only load model weights without optimizer states or other training artifacts.

  • kwargs

    Additional keyword arguments passed to the model constructor and loading process.

Returns:

  • SpeculatorModel

    A SpeculatorModel instance of the appropriate subclass, loaded with the pretrained weights and configuration.

Source code in speculators/model.py
@classmethod
def from_pretrained(
    cls: type["SpeculatorModel"],
    pretrained_model_name_or_path: str | os.PathLike | None,
    *model_args,
    config: PretrainedConfig | str | os.PathLike | None = None,
    cache_dir: str | os.PathLike | None = None,
    ignore_mismatched_sizes: bool = False,
    force_download: bool = False,
    local_files_only: bool = False,
    token: str | bool | None = None,
    revision: str = "main",
    use_safetensors: bool | None = None,
    weights_only: bool = True,
    t2d: torch.Tensor | None = None,
    d2t: torch.Tensor | None = None,
    **kwargs,
) -> "SpeculatorModel":
    """
    Load a pretrained speculator model from the Hugging Face Hub or local directory.

    This method automatically resolves the correct speculator model class based on
    the configuration type and loads the model with the appropriate weights. If
    called on the base SpeculatorModel class, it will automatically determine and
    instantiate the correct subclass based on the model configuration.

    Example:
        ```python
        # Load with automatic class resolution
        model = SpeculatorModel.from_pretrained("RedHatAI/speculator-llama-7b")

        # Load from local directory
        model = SpeculatorModel.from_pretrained("./my_speculator")

        # Load with custom config
        config = SpeculatorModelConfig.from_pretrained("RedHatAI/eagle-llama-7b")
        model = SpeculatorModel.from_pretrained(
            None, config=config, state_dict=state_dict
        )
        ```

    :param pretrained_model_name_or_path: The model identifier on Hugging Face Hub,
        or path to a local directory containing the model files. Can be None if
        config is provided as a path.
    :param model_args: Additional positional arguments passed to the model
        constructor.
    :param config: Optional configuration for the model. Can be a
        SpeculatorModelConfig instance, a path to a config file, or None to load
        from model directory.
    :param cache_dir: Directory to cache downloaded files. If None, uses default
        transformers cache directory.
    :param ignore_mismatched_sizes: Whether to ignore size mismatches when loading
        pretrained weights. Useful for loading models with different architectures.
    :param force_download: Whether to force re-download of model files even if
        they exist in cache.
    :param local_files_only: Whether to avoid downloading files and only use local
        cached files. Raises an error if files are not found locally.
    :param token: Optional authentication token for accessing private models on
        Hugging Face Hub. Can be a string token or True to use saved token.
    :param revision: The specific model revision to load (branch name, tag, or
        commit hash). Defaults to "main".
    :param use_safetensors: Whether to use safetensors format for loading weights.
        If None, automatically detects the available format.
    :param weights_only: Whether to only load model weights without optimizer
        states or other training artifacts.
    :param kwargs: Additional keyword arguments passed to the model constructor
        and loading process.
    :return: A SpeculatorModel instance of the appropriate subclass, loaded with
        the pretrained weights and configuration.
    """
    if not config:
        if not pretrained_model_name_or_path:
            raise ValueError(
                "Either `config` or `pretrained_model_name_or_path` must be "
                "provided to load a SpeculatorModel."
            )
        config = cls.config_class.from_pretrained(
            pretrained_model_name_or_path,
            cache_dir=cache_dir,
            force_download=force_download,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
        )

    if not isinstance(config, SpeculatorModelConfig):
        # once conversion is added, need to handle the case where a non speculator
        # config is passed in as a kwarg and auto convert
        raise TypeError(
            f"Expected config to be an instance of SpeculatorModelConfig, "
            f"got {type(config)}."
        )

    if not pretrained_model_name_or_path and not kwargs.get("state_dict"):
        raise ValueError(
            "Either `pretrained_model_name_or_path` or `state_dict` must be "
            "provided to load a SpeculatorModel."
        )

    if cls is SpeculatorModel:
        # generic call to from_pretrained on this class, need to resolve the
        # specific model class to use for loading based on the config and registry
        model_class = cls.registered_model_class_from_config(config)
        return model_class.from_pretrained(
            pretrained_model_name_or_path,
            *model_args,
            config=config,
            cache_dir=cache_dir,
            ignore_mismatched_sizes=ignore_mismatched_sizes,
            force_download=force_download,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            use_safetensors=use_safetensors,
            weights_only=weights_only,
            t2d=t2d,
            d2t=d2t,
            **kwargs,
        )

    model: SpeculatorModel = super().from_pretrained(  # type: ignore[misc]
        pretrained_model_name_or_path,
        *model_args,
        config=config,
        cache_dir=cache_dir,
        ignore_mismatched_sizes=ignore_mismatched_sizes,
        force_download=force_download,
        local_files_only=local_files_only,
        token=token,
        revision=revision,
        use_safetensors=use_safetensors,
        weights_only=weights_only,
        **kwargs,
    )
    if hasattr(model, "load_vocab_mappings"):
        model.load_vocab_mappings(t2d, d2t)  # type: ignore[operator,attr-defined]
    if hasattr(model, "load_verifier_weights"):
        model.load_verifier_weights()  # type: ignore[operator,attr-defined]
    return model

from_training_args abstractmethod classmethod

from_training_args(
    verifier_config: PretrainedConfig, **kwargs
) -> SpeculatorModel

Create model instance from training arguments.

This factory method is used by the training script to instantiate models from command-line arguments. Each algorithm must implement this to support the training infrastructure.

Args: verifier_config: Configuration from the verifier/base model. **kwargs: Training arguments as keyword arguments. Each algorithm extracts the parameters it needs.

Returns: Initialized model instance ready for training.

Example:

@classmethod
def from_training_args(cls, verifier_config, **kwargs):
    config = MySpeculatorConfig(
        transformer_layer_config=verifier_config,
        num_layers=kwargs['num_layers'],
        ...
    )
    return cls(config=config, t2d=kwargs.get('t2d'), d2t=kwargs.get('d2t'))

Source code in speculators/model.py
@classmethod
@abstractmethod
def from_training_args(
    cls, verifier_config: PretrainedConfig, **kwargs
) -> "SpeculatorModel":
    """Create model instance from training arguments.

    This factory method is used by the training script to instantiate models
    from command-line arguments. Each algorithm must implement this to support
    the training infrastructure.

    Args:
        verifier_config: Configuration from the verifier/base model.
        **kwargs: Training arguments as keyword arguments. Each algorithm
            extracts the parameters it needs.

    Returns:
        Initialized model instance ready for training.

    Example:
        ```python
        @classmethod
        def from_training_args(cls, verifier_config, **kwargs):
            config = MySpeculatorConfig(
                transformer_layer_config=verifier_config,
                num_layers=kwargs['num_layers'],
                ...
            )
            return cls(config=config, t2d=kwargs.get('t2d'), d2t=kwargs.get('d2t'))
        ```
    """
    raise NotImplementedError(
        f"{cls.__name__} must implement from_training_args() classmethod "
        "to support training infrastructure."
    )

get_trainer_kwargs abstractmethod staticmethod

get_trainer_kwargs(**kwargs) -> tuple[dict, dict]

Get algorithm-specific kwargs for training and validation.

This method extracts algorithm-specific parameters from the training arguments and returns separate kwargs dictionaries for training and validation forward passes.

Args: **kwargs: Training arguments containing algorithm-specific parameters.

Returns: Tuple of (train_kwargs, val_kwargs) where: - train_kwargs: Dict passed to model.forward() during training - val_kwargs: Dict passed to model.forward() during validation

Example:

@staticmethod
def get_trainer_kwargs(**kwargs):
    train_kwargs = {
        "num_steps": kwargs["num_steps"],
        "use_special_mode": True,
    }
    val_kwargs = {
        "num_steps": kwargs["num_steps"],
        "use_special_mode": False,
    }
    return train_kwargs, val_kwargs

Source code in speculators/model.py
@staticmethod
@abstractmethod
def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]:
    """Get algorithm-specific kwargs for training and validation.

    This method extracts algorithm-specific parameters from the training
    arguments and returns separate kwargs dictionaries for training and
    validation forward passes.

    Args:
        **kwargs: Training arguments containing algorithm-specific parameters.

    Returns:
        Tuple of (train_kwargs, val_kwargs) where:
            - train_kwargs: Dict passed to model.forward() during training
            - val_kwargs: Dict passed to model.forward() during validation

    Example:
        ```python
        @staticmethod
        def get_trainer_kwargs(**kwargs):
            train_kwargs = {
                "num_steps": kwargs["num_steps"],
                "use_special_mode": True,
            }
            val_kwargs = {
                "num_steps": kwargs["num_steps"],
                "use_special_mode": False,
            }
            return train_kwargs, val_kwargs
        ```
    """
    raise NotImplementedError(
        "Model must implement get_trainer_kwargs() staticmethod "
        "to support training infrastructure."
    )

registered_model_class_from_config classmethod

registered_model_class_from_config(
    config: SpeculatorModelConfig,
) -> type[SpeculatorModel]

Looks up the appropriate speculator model class from the registry based on the configuration type. It matches the config class to the corresponding model class that was registered during auto-discovery or manual registration.

Parameters:

  • config

    (SpeculatorModelConfig) –

    The configuration for which to find the registered model class. Must be an instance of a SpeculatorModelConfig subclass.

Returns:

  • type[SpeculatorModel]

    The registered model class that matches the configuration type.

Source code in speculators/model.py
@classmethod
def registered_model_class_from_config(
    cls, config: SpeculatorModelConfig
) -> type["SpeculatorModel"]:
    """
    Looks up the appropriate speculator model class from the registry
    based on the configuration type. It matches the config class to the
    corresponding model class that was registered during auto-discovery or manual
    registration.

    :param config: The configuration for which to find the registered model class.
        Must be an instance of a SpeculatorModelConfig subclass.
    :return: The registered model class that matches the configuration type.
    """
    if not isinstance(config, SpeculatorModelConfig):
        raise TypeError(
            f"Expected config to be an instance of SpeculatorModelConfig, "
            f"got {type(config)} {config}."
        )

    if type(config) is SpeculatorModelConfig:
        raise TypeError(
            "Received a SpeculatorModelConfig instance but expected a subclass. "
            "Use the specific subclass of SpeculatorModelConfig instead. "
            f"Received: {type(config)} {config}"
        )

    if not cls.registry:
        raise ValueError(
            "No registered model classes found. "
            "Ensure that models are registered with "
            "`SpeculatorModel.register(NAME)` or that auto-discovery is enabled."
        )

    for _, model_class in cls.registry.items():
        model_config_class: type[SpeculatorModelConfig] = model_class.config_class

        if type(config) is model_config_class:
            return model_class

    raise ValueError(
        f"No registered model class found for config type {type(config)}. "
        f"Available registered model classes: {list(cls.registry.keys())}."
    )

verify_training_compatible classmethod

verify_training_compatible(model: SpeculatorModel) -> None

Verify that a model instance is compatible with training infrastructure.

This method validates that the given model is: 1. An instance of SpeculatorModel 2. Registered in the SpeculatorModel registry 3. Has a layers attribute (required for FSDP wrapping)

Args: model: The model instance to verify

Raises: TypeError: If model is not a SpeculatorModel instance ValueError: If model's class is not in the registry AttributeError: If model doesn't have a layers attribute

Source code in speculators/model.py
@classmethod
def verify_training_compatible(cls, model: "SpeculatorModel") -> None:
    """Verify that a model instance is compatible with training infrastructure.

    This method validates that the given model is:
    1. An instance of SpeculatorModel
    2. Registered in the SpeculatorModel registry
    3. Has a `layers` attribute (required for FSDP wrapping)

    Args:
        model: The model instance to verify

    Raises:
        TypeError: If model is not a SpeculatorModel instance
        ValueError: If model's class is not in the registry
        AttributeError: If model doesn't have a `layers` attribute
    """
    if not isinstance(model, SpeculatorModel):
        raise TypeError(
            f"Model must be a SpeculatorModel, got {type(model).__name__}"
        )

    model_class = type(model)
    registry = cls.registry
    if registry is None or model_class not in registry.values():
        raise ValueError(
            f"Model {model_class.__name__} is not registered in "
            f"SpeculatorModel.registry. "
            f"Available models: {list(registry.keys()) if registry else []}"
        )

    if not hasattr(model, "layers"):
        raise AttributeError(
            f"Model {model_class.__name__} must have a 'layers' attribute "
            f"containing decoder layers for FSDP wrapping"
        )

SpeculatorModelConfig

SpeculatorModelConfig(**kwargs)

Bases: PydanticClassRegistryMixin, PretrainedConfig

The base config for a speculator model and implementation which defines the hyperparameters and settings required to implement a speculator model. It includes details on the speculator model architecture along with the speculators config describing the algorithm, token proposals, and verifier model.

It inherits from the Transformers PretrainedConfig class to ensure full compatibility with standard Transformers model pathways while building on the standard methods for PretrainedConfigs to load, save, and push to the HF hub.

This is the main config which maps to the config.json file for saved speculators.

Methods:

  • from_dict

    Create a SpeculatorModelConfig from a dictionary, automatically instantiating

  • from_pretrained

    Load a SpeculatorModelConfig from the name/id of a model on the Hugging Face Hub

  • to_dict

    :return: A dictionary representation of the full config, including the

  • to_diff_dict

    :return: A dictionary representation of a simplified config,

Source code in speculators/config.py
def __init__(self, **kwargs):
    # initialize the Pydantic arguments first to set all valid fields
    PydanticClassRegistryMixin.__init__(self, **kwargs)

    # reset kwargs handled by Pydantic so PretrainedConfig doesn't override
    for field in self.__class__.model_fields:
        kwargs[field] = getattr(self, field)

    # strip ClassVars so PretrainedConfig.__post_init__ doesn't try to
    # setattr them (pydantic blocks setattr on ClassVar names)
    class_vars = self.__class__.__class_vars__
    for cv in class_vars:
        kwargs.pop(cv, None)

    # initialize the Hugging Face PretrainedConfig arguments for the model
    PretrainedConfig.__init__(self, **kwargs)

    # ensure we always update the transformers version
    self.transformers_version = version("transformers")

from_dict classmethod

from_dict(
    config_dict: dict[str, Any], **kwargs
) -> SpeculatorModelConfig

Create a SpeculatorModelConfig from a dictionary, automatically instantiating the correct subclass based on the speculators_model_type field.

Parameters:

  • config_dict

    (dict[str, Any]) –

    Dictionary containing the configuration

  • kwargs

    Additional keyword arguments that override config values

Returns:

Source code in speculators/config.py
@classmethod
def from_dict(
    cls, config_dict: dict[str, Any], **kwargs
) -> "SpeculatorModelConfig":
    """
    Create a SpeculatorModelConfig from a dictionary, automatically instantiating
    the correct subclass based on the speculators_model_type field.

    :param config_dict: Dictionary containing the configuration
    :param kwargs: Additional keyword arguments that override config values
    :return: A SpeculatorModelConfig instance of the appropriate subclass
    """
    dict_obj = {**config_dict, **kwargs}

    if "speculators_model_type" not in dict_obj:
        raise ValueError(
            "The config dictionary must contain the 'speculators_model_type' field "
            "for loading a SpeculatorModelConfig in the Speculators library."
        )

    return cls.model_validate(dict_obj)

from_pretrained classmethod

from_pretrained(
    pretrained_model_name_or_path: str | PathLike,
    cache_dir: str | PathLike | None = None,
    force_download: bool = False,
    local_files_only: bool = False,
    token: str | bool | None = None,
    revision: str = "main",
    **kwargs,
) -> SpeculatorModelConfig

Load a SpeculatorModelConfig from the name/id of a model on the Hugging Face Hub or from a local directory. Will automatically instantiate the correct config from speculators.models package.

Parameters:

  • pretrained_model_name_or_path

    (str | PathLike) –

    The name or path to the pretrained model.

  • cache_dir

    (str | PathLike | None, default: None ) –

    The directory to cache the config in.

  • force_download

    (bool, default: False ) –

    Whether to force download the config from the Hub.

  • local_files_only

    (bool, default: False ) –

    Whether to use local files, not download from the Hub.

  • token

    (str | bool | None, default: None ) –

    The token to use for authentication with the Hub.

  • revision

    (str, default: 'main' ) –

    The revision of the config to load from the Hub.

  • kwargs

    Additional keyword arguments to pass to the config.

Returns:

Source code in speculators/config.py
@classmethod
def from_pretrained(
    cls,
    pretrained_model_name_or_path: str | os.PathLike,
    cache_dir: str | os.PathLike | None = None,
    force_download: bool = False,
    local_files_only: bool = False,
    token: str | bool | None = None,
    revision: str = "main",
    **kwargs,
) -> "SpeculatorModelConfig":
    """
    Load a SpeculatorModelConfig from the name/id of a model on the Hugging Face Hub
    or from a local directory. Will automatically instantiate the correct config
    from speculators.models package.

    :param pretrained_model_name_or_path: The name or path to the pretrained model.
    :param cache_dir: The directory to cache the config in.
    :param force_download: Whether to force download the config from the Hub.
    :param local_files_only: Whether to use local files, not download from the Hub.
    :param token: The token to use for authentication with the Hub.
    :param revision: The revision of the config to load from the Hub.
    :param kwargs: Additional keyword arguments to pass to the config.
    :return: A SpeculatorModelConfig object with the loaded parameters.
    """
    # Transformers config loading
    config_dict, kwargs = cls.get_config_dict(
        pretrained_model_name_or_path,
        cache_dir=cache_dir,
        force_download=force_download,
        local_files_only=local_files_only,
        token=token,
        revision=revision,
        **kwargs,
    )

    if "speculators_model_type" not in config_dict:
        # Conversion pathway
        raise NotImplementedError(
            "Loading a non-speculator model config is not supported yet."
        )

    return cls.from_dict(config_dict, **kwargs)

to_dict

to_dict() -> dict[str, Any]

Returns:

  • dict[str, Any]

    A dictionary representation of the full config, including the PretrainedConfig variables and Pydantic model fields.

Source code in speculators/config.py
def to_dict(self) -> dict[str, Any]:
    """
    :return: A dictionary representation of the full config, including the
        PretrainedConfig variables and Pydantic model fields.
    """
    pretrained_dict = super().to_dict()
    model_dict = self.model_dump()
    config_dict = {**pretrained_dict, **model_dict}

    # strip all class variables and metadata that are not needed in the output
    for key in (
        "model_config",
        "auto_package",
        "registry_auto_discovery",
        "schema_discriminator",
        "model_type",
        "base_config_key",
        "sub_configs",
        "is_composition",
        "attribute_map",
        "base_model_tp_plan",
        "base_model_pp_plan",
        "base_model_ep_plan",
        "has_no_defaults_at_init",
        "keys_to_ignore_at_inference",
        "_auto_class",
    ):
        config_dict.pop(key, None)

    return config_dict

to_diff_dict

to_diff_dict() -> dict[str, Any]

Returns:

  • dict[str, Any]

    A dictionary representation of a simplified config, including only the PretrainedConfig fields that have been modified or set, along with all Pydantic fields.

Source code in speculators/config.py
def to_diff_dict(self) -> dict[str, Any]:
    """
    :return: A dictionary representation of a simplified config,
        including only the PretrainedConfig fields that have been modified
        or set, along with all Pydantic fields.
    """
    return super().to_diff_dict()

SpeculatorsConfig

Bases: ReloadableBaseModel

The base config for a spec decode implementation which defines the parameters required to implement a speculators algorithm for the parent, speculator model. It includes details on the algorithm, token proposals, and the verifier model.

Methods:

check_default_proposal_method

check_default_proposal_method() -> SpeculatorsConfig

Validate default_proposal_method is one of the proposal_methods.

Source code in speculators/config.py
@model_validator(mode="after")
def check_default_proposal_method(self) -> "SpeculatorsConfig":
    """Validate default_proposal_method is one of the proposal_methods."""
    available = [method.proposal_type for method in self.proposal_methods]
    if self.default_proposal_method not in available:
        raise ValueError(
            "default_proposal_method "
            f"'{self.default_proposal_method}' must match the proposal_type of "
            f"one of the configured proposal_methods. Available proposal types: "
            f"{available}."
        )
    return self

TokenProposalConfig

Bases: PydanticClassRegistryMixin

The base config for a token proposal method which defines how tokens are generated by the speculator, how they are passed to the verifier, and how they are scored for acceptance or rejection. All implementations of token proposal methods must inherit from this class, set the proposal_type to a unique value, and add any additional parameters needed to instantiate and implement the method.

It uses pydantic to validate the parameters, provide default values, and enable automatic serialization and deserialization of the correct class types based on the proposal_type field.

VerifierConfig

Bases: BaseModel

The base config for a verifier model which defines the parameters that are required to either load the original verifier model or validate compatibility with a new verifier based on the the architecture and tokenizers/processor properties. It provides convenience methods to extract the required parameters from a PretrainedConfig object.

Methods:

  • from_config

    Create a VerifierConfig from a PretrainedConfig object.

from_config classmethod

from_config(
    config: PretrainedConfig,
    name_or_path: str | None = "UNSET",
) -> VerifierConfig

Create a VerifierConfig from a PretrainedConfig object. Used to extract the required parameters from the original verifier config and create a VerifierConfig object.

Parameters:

  • config

    (PretrainedConfig) –

    The PretrainedConfig object to extract the parameters from.

  • name_or_path

    (str | None, default: 'UNSET' ) –

    The name or path for the verifier model. Set to None to not add a specific name_or_path. If not provided, the name_or_path from the config will be used.

Returns:

  • VerifierConfig

    A VerifierConfig object with the extracted parameters.

Source code in speculators/config.py
@classmethod
def from_config(
    cls, config: PretrainedConfig, name_or_path: str | None = "UNSET"
) -> "VerifierConfig":
    """
    Create a VerifierConfig from a PretrainedConfig object.
    Used to extract the required parameters from the original verifier
    config and create a VerifierConfig object.

    :param config: The PretrainedConfig object to extract the parameters from.
    :param name_or_path: The name or path for the verifier model.
        Set to None to not add a specific name_or_path.
        If not provided, the name_or_path from the config will be used.
    :return: A VerifierConfig object with the extracted parameters.
    """
    config_dict = config.to_dict()

    if name_or_path == "UNSET":
        name_or_path = (
            getattr(config, "name_or_path", None)
            or config_dict.get("_name_or_path", None)
            or config_dict.get("name_or_path", None)
        )

    return cls(
        name_or_path=name_or_path,
        architectures=config_dict.get("architectures") or [],
    )

reload_schemas

reload_schemas()

Automatically populates the registry for all PydanticClassRegistryMixin subclasses and reloads schemas for all Config classes to ensure their schemas are up-to-date with the current registry state.

Source code in speculators/config.py
def reload_schemas():
    """
    Automatically populates the registry for all PydanticClassRegistryMixin subclasses
    and reloads schemas for all Config classes to ensure their schemas are up-to-date
    with the current registry state.
    """
    TokenProposalConfig.reload_schema()
    SpeculatorsConfig.reload_schema()
    SpeculatorModelConfig.reload_schema()