Skip to content

vllm.reasoning

Modules:

Name Description
abs_reasoning_parsers
deepseek_r1_reasoning_parser
granite_reasoning_parser
qwen3_reasoning_parser

__all__ module-attribute

__all__ = [
    "ReasoningParser",
    "ReasoningParserManager",
    "DeepSeekR1ReasoningParser",
    "GraniteReasoningParser",
    "Qwen3ReasoningParser",
]

DeepSeekR1ReasoningParser

Bases: ReasoningParser

Reasoning parser for DeepSeek R1 model.

The DeepSeek R1 model uses ... tokens to denote reasoning text. This parser extracts the reasoning content from the model output.

Source code in vllm/reasoning/deepseek_r1_reasoning_parser.py
@ReasoningParserManager.register_module("deepseek_r1")
class DeepSeekR1ReasoningParser(ReasoningParser):
    """
    Reasoning parser for DeepSeek R1 model.

    The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
    text. This parser extracts the reasoning content from the model output.
    """

    start_token_id: int
    end_token_id: int

    start_token: str = "<think>"
    end_token: str = "</think>"

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
                "constructor during construction.")

        self.start_token_id = self.vocab.get(self.start_token)
        self.end_token_id = self.vocab.get(self.end_token)
        if self.start_token_id is None or self.end_token_id is None:
            raise RuntimeError(
                "DeepSeek R1 reasoning parser could not locate think start/end "
                "tokens in the tokenizer!")

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        return self.end_token_id in input_ids

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract the content after the end tokens
        """
        if self.end_token_id not in input_ids[:-1]:
            return []
        else:
            return input_ids[input_ids.index(self.end_token_id) + 1:]

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
        For text <think>abc</think>xyz:
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content
        """
        # Skip single special tokens
        if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
                self.start_token_id, self.end_token_id
        ]):
            return None

        # Check if <think> is present in previous or delta.
        # Keep compatibility with models that don't generate <think> tokens.
        if self.start_token_id in previous_token_ids:
            if self.end_token_id in delta_token_ids:
                # <think> in previous, </think> in delta,
                # extract reasoning content
                end_index = delta_text.find(self.end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.end_token):]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
            elif self.end_token_id in previous_token_ids:
                # <think> in previous, </think> in previous,
                # reasoning content continues
                return DeltaMessage(content=delta_text)
            else:
                # <think> in previous, no </think> in previous or delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        elif self.start_token_id in delta_token_ids:
            if self.end_token_id in delta_token_ids:
                # <think> in delta, </think> in delta, extract reasoning content
                start_index = delta_text.find(self.start_token)
                end_index = delta_text.find(self.end_token)
                reasoning_content = delta_text[start_index +
                                               len(self.start_token):end_index]
                content = delta_text[end_index + len(self.end_token):]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
            else:
                # <think> in delta, no </think> in delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        else:
            # No <think> in previous or delta, also need to check for </think>.
            # Because the model may have generated </think> without <think>
            # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
            if self.end_token_id in delta_token_ids:
                # </think> in delta with more tokens,
                # extract reasoning content and content
                end_index = delta_text.find(self.end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.end_token):]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
            elif self.end_token_id in previous_token_ids:
                # </think> in previous, thinking content ends
                return DeltaMessage(content=delta_text)
            else:
                # no </think> in previous or delta, reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)

    def extract_reasoning_content(
            self, model_output: str, request: ChatCompletionRequest
    ) -> tuple[Optional[str], Optional[str]]:
        """
        Extract reasoning content from the model output.

        For text <think>abc</think>xyz:
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content

        Returns:
            tuple[Optional[str], Optional[str]]: reasoning content and content
        """

        # Check if the start token is present in the model output, remove it
        # if it is present.
        model_output_parts = model_output.partition(self.start_token)
        model_output = model_output_parts[2] if model_output_parts[
            1] else model_output_parts[0]

        # DeepSeek R1 doesn't generate <think> now.
        # Thus we assume the reasoning content is always at the start.
        # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
        if self.end_token not in model_output:
            return model_output, None
        else:
            reasoning_content, _, content = model_output.partition(
                self.end_token)
            # If the end token is not found, return the model output as is.
            # It should not happen since we already checked for the presence
            # of the end token.
            # If generation stops right after end-of-think, return null content
            final_content = content or None
            return reasoning_content, final_content

end_token class-attribute instance-attribute

end_token: str = '</think>'

end_token_id instance-attribute

end_token_id: int = get(end_token)

start_token class-attribute instance-attribute

start_token: str = '<think>'

start_token_id instance-attribute

start_token_id: int = get(start_token)

__init__

__init__(tokenizer: PreTrainedTokenizerBase)
Source code in vllm/reasoning/deepseek_r1_reasoning_parser.py
def __init__(self, tokenizer: PreTrainedTokenizerBase):
    super().__init__(tokenizer)

    if not self.model_tokenizer:
        raise ValueError(
            "The model tokenizer must be passed to the ReasoningParser "
            "constructor during construction.")

    self.start_token_id = self.vocab.get(self.start_token)
    self.end_token_id = self.vocab.get(self.end_token)
    if self.start_token_id is None or self.end_token_id is None:
        raise RuntimeError(
            "DeepSeek R1 reasoning parser could not locate think start/end "
            "tokens in the tokenizer!")

extract_content_ids

extract_content_ids(input_ids: list[int]) -> list[int]

Extract the content after the end tokens

Source code in vllm/reasoning/deepseek_r1_reasoning_parser.py
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
    """
    Extract the content after the end tokens
    """
    if self.end_token_id not in input_ids[:-1]:
        return []
    else:
        return input_ids[input_ids.index(self.end_token_id) + 1:]

extract_reasoning_content

extract_reasoning_content(
    model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]

Extract reasoning content from the model output.

For text abcxyz: - 'abc' goes to reasoning_content - 'xyz' goes to content

Returns:

Type Description
tuple[Optional[str], Optional[str]]

tuple[Optional[str], Optional[str]]: reasoning content and content

Source code in vllm/reasoning/deepseek_r1_reasoning_parser.py
def extract_reasoning_content(
        self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
    """
    Extract reasoning content from the model output.

    For text <think>abc</think>xyz:
    - 'abc' goes to reasoning_content
    - 'xyz' goes to content

    Returns:
        tuple[Optional[str], Optional[str]]: reasoning content and content
    """

    # Check if the start token is present in the model output, remove it
    # if it is present.
    model_output_parts = model_output.partition(self.start_token)
    model_output = model_output_parts[2] if model_output_parts[
        1] else model_output_parts[0]

    # DeepSeek R1 doesn't generate <think> now.
    # Thus we assume the reasoning content is always at the start.
    # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
    if self.end_token not in model_output:
        return model_output, None
    else:
        reasoning_content, _, content = model_output.partition(
            self.end_token)
        # If the end token is not found, return the model output as is.
        # It should not happen since we already checked for the presence
        # of the end token.
        # If generation stops right after end-of-think, return null content
        final_content = content or None
        return reasoning_content, final_content

extract_reasoning_content_streaming

extract_reasoning_content_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]

Extract reasoning content from a delta message. Handles streaming output where previous + delta = current. Uses token IDs for faster processing. For text abcxyz: - 'abc' goes to reasoning_content - 'xyz' goes to content

Source code in vllm/reasoning/deepseek_r1_reasoning_parser.py
def extract_reasoning_content_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
    """
    Extract reasoning content from a delta message.
    Handles streaming output where previous + delta = current.
    Uses token IDs for faster processing.
    For text <think>abc</think>xyz:
    - 'abc' goes to reasoning_content
    - 'xyz' goes to content
    """
    # Skip single special tokens
    if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
            self.start_token_id, self.end_token_id
    ]):
        return None

    # Check if <think> is present in previous or delta.
    # Keep compatibility with models that don't generate <think> tokens.
    if self.start_token_id in previous_token_ids:
        if self.end_token_id in delta_token_ids:
            # <think> in previous, </think> in delta,
            # extract reasoning content
            end_index = delta_text.find(self.end_token)
            reasoning_content = delta_text[:end_index]
            content = delta_text[end_index + len(self.end_token):]
            return DeltaMessage(
                reasoning_content=reasoning_content,
                content=content if content else None,
            )
        elif self.end_token_id in previous_token_ids:
            # <think> in previous, </think> in previous,
            # reasoning content continues
            return DeltaMessage(content=delta_text)
        else:
            # <think> in previous, no </think> in previous or delta,
            # reasoning content continues
            return DeltaMessage(reasoning_content=delta_text)
    elif self.start_token_id in delta_token_ids:
        if self.end_token_id in delta_token_ids:
            # <think> in delta, </think> in delta, extract reasoning content
            start_index = delta_text.find(self.start_token)
            end_index = delta_text.find(self.end_token)
            reasoning_content = delta_text[start_index +
                                           len(self.start_token):end_index]
            content = delta_text[end_index + len(self.end_token):]
            return DeltaMessage(
                reasoning_content=reasoning_content,
                content=content if content else None,
            )
        else:
            # <think> in delta, no </think> in delta,
            # reasoning content continues
            return DeltaMessage(reasoning_content=delta_text)
    else:
        # No <think> in previous or delta, also need to check for </think>.
        # Because the model may have generated </think> without <think>
        # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
        if self.end_token_id in delta_token_ids:
            # </think> in delta with more tokens,
            # extract reasoning content and content
            end_index = delta_text.find(self.end_token)
            reasoning_content = delta_text[:end_index]
            content = delta_text[end_index + len(self.end_token):]
            return DeltaMessage(
                reasoning_content=reasoning_content,
                content=content if content else None,
            )
        elif self.end_token_id in previous_token_ids:
            # </think> in previous, thinking content ends
            return DeltaMessage(content=delta_text)
        else:
            # no </think> in previous or delta, reasoning content continues
            return DeltaMessage(reasoning_content=delta_text)

is_reasoning_end

is_reasoning_end(input_ids: list[int]) -> bool
Source code in vllm/reasoning/deepseek_r1_reasoning_parser.py
def is_reasoning_end(self, input_ids: list[int]) -> bool:
    return self.end_token_id in input_ids

GraniteReasoningParser

Bases: ReasoningParser

Reasoning parser for IBM Granite.

IBM granite models currently use "Here is my thought process:" and "Here is my response:" to separate its thinking / response outputs.

Source code in vllm/reasoning/granite_reasoning_parser.py
@ReasoningParserManager.register_module("granite")
class GraniteReasoningParser(ReasoningParser):
    """
    Reasoning parser for IBM Granite.

    IBM granite models currently use "Here is my thought process:"
    and "Here is my response:" to separate its thinking / response outputs.
    """

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)

        # NOTE: There have been some observed occurrences of quantized
        # instances of the current models using "Here's" instead of "Here is",
        # so to be safe, we match on both.
        self.think_start_expr = r"(?:Here's|Here is) my thought process:"
        self.response_start_expr = r"(?:Here's|Here is) my response:"

        self.reasoning_regex = re.compile(
            rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
            re.DOTALL)

        self.valid_think_starts = [
            "Here's my thought process:", "Here is my thought process:"
        ]
        self.valid_response_starts = [
            "Here's my response:", "Here is my response:"
        ]

        # Substrings to match for sequence boundaries on raw text
        self.seq_boundary_end = ":"
        self.seq_boundary_start = "Here"

        # The longest any thinking / start of response message can be
        self.longest_think_start = max(
            len(think_start) for think_start in self.valid_think_starts)

    def extract_reasoning_content(
            self, model_output: str, request: ChatCompletionRequest
    ) -> tuple[Optional[str], Optional[str]]:
        """Extract the reasoning content & content sections, respectively.
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        Args:
            model_output (str): Output of the model to be parsed.
            request (ChatCompletionRequest): Request being processed.

        Returns:
            tuple[Optional[str], Optional[str]]: Tuple pair containing the
            reasoning content and non-reasoning content.
        """
        re_match = self.reasoning_regex.findall(model_output)
        if not re_match:
            return None, model_output
        reasoning_content, response_content = re_match[0]
        if not response_content:
            return reasoning_content, None
        return reasoning_content, response_content

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        """Extract the reasoning content / content emitted by granite models;
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        NOTE: Granite models do not use a special token to start their reasoning
        and response sections; instead they have token sequences, e.g.,

                Here is my thought process: Foo Here is my response: Bar

        This increases the complexity of correctly handling streams, since we
        need to watch for specific sequences and correctly parse them without
        dropping content that is potentially overlapping & spanning multiple
        delta messages.

        Args:
            previous_text (str): Previous text outside of this delta message.
            current_text (str): Previous text + delta text.
            delta_text (str): Text to consider and parse content from.
            previous_token_ids (Sequence[int]): Token IDs of previous_text.
            current_token_ids (Sequence[int]): Token IDs of current_text.
            delta_token_ids (Sequence[int]): Token IDs of delta_text.

        Returns:
            Union[DeltaMessage, None]
                DeltaMessage with either reasoning content or content, or None.
        """
        reasoning_content, resp_seq_len, content = self._get_content_sections(
            current_text)
        # Either we haven't finished the start of the reasoning sequence,
        # or the model is generating something unexpected.
        if not reasoning_content:
            delta_message = self._get_delta_message_with_no_reasoning_bounds(
                current_text, delta_text)
        # We have a start of reasoning message, but have not yet finished
        # the start of response sequence.
        elif not content:
            delta_message = self._get_delta_message_with_no_response_bounds(
                current_text, reasoning_content, delta_text)
        # We've finished both the start of reasoning and start of response seq.
        else:
            # This should never happen since we matched on the response
            assert resp_seq_len is not None
            delta_message = self._get_delta_message_with_both_bounds(
                delta_text, reasoning_content, content, current_text,
                resp_seq_len)
        if not delta_message.content and not delta_message.reasoning_content:
            return None
        return delta_message

    #### Implementation details of stream parsing for granite models
    def _is_reasoning_start_substr(self, text: str) -> bool:
        """Check if a text matches one of the possible start reasoning seqs.

        Args:
            text (str): Text to check for leading substr.

        Returns:
            bool: True if any of the possible reasoning start seqs match.
        """
        return any(
            think_start.startswith(text)
            for think_start in self.valid_think_starts)

    def _is_response_start_substr(self, text: str) -> bool:
        """Check if a text matches one of the possible start response seqs.

        Args:
            text (str): Text to check for leading substr.

        Returns:
            bool: True if any of the possible response start seqs match.
        """
        return any(
            response_start.startswith(text)
            for response_start in self.valid_response_starts)

    def _get_delta_message_with_no_reasoning_bounds(
        self,
        current_text: str,
        delta_text: str,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has not yet completed
        its start of reasoning sequence.

        Args:
            current_text (str): The full previous + delta text.
            delta_text (str): Text to consider and parse content from.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        prev_longest_length = len(current_text) - len(delta_text)
        is_substr = self._is_reasoning_start_substr(current_text)
        was_substr = self._is_reasoning_start_substr(
            current_text[:prev_longest_length])

        # Check if we just generated something NOT in the special token seq;
        # if so, add everything that we previously skipped with this delta
        # message and append everything to content in the future.
        if was_substr and not is_substr:
            return DeltaMessage(
                reasoning_content=None,
                content=current_text,
            )
        if is_substr:
            # Might still be in the special token sequence; return nothing
            return DeltaMessage(reasoning_content=None, content=None)
        # Otherwise the sequence has already been broken and we already
        # corrected; just return the delta text as normal content.
        return DeltaMessage(reasoning_content=None, content=delta_text)

    def _get_delta_message_with_no_response_bounds(
        self,
        current_text: str,
        reasoning_content: str,
        delta_text: str,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has both reasoning
        content with no (response) content. NOTE that we may have overlapping
        tokens with the start of reasoning / start of response sequences on
        either side of the delta text.

        Args:
            current_text (str): The full previous + delta text.
            reasoning_content (str): reasoning content from current_text.
            delta_text (str): Text to consider and parse content from.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        # If we have no reasoning content or explicitly end with the start of
        # response sequence, we are in transition to the response; need to be
        # careful here, since the final token (:) will match the reasoning
        # content and fully parse it out; we should not pass the : back.
        ends_with_start_response_seq = any(
            current_text.endswith(response_start)
            for response_start in self.valid_response_starts)
        if reasoning_content is None or ends_with_start_response_seq:
            return DeltaMessage(reasoning_content=None, content=None)

        # Consider previous / current text only within context of the reasoning
        previous_text = reasoning_content[:-len(delta_text)]
        current_text = reasoning_content

        # We need to be careful about adding unfinished response sequences;
        # Find the place at which we MIGHT be starting a response sequence
        prev_idx = previous_text.rfind(self.seq_boundary_start)
        delta_idx = delta_text.rfind(self.seq_boundary_start)

        # Check the state of potential start of response substring matches.
        prev_was_substr = self._is_response_start_substr(
            previous_text[prev_idx:]) if prev_idx >= 0 else False
        delta_continues_substr = self._is_response_start_substr(
            current_text[prev_idx:]) if prev_idx >= 0 else False
        delta_new_substr = self._is_response_start_substr(
            delta_text[delta_idx:]) if delta_idx >= 0 else False

        # Delta only contains potential continued response sequence text.
        if delta_continues_substr:
            return DeltaMessage(reasoning_content=None, content=None)

        if not prev_was_substr:
            # Delta may be starting a new response seq but has other text too.
            if delta_new_substr:
                return DeltaMessage(reasoning_content=delta_text[:delta_idx],
                                    content=None)
            # Normal case for most reasoning text (no potential special seqs).
            return DeltaMessage(reasoning_content=delta_text, content=None)
        # The substring that previously seemed to be a potential response
        # seq wasn't one; we need to add the content to the delta message,
        # and also slice off the potential response sequence
        elif delta_new_substr:
            reasoning_content = previous_text[
                prev_idx:] + delta_text[:delta_idx]
            return DeltaMessage(reasoning_content=reasoning_content,
                                content=None)
        # No new substring yet, and we broke our old one; take the whole delta
        return DeltaMessage(
            reasoning_content=previous_text[prev_idx:] + delta_text,
            content=None,
        )

    def _get_delta_message_with_both_bounds(
        self,
        delta_text: str,
        reasoning_content: str,
        response_content: str,
        current_text: str,
        response_seq_len: int,
    ) -> DeltaMessage:
        """Parse the delta message when the current text has both reasoning
        content and normal (response) content.

        Args:
            delta_text (str): Text to consider and parse content from.
            reasoning_content (str): reasoning content from current_text.
            response_content (str): response content from current_text.
            current_text (str): The full previous + delta text.
            response_seq_len(str): Len of the complete response sequence used.

        Returns:
            DeltaMessage: Message containing the parsed content.
        """
        # Always have content; take length to the end
        delta_content = delta_text[-len(response_content):]
        reasoning_end_idx = len(delta_text) - (len(response_content) +
                                               response_seq_len)

        if reasoning_end_idx < 0:
            delta_reasoning_content = None
        else:
            # Get the starting offset
            start_reasoning_content_idx = len(
                reasoning_content) + response_seq_len + len(
                    response_content) - 1
            delta_offset = len(current_text) - len(delta_text)
            start_offset = start_reasoning_content_idx - delta_offset
            if start_offset < 0:
                start_offset = 0
            delta_reasoning_content = delta_text[
                start_offset:reasoning_end_idx]

        return DeltaMessage(
            reasoning_content=delta_reasoning_content,
            content=delta_content,
        )

    def _get_content_sections(
        self, current_text: str
    ) -> tuple[Optional[str], Optional[int], Optional[str]]:
        """Parse the text to extract the reasoning content / content
        if we have them.

        Args:
            current_text (str): The full previous + delta text.

        Returns:
            tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
            containing the reasoning content, the length of the response seq
            (if there is one) and the non-reasoning content.
        """
        current_chunk_start = 0
        start_reasoning_content = None
        parsed_content = False
        delimiter_idxs = [
            idx for idx, char in enumerate(current_text)
            if char == self.seq_boundary_end
        ]

        for current_chunk_end in delimiter_idxs:
            current_chunk = current_text[current_chunk_start:current_chunk_end]
            # Check to see if the start of reasoning seq if complete
            if start_reasoning_content is None:
                for think_start in self.valid_think_starts:
                    if current_chunk == think_start[:-1]:
                        start_reasoning_content = current_chunk_end + 1
                        current_chunk_start = current_chunk_end + 1
                        break

            # Check to see if the start of response seq if complete
            elif not parsed_content:
                for response_start in self.valid_response_starts:
                    if current_chunk[-len(response_start) +
                                     1:] == response_start[:-1]:
                        # Mark end of reasoning and start response content
                        # after the start of response sequence.
                        end_reasoning_content = current_chunk_end - len(
                            response_start)
                        reasoning_content = current_text[
                            start_reasoning_content:end_reasoning_content]
                        response_content = current_text[current_chunk_end + 1:]
                        return reasoning_content, len(
                            response_start), response_content

        if start_reasoning_content and not parsed_content:
            return current_text[start_reasoning_content:], None, None
        return None, None, None

longest_think_start instance-attribute

longest_think_start = max(
    len(think_start) for think_start in valid_think_starts
)

reasoning_regex instance-attribute

reasoning_regex = compile(
    f"{think_start_expr}(.*?){response_start_expr}(.*)",
    DOTALL,
)

response_start_expr instance-attribute

response_start_expr = "(?:Here's|Here is) my response:"

seq_boundary_end instance-attribute

seq_boundary_end = ':'

seq_boundary_start instance-attribute

seq_boundary_start = 'Here'

think_start_expr instance-attribute

think_start_expr = "(?:Here's|Here is) my thought process:"

valid_response_starts instance-attribute

valid_response_starts = [
    "Here's my response:",
    "Here is my response:",
]

valid_think_starts instance-attribute

valid_think_starts = [
    "Here's my thought process:",
    "Here is my thought process:",
]

__init__

__init__(tokenizer: PreTrainedTokenizerBase)
Source code in vllm/reasoning/granite_reasoning_parser.py
def __init__(self, tokenizer: PreTrainedTokenizerBase):
    super().__init__(tokenizer)

    # NOTE: There have been some observed occurrences of quantized
    # instances of the current models using "Here's" instead of "Here is",
    # so to be safe, we match on both.
    self.think_start_expr = r"(?:Here's|Here is) my thought process:"
    self.response_start_expr = r"(?:Here's|Here is) my response:"

    self.reasoning_regex = re.compile(
        rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
        re.DOTALL)

    self.valid_think_starts = [
        "Here's my thought process:", "Here is my thought process:"
    ]
    self.valid_response_starts = [
        "Here's my response:", "Here is my response:"
    ]

    # Substrings to match for sequence boundaries on raw text
    self.seq_boundary_end = ":"
    self.seq_boundary_start = "Here"

    # The longest any thinking / start of response message can be
    self.longest_think_start = max(
        len(think_start) for think_start in self.valid_think_starts)

_get_content_sections

_get_content_sections(
    current_text: str,
) -> tuple[Optional[str], Optional[int], Optional[str]]

Parse the text to extract the reasoning content / content if we have them.

Parameters:

Name Type Description Default
current_text str

The full previous + delta text.

required

Returns:

Type Description
Optional[str]

tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3

Optional[int]

containing the reasoning content, the length of the response seq

Optional[str]

(if there is one) and the non-reasoning content.

Source code in vllm/reasoning/granite_reasoning_parser.py
def _get_content_sections(
    self, current_text: str
) -> tuple[Optional[str], Optional[int], Optional[str]]:
    """Parse the text to extract the reasoning content / content
    if we have them.

    Args:
        current_text (str): The full previous + delta text.

    Returns:
        tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
        containing the reasoning content, the length of the response seq
        (if there is one) and the non-reasoning content.
    """
    current_chunk_start = 0
    start_reasoning_content = None
    parsed_content = False
    delimiter_idxs = [
        idx for idx, char in enumerate(current_text)
        if char == self.seq_boundary_end
    ]

    for current_chunk_end in delimiter_idxs:
        current_chunk = current_text[current_chunk_start:current_chunk_end]
        # Check to see if the start of reasoning seq if complete
        if start_reasoning_content is None:
            for think_start in self.valid_think_starts:
                if current_chunk == think_start[:-1]:
                    start_reasoning_content = current_chunk_end + 1
                    current_chunk_start = current_chunk_end + 1
                    break

        # Check to see if the start of response seq if complete
        elif not parsed_content:
            for response_start in self.valid_response_starts:
                if current_chunk[-len(response_start) +
                                 1:] == response_start[:-1]:
                    # Mark end of reasoning and start response content
                    # after the start of response sequence.
                    end_reasoning_content = current_chunk_end - len(
                        response_start)
                    reasoning_content = current_text[
                        start_reasoning_content:end_reasoning_content]
                    response_content = current_text[current_chunk_end + 1:]
                    return reasoning_content, len(
                        response_start), response_content

    if start_reasoning_content and not parsed_content:
        return current_text[start_reasoning_content:], None, None
    return None, None, None

_get_delta_message_with_both_bounds

_get_delta_message_with_both_bounds(
    delta_text: str,
    reasoning_content: str,
    response_content: str,
    current_text: str,
    response_seq_len: int,
) -> DeltaMessage

Parse the delta message when the current text has both reasoning content and normal (response) content.

Parameters:

Name Type Description Default
delta_text str

Text to consider and parse content from.

required
reasoning_content str

reasoning content from current_text.

required
response_content str

response content from current_text.

required
current_text str

The full previous + delta text.

required
response_seq_len(str)

Len of the complete response sequence used.

required

Returns:

Name Type Description
DeltaMessage DeltaMessage

Message containing the parsed content.

Source code in vllm/reasoning/granite_reasoning_parser.py
def _get_delta_message_with_both_bounds(
    self,
    delta_text: str,
    reasoning_content: str,
    response_content: str,
    current_text: str,
    response_seq_len: int,
) -> DeltaMessage:
    """Parse the delta message when the current text has both reasoning
    content and normal (response) content.

    Args:
        delta_text (str): Text to consider and parse content from.
        reasoning_content (str): reasoning content from current_text.
        response_content (str): response content from current_text.
        current_text (str): The full previous + delta text.
        response_seq_len(str): Len of the complete response sequence used.

    Returns:
        DeltaMessage: Message containing the parsed content.
    """
    # Always have content; take length to the end
    delta_content = delta_text[-len(response_content):]
    reasoning_end_idx = len(delta_text) - (len(response_content) +
                                           response_seq_len)

    if reasoning_end_idx < 0:
        delta_reasoning_content = None
    else:
        # Get the starting offset
        start_reasoning_content_idx = len(
            reasoning_content) + response_seq_len + len(
                response_content) - 1
        delta_offset = len(current_text) - len(delta_text)
        start_offset = start_reasoning_content_idx - delta_offset
        if start_offset < 0:
            start_offset = 0
        delta_reasoning_content = delta_text[
            start_offset:reasoning_end_idx]

    return DeltaMessage(
        reasoning_content=delta_reasoning_content,
        content=delta_content,
    )

_get_delta_message_with_no_reasoning_bounds

_get_delta_message_with_no_reasoning_bounds(
    current_text: str, delta_text: str
) -> DeltaMessage

Parse the delta message when the current text has not yet completed its start of reasoning sequence.

Parameters:

Name Type Description Default
current_text str

The full previous + delta text.

required
delta_text str

Text to consider and parse content from.

required

Returns:

Name Type Description
DeltaMessage DeltaMessage

Message containing the parsed content.

Source code in vllm/reasoning/granite_reasoning_parser.py
def _get_delta_message_with_no_reasoning_bounds(
    self,
    current_text: str,
    delta_text: str,
) -> DeltaMessage:
    """Parse the delta message when the current text has not yet completed
    its start of reasoning sequence.

    Args:
        current_text (str): The full previous + delta text.
        delta_text (str): Text to consider and parse content from.

    Returns:
        DeltaMessage: Message containing the parsed content.
    """
    prev_longest_length = len(current_text) - len(delta_text)
    is_substr = self._is_reasoning_start_substr(current_text)
    was_substr = self._is_reasoning_start_substr(
        current_text[:prev_longest_length])

    # Check if we just generated something NOT in the special token seq;
    # if so, add everything that we previously skipped with this delta
    # message and append everything to content in the future.
    if was_substr and not is_substr:
        return DeltaMessage(
            reasoning_content=None,
            content=current_text,
        )
    if is_substr:
        # Might still be in the special token sequence; return nothing
        return DeltaMessage(reasoning_content=None, content=None)
    # Otherwise the sequence has already been broken and we already
    # corrected; just return the delta text as normal content.
    return DeltaMessage(reasoning_content=None, content=delta_text)

_get_delta_message_with_no_response_bounds

_get_delta_message_with_no_response_bounds(
    current_text: str,
    reasoning_content: str,
    delta_text: str,
) -> DeltaMessage

Parse the delta message when the current text has both reasoning content with no (response) content. NOTE that we may have overlapping tokens with the start of reasoning / start of response sequences on either side of the delta text.

Parameters:

Name Type Description Default
current_text str

The full previous + delta text.

required
reasoning_content str

reasoning content from current_text.

required
delta_text str

Text to consider and parse content from.

required

Returns:

Name Type Description
DeltaMessage DeltaMessage

Message containing the parsed content.

Source code in vllm/reasoning/granite_reasoning_parser.py
def _get_delta_message_with_no_response_bounds(
    self,
    current_text: str,
    reasoning_content: str,
    delta_text: str,
) -> DeltaMessage:
    """Parse the delta message when the current text has both reasoning
    content with no (response) content. NOTE that we may have overlapping
    tokens with the start of reasoning / start of response sequences on
    either side of the delta text.

    Args:
        current_text (str): The full previous + delta text.
        reasoning_content (str): reasoning content from current_text.
        delta_text (str): Text to consider and parse content from.

    Returns:
        DeltaMessage: Message containing the parsed content.
    """
    # If we have no reasoning content or explicitly end with the start of
    # response sequence, we are in transition to the response; need to be
    # careful here, since the final token (:) will match the reasoning
    # content and fully parse it out; we should not pass the : back.
    ends_with_start_response_seq = any(
        current_text.endswith(response_start)
        for response_start in self.valid_response_starts)
    if reasoning_content is None or ends_with_start_response_seq:
        return DeltaMessage(reasoning_content=None, content=None)

    # Consider previous / current text only within context of the reasoning
    previous_text = reasoning_content[:-len(delta_text)]
    current_text = reasoning_content

    # We need to be careful about adding unfinished response sequences;
    # Find the place at which we MIGHT be starting a response sequence
    prev_idx = previous_text.rfind(self.seq_boundary_start)
    delta_idx = delta_text.rfind(self.seq_boundary_start)

    # Check the state of potential start of response substring matches.
    prev_was_substr = self._is_response_start_substr(
        previous_text[prev_idx:]) if prev_idx >= 0 else False
    delta_continues_substr = self._is_response_start_substr(
        current_text[prev_idx:]) if prev_idx >= 0 else False
    delta_new_substr = self._is_response_start_substr(
        delta_text[delta_idx:]) if delta_idx >= 0 else False

    # Delta only contains potential continued response sequence text.
    if delta_continues_substr:
        return DeltaMessage(reasoning_content=None, content=None)

    if not prev_was_substr:
        # Delta may be starting a new response seq but has other text too.
        if delta_new_substr:
            return DeltaMessage(reasoning_content=delta_text[:delta_idx],
                                content=None)
        # Normal case for most reasoning text (no potential special seqs).
        return DeltaMessage(reasoning_content=delta_text, content=None)
    # The substring that previously seemed to be a potential response
    # seq wasn't one; we need to add the content to the delta message,
    # and also slice off the potential response sequence
    elif delta_new_substr:
        reasoning_content = previous_text[
            prev_idx:] + delta_text[:delta_idx]
        return DeltaMessage(reasoning_content=reasoning_content,
                            content=None)
    # No new substring yet, and we broke our old one; take the whole delta
    return DeltaMessage(
        reasoning_content=previous_text[prev_idx:] + delta_text,
        content=None,
    )

_is_reasoning_start_substr

_is_reasoning_start_substr(text: str) -> bool

Check if a text matches one of the possible start reasoning seqs.

Parameters:

Name Type Description Default
text str

Text to check for leading substr.

required

Returns:

Name Type Description
bool bool

True if any of the possible reasoning start seqs match.

Source code in vllm/reasoning/granite_reasoning_parser.py
def _is_reasoning_start_substr(self, text: str) -> bool:
    """Check if a text matches one of the possible start reasoning seqs.

    Args:
        text (str): Text to check for leading substr.

    Returns:
        bool: True if any of the possible reasoning start seqs match.
    """
    return any(
        think_start.startswith(text)
        for think_start in self.valid_think_starts)

_is_response_start_substr

_is_response_start_substr(text: str) -> bool

Check if a text matches one of the possible start response seqs.

Parameters:

Name Type Description Default
text str

Text to check for leading substr.

required

Returns:

Name Type Description
bool bool

True if any of the possible response start seqs match.

Source code in vllm/reasoning/granite_reasoning_parser.py
def _is_response_start_substr(self, text: str) -> bool:
    """Check if a text matches one of the possible start response seqs.

    Args:
        text (str): Text to check for leading substr.

    Returns:
        bool: True if any of the possible response start seqs match.
    """
    return any(
        response_start.startswith(text)
        for response_start in self.valid_response_starts)

extract_reasoning_content

extract_reasoning_content(
    model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]

Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates something else, all content is considered non-reasoning content.

Parameters:

Name Type Description Default
model_output str

Output of the model to be parsed.

required
request ChatCompletionRequest

Request being processed.

required

Returns:

Type Description
Optional[str]

tuple[Optional[str], Optional[str]]: Tuple pair containing the

Optional[str]

reasoning content and non-reasoning content.

Source code in vllm/reasoning/granite_reasoning_parser.py
def extract_reasoning_content(
        self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
    """Extract the reasoning content & content sections, respectively.
    If the sequence doesn't match what we expect, i.e., the model generates
    something else, all content is considered non-reasoning content.

    Args:
        model_output (str): Output of the model to be parsed.
        request (ChatCompletionRequest): Request being processed.

    Returns:
        tuple[Optional[str], Optional[str]]: Tuple pair containing the
        reasoning content and non-reasoning content.
    """
    re_match = self.reasoning_regex.findall(model_output)
    if not re_match:
        return None, model_output
    reasoning_content, response_content = re_match[0]
    if not response_content:
        return reasoning_content, None
    return reasoning_content, response_content

extract_reasoning_content_streaming

extract_reasoning_content_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]

Extract the reasoning content / content emitted by granite models; If the sequence doesn't match what we expect, i.e., the model generates something else, all content is considered non-reasoning content.

NOTE: Granite models do not use a special token to start their reasoning and response sections; instead they have token sequences, e.g.,

    Here is my thought process: Foo Here is my response: Bar

This increases the complexity of correctly handling streams, since we need to watch for specific sequences and correctly parse them without dropping content that is potentially overlapping & spanning multiple delta messages.

Parameters:

Name Type Description Default
previous_text str

Previous text outside of this delta message.

required
current_text str

Previous text + delta text.

required
delta_text str

Text to consider and parse content from.

required
previous_token_ids Sequence[int]

Token IDs of previous_text.

required
current_token_ids Sequence[int]

Token IDs of current_text.

required
delta_token_ids Sequence[int]

Token IDs of delta_text.

required

Returns:

Type Description
Union[DeltaMessage, None]

Union[DeltaMessage, None] DeltaMessage with either reasoning content or content, or None.

Source code in vllm/reasoning/granite_reasoning_parser.py
def extract_reasoning_content_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
    """Extract the reasoning content / content emitted by granite models;
    If the sequence doesn't match what we expect, i.e., the model generates
    something else, all content is considered non-reasoning content.

    NOTE: Granite models do not use a special token to start their reasoning
    and response sections; instead they have token sequences, e.g.,

            Here is my thought process: Foo Here is my response: Bar

    This increases the complexity of correctly handling streams, since we
    need to watch for specific sequences and correctly parse them without
    dropping content that is potentially overlapping & spanning multiple
    delta messages.

    Args:
        previous_text (str): Previous text outside of this delta message.
        current_text (str): Previous text + delta text.
        delta_text (str): Text to consider and parse content from.
        previous_token_ids (Sequence[int]): Token IDs of previous_text.
        current_token_ids (Sequence[int]): Token IDs of current_text.
        delta_token_ids (Sequence[int]): Token IDs of delta_text.

    Returns:
        Union[DeltaMessage, None]
            DeltaMessage with either reasoning content or content, or None.
    """
    reasoning_content, resp_seq_len, content = self._get_content_sections(
        current_text)
    # Either we haven't finished the start of the reasoning sequence,
    # or the model is generating something unexpected.
    if not reasoning_content:
        delta_message = self._get_delta_message_with_no_reasoning_bounds(
            current_text, delta_text)
    # We have a start of reasoning message, but have not yet finished
    # the start of response sequence.
    elif not content:
        delta_message = self._get_delta_message_with_no_response_bounds(
            current_text, reasoning_content, delta_text)
    # We've finished both the start of reasoning and start of response seq.
    else:
        # This should never happen since we matched on the response
        assert resp_seq_len is not None
        delta_message = self._get_delta_message_with_both_bounds(
            delta_text, reasoning_content, content, current_text,
            resp_seq_len)
    if not delta_message.content and not delta_message.reasoning_content:
        return None
    return delta_message

Qwen3ReasoningParser

Bases: ReasoningParser

Reasoning parser for the Qwen3 model.

The Qwen3 model uses ... tokens to denote reasoning text within its output. The model provides a strict switch to disable reasoning output via the 'enable_thinking=False' parameter. This parser extracts the reasoning content enclosed by and tokens from the model's output.

Source code in vllm/reasoning/qwen3_reasoning_parser.py
@ReasoningParserManager.register_module("qwen3")
class Qwen3ReasoningParser(ReasoningParser):
    """
    Reasoning parser for the Qwen3 model.

    The Qwen3 model uses <think>...</think> tokens to denote reasoning text
    within its output. The model provides a strict switch to disable reasoning
    output via the 'enable_thinking=False' parameter. This parser extracts the
    reasoning content enclosed by <think> and </think> tokens from the model's
    output.
    """

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)
        self.think_start_token = "<think>"
        self.think_end_token = "</think>"

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
                "constructor during construction.")

        self.think_start_token_id = self.vocab.get(self.think_start_token)
        self.think_end_token_id = self.vocab.get(self.think_end_token)
        if (self.think_start_token_id is None
                or self.think_end_token_id is None):
            raise RuntimeError(
                "Qwen3 reasoning parser could not locate think start/end "
                "tokens in the tokenizer!")

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        return self.think_end_token_id in input_ids

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract the content after the end tokens
        """
        if self.think_end_token_id not in input_ids[:-1]:
            return []
        else:
            return input_ids[input_ids.index(self.think_end_token_id) + 1:]

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        """
        Extract reasoning content from a delta message.
        Handles streaming output where previous + delta = current.
        Uses token IDs for faster processing.
        For text <think>abc</think>xyz:
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content
        """
        # Skip single special tokens
        if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
                self.think_start_token_id, self.think_end_token_id
        ]):
            return None

        if self.think_start_token_id in previous_token_ids:
            if self.think_end_token_id in delta_token_ids:
                # <think> in previous, </think> in delta,
                # extract reasoning content
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.think_end_token):]
                return DeltaMessage(reasoning_content=reasoning_content,
                                    content=content if content else None)
            elif self.think_end_token_id in previous_token_ids:
                # <think> in previous, </think> in previous,
                # reasoning content continues
                return DeltaMessage(content=delta_text)
            else:
                # <think> in previous, no </think> in previous or delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        elif self.think_start_token_id in delta_token_ids:
            if self.think_end_token_id in delta_token_ids:
                # <think> in delta, </think> in delta, extract reasoning content
                start_index = delta_text.find(self.think_start_token)
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[start_index +
                                               len(self.think_start_token
                                                   ):end_index]
                content = delta_text[end_index + len(self.think_end_token):]
                return DeltaMessage(reasoning_content=reasoning_content,
                                    content=content if content else None)
            else:
                # <think> in delta, no </think> in delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        else:
            # thinking is disabled, just content
            return DeltaMessage(content=delta_text)

    def extract_reasoning_content(
            self, model_output: str, request: ChatCompletionRequest
    ) -> tuple[Optional[str], Optional[str]]:
        """
        Extract reasoning content from the model output.

        For text <think>abc</think>xyz:
        - 'abc' goes to reasoning_content
        - 'xyz' goes to content

        Returns:
            tuple[Optional[str], Optional[str]]: reasoning content and content
        """

        # Check if the model output contains the <think> and </think> tokens.
        if (self.think_start_token not in model_output
                or self.think_end_token not in model_output):
            return None, model_output
        # Check if the <think> is present in the model output, remove it
        # if it is present.
        model_output_parts = model_output.partition(self.think_start_token)
        model_output = model_output_parts[2] if model_output_parts[
            1] else model_output_parts[0]
        # Check if the model output contains the </think> tokens.
        # If the end token is not found, return the model output as is.
        if self.think_end_token not in model_output:
            return None, model_output

        # Extract reasoning content from the model output.
        reasoning_content, _, content = model_output.partition(
            self.think_end_token)

        final_content = content or None
        return reasoning_content, final_content

think_end_token instance-attribute

think_end_token = '</think>'

think_end_token_id instance-attribute

think_end_token_id = get(think_end_token)

think_start_token instance-attribute

think_start_token = '<think>'

think_start_token_id instance-attribute

think_start_token_id = get(think_start_token)

__init__

__init__(tokenizer: PreTrainedTokenizerBase)
Source code in vllm/reasoning/qwen3_reasoning_parser.py
def __init__(self, tokenizer: PreTrainedTokenizerBase):
    super().__init__(tokenizer)
    self.think_start_token = "<think>"
    self.think_end_token = "</think>"

    if not self.model_tokenizer:
        raise ValueError(
            "The model tokenizer must be passed to the ReasoningParser "
            "constructor during construction.")

    self.think_start_token_id = self.vocab.get(self.think_start_token)
    self.think_end_token_id = self.vocab.get(self.think_end_token)
    if (self.think_start_token_id is None
            or self.think_end_token_id is None):
        raise RuntimeError(
            "Qwen3 reasoning parser could not locate think start/end "
            "tokens in the tokenizer!")

extract_content_ids

extract_content_ids(input_ids: list[int]) -> list[int]

Extract the content after the end tokens

Source code in vllm/reasoning/qwen3_reasoning_parser.py
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
    """
    Extract the content after the end tokens
    """
    if self.think_end_token_id not in input_ids[:-1]:
        return []
    else:
        return input_ids[input_ids.index(self.think_end_token_id) + 1:]

extract_reasoning_content

extract_reasoning_content(
    model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]

Extract reasoning content from the model output.

For text abcxyz: - 'abc' goes to reasoning_content - 'xyz' goes to content

Returns:

Type Description
tuple[Optional[str], Optional[str]]

tuple[Optional[str], Optional[str]]: reasoning content and content

Source code in vllm/reasoning/qwen3_reasoning_parser.py
def extract_reasoning_content(
        self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
    """
    Extract reasoning content from the model output.

    For text <think>abc</think>xyz:
    - 'abc' goes to reasoning_content
    - 'xyz' goes to content

    Returns:
        tuple[Optional[str], Optional[str]]: reasoning content and content
    """

    # Check if the model output contains the <think> and </think> tokens.
    if (self.think_start_token not in model_output
            or self.think_end_token not in model_output):
        return None, model_output
    # Check if the <think> is present in the model output, remove it
    # if it is present.
    model_output_parts = model_output.partition(self.think_start_token)
    model_output = model_output_parts[2] if model_output_parts[
        1] else model_output_parts[0]
    # Check if the model output contains the </think> tokens.
    # If the end token is not found, return the model output as is.
    if self.think_end_token not in model_output:
        return None, model_output

    # Extract reasoning content from the model output.
    reasoning_content, _, content = model_output.partition(
        self.think_end_token)

    final_content = content or None
    return reasoning_content, final_content

extract_reasoning_content_streaming

extract_reasoning_content_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]

Extract reasoning content from a delta message. Handles streaming output where previous + delta = current. Uses token IDs for faster processing. For text abcxyz: - 'abc' goes to reasoning_content - 'xyz' goes to content

Source code in vllm/reasoning/qwen3_reasoning_parser.py
def extract_reasoning_content_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
    """
    Extract reasoning content from a delta message.
    Handles streaming output where previous + delta = current.
    Uses token IDs for faster processing.
    For text <think>abc</think>xyz:
    - 'abc' goes to reasoning_content
    - 'xyz' goes to content
    """
    # Skip single special tokens
    if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
            self.think_start_token_id, self.think_end_token_id
    ]):
        return None

    if self.think_start_token_id in previous_token_ids:
        if self.think_end_token_id in delta_token_ids:
            # <think> in previous, </think> in delta,
            # extract reasoning content
            end_index = delta_text.find(self.think_end_token)
            reasoning_content = delta_text[:end_index]
            content = delta_text[end_index + len(self.think_end_token):]
            return DeltaMessage(reasoning_content=reasoning_content,
                                content=content if content else None)
        elif self.think_end_token_id in previous_token_ids:
            # <think> in previous, </think> in previous,
            # reasoning content continues
            return DeltaMessage(content=delta_text)
        else:
            # <think> in previous, no </think> in previous or delta,
            # reasoning content continues
            return DeltaMessage(reasoning_content=delta_text)
    elif self.think_start_token_id in delta_token_ids:
        if self.think_end_token_id in delta_token_ids:
            # <think> in delta, </think> in delta, extract reasoning content
            start_index = delta_text.find(self.think_start_token)
            end_index = delta_text.find(self.think_end_token)
            reasoning_content = delta_text[start_index +
                                           len(self.think_start_token
                                               ):end_index]
            content = delta_text[end_index + len(self.think_end_token):]
            return DeltaMessage(reasoning_content=reasoning_content,
                                content=content if content else None)
        else:
            # <think> in delta, no </think> in delta,
            # reasoning content continues
            return DeltaMessage(reasoning_content=delta_text)
    else:
        # thinking is disabled, just content
        return DeltaMessage(content=delta_text)

is_reasoning_end

is_reasoning_end(input_ids: list[int]) -> bool
Source code in vllm/reasoning/qwen3_reasoning_parser.py
def is_reasoning_end(self, input_ids: list[int]) -> bool:
    return self.think_end_token_id in input_ids

ReasoningParser

Abstract reasoning parser class that should not be used directly. Provided and methods should be used in derived classes.

It is used to extract reasoning content from the model output.

Source code in vllm/reasoning/abs_reasoning_parsers.py
class ReasoningParser:
    """
    Abstract reasoning parser class that should not be used directly.
    Provided and methods should be used in derived classes.

    It is used to extract reasoning content from the model output.
    """

    def __init__(self, tokenizer: AnyTokenizer):
        self.model_tokenizer = tokenizer

    @cached_property
    def vocab(self) -> dict[str, int]:
        # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
        # whereas all tokenizers have .get_vocab()
        return self.model_tokenizer.get_vocab()

    @abstractmethod
    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
        """
        Check if the reasoning content ends in the input_ids.

        It is used in structured engines like `xgrammar` to check if the
        reasoning content ends in the model output.

        Parameters:
        input_ids: list[int]
            The input_ids of the model output.

        Returns:
        bool
            True if the reasoning content ends in the input_ids.
        """

    @abstractmethod
    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract content token ids from the input_ids.
        Parameters:
        input_ids: list[int]
            The input_ids of the model output.
        Returns:
        list[int]
            The extracted content from the input_ids.
        """

    @abstractmethod
    def extract_reasoning_content(
            self, model_output: str, request: ChatCompletionRequest
    ) -> tuple[Optional[str], Optional[str]]:
        """
        Extract reasoning content from a complete model-generated string.

        Used for non-streaming responses where we have the entire model response
        available before sending to the client.

        Parameters:
        model_output: str
            The model-generated string to extract reasoning content from.

        request: ChatCompletionRequest
            The request object that was used to generate the model_output.

        Returns:
        tuple[Optional[str], Optional[str]]
            A tuple containing the reasoning content and the content.
        """

    @abstractmethod
    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        """
        Instance method that should be implemented for extracting reasoning
        from an incomplete response; for use when handling reasoning calls and
        streaming. Has to be an instance method because  it requires state -
        the current tokens/diffs, but also the information about what has
        previously been parsed and extracted (see constructor)
        """

model_tokenizer instance-attribute

model_tokenizer = tokenizer

vocab cached property

vocab: dict[str, int]

__init__

__init__(tokenizer: AnyTokenizer)
Source code in vllm/reasoning/abs_reasoning_parsers.py
def __init__(self, tokenizer: AnyTokenizer):
    self.model_tokenizer = tokenizer

extract_content_ids abstractmethod

extract_content_ids(input_ids: list[int]) -> list[int]

Extract content token ids from the input_ids. Parameters: input_ids: list[int] The input_ids of the model output. Returns: list[int] The extracted content from the input_ids.

Source code in vllm/reasoning/abs_reasoning_parsers.py
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
    """
    Extract content token ids from the input_ids.
    Parameters:
    input_ids: list[int]
        The input_ids of the model output.
    Returns:
    list[int]
        The extracted content from the input_ids.
    """

extract_reasoning_content abstractmethod

extract_reasoning_content(
    model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]

Extract reasoning content from a complete model-generated string.

Used for non-streaming responses where we have the entire model response available before sending to the client.

model_output: str The model-generated string to extract reasoning content from.

ChatCompletionRequest

The request object that was used to generate the model_output.

tuple[Optional[str], Optional[str]] A tuple containing the reasoning content and the content.

Source code in vllm/reasoning/abs_reasoning_parsers.py
@abstractmethod
def extract_reasoning_content(
        self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
    """
    Extract reasoning content from a complete model-generated string.

    Used for non-streaming responses where we have the entire model response
    available before sending to the client.

    Parameters:
    model_output: str
        The model-generated string to extract reasoning content from.

    request: ChatCompletionRequest
        The request object that was used to generate the model_output.

    Returns:
    tuple[Optional[str], Optional[str]]
        A tuple containing the reasoning content and the content.
    """

extract_reasoning_content_streaming abstractmethod

extract_reasoning_content_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]

Instance method that should be implemented for extracting reasoning from an incomplete response; for use when handling reasoning calls and streaming. Has to be an instance method because it requires state - the current tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor)

Source code in vllm/reasoning/abs_reasoning_parsers.py
@abstractmethod
def extract_reasoning_content_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
    """
    Instance method that should be implemented for extracting reasoning
    from an incomplete response; for use when handling reasoning calls and
    streaming. Has to be an instance method because  it requires state -
    the current tokens/diffs, but also the information about what has
    previously been parsed and extracted (see constructor)
    """

is_reasoning_end abstractmethod

is_reasoning_end(input_ids: Sequence[int]) -> bool

Check if the reasoning content ends in the input_ids.

It is used in structured engines like xgrammar to check if the reasoning content ends in the model output.

input_ids: list[int] The input_ids of the model output.

bool True if the reasoning content ends in the input_ids.

Source code in vllm/reasoning/abs_reasoning_parsers.py
@abstractmethod
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
    """
    Check if the reasoning content ends in the input_ids.

    It is used in structured engines like `xgrammar` to check if the
    reasoning content ends in the model output.

    Parameters:
    input_ids: list[int]
        The input_ids of the model output.

    Returns:
    bool
        True if the reasoning content ends in the input_ids.
    """

ReasoningParserManager

Source code in vllm/reasoning/abs_reasoning_parsers.py
class ReasoningParserManager:
    reasoning_parsers: dict[str, type] = {}

    @classmethod
    def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
        """
        Get reasoning parser by name which is registered by `register_module`.

        Raise a KeyError exception if the name is not registered.
        """
        if name in cls.reasoning_parsers:
            return cls.reasoning_parsers[name]

        raise KeyError(
            f"reasoning helper: '{name}' not found in reasoning_parsers")

    @classmethod
    def _register_module(
        cls,
        module: type,
        module_name: Optional[Union[str, list[str]]] = None,
        force: bool = True,
    ) -> None:
        if not issubclass(module, ReasoningParser):
            raise TypeError("module must be subclass of ReasoningParser, "
                            f"but got {type(module)}")
        if module_name is None:
            module_name = module.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        for name in module_name:
            if not force and name in cls.reasoning_parsers:
                existed_module = cls.reasoning_parsers[name]
                raise KeyError(f"{name} is already registered "
                               f"at {existed_module.__module__}")
            cls.reasoning_parsers[name] = module

    @classmethod
    def register_module(
        cls,
        name: Optional[Union[str, list[str]]] = None,
        force: bool = True,
        module: Union[type, None] = None,
    ) -> Union[type, Callable]:
        """
        Register module with the given name or name list. it can be used as a
        decoder(with module as None) or normal function(with module as not
        None).
        """
        if not isinstance(force, bool):
            raise TypeError(f"force must be a boolean, but got {type(force)}")

        # raise the error ahead of time
        if not (name is None or isinstance(name, str)
                or is_list_of(name, str)):
            raise TypeError(
                "name must be None, an instance of str, or a sequence of str, "
                f"but got {type(name)}")

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            cls._register_module(module=module, module_name=name, force=force)
            return module

        # use it as a decorator: @x.register_module()
        def _register(module):
            cls._register_module(module=module, module_name=name, force=force)
            return module

        return _register

    @classmethod
    def import_reasoning_parser(cls, plugin_path: str) -> None:
        """
        Import a user-defined reasoning parser by the path
        of the reasoning parser define file.
        """
        module_name = os.path.splitext(os.path.basename(plugin_path))[0]

        try:
            import_from_path(module_name, plugin_path)
        except Exception:
            logger.exception("Failed to load module '%s' from %s.",
                             module_name, plugin_path)
            return

reasoning_parsers class-attribute instance-attribute

reasoning_parsers: dict[str, type] = {}

_register_module classmethod

_register_module(
    module: type,
    module_name: Optional[Union[str, list[str]]] = None,
    force: bool = True,
) -> None
Source code in vllm/reasoning/abs_reasoning_parsers.py
@classmethod
def _register_module(
    cls,
    module: type,
    module_name: Optional[Union[str, list[str]]] = None,
    force: bool = True,
) -> None:
    if not issubclass(module, ReasoningParser):
        raise TypeError("module must be subclass of ReasoningParser, "
                        f"but got {type(module)}")
    if module_name is None:
        module_name = module.__name__
    if isinstance(module_name, str):
        module_name = [module_name]
    for name in module_name:
        if not force and name in cls.reasoning_parsers:
            existed_module = cls.reasoning_parsers[name]
            raise KeyError(f"{name} is already registered "
                           f"at {existed_module.__module__}")
        cls.reasoning_parsers[name] = module

get_reasoning_parser classmethod

get_reasoning_parser(
    name: str | None,
) -> type[ReasoningParser]

Get reasoning parser by name which is registered by register_module.

Raise a KeyError exception if the name is not registered.

Source code in vllm/reasoning/abs_reasoning_parsers.py
@classmethod
def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
    """
    Get reasoning parser by name which is registered by `register_module`.

    Raise a KeyError exception if the name is not registered.
    """
    if name in cls.reasoning_parsers:
        return cls.reasoning_parsers[name]

    raise KeyError(
        f"reasoning helper: '{name}' not found in reasoning_parsers")

import_reasoning_parser classmethod

import_reasoning_parser(plugin_path: str) -> None

Import a user-defined reasoning parser by the path of the reasoning parser define file.

Source code in vllm/reasoning/abs_reasoning_parsers.py
@classmethod
def import_reasoning_parser(cls, plugin_path: str) -> None:
    """
    Import a user-defined reasoning parser by the path
    of the reasoning parser define file.
    """
    module_name = os.path.splitext(os.path.basename(plugin_path))[0]

    try:
        import_from_path(module_name, plugin_path)
    except Exception:
        logger.exception("Failed to load module '%s' from %s.",
                         module_name, plugin_path)
        return

register_module classmethod

register_module(
    name: Optional[Union[str, list[str]]] = None,
    force: bool = True,
    module: Union[type, None] = None,
) -> Union[type, Callable]

Register module with the given name or name list. it can be used as a decoder(with module as None) or normal function(with module as not None).

Source code in vllm/reasoning/abs_reasoning_parsers.py
@classmethod
def register_module(
    cls,
    name: Optional[Union[str, list[str]]] = None,
    force: bool = True,
    module: Union[type, None] = None,
) -> Union[type, Callable]:
    """
    Register module with the given name or name list. it can be used as a
    decoder(with module as None) or normal function(with module as not
    None).
    """
    if not isinstance(force, bool):
        raise TypeError(f"force must be a boolean, but got {type(force)}")

    # raise the error ahead of time
    if not (name is None or isinstance(name, str)
            or is_list_of(name, str)):
        raise TypeError(
            "name must be None, an instance of str, or a sequence of str, "
            f"but got {type(name)}")

    # use it as a normal method: x.register_module(module=SomeClass)
    if module is not None:
        cls._register_module(module=module, module_name=name, force=force)
        return module

    # use it as a decorator: @x.register_module()
    def _register(module):
        cls._register_module(module=module, module_name=name, force=force)
        return module

    return _register