Skip to content

vllm.models.deepseek_v4.nvidia.ops.o_proj

compute_fp8_einsum_recipe

compute_fp8_einsum_recipe() -> tuple[
    tuple[int, int, int], bool
]

fp8_einsum recipe + scale layout for the current GPU arch.

SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128. SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1.

Returns (einsum_recipe, tma_aligned_scales) for deep_gemm_fp8_o_proj.

Source code in vllm/models/deepseek_v4/nvidia/ops/o_proj.py
def compute_fp8_einsum_recipe() -> tuple[tuple[int, int, int], bool]:
    """fp8_einsum recipe + scale layout for the current GPU arch.

    SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128.
    SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1.

    Returns ``(einsum_recipe, tma_aligned_scales)`` for ``deep_gemm_fp8_o_proj``.
    """
    cap = current_platform.get_device_capability()
    assert cap is not None, "DeepseekV4 attention requires a CUDA device"
    einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
    tma_aligned_scales = cap.major >= 10
    return einsum_recipe, tma_aligned_scales

deep_gemm_fp8_o_proj

deep_gemm_fp8_o_proj(
    o: Tensor,
    positions: Tensor,
    cos_sin_cache: Tensor,
    wo_a: Module,
    wo_b: Module,
    *,
    n_groups: int,
    heads_per_group: int,
    nope_dim: int,
    rope_dim: int,
    o_lora_rank: int,
    einsum_recipe: tuple[int, int, int],
    tma_aligned_scales: bool,
) -> Tensor

O projection: inverse RoPE + FP8 quant + einsum + wo_b.

Shared by the FlashMLA and FlashInfer CUDA backends. einsum_recipe / tma_aligned_scales come from compute_fp8_einsum_recipe.

Source code in vllm/models/deepseek_v4/nvidia/ops/o_proj.py
def deep_gemm_fp8_o_proj(
    o: torch.Tensor,
    positions: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    wo_a: nn.Module,
    wo_b: nn.Module,
    *,
    n_groups: int,
    heads_per_group: int,
    nope_dim: int,
    rope_dim: int,
    o_lora_rank: int,
    einsum_recipe: tuple[int, int, int],
    tma_aligned_scales: bool,
) -> torch.Tensor:
    """O projection: inverse RoPE + FP8 quant + einsum + wo_b.

    Shared by the FlashMLA and FlashInfer CUDA backends. ``einsum_recipe`` /
    ``tma_aligned_scales`` come from ``compute_fp8_einsum_recipe``.
    """
    o_fp8, o_scale = fused_inv_rope_fp8_quant(
        o,
        positions,
        cos_sin_cache,
        n_groups=n_groups,
        heads_per_group=heads_per_group,
        nope_dim=nope_dim,
        rope_dim=rope_dim,
        tma_aligned_scales=tma_aligned_scales,
    )
    z = torch.empty(
        (o.shape[0], n_groups, o_lora_rank),
        device=o.device,
        dtype=torch.bfloat16,
    )
    fp8_einsum(
        "bhr,hdr->bhd",
        (o_fp8, o_scale),
        (wo_a.weight, wo_a.weight_scale_inv),
        z,
        recipe=einsum_recipe,
    )
    return wo_b(z.flatten(1))