Skip to content

speculators.data_generation.vllm_client

Functions:

generate_hidden_states

generate_hidden_states(
    client: Client,
    model: str,
    token_ids: list[int],
    timeout: float | None = DEFAULT_REQUEST_TIMEOUT,
) -> str

Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file.

Source code in speculators/data_generation/vllm_client.py
@with_retries
def generate_hidden_states(
    client: openai.Client,
    model: str,
    token_ids: list[int],
    timeout: float | None = DEFAULT_REQUEST_TIMEOUT,
) -> str:
    """
    Runs decode w/ max_tokens 1 to generate hidden states and returns path to
    hidden states file.
    """
    completion = client.completions.create(
        model=model,
        prompt=token_ids,
        max_tokens=1,
        extra_body={"return_token_ids": True},
        timeout=timeout,
    )
    return extract_output(completion, token_ids)

generate_hidden_states_async async

generate_hidden_states_async(
    client: AsyncClient,
    model: str,
    token_ids: list[int],
    timeout: float | None = DEFAULT_REQUEST_TIMEOUT,
) -> str

Runs decode w/ max_tokens 1 to generate hidden states and returns path to hidden states file.

Args: client: The async OpenAI client. model: The model ID. token_ids: The input token IDs. timeout: Timeout in seconds for each request attempt. None for no timeout.

Source code in speculators/data_generation/vllm_client.py
@with_retries
async def generate_hidden_states_async(
    client: openai.AsyncClient,
    model: str,
    token_ids: list[int],
    timeout: float | None = DEFAULT_REQUEST_TIMEOUT,
) -> str:
    """
    Runs decode w/ max_tokens 1 to generate hidden states and returns path to
    hidden states file.

    Args:
        client: The async OpenAI client.
        model: The model ID.
        token_ids: The input token IDs.
        timeout: Timeout in seconds for each request attempt. None for no timeout.
    """
    coro = client.completions.create(
        model=model,
        prompt=token_ids,
        max_tokens=1,
        extra_body={"return_token_ids": True},
        timeout=timeout,
    )
    if timeout is not None:
        completion = await asyncio.wait_for(coro, timeout=timeout)
    else:
        completion = await coro

    return extract_output(completion, token_ids)

with_retries

with_retries(fn)

Decorator that adds retry logic with exponential backoff.

The decorated function gains a max_retries keyword argument (default DEFAULT_MAX_RETRIES). InvalidResponseError is never retried. Works for both sync and async functions.

Source code in speculators/data_generation/vllm_client.py
def with_retries(fn):
    """Decorator that adds retry logic with exponential backoff.

    The decorated function gains a ``max_retries`` keyword argument
    (default ``DEFAULT_MAX_RETRIES``). ``InvalidResponseError`` is never
    retried. Works for both sync and async functions.
    """
    if asyncio.iscoroutinefunction(fn):

        @functools.wraps(fn)
        async def async_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs):
            total_attempts = max_retries + 1
            last_error: Exception | None = None
            for attempt in range(1, total_attempts + 1):
                try:
                    return await fn(*args, **kwargs)
                except Exception as e:
                    last_error = e
                    backoff = _handle_retry_error(e, attempt, total_attempts)
                    if backoff is not None:
                        await asyncio.sleep(backoff)
            raise last_error  # type: ignore[misc]

        return async_wrapper

    @functools.wraps(fn)
    def sync_wrapper(*args, max_retries=DEFAULT_MAX_RETRIES, **kwargs):
        total_attempts = max_retries + 1
        last_error: Exception | None = None
        for attempt in range(1, total_attempts + 1):
            try:
                return fn(*args, **kwargs)
            except Exception as e:
                last_error = e
                backoff = _handle_retry_error(e, attempt, total_attempts)
                if backoff is not None:
                    time.sleep(backoff)
        raise last_error  # type: ignore[misc]

    return sync_wrapper