Skip to content

Token Classify

Source https://github.com/vllm-project/vllm/tree/main/examples/pooling/token_classify.

Forced Alignment Offline

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from Qwen3-ForcedAligner inference:
# https://github.com/QwenLM/Qwen3-ASR

"""
Offline forced alignment example using Qwen3-ForcedAligner-0.6B.

Forced alignment takes audio and reference text as input and produces
word-level timestamps. The model predicts a time bin at each <timestamp>
token position; multiplying by ``timestamp_segment_time`` gives milliseconds.

Usage::

    python forced_alignment_offline.py \
        --model Qwen/Qwen3-ForcedAligner-0.6B
"""

from argparse import Namespace

import numpy as np

from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser


def parse_args():
    parser = FlexibleArgumentParser()
    parser = EngineArgs.add_cli_args(parser)
    parser.set_defaults(
        model="Qwen/Qwen3-ForcedAligner-0.6B",
        runner="pooling",
        enforce_eager=True,
        hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]},
    )
    return parser.parse_args()


def build_prompt(words: list[str]) -> str:
    """Build the forced alignment prompt from a word list.

    Format: <|audio_start|><|audio_pad|><|audio_end|>
            word1<timestamp><timestamp>word2<timestamp><timestamp>...
    """
    body = "<timestamp><timestamp>".join(words) + "<timestamp><timestamp>"
    return f"<|audio_start|><|audio_pad|><|audio_end|>{body}"


def main(args: Namespace):
    llm = LLM(**vars(args))

    config = llm.llm_engine.vllm_config.model_config.hf_config
    timestamp_token_id = config.timestamp_token_id
    timestamp_segment_time = config.timestamp_segment_time

    # Example: align these words against a 5-second audio clip
    words = ["Hello", "world"]
    prompt = build_prompt(words)

    # Use a 5-second silent audio as placeholder (replace with real audio)
    sample_rate = 16000
    audio = np.zeros(sample_rate * 5, dtype=np.float32)

    outputs = llm.encode(
        [{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
        pooling_task="token_classify",
    )

    for output in outputs:
        logits = output.outputs.data  # [num_tokens, classify_num]
        predictions = logits.argmax(dim=-1)
        token_ids = output.prompt_token_ids

        # Extract timestamps at <timestamp> positions
        ts_predictions = [
            pred.item() * timestamp_segment_time
            for tid, pred in zip(token_ids, predictions)
            if tid == timestamp_token_id
        ]

        # Pair up start/end times per word
        for i, word in enumerate(words):
            start_ms = ts_predictions[i * 2]
            end_ms = ts_predictions[i * 2 + 1]
            print(f"{word:15s} {start_ms / 1000:.3f}s - {end_ms / 1000:.3f}s")


if __name__ == "__main__":
    args = parse_args()
    main(args)

Forced Alignment Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from Qwen3-ForcedAligner inference:
# https://github.com/QwenLM/Qwen3-ASR

"""
Online forced alignment example using Qwen3-ForcedAligner-0.6B.

Forced alignment takes audio and reference text as input and produces
word-level timestamps. The model predicts a time bin at each <timestamp>
token position; multiplying by ``timestamp_segment_time`` gives milliseconds.

Start the server with:

    vllm serve Qwen/Qwen3-ForcedAligner-0.6B \\
        --runner pooling \\
        --enforce-eager \\
        --trust-request-chat-template \\
        --hf-overrides \\
        '{"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]}'

Then run:

    python forced_alignment_online.py
"""

import argparse
import json
import mimetypes
import wave
from io import BytesIO
from pathlib import Path
from typing import Any

import numpy as np
import pybase64 as base64
import requests
import torch
from huggingface_hub import hf_hub_download

RAW_CONTENT_CHAT_TEMPLATE = "{{ messages[0]['content'] }}"


def build_prompt(words: list[str]) -> str:
    """Build the forced alignment prompt from a word list.

    Format: <|audio_start|><|audio_pad|><|audio_end|>
            word1<timestamp><timestamp>word2<timestamp><timestamp>...
    """
    body = "<timestamp><timestamp>".join(words) + "<timestamp><timestamp>"
    return f"<|audio_start|><|audio_pad|><|audio_end|>{body}"


def encode_audio_data_uri(audio_path: Path) -> str:
    mime_type = mimetypes.guess_type(audio_path)[0] or "audio/wav"
    audio_base64 = base64.b64encode(audio_path.read_bytes()).decode("utf-8")
    return f"data:{mime_type};base64,{audio_base64}"


def encode_silent_wav_data_uri(sample_rate: int = 16000, duration_s: int = 5) -> str:
    audio = np.zeros(sample_rate * duration_s, dtype=np.int16)

    with BytesIO() as audio_buffer:
        with wave.open(audio_buffer, "wb") as wav_file:
            wav_file.setnchannels(1)
            wav_file.setsampwidth(np.dtype(np.int16).itemsize)
            wav_file.setframerate(sample_rate)
            wav_file.writeframes(audio.tobytes())

        audio_base64 = base64.b64encode(audio_buffer.getvalue()).decode("utf-8")

    return f"data:audio/wav;base64,{audio_base64}"


def build_payload(model: str, prompt: str, audio_uri: str) -> dict[str, Any]:
    return {
        "model": model,
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "audio_url", "audio_url": {"url": audio_uri}},
                ],
            }
        ],
        "task": "token_classify",
        "chat_template": RAW_CONTENT_CHAT_TEMPLATE,
    }


def post_http_request(payload: dict[str, Any], api_url: str) -> requests.Response:
    headers = {"User-Agent": "Test Client"}
    return requests.post(api_url, headers=headers, json=payload)


def parse_response(response: requests.Response) -> dict[str, Any]:
    try:
        result = response.json()
    except ValueError as exc:
        raise RuntimeError(
            f"Server returned non-JSON response: {response.text}"
        ) from exc

    if response.status_code != 200 or "data" not in result:
        raise RuntimeError(f"Server error ({response.status_code}): {result}")

    return result


def load_timestamp_config(model: str) -> tuple[int, float]:
    model_path = Path(model)
    config_path = (
        model_path / "config.json"
        if model_path.exists()
        else Path(hf_hub_download(repo_id=model, filename="config.json"))
    )

    with config_path.open() as f:
        config = json.load(f)

    return config["timestamp_token_id"], config["timestamp_segment_time"]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument(
        "--model",
        type=str,
        default="Qwen/Qwen3-ForcedAligner-0.6B",
    )
    parser.add_argument(
        "--audio-path",
        type=Path,
        default=None,
        help="Optional audio file. Defaults to a 5-second silent WAV.",
    )
    parser.add_argument(
        "--words",
        nargs="+",
        default=["Hello", "world"],
        help="Reference words to align against the audio.",
    )
    return parser.parse_args()


def main(args):
    from transformers import AutoTokenizer

    api_url = f"http://{args.host}:{args.port}/pooling"
    prompt = build_prompt(args.words)
    audio_uri = (
        encode_audio_data_uri(args.audio_path)
        if args.audio_path
        else encode_silent_wav_data_uri()
    )
    payload = build_payload(args.model, prompt, audio_uri)

    pooling_response = post_http_request(payload=payload, api_url=api_url)
    result = parse_response(pooling_response)

    tokenizer = AutoTokenizer.from_pretrained(args.model)
    timestamp_token_id, timestamp_segment_time = load_timestamp_config(args.model)

    output = result["data"][0]
    logits = torch.tensor(output["data"])
    predictions = logits.argmax(dim=-1)
    token_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
    audio_pad_token_id = tokenizer.convert_tokens_to_ids("<|audio_pad|>")

    usage = result.get("usage") or {}
    prompt_tokens = usage.get("prompt_tokens")
    if prompt_tokens is not None and prompt_tokens != len(predictions):
        raise RuntimeError(
            "The response length does not match the reported prompt token count."
        )

    try:
        audio_pad_index = token_ids.index(audio_pad_token_id)
    except ValueError as exc:
        raise RuntimeError("The prompt does not contain the audio pad token.") from exc

    audio_token_shift = len(predictions) - len(token_ids)
    if audio_token_shift < 0:
        raise RuntimeError(
            "The response is shorter than the locally tokenized prompt. "
            "Check that the server was started with --trust-request-chat-template."
        )

    ts_predictions = []
    for i, token_id in enumerate(token_ids):
        if token_id != timestamp_token_id:
            continue

        prediction_index = i + audio_token_shift if i > audio_pad_index else i
        ts_predictions.append(
            predictions[prediction_index].item() * timestamp_segment_time
        )

    if len(ts_predictions) < len(args.words) * 2:
        raise RuntimeError("The model did not return enough timestamp predictions.")

    for i, word in enumerate(args.words):
        start_ms = ts_predictions[i * 2]
        end_ms = ts_predictions[i * 2 + 1]
        print(f"{word:15s} {start_ms / 1000:.3f}s - {end_ms / 1000:.3f}s")


if __name__ == "__main__":
    args = parse_args()
    main(args)

NER Offline

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER

from argparse import Namespace

from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser


def parse_args():
    parser = FlexibleArgumentParser()
    parser = EngineArgs.add_cli_args(parser)
    # Set example specific arguments
    parser.set_defaults(
        model="boltuix/NeuroBERT-NER",
        runner="pooling",
        enforce_eager=True,
        trust_remote_code=True,
    )
    return parser.parse_args()


def main(args: Namespace):
    # Sample prompts.
    prompts = [
        "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
    ]

    # Create an LLM.
    llm = LLM(**vars(args))
    tokenizer = llm.get_tokenizer()
    label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label

    # Run inference
    outputs = llm.encode(prompts, pooling_task="token_classify")

    for prompt, output in zip(prompts, outputs):
        logits = output.outputs.data
        predictions = logits.argmax(dim=-1)

        # Map predictions to labels
        tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
        labels = [label_map[p.item()] for p in predictions]

        # Print results
        for token, label in zip(tokens, labels):
            if token not in tokenizer.all_special_tokens:
                print(f"{token:15}{label}")


if __name__ == "__main__":
    args = parse_args()
    main(args)

NER Online

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER

"""
Example online usage of Pooling API for Named Entity Recognition (NER).

Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.

vllm serve boltuix/NeuroBERT-NER
"""

import argparse

import requests
import torch


def post_http_request(prompt: dict, api_url: str) -> requests.Response:
    headers = {"User-Agent": "Test Client"}
    response = requests.post(api_url, headers=headers, json=prompt)
    return response


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER")

    return parser.parse_args()


def main(args):
    from transformers import AutoConfig, AutoTokenizer

    api_url = f"http://{args.host}:{args.port}/pooling"
    model_name = args.model

    # Load tokenizer and config
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    config = AutoConfig.from_pretrained(model_name)
    label_map = config.id2label

    # Input text
    text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
    prompt = {"model": model_name, "input": text}

    pooling_response = post_http_request(prompt=prompt, api_url=api_url)

    # Run inference
    output = pooling_response.json()["data"][0]
    logits = torch.tensor(output["data"])
    predictions = logits.argmax(dim=-1)
    inputs = tokenizer(text, return_tensors="pt")

    # Map predictions to labels
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    labels = [label_map[p.item()] for p in predictions]
    assert len(tokens) == len(predictions)

    # Print results
    for token, label in zip(tokens, labels):
        if token not in tokenizer.all_special_tokens:
            print(f"{token:15}{label}")


if __name__ == "__main__":
    args = parse_args()
    main(args)