Skip to content

speculators.utils.pydantic_utils

General pydantic utilities for Speculators.

This module provides integration between Pydantic and the Speculators library, enabling things like polymorphic serialization and deserialization of Pydantic models using a discriminator field and registry.

Classes: PydanticClassRegistryMixin: A mixin that combines Pydantic models with the ClassRegistryMixin to support polymorphic model instantiation based on a discriminator field

Classes:

PydanticClassRegistryMixin

Bases: ReloadableBaseModel, ABC, ClassRegistryMixin

A mixin class that integrates Pydantic models with the ClassRegistryMixin to enable polymorphic serialization and deserialization based on a discriminator field.

This mixin allows Pydantic models to be registered in a registry and dynamically instantiated based on a discriminator field in the input data. It overrides Pydantic's validation system to correctly instantiate the appropriate subclass based on the discriminator value and the name of the registered classes.

The mixin is particularly useful for implementing base registry classes that need to support multiple implementations, such as different token proposal methods or speculative decoding algorithms.

Usage Example:

from typing import ClassVar
from pydantic import BaseModel, Field
from speculators.utils import PydanticClassRegistryMixin

class BaseConfig(PydanticClassRegistryMixin):
    @classmethod
    def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]:
        if cls.__name__ == "BaseConfig":
            return cls
        return BaseConfig

    schema_discriminator: ClassVar[str] = "config_type"
    config_type: str = Field(description="The type of configuration")

@BaseConfig.register("config_a")
class ConfigA(BaseConfig):
    config_type: str = "config_a"
    value_a: str = Field(description="A value specific to ConfigA")

@BaseConfig.register("config_b")
class ConfigB(BaseConfig):
    config_type: str = "config_b"
    value_b: int = Field(description="A value specific to ConfigB")

BaseConfig.reload_schema()  # Ensures the schema is up-to-date with registry

# Dynamic instantiation based on config_type
config_data = {"config_type": "config_a", "value_a": "test"}
config = BaseConfig.model_validate(config_data)  # Returns ConfigA instance
print(config)
dump_data = config.model_dump()  # Dumps the data to a dictionary
print(dump_data)  # Output: {'config_type': 'config_a', 'value_a': 'test'}

Attributes:

  • schema_discriminator (str) –

    The field name used as the discriminator in the JSON schema. Default is "model_type".

  • registry (dict[str, BaseModel] | None) –

    A dictionary mapping discriminator values to pydantic model classes.

Methods:

register_decorator classmethod

register_decorator(
    clazz: type[BaseModel], name: str | None = None
) -> type[BaseModel]

Registers a Pydantic model class with the registry.

This method extends the ClassRegistryMixin.register_decorator method by adding a type check to ensure only Pydantic BaseModel subclasses can be registered.

Parameters:

  • clazz

    (type[BaseModel]) –

    The Pydantic model class to register

  • name

    (str | None, default: None ) –

    Optional name to register the class under. If None, the class name is used as the registry key.

Returns:

  • type[BaseModel]

    The registered class.

Raises:

  • TypeError

    If clazz is not a subclass of Pydantic BaseModel

Source code in speculators/utils/pydantic_utils.py
@classmethod
def register_decorator(
    cls, clazz: type[BaseModel], name: str | None = None
) -> type[BaseModel]:
    """
    Registers a Pydantic model class with the registry.

    This method extends the ClassRegistryMixin.register_decorator method by adding
    a type check to ensure only Pydantic BaseModel subclasses can be registered.

    :param clazz: The Pydantic model class to register
    :param name: Optional name to register the class under. If None, the class name
        is used as the registry key.
    :return: The registered class.
    :raises TypeError: If clazz is not a subclass of Pydantic BaseModel
    """
    if not issubclass(clazz, BaseModel):
        raise TypeError(
            f"Cannot register {clazz.__name__} as it is not a subclass of "
            "Pydantic BaseModel"
        )

    return super().register_decorator(clazz, name=name)

ReloadableBaseModel

Bases: BaseModel

Methods:

  • reload_schema

    Reloads the schema for the class, ensuring that the registry is populated

reload_schema classmethod

reload_schema()

Reloads the schema for the class, ensuring that the registry is populated and that the schema is up-to-date.

This method is useful when the registry has been modified or when the class needs to be re-validated with the latest schema.

Source code in speculators/utils/pydantic_utils.py
@classmethod
def reload_schema(cls):
    """
    Reloads the schema for the class, ensuring that the registry is populated
    and that the schema is up-to-date.

    This method is useful when the registry has been modified or when the
    class needs to be re-validated with the latest schema.
    """
    # transformers 5.4+ uses torch.dtype in annotations without importing torch,
    # causing model_rebuild() to raise NameError. Ignored on older versions.
    cls.model_rebuild(force=True, _types_namespace={"torch": torch})