class StreamingParserEngine:
"""Consumes ``(delta_text, delta_token_ids)`` pairs and produces a
stream of :class:`SemanticEvent` instances.
This is the main entry point for streaming parsing.
Create one per request (it is stateful).
The pipeline is::
delta_text + delta_token_ids
→ TokenIDScanner (special token pre-lexing)
→ IncrementalLexer (text → terminal tokens with prefix buffering)
→ State Machine (terminal → semantic events)
→ list[SemanticEvent]
Usage::
engine = StreamingParserEngine(config, tokenizer)
for each streaming delta:
events = engine.feed(delta_text, delta_token_ids)
# convert events to DeltaMessage
"""
def __init__(
self,
config: ParserEngineConfig,
tokenizer,
initial_state: ParserState | None = None,
vocab: dict[str, int] | None = None,
) -> None:
self.config = config
resolved_token_ids: dict[int, str] = {}
drop_token_ids: set[int] = set()
if tokenizer is not None:
if vocab is None:
vocab = tokenizer.get_vocab()
if config.token_id_terminals:
for terminal_name, token_text in config.token_id_terminals.items():
tid = vocab.get(token_text)
if tid is not None:
resolved_token_ids[tid] = terminal_name
all_drop = config.drop_tokens | STRUCTURAL_DROP_TOKENS
for token_text in all_drop:
tid = vocab.get(token_text)
if tid is not None:
drop_token_ids.add(tid)
for attr in ("eos_token_id", "bos_token_id", "pad_token_id"):
tid = getattr(tokenizer, attr, None)
if tid is not None:
drop_token_ids.add(tid)
self._resolved_token_ids = resolved_token_ids
self._drop_token_ids = drop_token_ids
self._scanner = TokenIDScanner(
resolved_token_ids,
tokenizer,
drop_token_ids,
)
self._token_id_terminal_names: frozenset[str] = frozenset(
resolved_token_ids.values()
)
self._lexer = IncrementalLexer(
config.lexer_shape, content_terminal=CONTENT_TERMINAL
)
self._tool_terminals: frozenset[str] = frozenset(
terminal
for (state, terminal), tr in config.transitions.items()
if tr.next_state in self._TOOL_STATES or state in self._TOOL_STATES
)
self.skip_tool_parsing = False
self.reset(initial_state=initial_state)
def _reset_args_state(self) -> None:
self._args_buffer: str = ""
self._args_safe_end: int = 0
self._args_brace_depth: int = 0
self._args_in_string: bool = False
self._args_escape_next: bool = False
def reset(self, initial_state: ParserState | None = None) -> None:
"""Reset mutable state for reuse across requests.
Preserves cached immutable structures (compiled terminals,
resolved token IDs, lexer shape, token text cache) to avoid
redundant initialization work.
"""
self.state = (
initial_state if initial_state is not None else self.config.initial_state
)
self.tool_index = -1
self._ever_had_token_ids = False
# DO NOT reset skip_tool_parsing here — callers set it before
# calling methods that trigger reset() (e.g. extract_reasoning),
# and clearing it silently breaks non-streaming tool-call-as-
# implicit-reasoning-end (content returns None).
self._scanner.reset()
self._lexer.reset()
self._reset_args_state()
def feed(
self,
delta_text: str,
delta_token_ids: Sequence[int],
) -> list[SemanticEvent]:
if delta_token_ids:
self._ever_had_token_ids = True
# Fast path: skip scanner and lexer when the delta is plain
# content with no special tokens and no terminal-starting chars.
if (
delta_text
and not self._lexer.buffer
and not self._scanner._deferred_terminals
and self._lexer._literal_first_chars.isdisjoint(delta_text)
):
has_special = False
for tid in delta_token_ids:
if tid in self._resolved_token_ids or tid in self._drop_token_ids:
has_special = True
break
if not has_special:
return self._emit_for_state(delta_text)
scanner_items = self._scanner.scan(delta_text, delta_token_ids)
if len(scanner_items) == 1 and isinstance(scanner_items[0], TextChunk):
lex_tokens = self._lexer.feed(scanner_items[0].text)
if len(lex_tokens) == 1 and lex_tokens[0].terminal == CONTENT_TERMINAL:
text = lex_tokens[0].value
return self._emit_for_state(text)
return self._process_lex_tokens(lex_tokens)
return self._process_scanner_items(scanner_items)
def _process_scanner_items(
self, items: Sequence[LexerInput]
) -> list[SemanticEvent]:
events: list[SemanticEvent] = []
for item in items:
if isinstance(item, PreLexedTerminal):
events.extend(self._process_lex_tokens(self._lexer.flush()))
events.extend(self._on_terminal(item.terminal, item.text))
elif isinstance(item, TextChunk):
events.extend(self._process_lex_tokens(self._lexer.feed(item.text)))
return events
def finish(self) -> list[SemanticEvent]:
events = self._process_scanner_items(self._scanner.flush_pending())
events.extend(self._process_lex_tokens(self._lexer.flush()))
if self._args_buffer:
events.append(
SemanticEvent(
EventType.ARG_VALUE_CHUNK,
value=self._args_buffer,
tool_index=self.tool_index,
)
)
self._args_buffer = ""
self._args_safe_end = 0
if self.state in (
ParserState.TOOL_PREAMBLE,
ParserState.TOOL_ARGS,
ParserState.TOOL_NAME,
ParserState.TOOL_BETWEEN,
):
if self.tool_index >= 0:
events.append(
SemanticEvent(
EventType.TOOL_CALL_END,
tool_index=self.tool_index,
)
)
self.state = ParserState.CONTENT
elif self.state == ParserState.REASONING:
events.append(
SemanticEvent(EventType.REASONING_END, tool_index=self.tool_index)
)
self.state = ParserState.CONTENT
return events
def parse_complete(self, text: str) -> list[SemanticEvent]:
token_ids: list[int] = []
events = self.feed(text, token_ids)
events.extend(self.finish())
return events
def _process_lex_tokens(self, tokens: list[LexToken]) -> list[SemanticEvent]:
events: list[SemanticEvent] = []
strict = self._token_id_terminal_names if self._ever_had_token_ids else None
for tok in tokens:
if tok.terminal == CONTENT_TERMINAL or (strict and tok.terminal in strict):
events.extend(self._on_content(tok.value))
else:
events.extend(self._on_terminal(tok.terminal, tok.value))
return events
_TOOL_STATES = frozenset(
{
ParserState.TOOL_PREAMBLE,
ParserState.TOOL_NAME,
ParserState.TOOL_ARGS,
ParserState.TOOL_BETWEEN,
}
)
def _on_terminal(self, terminal: str, value: str) -> list[SemanticEvent]:
key = (self.state, terminal)
transition = self.config.transitions.get(key)
if transition is None:
return self._emit_for_state(value)
if self.skip_tool_parsing and terminal in self._tool_terminals:
if EventType.REASONING_END in transition.events:
self.state = ParserState.CONTENT
return [
SemanticEvent(
EventType.REASONING_END,
value=value,
tool_index=self.tool_index,
),
SemanticEvent(
EventType.TEXT_CHUNK,
value=value,
tool_index=self.tool_index,
),
]
content_type = self.config.content_events.get(self.state)
if content_type is not None:
return [
SemanticEvent(content_type, value=value, tool_index=self.tool_index)
]
return []
if transition.skip_in_token_id_mode and self._ever_had_token_ids:
return self._emit_for_state(value)
return self._apply_transition(transition, value)
def _emit_for_state(self, text: str) -> list[SemanticEvent]:
if self.state == ParserState.TOOL_ARGS:
if self.config.tool_args_json:
return self._feed_args_text(text)
return [
SemanticEvent(
EventType.ARG_VALUE_CHUNK,
value=text,
tool_index=self.tool_index,
)
]
content_type = self.config.content_events.get(self.state)
if content_type is not None:
return [SemanticEvent(content_type, value=text, tool_index=self.tool_index)]
return []
def _on_content(self, text: str) -> list[SemanticEvent]:
if not text:
return []
return self._emit_for_state(text)
def _apply_transition(
self,
transition: Transition,
value: str,
) -> list[SemanticEvent]:
events: list[SemanticEvent] = []
if (
self.state == ParserState.TOOL_ARGS
and transition.next_state != ParserState.TOOL_ARGS
and self._args_buffer
):
events.append(
SemanticEvent(
EventType.ARG_VALUE_CHUNK,
value=self._args_buffer,
tool_index=self.tool_index,
)
)
self._args_buffer = ""
self.state = transition.next_state
for event_type in transition.events:
if event_type == EventType.TOOL_CALL_START:
self.tool_index += 1
events.append(
SemanticEvent(
event_type,
value=value,
tool_index=self.tool_index,
)
)
if self.state == ParserState.TOOL_ARGS:
self._args_brace_depth = 0
self._args_in_string = False
self._args_escape_next = False
self._args_safe_end = 0
return events
def _feed_args_text(self, text: str) -> list[SemanticEvent]:
"""Feed text into the JSON argument streaming buffer.
Streams argument characters incrementally while holding back
closing braces/brackets that might change as more input arrives.
"""
events: list[SemanticEvent] = []
for ch in text:
result = self._feed_args_char(ch)
events.extend(result)
return events
def _feed_args_char(self, ch: str) -> list[SemanticEvent]:
self._args_buffer += ch
if self._args_escape_next:
self._args_escape_next = False
self._args_safe_end = len(self._args_buffer)
return self._flush_safe_args()
if self._args_in_string:
if ch == "\\":
self._args_escape_next = True
elif ch == '"':
self._args_in_string = False
self._args_safe_end = len(self._args_buffer)
return self._flush_safe_args()
if ch == '"':
self._args_in_string = True
self._args_safe_end = len(self._args_buffer)
return self._flush_safe_args()
if ch in ("{", "["):
self._args_brace_depth += 1
self._args_safe_end = len(self._args_buffer)
return self._flush_safe_args()
if ch in ("}", "]"):
if self._args_brace_depth > 0:
self._args_brace_depth -= 1
if self._args_brace_depth == 0:
return []
self._args_safe_end = len(self._args_buffer)
return self._flush_safe_args()
self._args_safe_end = len(self._args_buffer)
return self._flush_safe_args()
def _flush_safe_args(self) -> list[SemanticEvent]:
"""Emit buffered argument characters up to the safe-end watermark.
Top-level closing braces are held back (safe_end not advanced)
until confirmed safe by a subsequent character or finish().
"""
if self._args_safe_end == 0:
return []
to_emit = self._args_buffer[: self._args_safe_end]
self._args_buffer = self._args_buffer[self._args_safe_end :]
self._args_safe_end = 0
return [
SemanticEvent(
EventType.ARG_VALUE_CHUNK,
value=to_emit,
tool_index=self.tool_index,
)
]