Skip to content

vllm.v1.worker.gpu.spec_decode.utils

Functions:

get_parallel_drafting_token_id(hf_config)

Resolve the mask token id used for parallel drafting slots.

Checks (in order): dflash_config.mask_token_id, pard_token, ptd_token_id. Raises ValueError if none are present.

Source code in vllm/v1/worker/gpu/spec_decode/utils.py
def get_parallel_drafting_token_id(hf_config) -> int:
    """Resolve the mask token id used for parallel drafting slots.

    Checks (in order): `dflash_config.mask_token_id`, `pard_token`,
    `ptd_token_id`. Raises ValueError if none are present.
    """
    dflash_config = getattr(hf_config, "dflash_config", None) or {}
    if "mask_token_id" in dflash_config:
        return int(dflash_config["mask_token_id"])
    if hasattr(hf_config, "pard_token"):
        return int(hf_config.pard_token)
    if hasattr(hf_config, "ptd_token_id"):
        return int(hf_config.ptd_token_id)
    raise ValueError(
        "Model config must specify `dflash_config.mask_token_id`,"
        " `pard_token`, or `ptd_token_id` for parallel drafting."
    )