Skip to content

vllm.model_executor.layers.quantization.utils.marlin_utils

Functions:

marlin_moe_intermediate_size(w1_packed, w2_packed)

Given Marlin packed weight matrices w1_packed, and w2_packed, return the MoE intermediate size N

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor):
    """
    Given Marlin packed weight matrices w1_packed, and w2_packed,
    return the MoE intermediate size N
    """
    marlin_tile_size = 16
    return w2_packed.size(1) * marlin_tile_size

marlin_pad_dim(x, size, padded)

Zero-pad the last dim from size to padded (activations K, bias N).

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_pad_dim(x: torch.Tensor, size: int, padded: int) -> torch.Tensor:
    """Zero-pad the last dim from size to padded (activations K, bias N)."""
    if padded == size:
        return x
    return torch.nn.functional.pad(x, (0, padded - size))

marlin_pad_qweight(qweight, size_n, size_k, padded_n, padded_k)

Zero-pad a GPTQ-layout packed weight (size_k / pack, size_n) for gptq_marlin_repack.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_pad_qweight(
    qweight: torch.Tensor, size_n: int, size_k: int, padded_n: int, padded_k: int
) -> torch.Tensor:
    """Zero-pad a GPTQ-layout packed weight (size_k / pack, size_n) for
    gptq_marlin_repack."""
    if (padded_n, padded_k) == (size_n, size_k):
        return qweight
    pack_factor = size_k // qweight.size(0)
    return torch.nn.functional.pad(
        qweight, (0, padded_n - size_n, 0, (padded_k - size_k) // pack_factor)
    )

marlin_pad_scales(scales, size_n, size_k, padded_n, padded_k, group_size)

Zero-pad weight scales (num_groups, size_n); call before marlin_permute_scales and pass the padded extents to it.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_pad_scales(
    scales: torch.Tensor,
    size_n: int,
    size_k: int,
    padded_n: int,
    padded_k: int,
    group_size: int,
) -> torch.Tensor:
    """Zero-pad weight scales (num_groups, size_n); call before
    marlin_permute_scales and pass the padded extents to it."""
    if (padded_n, padded_k) == (size_n, size_k):
        return scales
    pad_rows = padded_k // group_size - scales.size(0) if group_size > 0 else 0
    assert pad_rows >= 0
    return torch.nn.functional.pad(scales, (0, padded_n - size_n, 0, pad_rows))

marlin_padded_nk(size_n, size_k, group_size=-1)

Minimal (padded_n, padded_k) satisfying a Marlin thread-tile family.

Marlin GEMM and repack require (n % 64, k % 128) or (n % 128, k % 64); shapes satisfying neither are zero-padded up to the cheaper family. K stays divisible by group_size so padded scales keep an integral group count. Padded weight regions contribute nothing to the GEMM output: quantized value 0 decodes to 0.0 (FP4/FP8) or is cancelled by the zero-padded scales/zero-points (INT).

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_padded_nk(size_n: int, size_k: int, group_size: int = -1) -> tuple[int, int]:
    """Minimal (padded_n, padded_k) satisfying a Marlin thread-tile family.

    Marlin GEMM and repack require (n % 64, k % 128) or (n % 128, k % 64);
    shapes satisfying neither are zero-padded up to the cheaper family. K
    stays divisible by group_size so padded scales keep an integral group
    count. Padded weight regions contribute nothing to the GEMM output:
    quantized value 0 decodes to 0.0 (FP4/FP8) or is cancelled by the
    zero-padded scales/zero-points (INT).
    """
    group = group_size if group_size > 0 else 1
    candidates = (
        (round_up(size_n, 64), round_up(size_k, math.lcm(128, group))),
        (round_up(size_n, 128), round_up(size_k, math.lcm(64, group))),
    )
    padded_nk = min(candidates, key=lambda nk: (nk[0] * nk[1], nk[0] + nk[1]))
    if padded_nk != (size_n, size_k):
        logger.warning_once(
            "Marlin requires thread-tile padding for some weight shapes in "
            "this model. Activations and/or outputs of the padded layers are "
            "padded/sliced on every forward; performance may be degraded."
        )
    return padded_nk

marlin_repacked_nk(qweight, num_bits)

Recover the (size_n, size_k) a Marlin weight was repacked with (including any tile padding) from its packed shape.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_repacked_nk(qweight: torch.Tensor, num_bits: int) -> tuple[int, int]:
    """Recover the (size_n, size_k) a Marlin weight was repacked with
    (including any tile padding) from its packed shape."""
    pack_factor = 32 // num_bits
    size_k = qweight.size(0) * GPTQ_MARLIN_TILE
    size_n = qweight.size(1) * pack_factor // GPTQ_MARLIN_TILE
    return size_n, size_k

marlin_unpad_output(output, size_n, padded_n)

Strip padded output columns back to the logical N.

TODO: marlin_gemm could instead write the un-padded columns directly into a caller-provided c buffer so this slice copy disappears.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_unpad_output(
    output: torch.Tensor, size_n: int, padded_n: int
) -> torch.Tensor:
    """Strip padded output columns back to the logical N.

    TODO: marlin_gemm could instead write the un-padded columns directly
    into a caller-provided `c` buffer so this slice copy disappears.
    """
    if padded_n == size_n:
        return output
    return output[..., :size_n].contiguous()

moe_packed_to_marlin_zero_points(q_zp_packed, size_k, size_n, num_bits, is_a_8bit=False)

Convert compressed-tensors packed zero points to Marlin format.

Unlike AWQ, compressed-tensors uses standard bit packing without interleaving, so we just unpack and apply Marlin permutation directly.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def moe_packed_to_marlin_zero_points(
    q_zp_packed: torch.Tensor,
    size_k: int,
    size_n: int,
    num_bits: int,
    is_a_8bit: bool = False,
):
    """Convert compressed-tensors packed zero points to Marlin format.

    Unlike AWQ, compressed-tensors uses standard bit packing without
    interleaving, so we just unpack and apply Marlin permutation directly.
    """
    num_experts = q_zp_packed.shape[0]
    output = torch.empty(
        (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
        device=q_zp_packed.device,
        dtype=q_zp_packed.dtype,
    )
    for e in range(num_experts):
        q_zp = unpack_cols(q_zp_packed[e], num_bits, size_k, size_n)
        output[e] = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
    return output