Skip to content

vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding

Quantized embedding method for compressed-tensors.

Adds dequant-on-lookup support for a pack-quantized VocabParallelEmbedding (2-8 bit INT, channel- or group-quantized). Only the gathered token rows are unpacked and dequantized, so the packed weight is never densified.

_dequant_gather_kernel(ids_ptr, packed_ptr, scale_ptr, out_ptr, hidden, packed_cols, num_groups, NUM_BITS, PACK_FACTOR, GROUP_SIZE, BLOCK)

Gather embedding rows by token id, unpack int32-packed INT weights, and dequantize to out dtype in one pass (no int8 intermediate).

Source code in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_embedding.py
@triton.jit
def _dequant_gather_kernel(
    ids_ptr,
    packed_ptr,
    scale_ptr,
    out_ptr,
    hidden,
    packed_cols,
    num_groups,
    NUM_BITS: tl.constexpr,
    PACK_FACTOR: tl.constexpr,
    GROUP_SIZE: tl.constexpr,
    BLOCK: tl.constexpr,
):
    """Gather embedding rows by token id, unpack int32-packed INT weights, and
    dequantize to ``out`` dtype in one pass (no int8 intermediate)."""
    row = tl.program_id(0)
    col = tl.program_id(1) * BLOCK + tl.arange(0, BLOCK)
    col_mask = col < hidden
    tid = tl.load(ids_ptr + row).to(tl.int64)

    packed_idx = col // PACK_FACTOR
    shift = (col % PACK_FACTOR) * NUM_BITS
    packed = tl.load(
        packed_ptr + tid * packed_cols + packed_idx, mask=col_mask, other=0
    )
    q = ((packed >> shift) & ((1 << NUM_BITS) - 1)) - (1 << (NUM_BITS - 1))

    if GROUP_SIZE == 0:  # channel: one scale per row
        scale = tl.load(scale_ptr + tid)
    else:  # group: one scale per (row, group)
        grp = col // GROUP_SIZE
        scale = tl.load(scale_ptr + tid * num_groups + grp, mask=col_mask, other=0.0)

    out = q.to(tl.float32) * scale.to(tl.float32)
    tl.store(
        out_ptr + row * hidden + col, out.to(out_ptr.dtype.element_ty), mask=col_mask
    )