Skip to content

vllm.model_executor.layers.quantization.utils.marlin_utils

Functions:

check_moe_marlin_supports_layer(layer, group_size, allow_tile_padding=False)

Whether the fused MoE Marlin kernel supports layer.

Callers without act-order may pass allow_tile_padding=True: a tile-misaligned intermediate size is then zero-padded to a valid thread tile at weight prep (see marlin_moe_padded_intermediate), so only a group straddling the padded boundary stays unsupported. hidden_size is the MoE I/O extent and is never padded. Act-order keeps the strict shape.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def check_moe_marlin_supports_layer(
    layer: RoutedExperts, group_size: int, allow_tile_padding: bool = False
) -> bool:
    """Whether the fused MoE Marlin kernel supports ``layer``.

    Callers without act-order may pass ``allow_tile_padding=True``: a
    tile-misaligned intermediate size is then zero-padded to a valid thread
    tile at weight prep (see marlin_moe_padded_intermediate), so only a group
    straddling the padded boundary stays unsupported. hidden_size is the MoE
    I/O extent and is never padded. Act-order keeps the strict shape.
    """
    if current_platform.is_rocm():
        return False
    hidden_size = layer.hidden_size
    # The layer has not rounded intermediate_size yet; use the stable unpadded
    # size. gate-up needs n=2*intermediate % 128, down needs k=intermediate % 64.
    intermediate_size_per_partition = (
        layer.moe_config.intermediate_size_per_partition_unpadded
    )
    assert intermediate_size_per_partition is not None
    # apply_router_weight_on_input is not supported for moe marlin
    supports_router_weight = not layer.apply_router_weight_on_input

    if allow_tile_padding:
        supports_shape = hidden_size % 128 == 0 and (
            group_size <= 0 or intermediate_size_per_partition % group_size == 0
        )
    else:
        supports_shape = (
            hidden_size % 128 == 0
            and intermediate_size_per_partition % max(64, group_size) == 0
        )
    supports_group_size = group_size in [-1, 32, 64, 128]
    return supports_shape and supports_group_size and supports_router_weight

marlin_moe_intermediate_size(w1_packed, w2_packed)

Given Marlin packed weight matrices w1_packed, and w2_packed, return the MoE intermediate size N

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor):
    """
    Given Marlin packed weight matrices w1_packed, and w2_packed,
    return the MoE intermediate size N
    """
    marlin_tile_size = 16
    return w2_packed.size(1) * marlin_tile_size

marlin_moe_padded_intermediate(intermediate_size, group_size=-1)

Smallest MoE intermediate size satisfying the Marlin MoE thread tiles.

The kernel needs gate-up 2 * intermediate % 128 == 0 and down intermediate % 64 == 0, i.e. intermediate % 64 == 0. A misaligned size is zero-padded to the next valid tile at weight prep, kept a multiple of group_size so the group count stays integral. The padded region never reaches the MoE output: w13's padded output channels are zeroed by the zero-padded scales, so the padded inputs to w2 are zero.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_moe_padded_intermediate(intermediate_size: int, group_size: int = -1) -> int:
    """Smallest MoE intermediate size satisfying the Marlin MoE thread tiles.

    The kernel needs gate-up ``2 * intermediate % 128 == 0`` and down
    ``intermediate % 64 == 0``, i.e. ``intermediate % 64 == 0``. A misaligned
    size is zero-padded to the next valid tile at weight prep, kept a multiple
    of ``group_size`` so the group count stays integral. The padded region never
    reaches the MoE output: w13's padded output channels are zeroed by the
    zero-padded scales, so the padded inputs to w2 are zero.
    """
    group = group_size if group_size > 0 else 1
    padded = round_up(intermediate_size, math.lcm(64, group))
    if padded != intermediate_size:
        logger.warning_once(
            "Marlin requires thread-tile padding for the MoE intermediate size "
            "of some layers in this model. Padded experts pad/slice activations "
            "on every forward; performance may be degraded."
        )
    return padded

marlin_pad_dim(x, size, padded)

Zero-pad the last dim from size to padded (activations K, bias N).

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_pad_dim(x: torch.Tensor, size: int, padded: int) -> torch.Tensor:
    """Zero-pad the last dim from size to padded (activations K, bias N)."""
    if padded == size:
        return x
    return torch.nn.functional.pad(x, (0, padded - size))

marlin_pad_qweight(qweight, size_n, size_k, padded_n, padded_k)

Zero-pad a GPTQ-layout packed weight (size_k / pack, size_n) for gptq_marlin_repack.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_pad_qweight(
    qweight: torch.Tensor, size_n: int, size_k: int, padded_n: int, padded_k: int
) -> torch.Tensor:
    """Zero-pad a GPTQ-layout packed weight (size_k / pack, size_n) for
    gptq_marlin_repack."""
    if (padded_n, padded_k) == (size_n, size_k):
        return qweight
    pack_factor = size_k // qweight.size(0)
    return torch.nn.functional.pad(
        qweight, (0, padded_n - size_n, 0, (padded_k - size_k) // pack_factor)
    )

marlin_pad_scales(scales, size_n, size_k, padded_n, padded_k, group_size)

Zero-pad weight scales (num_groups, size_n); call before marlin_permute_scales and pass the padded extents to it.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_pad_scales(
    scales: torch.Tensor,
    size_n: int,
    size_k: int,
    padded_n: int,
    padded_k: int,
    group_size: int,
) -> torch.Tensor:
    """Zero-pad weight scales (num_groups, size_n); call before
    marlin_permute_scales and pass the padded extents to it."""
    if (padded_n, padded_k) == (size_n, size_k):
        return scales
    pad_rows = padded_k // group_size - scales.size(0) if group_size > 0 else 0
    assert pad_rows >= 0
    return torch.nn.functional.pad(scales, (0, padded_n - size_n, 0, pad_rows))

marlin_padded_nk(size_n, size_k, group_size=-1)

Minimal (padded_n, padded_k) satisfying a Marlin thread-tile family.

Marlin GEMM and repack require (n % 64, k % 128) or (n % 128, k % 64); shapes satisfying neither are zero-padded up to the cheaper family. K stays divisible by group_size so padded scales keep an integral group count. Padded weight regions contribute nothing to the GEMM output: quantized value 0 decodes to 0.0 (FP4/FP8) or is cancelled by the zero-padded scales/zero-points (INT).

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_padded_nk(size_n: int, size_k: int, group_size: int = -1) -> tuple[int, int]:
    """Minimal (padded_n, padded_k) satisfying a Marlin thread-tile family.

    Marlin GEMM and repack require (n % 64, k % 128) or (n % 128, k % 64);
    shapes satisfying neither are zero-padded up to the cheaper family. K
    stays divisible by group_size so padded scales keep an integral group
    count. Padded weight regions contribute nothing to the GEMM output:
    quantized value 0 decodes to 0.0 (FP4/FP8) or is cancelled by the
    zero-padded scales/zero-points (INT).
    """
    group = group_size if group_size > 0 else 1
    candidates = (
        (round_up(size_n, 64), round_up(size_k, math.lcm(128, group))),
        (round_up(size_n, 128), round_up(size_k, math.lcm(64, group))),
    )
    padded_nk = min(candidates, key=lambda nk: (nk[0] * nk[1], nk[0] + nk[1]))
    if padded_nk != (size_n, size_k):
        logger.warning_once(
            "Marlin requires thread-tile padding for some weight shapes in "
            "this model. Activations and/or outputs of the padded layers are "
            "padded/sliced on every forward; performance may be degraded."
        )
    return padded_nk

marlin_repacked_nk(qweight, num_bits)

Recover the (size_n, size_k) a Marlin weight was repacked with (including any tile padding) from its packed shape.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_repacked_nk(qweight: torch.Tensor, num_bits: int) -> tuple[int, int]:
    """Recover the (size_n, size_k) a Marlin weight was repacked with
    (including any tile padding) from its packed shape."""
    pack_factor = 32 // num_bits
    size_k = qweight.size(0) * GPTQ_MARLIN_TILE
    size_n = qweight.size(1) * pack_factor // GPTQ_MARLIN_TILE
    return size_n, size_k

marlin_unpad_output(output, size_n, padded_n)

Strip padded output columns back to the logical N.

TODO: marlin_gemm could instead write the un-padded columns directly into a caller-provided c buffer so this slice copy disappears.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def marlin_unpad_output(
    output: torch.Tensor, size_n: int, padded_n: int
) -> torch.Tensor:
    """Strip padded output columns back to the logical N.

    TODO: marlin_gemm could instead write the un-padded columns directly
    into a caller-provided `c` buffer so this slice copy disappears.
    """
    if padded_n == size_n:
        return output
    return output[..., :size_n].contiguous()

moe_packed_to_marlin_zero_points(q_zp_packed, size_k, size_n, num_bits, is_a_8bit=False)

Convert compressed-tensors packed zero points to Marlin format.

Unlike AWQ, compressed-tensors uses standard bit packing without interleaving, so we just unpack and apply Marlin permutation directly.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils.py
def moe_packed_to_marlin_zero_points(
    q_zp_packed: torch.Tensor,
    size_k: int,
    size_n: int,
    num_bits: int,
    is_a_8bit: bool = False,
):
    """Convert compressed-tensors packed zero points to Marlin format.

    Unlike AWQ, compressed-tensors uses standard bit packing without
    interleaving, so we just unpack and apply Marlin permutation directly.
    """
    num_experts = q_zp_packed.shape[0]
    output = torch.empty(
        (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
        device=q_zp_packed.device,
        dtype=q_zp_packed.dtype,
    )
    for e in range(num_experts):
        q_zp = unpack_cols(q_zp_packed[e], num_bits, size_k, size_n)
        output[e] = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
    return output