Skip to content

vllm.v1.attention.ops.int4_per_token_head

Sub-byte packed (INT4) per-token-head KV cache mode.

INT4 packs two 4-bit values per cache byte, pre-rotates with a single RHT, and hides a 4-bit zero-point in the scale's low mantissa bits — too different from the core kernel to share it. Owns the whole mode: nibble pack/unpack, the reshape (write) kernel, the split-dot attention (read) kernel, the RHT transform, and the public reshape_and_cache_int4 / unified_attention_int4 entry points.

Functions:

_get_rht_signs(d, round_idx, device)

Return a cached deterministic ±1 sign vector of length d.

Source code in vllm/v1/attention/ops/int4_per_token_head.py
def _get_rht_signs(d: int, round_idx: int, device: torch.device) -> torch.Tensor:
    """Return a cached deterministic ±1 sign vector of length *d*."""
    key = (d, round_idx, str(device))
    if key not in _RHT_SIGNS_CACHE:
        gen = torch.Generator(device="cpu")
        gen.manual_seed(0x9E3779B9 + round_idx * 0x517CC1B7)
        signs = (
            2.0 * torch.bernoulli(torch.full((d,), 0.5, device="cpu"), generator=gen)
            - 1.0
        )
        _RHT_SIGNS_CACHE[key] = signs.to(device)
    return _RHT_SIGNS_CACHE[key]

_launch_packed_attn(*, q, k_cache, v_cache, out, cu_seqlens_q, max_seqlen_q, seqused_k, softmax_scale, window_size, block_table, softcap, sinks, alibi_slopes, use_alibi_sqrt, qq_bias, output_scale, mm_prefix_range, k_scale_cache, v_scale_cache, seq_threshold_3D, num_par_softmax_segments, softmax_segm_output, softmax_segm_max, softmax_segm_expsum, packing_factor)

Launch _attn_packed for one of the sub-byte modes.

Handles 2D-vs-3D dispatch, placeholder pointers for the unused side of that split, and the trailing reduce_segments pass. Writes into out (directly for 2D; via the segm buffers for 3D).

Source code in vllm/v1/attention/ops/int4_per_token_head.py
def _launch_packed_attn(
    *,
    q,
    k_cache,
    v_cache,
    out,
    cu_seqlens_q,
    max_seqlen_q,
    seqused_k,
    softmax_scale,
    window_size,
    block_table,
    softcap,
    sinks,
    alibi_slopes,
    use_alibi_sqrt,
    qq_bias,
    output_scale,
    mm_prefix_range,
    k_scale_cache,
    v_scale_cache,
    seq_threshold_3D,
    num_par_softmax_segments,
    softmax_segm_output,
    softmax_segm_max,
    softmax_segm_expsum,
    packing_factor: int,
):
    """Launch ``_attn_packed`` for one of the sub-byte modes.

    Handles 2D-vs-3D dispatch, placeholder pointers for the unused side
    of that split, and the trailing ``reduce_segments`` pass.  Writes
    into ``out`` (directly for 2D; via the segm buffers for 3D).
    """
    import vllm.envs as envs
    from vllm.v1.attention.ops.triton_unified_attention import _get_tile_size

    is_batch_invariant = envs.VLLM_BATCH_INVARIANT

    use_mm_prefix = False
    max_mm_ranges = 0
    if mm_prefix_range is not None:
        assert mm_prefix_range.ndim == 3, (
            f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
        )
        use_mm_prefix = True
        max_mm_ranges = mm_prefix_range.shape[1]

    block_size = v_cache.shape[1]
    num_seqs = len(seqused_k)
    num_query_heads = q.shape[1]
    num_kv_heads = k_cache.shape[2]
    num_queries_per_kv = num_query_heads // num_kv_heads
    head_size = q.shape[2]

    BLOCK_M = (
        16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv)
    )
    BLOCK_Q = BLOCK_M // num_queries_per_kv
    total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
    sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
    TILE_SIZE_PREFILL = _get_tile_size(
        head_size, sliding_window_val, q.element_size(), is_prefill=True
    )
    TILE_SIZE_DECODE = _get_tile_size(
        head_size, sliding_window_val, q.element_size(), is_prefill=False
    )

    use_3d = not (
        seq_threshold_3D is None
        or num_par_softmax_segments is None
        or softmax_segm_output is None
        or softmax_segm_max is None
        or softmax_segm_expsum is None
        or max_seqlen_q > 1
        or num_seqs > seq_threshold_3D
        or is_batch_invariant
    )

    # 3D never reads ``output_ptr`` and 2D never reads the segm tensors,
    # but Triton needs a non-null pointer everywhere; reuse ``out`` as
    # the placeholder for the unused side.
    segm_output_ptr = softmax_segm_output if use_3d else out
    segm_max_ptr = softmax_segm_max if use_3d else out
    segm_expsum_ptr = softmax_segm_expsum if use_3d else out
    num_segments = num_par_softmax_segments if use_3d else 1

    grid: tuple[Any, ...]
    if use_3d:
        grid = (total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
        tile_size = TILE_SIZE_DECODE
    else:
        grid = (total_num_q_blocks, num_kv_heads)
        tile_size = TILE_SIZE_PREFILL

    _attn_packed[grid](
        output_ptr=out,
        segm_output_ptr=segm_output_ptr,
        segm_max_ptr=segm_max_ptr,
        segm_expsum_ptr=segm_expsum_ptr,
        query_ptr=q,
        key_cache_ptr=k_cache,
        value_cache_ptr=v_cache,
        sink_ptr=sinks,
        block_tables_ptr=block_table,
        seq_lens_ptr=seqused_k,
        alibi_slopes_ptr=alibi_slopes,
        qq_bias_ptr=qq_bias,
        scale=softmax_scale,
        out_scale=1 / output_scale if output_scale is not None else 1.0,
        softcap=softcap,
        k_scale_cache_ptr=k_scale_cache,
        v_scale_cache_ptr=v_scale_cache,
        num_query_heads=num_query_heads,
        num_queries_per_kv=num_queries_per_kv,
        block_table_stride=block_table.stride(0),
        query_stride_0=q.stride(0),
        query_stride_1=q.stride(1),
        output_stride_0=out.stride(0),
        output_stride_1=out.stride(1),
        qq_bias_stride_0=qq_bias.stride(0) if qq_bias is not None else 0,
        BLOCK_SIZE=block_size,
        TILE_SIZE=tile_size,
        HEAD_SIZE=head_size,
        HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
        PACKED_HEAD_PADDED=triton.next_power_of_2(head_size) // packing_factor,
        USE_ALIBI_SLOPES=alibi_slopes is not None,
        USE_ALIBI_SQRT=use_alibi_sqrt,
        USE_QQ_BIAS=qq_bias is not None,
        USE_SOFTCAP=(softcap > 0),
        USE_SINKS=(sinks is not None),
        SLIDING_WINDOW=(1 + window_size[0]),
        USE_MM_PREFIX=use_mm_prefix,
        MAX_MM_RANGES=max_mm_ranges,
        mm_prefix_range_ptr=mm_prefix_range,
        stride_k_cache_0=k_cache.stride(0),
        stride_k_cache_1=k_cache.stride(1),
        stride_k_cache_2=k_cache.stride(2),
        stride_k_cache_3=k_cache.stride(3),
        stride_v_cache_0=v_cache.stride(0),
        stride_v_cache_1=v_cache.stride(1),
        stride_v_cache_2=v_cache.stride(2),
        stride_v_cache_3=v_cache.stride(3),
        stride_ks_blk=k_scale_cache.stride(0),
        stride_ks_slot=k_scale_cache.stride(1),
        stride_ks_head=k_scale_cache.stride(2),
        stride_vs_blk=v_scale_cache.stride(0),
        stride_vs_slot=v_scale_cache.stride(1),
        stride_vs_head=v_scale_cache.stride(2),
        query_start_len_ptr=cu_seqlens_q,
        BLOCK_Q=BLOCK_Q,
        num_seqs=num_seqs,
        BLOCK_M=BLOCK_M,
        NUM_SEGMENTS_PER_SEQ=num_segments,
        USE_FP8=output_scale is not None,
        IS_3D=use_3d,
        PACKING_FACTOR=packing_factor,
    )

    if use_3d:
        reduce_segments[(q.shape[0], num_query_heads)](
            output_ptr=out,
            segm_output_ptr=softmax_segm_output,
            segm_max_ptr=softmax_segm_max,
            segm_expsum_ptr=softmax_segm_expsum,
            seq_lens_ptr=seqused_k,
            num_seqs=num_seqs,
            num_query_heads=num_query_heads,
            out_scale_inv=1 / output_scale if output_scale is not None else 1.0,
            output_stride_0=out.stride(0),
            output_stride_1=out.stride(1),
            block_table_stride=block_table.stride(0),
            TILE_SIZE=TILE_SIZE_DECODE,
            HEAD_SIZE=head_size,
            HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
            query_start_len_ptr=cu_seqlens_q,
            BLOCK_Q=BLOCK_Q,
            NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
            USE_FP8=output_scale is not None,
        )

_reshape_cache_int4_kernel(key_ptr, value_ptr, key_cache_ptr, value_cache_ptr, k_scale_cache_ptr, v_scale_cache_ptr, slot_mapping_ptr, stride_key_tok, stride_key_head, stride_val_tok, stride_val_head, stride_kc_blk, stride_kc_slot, stride_kc_head, stride_vc_blk, stride_vc_slot, stride_vc_head, stride_ks_blk, stride_ks_slot, stride_ks_head, stride_vs_blk, stride_vs_slot, stride_vs_head, block_size, head_size, head_size_v, PACKED_HEAD_PADDED)

INT4 asymmetric quantization with zero-point steganography.

Source code in vllm/v1/attention/ops/int4_per_token_head.py
@triton.jit
def _reshape_cache_int4_kernel(
    key_ptr,
    value_ptr,
    key_cache_ptr,
    value_cache_ptr,
    k_scale_cache_ptr,
    v_scale_cache_ptr,
    slot_mapping_ptr,
    stride_key_tok: tl.int64,
    stride_key_head: tl.int64,
    stride_val_tok: tl.int64,
    stride_val_head: tl.int64,
    stride_kc_blk: tl.int64,
    stride_kc_slot: tl.int64,
    stride_kc_head: tl.int64,
    stride_vc_blk: tl.int64,
    stride_vc_slot: tl.int64,
    stride_vc_head: tl.int64,
    stride_ks_blk: tl.int64,
    stride_ks_slot: tl.int64,
    stride_ks_head: tl.int64,
    stride_vs_blk: tl.int64,
    stride_vs_slot: tl.int64,
    stride_vs_head: tl.int64,
    block_size: tl.constexpr,
    head_size: tl.constexpr,
    head_size_v: tl.constexpr,
    PACKED_HEAD_PADDED: tl.constexpr,
):
    """INT4 asymmetric quantization with zero-point steganography."""
    tok = tl.program_id(0)
    head = tl.program_id(1)

    slot = tl.load(slot_mapping_ptr + tok).to(tl.int64)
    if slot < 0:
        return

    blk = slot // block_size
    slot_in_blk = slot % block_size

    half_offs = tl.arange(0, PACKED_HEAD_PADDED)
    even_offs = half_offs * 2
    odd_offs = half_offs * 2 + 1

    half_k = head_size // 2
    even_k_mask = even_offs < head_size
    odd_k_mask = odd_offs < head_size
    key_base = key_ptr + tok * stride_key_tok + head * stride_key_head

    k_even = tl.load(key_base + even_offs, mask=even_k_mask, other=0.0).to(tl.float32)
    k_odd = tl.load(key_base + odd_offs, mask=odd_k_mask, other=0.0).to(tl.float32)

    k_min = tl.minimum(
        tl.min(tl.where(even_k_mask, k_even, float("inf"))),
        tl.min(tl.where(odd_k_mask, k_odd, float("inf"))),
    )
    k_max = tl.maximum(
        tl.max(tl.where(even_k_mask, k_even, float("-inf"))),
        tl.max(tl.where(odd_k_mask, k_odd, float("-inf"))),
    )
    k_scale = tl.maximum((k_max - k_min) / 15.0, 1e-6)
    k_zp_f = tl.clamp(
        tl.where(
            -k_min / k_scale >= 0,
            (-k_min / k_scale + 0.5).to(tl.int32),
            (-k_min / k_scale - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )

    inv_k = 1.0 / k_scale
    k_even_s = k_even * inv_k + k_zp_f
    k_odd_s = k_odd * inv_k + k_zp_f
    k_even_q = tl.clamp(
        tl.where(
            k_even_s >= 0,
            (k_even_s + 0.5).to(tl.int32),
            (k_even_s - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )
    k_odd_q = tl.clamp(
        tl.where(
            k_odd_s >= 0,
            (k_odd_s + 0.5).to(tl.int32),
            (k_odd_s - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )

    k_zp_int = k_zp_f.to(tl.int32)
    k_scale_bits = k_scale.to(tl.int32, bitcast=True)
    k_scale_packed = ((k_scale_bits & -16) | (k_zp_int & 0xF)).to(
        tl.float32, bitcast=True
    )

    tl.store(
        k_scale_cache_ptr
        + blk * stride_ks_blk
        + slot_in_blk * stride_ks_slot
        + head * stride_ks_head,
        k_scale_packed,
    )

    k_packed = pack_int4_nibbles(k_even_q.to(tl.uint8), k_odd_q.to(tl.uint8))
    tl.store(
        key_cache_ptr
        + blk * stride_kc_blk
        + slot_in_blk * stride_kc_slot
        + head * stride_kc_head
        + half_offs,
        k_packed,
        mask=half_offs < half_k,
    )

    half_v = head_size_v // 2
    even_v_mask = even_offs < head_size_v
    odd_v_mask = odd_offs < head_size_v
    val_base = value_ptr + tok * stride_val_tok + head * stride_val_head

    v_even = tl.load(val_base + even_offs, mask=even_v_mask, other=0.0).to(tl.float32)
    v_odd = tl.load(val_base + odd_offs, mask=odd_v_mask, other=0.0).to(tl.float32)

    v_min = tl.minimum(
        tl.min(tl.where(even_v_mask, v_even, float("inf"))),
        tl.min(tl.where(odd_v_mask, v_odd, float("inf"))),
    )
    v_max = tl.maximum(
        tl.max(tl.where(even_v_mask, v_even, float("-inf"))),
        tl.max(tl.where(odd_v_mask, v_odd, float("-inf"))),
    )
    v_scale = tl.maximum((v_max - v_min) / 15.0, 1e-6)
    v_zp_f = tl.clamp(
        tl.where(
            -v_min / v_scale >= 0,
            (-v_min / v_scale + 0.5).to(tl.int32),
            (-v_min / v_scale - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )

    inv_v = 1.0 / v_scale
    v_even_s = v_even * inv_v + v_zp_f
    v_odd_s = v_odd * inv_v + v_zp_f
    v_even_q = tl.clamp(
        tl.where(
            v_even_s >= 0,
            (v_even_s + 0.5).to(tl.int32),
            (v_even_s - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )
    v_odd_q = tl.clamp(
        tl.where(
            v_odd_s >= 0,
            (v_odd_s + 0.5).to(tl.int32),
            (v_odd_s - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )

    v_zp_int = v_zp_f.to(tl.int32)
    v_scale_bits = v_scale.to(tl.int32, bitcast=True)
    v_scale_packed = ((v_scale_bits & -16) | (v_zp_int & 0xF)).to(
        tl.float32, bitcast=True
    )

    tl.store(
        v_scale_cache_ptr
        + blk * stride_vs_blk
        + slot_in_blk * stride_vs_slot
        + head * stride_vs_head,
        v_scale_packed,
    )

    v_packed = pack_int4_nibbles(v_even_q.to(tl.uint8), v_odd_q.to(tl.uint8))
    tl.store(
        value_cache_ptr
        + blk * stride_vc_blk
        + slot_in_blk * stride_vc_slot
        + head * stride_vc_head
        + half_offs,
        v_packed,
        mask=half_offs < half_v,
    )

_run_reshape_kernel(kernel, *, key, value, key_cache, value_cache, k_scale_cache, v_scale_cache, slot_mapping, packing_factor)

Launch the packed INT4 reshape kernel.

Source code in vllm/v1/attention/ops/int4_per_token_head.py
def _run_reshape_kernel(
    kernel,
    *,
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    k_scale_cache: torch.Tensor,
    v_scale_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    packing_factor: int,
) -> None:
    """Launch the packed INT4 reshape kernel."""
    num_tokens, num_kv_heads, head_size = key.shape
    head_size_v = value.shape[2]
    assert head_size % packing_factor == 0 and head_size_v % packing_factor == 0
    packed_padded = triton.next_power_of_2(
        max(head_size, head_size_v) // packing_factor
    )
    if current_platform.is_rocm() or current_platform.is_xpu():
        num_warps = 4
    else:
        num_warps = min(16, max(1, packed_padded // 32))

    kernel[(num_tokens, num_kv_heads)](
        key_ptr=key,
        value_ptr=value,
        key_cache_ptr=key_cache,
        value_cache_ptr=value_cache,
        k_scale_cache_ptr=k_scale_cache,
        v_scale_cache_ptr=v_scale_cache,
        slot_mapping_ptr=slot_mapping,
        stride_key_tok=key.stride(0),
        stride_key_head=key.stride(1),
        stride_val_tok=value.stride(0),
        stride_val_head=value.stride(1),
        stride_kc_blk=key_cache.stride(0),
        stride_kc_slot=key_cache.stride(1),
        stride_kc_head=key_cache.stride(2),
        stride_vc_blk=value_cache.stride(0),
        stride_vc_slot=value_cache.stride(1),
        stride_vc_head=value_cache.stride(2),
        stride_ks_blk=k_scale_cache.stride(0),
        stride_ks_slot=k_scale_cache.stride(1),
        stride_ks_head=k_scale_cache.stride(2),
        stride_vs_blk=v_scale_cache.stride(0),
        stride_vs_slot=v_scale_cache.stride(1),
        stride_vs_head=v_scale_cache.stride(2),
        block_size=key_cache.shape[1],
        head_size=head_size,
        head_size_v=head_size_v,
        PACKED_HEAD_PADDED=packed_padded,
        num_warps=num_warps,
    )

fast_hadamard_transform(x)

Unnormalized Walsh-Hadamard Transform along the last dimension.

H_d × x where H_d × H_d = d × I. Last dim must be a power of 2.

Three-tier dispatch
  1. Hadacore CUDA Tensor Core kernel (sm_80+).
  2. Triton MMA matmul kernel (CUDA fallback + ROCm MFMA/WMMA path).
  3. PyTorch butterfly (CPU and any GPU/dtype combo Triton can't take).
Source code in vllm/v1/attention/ops/int4_per_token_head.py
def fast_hadamard_transform(x: torch.Tensor) -> torch.Tensor:
    """Unnormalized Walsh-Hadamard Transform along the last dimension.

    H_d × x where H_d × H_d = d × I.  Last dim must be a power of 2.

    Three-tier dispatch:
      1. Hadacore CUDA Tensor Core kernel (sm_80+).
      2. Triton MMA matmul kernel (CUDA fallback + ROCm MFMA/WMMA path).
      3. PyTorch butterfly (CPU and any GPU/dtype combo Triton can't take).
    """
    d = x.shape[-1]
    assert d & (d - 1) == 0, f"Requires power-of-2 dim, got {d}"

    # Tier 1 — hadacore on CUDA.
    if _hadacore_available() and 0 < d <= (1 << 15):
        from vllm import _custom_ops as ops

        # hadacore returns x @ (H/√d); rescale to the unnormalized H × x
        # convention the INT4 scale math is calibrated to.
        rescale = d**0.5
        if x.dtype in (torch.float16, torch.bfloat16):
            y = ops.hadacore_transform(x.contiguous().clone(), inplace=True)
            return y * rescale
        # fp32 → bf16 round-trip; precision loss is irrelevant before
        # INT4 quantization.
        orig_dtype = x.dtype
        x_bf16 = x.contiguous().to(torch.bfloat16)
        y_bf16 = ops.hadacore_transform(x_bf16, inplace=True)
        return y_bf16.to(orig_dtype) * rescale

    # Tier 2 — Triton MMA kernel (covers ROCm via MFMA/WMMA codegen, and
    # also CUDA when hadacore is unavailable).
    if (
        x.is_cuda
        and _TRITON_HADAMARD_MIN_D <= d <= _TRITON_HADAMARD_MAX_D
        and x.dtype in (torch.float16, torch.bfloat16, torch.float32)
    ):
        return _triton_hadamard_transform(x)

    # Tier 3 — PyTorch butterfly (CPU / unsupported dtype / D < 16).
    h = 1
    while h < d:
        xv = x.view(*x.shape[:-1], d // (2 * h), 2, h)
        a = xv[..., 0, :]
        b = xv[..., 1, :]
        x = torch.stack([a + b, a - b], dim=-2).reshape(x.shape)
        h <<= 1
    return x

pack_int4_nibbles(lo, hi)

Pack two uint8 values (each in [0, 15]) into one byte.

Source code in vllm/v1/attention/ops/int4_per_token_head.py
@triton.jit
def pack_int4_nibbles(lo, hi):
    """Pack two uint8 values (each in [0, 15]) into one byte."""
    return (lo & 0xF) | ((hi & 0xF) << 4)

reshape_and_cache_int4(key, value, key_cache, value_cache, slot_mapping, *, k_scale_cache, v_scale_cache)

Pre-rotate (RHT), pack to INT4 and write into the paged cache.

Source code in vllm/v1/attention/ops/int4_per_token_head.py
def reshape_and_cache_int4(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    *,
    k_scale_cache: torch.Tensor,
    v_scale_cache: torch.Tensor,
) -> None:
    """Pre-rotate (RHT), pack to INT4 and write into the paged cache."""
    key = single_rht(key.float()).to(key.dtype)
    value = single_rht(value.float()).to(value.dtype)
    _run_reshape_kernel(
        _reshape_cache_int4_kernel,
        key=key,
        value=value,
        key_cache=key_cache,
        value_cache=value_cache,
        k_scale_cache=k_scale_cache,
        v_scale_cache=v_scale_cache,
        slot_mapping=slot_mapping,
        packing_factor=_INT4_PACKING_FACTOR,
    )

single_rht(x, inverse=False)

Single Randomized Hadamard Transform: H × D₁ × x.

Used by INT4 per-token-head quantization to gaussianize data before asymmetric quantization.

Source code in vllm/v1/attention/ops/int4_per_token_head.py
def single_rht(x: torch.Tensor, inverse: bool = False) -> torch.Tensor:
    """Single Randomized Hadamard Transform: H × D₁ × x.

    Used by INT4 per-token-head quantization to gaussianize data
    before asymmetric quantization.
    """
    d = x.shape[-1]
    d1 = _get_rht_signs(d, 0, x.device)
    if inverse:
        return fast_hadamard_transform(x) * d1
    else:
        return fast_hadamard_transform(x * d1)

unified_attention_int4(q, k_cache, v_cache, out, *, cu_seqlens_q, max_seqlen_q, seqused_k, max_seqlen_k, softmax_scale, window_size, block_table, softcap, sinks, alibi_slopes, use_alibi_sqrt, qq_bias, output_scale, mm_prefix_range, k_scale_cache, v_scale_cache, seq_threshold_3D=None, num_par_softmax_segments=None, softmax_segm_output=None, softmax_segm_max=None, softmax_segm_expsum=None)

Paged attention over the INT4 packed cache, writing into out.

The forward RHT has norm sqrt(head_size), so softmax_scale is divided by head_size and the inverse RHT divides the output by head_size as well.

Source code in vllm/v1/attention/ops/int4_per_token_head.py
def unified_attention_int4(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
    out: torch.Tensor,
    *,
    cu_seqlens_q: torch.Tensor,
    max_seqlen_q: int,
    seqused_k: torch.Tensor,
    max_seqlen_k: int,
    softmax_scale: float,
    window_size: tuple[int, int],
    block_table: torch.Tensor,
    softcap: float,
    sinks: torch.Tensor | None,
    alibi_slopes: torch.Tensor | None,
    use_alibi_sqrt: bool,
    qq_bias: torch.Tensor | None,
    output_scale: torch.Tensor | None,
    mm_prefix_range: torch.Tensor | None,
    k_scale_cache: torch.Tensor,
    v_scale_cache: torch.Tensor,
    seq_threshold_3D: int | None = None,
    num_par_softmax_segments: int | None = None,
    softmax_segm_output: torch.Tensor | None = None,
    softmax_segm_max: torch.Tensor | None = None,
    softmax_segm_expsum: torch.Tensor | None = None,
) -> None:
    """Paged attention over the INT4 packed cache, writing into *out*.

    The forward RHT has norm ``sqrt(head_size)``, so ``softmax_scale`` is
    divided by ``head_size`` and the inverse RHT divides the output by
    ``head_size`` as well.
    """
    q_orig_dtype = q.dtype
    q = single_rht(q.float()).to(q_orig_dtype)
    head_size = q.shape[2]
    softmax_scale = softmax_scale / head_size

    _launch_packed_attn(
        q=q,
        k_cache=k_cache,
        v_cache=v_cache,
        out=out,
        cu_seqlens_q=cu_seqlens_q,
        max_seqlen_q=max_seqlen_q,
        seqused_k=seqused_k,
        softmax_scale=softmax_scale,
        window_size=window_size,
        block_table=block_table,
        softcap=softcap,
        sinks=sinks,
        alibi_slopes=alibi_slopes,
        use_alibi_sqrt=use_alibi_sqrt,
        qq_bias=qq_bias,
        output_scale=output_scale,
        mm_prefix_range=mm_prefix_range,
        k_scale_cache=k_scale_cache,
        v_scale_cache=v_scale_cache,
        seq_threshold_3D=seq_threshold_3D,
        num_par_softmax_segments=num_par_softmax_segments,
        softmax_segm_output=softmax_segm_output,
        softmax_segm_max=softmax_segm_max,
        softmax_segm_expsum=softmax_segm_expsum,
        packing_factor=_INT4_PACKING_FACTOR,
    )

    out_f = single_rht(out.float(), inverse=True) / head_size
    out.copy_(out_f.to(q_orig_dtype))

unpack_int4_nibbles(packed)

Split one packed byte into the (low, high) nibble pair as uint8.

Source code in vllm/v1/attention/ops/int4_per_token_head.py
@triton.jit
def unpack_int4_nibbles(packed):
    """Split one packed byte into the (low, high) nibble pair as uint8."""
    return packed & 0xF, (packed >> 4) & 0xF