Skip to content

vllm.reasoning.minimax_m3_reasoning_parser

Classes:

MiniMaxM3ReasoningParser

Bases: BaseThinkingReasoningParser

Reasoning parser for MiniMax M3 explicit thinking blocks.

MiniMax M3 emits reasoning as:

<mm:think>reasoning text</mm:think>assistant content

The M3 tokenizer exposes both markers as complete vocabulary entries, but generated marker text may be tokenized into smaller pieces. The streaming parser therefore uses text markers for extraction instead of relying on the single vocabulary IDs. The chat template may also prefill the start marker when thinking_mode="enabled", so generated text can begin directly inside a reasoning block without emitting <mm:think> again.

Source code in vllm/reasoning/minimax_m3_reasoning_parser.py
class MiniMaxM3ReasoningParser(BaseThinkingReasoningParser):
    """Reasoning parser for MiniMax M3 explicit thinking blocks.

    MiniMax M3 emits reasoning as:

        <mm:think>reasoning text</mm:think>assistant content

    The M3 tokenizer exposes both markers as complete vocabulary entries, but
    generated marker text may be tokenized into smaller pieces. The streaming
    parser therefore uses text markers for extraction instead of relying on the
    single vocabulary IDs. The chat template may also prefill the start marker
    when ``thinking_mode="enabled"``, so generated text can begin directly
    inside a reasoning block without emitting ``<mm:think>`` again.
    """

    @property
    def start_token(self) -> str:
        return "<mm:think>"

    @property
    def end_token(self) -> str:
        return "</mm:think>"

    def __init__(self, tokenizer, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
        self._start_token_ids = self._encode_marker(self.start_token)
        self._end_token_ids = self._encode_marker(self.end_token)
        chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
        self._initial_in_reasoning = chat_kwargs.get("thinking_mode") == "enabled"
        self._reasoning_ended_streaming = False
        self._reasoning_active_streaming = self._initial_in_reasoning
        self._pending_marker_streaming = False
        self._last_streaming_delta_token_ids: tuple[int, ...] | None = None
        self._last_streaming_content_token_ids: list[int] | None = None

    def _encode_text(self, text: str) -> list[int]:
        try:
            return list(self.model_tokenizer.encode(text, add_special_tokens=False))
        except TypeError:
            return list(self.model_tokenizer.encode(text))

    def _encode_marker(self, marker: str) -> tuple[int, ...]:
        return tuple(self._encode_text(marker))

    def _decode_text(self, token_ids: Sequence[int]) -> str:
        try:
            return self.model_tokenizer.decode(
                list(token_ids), skip_special_tokens=False
            )
        except TypeError:
            return self.model_tokenizer.decode(list(token_ids))

    def _content_suffix_token_ids(
        self,
        delta_text: str,
        delta_token_ids: Sequence[int],
        content: str | None,
    ) -> list[int]:
        if content is None:
            return []
        if content == delta_text:
            return list(delta_token_ids)
        if delta_text.endswith(content):
            prefix_text = delta_text[: len(delta_text) - len(content)]
            for index in range(len(delta_token_ids) + 1):
                if self._decode_text(delta_token_ids[:index]) == prefix_text:
                    return list(delta_token_ids[index:])
        return self._encode_text(content)

    @staticmethod
    def _contains_token_sequence(
        token_ids: Sequence[int], marker_ids: Sequence[int]
    ) -> bool:
        if not marker_ids or len(marker_ids) > len(token_ids):
            return False
        marker_len = len(marker_ids)
        return any(
            tuple(token_ids[i : i + marker_len]) == tuple(marker_ids)
            for i in range(len(token_ids) - marker_len + 1)
        )

    @staticmethod
    def _rfind_token_sequence(
        token_ids: Sequence[int], marker_ids: Sequence[int]
    ) -> int:
        if not marker_ids or len(marker_ids) > len(token_ids):
            return -1
        marker_len = len(marker_ids)
        for i in range(len(token_ids) - marker_len, -1, -1):
            if tuple(token_ids[i : i + marker_len]) == tuple(marker_ids):
                return i
        return -1

    @staticmethod
    def _ends_with_token_sequence_prefix(
        token_ids: Sequence[int], marker_ids: Sequence[int]
    ) -> bool:
        if not marker_ids:
            return False
        max_len = min(len(token_ids), len(marker_ids) - 1)
        for prefix_len in range(max_len, 0, -1):
            if tuple(token_ids[-prefix_len:]) == tuple(marker_ids[:prefix_len]):
                return True
        return False

    @staticmethod
    def _strip_partial_marker_suffix(text: str, marker: str) -> str:
        max_len = min(len(text), len(marker) - 1)
        for suffix_len in range(max_len, 0, -1):
            if marker.startswith(text[-suffix_len:]):
                return text[:-suffix_len]
        return text

    @staticmethod
    def _visible_delta(previous: str | None, current: str | None) -> str | None:
        if not current:
            return None
        if not previous:
            return current
        if current.startswith(previous):
            delta = current[len(previous) :]
            return delta or None
        return current

    def _visible_segments(self, text: str) -> tuple[str | None, str | None]:
        if not text:
            return None, None

        if not self._initial_in_reasoning:
            if self.end_token.startswith(text) and len(text) < len(self.end_token):
                return None, None
            if text.startswith(self.end_token):
                text = text[len(self.end_token) :]
                if not text:
                    return None, None

        if self._initial_in_reasoning and self.start_token not in text:
            reasoning, end, content = text.partition(self.end_token)
            if end:
                return reasoning or None, content or None
            reasoning = self._strip_partial_marker_suffix(reasoning, self.end_token)
            return reasoning or None, None

        if self.start_token not in text:
            content = self._strip_partial_marker_suffix(text, self.start_token)
            return None, content or None

        content_before, _, after_start = text.partition(self.start_token)
        reasoning, end, content_after = after_start.partition(self.end_token)
        if end:
            return reasoning or None, (content_before + content_after) or None

        reasoning = self._strip_partial_marker_suffix(reasoning, self.end_token)
        return reasoning or None, content_before or None

    def extract_reasoning(
        self,
        model_output: str,
        request: "ChatCompletionRequest | ResponsesRequest",
    ) -> tuple[str | None, str | None]:
        # MiniMax M3 can start a response with a stray closer. Drop that first
        # token only; later unmatched closers stay visible as content.
        if not self._initial_in_reasoning and model_output.startswith(self.end_token):
            content = model_output[len(self.end_token) :]
            return None, content or None

        if self._initial_in_reasoning and self.start_token not in model_output:
            reasoning, end, content = model_output.partition(self.end_token)
            if not end:
                return model_output, None
            return reasoning, content or None

        if self.start_token not in model_output:
            return None, model_output

        content_before, _, after_start = model_output.partition(self.start_token)
        reasoning, end, content_after = after_start.partition(self.end_token)
        if not end:
            return reasoning, content_before or None

        return reasoning, (content_before + content_after) or None

    def is_reasoning_end_streaming(
        self, input_ids: Sequence[int], delta_ids: Iterable[int]
    ) -> bool:
        if self._reasoning_ended_streaming:
            return True

        if self._reasoning_active_streaming or self._pending_marker_streaming:
            return False

        delta_ids = tuple(delta_ids)
        if self._contains_token_sequence(delta_ids, self._end_token_ids):
            return True
        if self._contains_token_sequence(input_ids, self._end_token_ids):
            return True
        if self._initial_in_reasoning:
            return False
        if self._ends_with_token_sequence_prefix(input_ids, self._start_token_ids):
            return False
        if self._ends_with_token_sequence_prefix(input_ids, self._end_token_ids):
            return False
        if not self._contains_token_sequence(input_ids, self._start_token_ids):
            return bool(input_ids)
        return False

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        if (
            self._last_streaming_delta_token_ids == tuple(input_ids)
            and self._last_streaming_content_token_ids is not None
        ):
            content_ids = self._last_streaming_content_token_ids
            self._last_streaming_delta_token_ids = None
            self._last_streaming_content_token_ids = None
            return list(content_ids)

        end_index = self._rfind_token_sequence(input_ids, self._end_token_ids)
        if end_index >= 0:
            return input_ids[end_index + len(self._end_token_ids) :]

        has_start = self._contains_token_sequence(input_ids, self._start_token_ids)
        if self._initial_in_reasoning and not has_start:
            return []

        if not has_start:
            return input_ids
        return []

    def extract_reasoning_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> DeltaMessage | None:
        if not delta_text:
            return None

        if not previous_text:
            self._reasoning_ended_streaming = False
            self._reasoning_active_streaming = self._initial_in_reasoning
            self._pending_marker_streaming = False
            self._last_streaming_delta_token_ids = None
            self._last_streaming_content_token_ids = None
        previous_reasoning, previous_content = self._visible_segments(previous_text)
        current_reasoning, current_content = self._visible_segments(current_text)
        if self.end_token in current_text or current_content is not None:
            self._reasoning_ended_streaming = True
            self._reasoning_active_streaming = False
            self._pending_marker_streaming = False
        else:
            self._last_streaming_delta_token_ids = None
            self._last_streaming_content_token_ids = None
            self._reasoning_active_streaming = (
                self._initial_in_reasoning
                or self.start_token in current_text
                or current_reasoning is not None
            )
            self._pending_marker_streaming = not self._reasoning_active_streaming and (
                self.start_token.startswith(current_text)
                or self.end_token.startswith(current_text)
            )
        reasoning = self._visible_delta(previous_reasoning, current_reasoning)
        content = self._visible_delta(previous_content, current_content)
        if self._reasoning_ended_streaming:
            self._last_streaming_delta_token_ids = tuple(delta_token_ids)
            self._last_streaming_content_token_ids = self._content_suffix_token_ids(
                delta_text, delta_token_ids, content
            )
        if reasoning is None and content is None:
            return None
        return DeltaMessage(reasoning=reasoning, content=content)

    def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
        count = 0
        depth = 1 if self._initial_in_reasoning else 0
        i = 0
        while i < len(token_ids):
            if tuple(token_ids[i : i + len(self._start_token_ids)]) == (
                self._start_token_ids
            ):
                depth += 1
                i += len(self._start_token_ids)
                continue
            if tuple(token_ids[i : i + len(self._end_token_ids)]) == (
                self._end_token_ids
            ):
                if depth > 0:
                    depth -= 1
                i += len(self._end_token_ids)
                continue
            if depth > 0:
                count += 1
            i += 1
        return count

    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
        start_index = self._rfind_token_sequence(input_ids, self._start_token_ids)
        end_index = self._rfind_token_sequence(input_ids, self._end_token_ids)
        if end_index < 0:
            return False
        if start_index < 0:
            return True
        return end_index > start_index