Sampling Parameters#
- class vllm.SamplingParams(n: int = 1, best_of: int | None = None, _real_n: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, min_p: float = 0.0, seed: int | None = None, stop: ~typing.List[str] | str | None = None, stop_token_ids: ~typing.List[int] | None = None, bad_words: ~typing.List[str] | None = None, ignore_eos: bool = False, max_tokens: int | None = 16, min_tokens: int = 0, logprobs: int | None = None, prompt_logprobs: int | None = None, detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: ~typing.Any | None = None, include_stop_str_in_output: bool = False, truncate_prompt_tokens: int | None = None, output_kind: ~vllm.sampling_params.RequestOutputKind = RequestOutputKind.CUMULATIVE, output_text_buffer_length: int = 0, _all_stop_token_ids: ~typing.Set[int] = <factory>, guided_decoding: ~vllm.sampling_params.GuidedDecodingParams | None = None, logit_bias: ~typing.Dict[int, float] | None = None, allowed_token_ids: ~typing.List[int] | None = None)[source]#
Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion API (https://platform.openai.com/docs/api-reference/completions/create). In addition, we support beam search, which is not supported by OpenAI.
- Parameters:
n – Number of output sequences to return for the given prompt.
best_of – Number of output sequences that are generated from the prompt. From these best_of sequences, the top n sequences are returned. best_of must be greater than or equal to n. By default, best_of is set to n.
presence_penalty – Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.
frequency_penalty – Float that penalizes new tokens based on their frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.
repetition_penalty – Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens.
temperature – Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.
top_p – Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k – Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
min_p – Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this.
seed – Random seed to use for the generation.
stop – List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.
stop_token_ids – List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens.
bad_words – List of words that are not allowed to be generated. More precisely, only the last token of a corresponding token sequence is not allowed when the next generated token can complete the sequence.
include_stop_str_in_output – Whether to include the stop strings in output text. Defaults to False.
ignore_eos – Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.
max_tokens – Maximum number of tokens to generate per output sequence.
min_tokens – Minimum number of tokens to generate per output sequence before EOS or stop_token_ids can be generated
logprobs – Number of log probabilities to return per output token. When set to None, no probability is returned. If set to a non-None value, the result includes the log probabilities of the specified number of most likely tokens, as well as the chosen tokens. Note that the implementation follows the OpenAI API: The API will always return the log probability of the sampled token, so there may be up to logprobs+1 elements in the response.
prompt_logprobs – Number of log probabilities to return per prompt token.
detokenize – Whether to detokenize the output. Defaults to True.
skip_special_tokens – Whether to skip special tokens in the output.
spaces_between_special_tokens – Whether to add spaces between special tokens in the output. Defaults to True.
logits_processors – List of functions that modify logits based on previously generated tokens, and optionally prompt tokens as a first argument.
truncate_prompt_tokens – If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None (i.e., no truncation).
guided_decoding – If provided, the engine will construct a guided decoding logits processor from these parameters. Defaults to None.
logit_bias – If provided, the engine will construct a logits processor that applies these logit biases. Defaults to None.
allowed_token_ids – If provided, the engine will construct a logits processor which only retains scores for the given token ids. Defaults to None.
- clone() SamplingParams [source]#
Deep copy excluding LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an arbitrary, nontrivial amount of data. See vllm-project/vllm#3087