Skip to content

speculators.models.dflash

Modules:

Classes:

DFlashDraftModel

DFlashDraftModel(config: DFlashSpeculatorConfig)

Bases: DraftVocabMixin, SpeculatorModel

Methods:

Source code in speculators/models/dflash/core.py
def __init__(
    self,
    config: DFlashSpeculatorConfig,
) -> None:
    # Forcibly override config settings
    if config.transformer_layer_config._attn_implementation is None:  # noqa: SLF001
        config.transformer_layer_config._attn_implementation = (  # noqa: SLF001
            "simple_flex_attention"
        )
    super().__init__(config=config)
    self._init_vocab(config)

    tl_config = config.transformer_layer_config

    # Number of draft layers is encoded in transformer_layer_config
    num_draft_layers = tl_config.num_hidden_layers
    self.layers = nn.ModuleList(
        [
            Qwen3DFlashDecoderLayer(config.transformer_layer_config, layer_idx)  # type: ignore[arg-type]
            for layer_idx in range(num_draft_layers)
        ]
    )
    self.sliding_window = tl_config.sliding_window
    self.sliding_window_indices = [
        i
        for i, layer_type in enumerate(tl_config.layer_types)
        if layer_type == "sliding_attention"
    ]
    self.uses_sliding_window_attn = bool(self.sliding_window_indices)
    self.uses_full_attn = bool(num_draft_layers - len(self.sliding_window_indices))
    self.sliding_window_non_causal = config.sliding_window_non_causal

    if config.aux_hidden_state_layer_ids is None:
        raise ValueError(
            "aux_hidden_state_layer_ids must be set in DFlashSpeculatorConfig. "
            "Use DFlashDraftModel.from_training_args() to resolve defaults."
        )
    self.target_layer_ids = config.aux_hidden_state_layer_ids

    self.norm = Qwen3RMSNorm(
        config.transformer_layer_config.hidden_size,
        eps=config.transformer_layer_config.rms_norm_eps,  # type: ignore[arg-type]
    )
    self.rotary_emb = Qwen3RotaryEmbedding(config.transformer_layer_config)  # type: ignore[arg-type]

    self.fc = nn.Linear(
        len(self.target_layer_ids) * config.transformer_layer_config.hidden_size,
        config.transformer_layer_config.hidden_size,
        bias=False,
    )
    self.hidden_norm = Qwen3RMSNorm(
        config.transformer_layer_config.hidden_size,
        eps=config.transformer_layer_config.rms_norm_eps,  # type: ignore[arg-type]
    )
    self.verifier_norm = Qwen3RMSNorm(
        config.transformer_layer_config.hidden_size,
        eps=config.transformer_layer_config.rms_norm_eps,  # type: ignore[arg-type]
    )
    self.verifier_norm.weight.requires_grad = False
    self.block_size = config.block_size
    self.post_init()

from_training_args classmethod

from_training_args(
    verifier_config: PretrainedConfig,
    t2d: Tensor | None = None,
    d2t: Tensor | None = None,
    **kwargs,
) -> DFlashDraftModel

Create DFlash model from training arguments.

Args: verifier_config: Verifier model configuration. This should be a config with num_hidden_layers set to the number of DRAFT layers (created by create_transformer_layer_config in train.py). t2d: Target-to-draft vocabulary mapping tensor (optional) d2t: Draft-to-target vocabulary mapping tensor (optional) **kwargs: Training arguments with DFlash-specific params - draft_vocab_size: Size of draft vocabulary - block_size: Block size for draft predictions (default: 8) - max_anchors: Max anchor positions during training (default: 256) - verifier_name_or_path: Path to verifier model

Returns: Initialized DFlashDraftModel

Note: The number of draft layers is encoded in verifier_config.num_hidden_layers, following the same pattern as EAGLE3.

Source code in speculators/models/dflash/core.py
@classmethod
def from_training_args(
    cls,
    verifier_config: "PretrainedConfig",
    t2d: torch.Tensor | None = None,
    d2t: torch.Tensor | None = None,
    **kwargs,
) -> "DFlashDraftModel":
    """Create DFlash model from training arguments.

    Args:
        verifier_config: Verifier model configuration. This should be a config
            with num_hidden_layers set to the number of DRAFT layers (created
            by create_transformer_layer_config in train.py).
        t2d: Target-to-draft vocabulary mapping tensor (optional)
        d2t: Draft-to-target vocabulary mapping tensor (optional)
        **kwargs: Training arguments with DFlash-specific params
            - draft_vocab_size: Size of draft vocabulary
            - block_size: Block size for draft predictions (default: 8)
            - max_anchors: Max anchor positions during training (default: 256)
            - verifier_name_or_path: Path to verifier model

    Returns:
        Initialized DFlashDraftModel

    Note:
        The number of draft layers is encoded in verifier_config.num_hidden_layers,
        following the same pattern as EAGLE3.
    """
    from speculators.config import (  # noqa: PLC0415
        SpeculatorsConfig,
        VerifierConfig,
    )
    from speculators.proposals.greedy import (  # noqa: PLC0415
        GreedyTokenProposalConfig,
    )

    target_layer_ids = resolve_target_layer_ids(
        kwargs.get("target_layer_ids"),
        kwargs["verifier_name_or_path"],
    )

    config = DFlashSpeculatorConfig(
        transformer_layer_config=verifier_config,
        draft_vocab_size=kwargs["draft_vocab_size"],
        block_size=kwargs.get("block_size", 8),
        max_anchors=kwargs.get("max_anchors", 3072),
        aux_hidden_state_layer_ids=target_layer_ids,
        mask_token_id=kwargs.get("mask_token_id"),
        sliding_window_non_causal=kwargs.get("sliding_window_non_causal", False),
        speculators_config=SpeculatorsConfig(
            algorithm="dflash",
            proposal_methods=[
                GreedyTokenProposalConfig(
                    # DFlash first position is anchor position, not used during gen
                    speculative_tokens=kwargs.get("block_size", 8) - 1,
                )
            ],
            default_proposal_method="greedy",
            verifier=VerifierConfig.from_config(
                verifier_config, name_or_path=kwargs["verifier_name_or_path"]
            ),
        ),
    )

    model = cls(config=config)
    model.load_vocab_mappings(t2d, d2t)
    model.load_verifier_weights()
    return model

get_trainer_kwargs staticmethod

get_trainer_kwargs(**kwargs) -> tuple[dict, dict]

Get training and validation kwargs for DFlash.

Args: **kwargs: Training arguments

Returns: Tuple of (train_call_kwargs, val_call_kwargs)

Source code in speculators/models/dflash/core.py
@staticmethod
def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]:
    """Get training and validation kwargs for DFlash.

    Args:
        **kwargs: Training arguments

    Returns:
        Tuple of (train_call_kwargs, val_call_kwargs)
    """
    loss_fn = resolve_loss_fn(kwargs["loss_fn"])
    return {"loss_fn": loss_fn}, {"loss_fn": loss_fn}

DFlashSpeculatorConfig

DFlashSpeculatorConfig(**kwargs)

Bases: SpeculatorModelConfig

Configuration for DFlash speculator with vocabulary mapping.

DFlash features vocabulary mapping between draft (64K) and target (128K) vocabularies, enabling cross-tokenizer speculation.

Parameters:

  • transformer_layer_config

    Configuration for the transformer decoder layer

  • draft_vocab_size

    Size of draft model vocabulary for speculation

Methods:

Attributes:

Source code in speculators/config.py
def __init__(self, **kwargs):
    # initialize the Pydantic arguments first to set all valid fields
    PydanticClassRegistryMixin.__init__(self, **kwargs)

    # reset kwargs handled by Pydantic so PretrainedConfig doesn't override
    for field in self.__class__.model_fields:
        kwargs[field] = getattr(self, field)

    # strip ClassVars so PretrainedConfig.__post_init__ doesn't try to
    # setattr them (pydantic blocks setattr on ClassVar names)
    class_vars = self.__class__.__class_vars__
    for cv in class_vars:
        kwargs.pop(cv, None)

    # initialize the Hugging Face PretrainedConfig arguments for the model
    PretrainedConfig.__init__(self, **kwargs)

    # ensure we always update the transformers version
    self.transformers_version = version("transformers")

target_vocab_size property

target_vocab_size: int

Get target vocabulary size from transformer config.

serialize_transformer_config

serialize_transformer_config(
    value: PretrainedConfig,
) -> dict

Serialize transformer config to dict.

Source code in speculators/models/dflash/config.py
@field_serializer("transformer_layer_config")
def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
    """Serialize transformer config to dict."""
    return value.to_diff_dict()

validate_transformer_config classmethod

validate_transformer_config(value: Any) -> PretrainedConfig

Validate and convert transformer config.

Source code in speculators/models/dflash/config.py
@field_validator("transformer_layer_config", mode="before")
@classmethod
def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
    """Validate and convert transformer config."""
    if isinstance(value, dict):
        config_class: type[PretrainedConfig] = Qwen3Config
        if "model_type" in value:
            config_class = AutoConfig.for_model(
                model_type=value["model_type"]
            ).__class__
        return config_class(**value)
    return value