Skip to content

vllm.utils.cpu_triton_utils

Contains replacement functions to fallback Triton usages in CPU backend

_copy_and_expand_dflash_inputs_kernel_impl(next_token_ids_ptr, target_positions_ptr, out_input_ids_ptr, out_context_positions_ptr, out_query_positions_ptr, out_context_slot_mapping_ptr, out_query_slot_mapping_ptr, out_token_indices_ptr, block_table_ptr, block_table_stride, query_start_loc_ptr, num_rejected_tokens_ptr, parallel_drafting_token_id, block_size, num_query_per_req, num_speculative_tokens, total_input_tokens, BLOCK_SIZE=None, HAS_NUM_REJECTED=False)

Adapter between the DFlash Triton launch and the C++ CPU op.

Source code in vllm/utils/cpu_triton_utils.py
def _copy_and_expand_dflash_inputs_kernel_impl(
    next_token_ids_ptr,
    target_positions_ptr,
    out_input_ids_ptr,
    out_context_positions_ptr,
    out_query_positions_ptr,
    out_context_slot_mapping_ptr,
    out_query_slot_mapping_ptr,
    out_token_indices_ptr,
    block_table_ptr,
    block_table_stride,
    query_start_loc_ptr,
    num_rejected_tokens_ptr,
    parallel_drafting_token_id,
    block_size,
    num_query_per_req,
    num_speculative_tokens,
    total_input_tokens,
    BLOCK_SIZE=None,
    HAS_NUM_REJECTED=False,
):
    """Adapter between the DFlash Triton launch and the C++ CPU op."""
    assert block_table_stride == block_table_ptr.stride(0), (
        "block_table_stride mismatch: "
        f"{block_table_stride} vs {block_table_ptr.stride(0)}"
    )

    orig_ids_dtype = out_input_ids_ptr.dtype
    orig_context_positions_dtype = out_context_positions_ptr.dtype
    orig_query_positions_dtype = out_query_positions_ptr.dtype
    orig_context_slot_mapping_dtype = out_context_slot_mapping_ptr.dtype
    orig_query_slot_mapping_dtype = out_query_slot_mapping_ptr.dtype
    out_ids_i64 = _ensure_int64(out_input_ids_ptr)
    out_context_positions_i64 = _ensure_int64(out_context_positions_ptr)
    out_query_positions_i64 = _ensure_int64(out_query_positions_ptr)
    out_context_slot_mapping_i64 = _ensure_int64(out_context_slot_mapping_ptr)
    out_query_slot_mapping_i64 = _ensure_int64(out_query_slot_mapping_ptr)
    rejected_i64 = _ensure_int64(num_rejected_tokens_ptr) if HAS_NUM_REJECTED else None

    if hasattr(torch.ops._C, "copy_and_expand_dflash_inputs_kernel_impl"):
        torch.ops._C.copy_and_expand_dflash_inputs_kernel_impl(
            _ensure_int64(next_token_ids_ptr),
            _ensure_int64(target_positions_ptr),
            out_ids_i64,
            out_context_positions_i64,
            out_query_positions_i64,
            out_context_slot_mapping_i64,
            out_query_slot_mapping_i64,
            out_token_indices_ptr,
            block_table_ptr,
            query_start_loc_ptr,
            rejected_i64,
            parallel_drafting_token_id,
            block_size,
            num_query_per_req,
            num_speculative_tokens,
            total_input_tokens,
            HAS_NUM_REJECTED,
        )
    else:
        next_ids_i64 = _ensure_int64(next_token_ids_ptr)
        target_positions_i64 = _ensure_int64(target_positions_ptr)
        block_table_stride = block_table_ptr.stride(0)
        num_reqs = query_start_loc_ptr.shape[0] - 1

        for req_idx in range(num_reqs):
            ctx_start = int(query_start_loc_ptr[req_idx].item())
            ctx_end = int(query_start_loc_ptr[req_idx + 1].item())
            num_ctx = ctx_end - ctx_start
            valid_ctx_end = ctx_end
            if rejected_i64 is not None:
                valid_ctx_end -= int(rejected_i64[req_idx].item())
            # Guard against out-of-bounds: ensure valid_ctx_end > ctx_start.
            valid_ctx_end = max(valid_ctx_end, ctx_start + 1)

            last_pos = int(target_positions_i64[valid_ctx_end - 1].item())

            for j in range(num_ctx):
                ctx_idx = ctx_start + j
                ctx_pos_idx = min(ctx_idx, total_input_tokens - 1)
                position = int(target_positions_i64[ctx_pos_idx].item())
                block_num = min(position // block_size, block_table_stride - 1)
                block_id = int(block_table_ptr[req_idx, block_num].item())
                slot = block_id * block_size + (position % block_size)

                out_context_positions_i64[ctx_idx] = position
                out_context_slot_mapping_i64[ctx_idx] = slot

            for query_off in range(num_query_per_req):
                query_out = req_idx * num_query_per_req + query_off
                position = last_pos + 1 + query_off
                block_num = min(position // block_size, block_table_stride - 1)
                block_id = int(block_table_ptr[req_idx, block_num].item())
                slot = block_id * block_size + (position % block_size)

                out_query_positions_i64[query_out] = position
                out_query_slot_mapping_i64[query_out] = slot
                out_ids_i64[query_out] = (
                    int(next_ids_i64[req_idx].item())
                    if query_off == 0
                    else parallel_drafting_token_id
                )

                if query_off > 0:
                    sample_out_idx = req_idx * num_speculative_tokens + (query_off - 1)
                    out_token_indices_ptr[sample_out_idx] = query_out

    if orig_ids_dtype != torch.int64:
        out_input_ids_ptr.copy_(out_ids_i64.to(orig_ids_dtype))
    if orig_context_positions_dtype != torch.int64:
        out_context_positions_ptr.copy_(
            out_context_positions_i64.to(orig_context_positions_dtype)
        )
    if orig_query_positions_dtype != torch.int64:
        out_query_positions_ptr.copy_(
            out_query_positions_i64.to(orig_query_positions_dtype)
        )
    if orig_context_slot_mapping_dtype != torch.int64:
        out_context_slot_mapping_ptr.copy_(
            out_context_slot_mapping_i64.to(orig_context_slot_mapping_dtype)
        )
    if orig_query_slot_mapping_dtype != torch.int64:
        out_query_slot_mapping_ptr.copy_(
            out_query_slot_mapping_i64.to(orig_query_slot_mapping_dtype)
        )

_copy_and_expand_eagle_inputs_kernel_impl(target_token_ids_ptr, target_positions_ptr, next_token_ids_ptr, out_input_ids_ptr, out_positions_ptr, out_is_rejected_token_mask_ptr, out_is_masked_token_mask_ptr, out_new_token_indices_ptr, out_hidden_state_mapping_ptr, query_start_loc_ptr, query_end_loc_ptr, padding_token_id, parallel_drafting_token_id, total_input_tokens, num_padding_slots_per_request, shift_input_ids, BLOCK_SIZE_TOKENS=None, BLOCK_SIZE_REQS=None)

Adapter between Triton kernel call convention and C++ implementation.

The Triton kernel uses '_ptr' suffixed parameter names and compile-time constants (BLOCK_SIZE_TOKENS, BLOCK_SIZE_REQS) which are not needed by the C++ implementation. C++ reads token id tensors as int64_t*. Output tensors that are int32 need copy-back after C++ writes int64.

Source code in vllm/utils/cpu_triton_utils.py
def _copy_and_expand_eagle_inputs_kernel_impl(
    target_token_ids_ptr,
    target_positions_ptr,
    next_token_ids_ptr,
    out_input_ids_ptr,
    out_positions_ptr,
    out_is_rejected_token_mask_ptr,
    out_is_masked_token_mask_ptr,
    out_new_token_indices_ptr,
    out_hidden_state_mapping_ptr,
    query_start_loc_ptr,
    query_end_loc_ptr,
    padding_token_id,
    parallel_drafting_token_id,
    total_input_tokens,
    num_padding_slots_per_request,
    shift_input_ids,
    BLOCK_SIZE_TOKENS=None,
    BLOCK_SIZE_REQS=None,
):
    """Adapter between Triton kernel call convention and C++ implementation.

    The Triton kernel uses '_ptr' suffixed parameter names and compile-time
    constants (BLOCK_SIZE_TOKENS, BLOCK_SIZE_REQS) which are not needed by
    the C++ implementation. C++ reads token id tensors as int64_t*.
    Output tensors that are int32 need copy-back after C++ writes int64.
    """
    orig_ids_dtype = out_input_ids_ptr.dtype
    orig_pos_dtype = out_positions_ptr.dtype
    out_ids_i64 = _ensure_int64(out_input_ids_ptr)
    out_pos_i64 = _ensure_int64(out_positions_ptr)
    torch.ops._C.copy_and_expand_eagle_inputs_kernel_impl(
        _ensure_int64(target_token_ids_ptr),
        _ensure_int64(target_positions_ptr),
        _ensure_int64(next_token_ids_ptr),
        out_ids_i64,
        out_pos_i64,
        out_is_rejected_token_mask_ptr,
        out_is_masked_token_mask_ptr,
        out_new_token_indices_ptr,
        out_hidden_state_mapping_ptr,
        query_start_loc_ptr,
        query_end_loc_ptr,
        padding_token_id,
        parallel_drafting_token_id,
        total_input_tokens,
        num_padding_slots_per_request,
        shift_input_ids,
    )
    if orig_ids_dtype != torch.int64:
        out_input_ids_ptr.copy_(out_ids_i64.to(orig_ids_dtype))
    if orig_pos_dtype != torch.int64:
        out_positions_ptr.copy_(out_pos_i64.to(orig_pos_dtype))