vllm.utils.deep_gemm ¶
Compatibility wrapper for DeepGEMM API changes.
Users of vLLM should always import only these wrappers.
Classes:
Functions:
-
calc_diff–Return a global difference metric for unit tests.
-
fp8_fp4_mqa_logits–Compute MQA logits for a single sequence without KV paging.
-
fp8_fp4_paged_mqa_logits–Compute MQA logits using a paged KV-cache.
-
get_col_major_tma_aligned_tensor–Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor
-
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor–Grouped (3D, expert-batched) variant of
-
get_mn_major_tma_aligned_packed_ue8m0_tensor–Pack UE8M0 (uint8) → int32 with the MN-major TMA-aligned layout the
-
get_paged_mqa_logits_metadata–Build scheduling metadata for paged MQA logits.
-
get_theoretical_mk_alignment_for_contiguous_layout–Per-call optimal M alignment for grouped contiguous GEMMs.
-
is_deep_gemm_e8m0_used–Return
Trueif vLLM is configured to use DeepGEMM " -
is_deep_gemm_supported–Return
Trueif DeepGEMM is supported on the current platform. -
pack_ue8m0_to_int–Pack 4 UE8M0 (uint8) scales into one int32.
DeepGemmQuantScaleFMT ¶
Bases: Enum
Methods:
-
from_oracle–Return the pre-initialized oracle decision
-
init_oracle_cache–Initialize the oracle decision and store it in the class cache
Source code in vllm/utils/deep_gemm.py
from_oracle() classmethod ¶
Return the pre-initialized oracle decision
Source code in vllm/utils/deep_gemm.py
init_oracle_cache() classmethod ¶
Initialize the oracle decision and store it in the class cache
Source code in vllm/utils/deep_gemm.py
_import_deep_gemm() cached ¶
Import the deep_gemm module.
Prefers an externally installed deep_gemm package (so users can pin a specific version), then falls back to the vendored copy bundled in the vLLM wheel.
Returns None when neither source is usable.
Source code in vllm/utils/deep_gemm.py
_lazy_init() ¶
Import deep_gemm and resolve symbols on first use.
Source code in vllm/utils/deep_gemm.py
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | |
_missing(*_, **__) ¶
Placeholder for unavailable DeepGEMM backend.
Source code in vllm/utils/deep_gemm.py
calc_diff(x, y) ¶
Return a global difference metric for unit tests.
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing torch.testing.assert_close to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report 1 - sim. Once kernel accuracy improves this helper can be removed.
Source code in vllm/utils/deep_gemm.py
fp8_fp4_mqa_logits(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits) ¶
Compute MQA logits for a single sequence without KV paging.
Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes q = (values, scales_or_None) where scales is None for FP8 Q (per-token scale is folded into weights) and a packed block-scale tensor for MXFP4 Q.
Parameters:
-
(q¶tuple[Tensor, Tensor | None]) –Tuple
(q_values, q_scale). FP8 path: q_values is [M, H, D] float8_e4m3fn and q_scale is None (per-token scale is folded intoweights). FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. -
(kv¶tuple[Tensor, Tensor]) –Tuple
(k_packed, k_scales)— FP8 layout is [N, D] float8_e4m3fn plus fp32 scales [N]; FP4 layout is packed uint8. -
(weights¶Tensor) –weights of shape [M, H], dtype
torch.float32. -
(cu_seqlen_ks¶Tensor) –Start indices (inclusive) for valid K per query position, shape [M], dtype int32.
-
(cu_seqlen_ke¶Tensor) –End indices (exclusive) for valid K per query position, shape [M], dtype int32.
-
(clean_logits¶bool) –Whether to clean the unfilled logits into
-inf.
Returns:
-
Tensor–Logits tensor of shape [M, N], dtype
torch.float32.
Source code in vllm/utils/deep_gemm.py
fp8_fp4_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits) ¶
Compute MQA logits using a paged KV-cache.
Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes q = (values, scales_or_None); pass (q_tensor, None) for the FP8 path and (q_values, q_scale) for MXFP4.
Parameters:
-
(q¶tuple[Tensor, Tensor | None]) –Tuple
(q_values, q_scale). FP8 path: q_values is [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. -
(kv_cache¶Tensor) –Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, D+4], dtype
torch.uint8, with the last 4 bytes per (block, pos) storing the float dequant scale. -
(weights¶Tensor) –Tensor of shape [B * next_n, H], dtype
torch.float32. -
(context_lens¶Tensor) –Tensor of shape [B], dtype int32; effective context length for each batch element.
-
(block_tables¶Tensor) –Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache.
-
(schedule_metadata¶Tensor) –Returned by
get_paged_mqa_logits_metadata; used to distribute work across SMs. -
(max_model_len¶int) –Maximum sequence length used to size the logits output.
-
(clean_logits¶bool) –Whether to clean the unfilled logits into
-inf.
Returns:
Source code in vllm/utils/deep_gemm.py
get_col_major_tma_aligned_tensor(x) ¶
Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor
Source code in vllm/utils/deep_gemm.py
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks, gran_k) ¶
Grouped (3D, expert-batched) variant of get_mn_major_tma_aligned_packed_ue8m0_tensor. Use for MoE weight scale tensors of shape (num_experts, mn, k_scale).
Source code in vllm/utils/deep_gemm.py
get_mn_major_tma_aligned_packed_ue8m0_tensor(x) ¶
Pack UE8M0 (uint8) → int32 with the MN-major TMA-aligned layout the DeepGEMM kernels consume directly. 16× smaller than the fp32 legacy SF format. Use for non-grouped 2D scale tensors.
Source code in vllm/utils/deep_gemm.py
get_paged_mqa_logits_metadata(context_lens, block_size, num_sms) ¶
Build scheduling metadata for paged MQA logits.
Parameters:
-
(context_lens¶Tensor) –Tensor of shape [B], dtype int32; effective context length per batch element.
-
(block_size¶int) –KV-cache block size in tokens (e.g., 64).
-
(num_sms¶int) –Number of SMs available. 132 for Hopper
Returns:
-
Tensor–Backend-specific tensor consumed by
fp8_fp4_paged_mqa_logitsto -
Tensor–schedule work across SMs.
Source code in vllm/utils/deep_gemm.py
get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None, num_groups=None) ¶
Per-call optimal M alignment for grouped contiguous GEMMs.
expected_m is the TOTAL routed tokens (sum across experts, typically M × num_topk). num_groups is the number of experts on this rank. The helper divides to recover per-expert em and picks an alignment based on data-driven thresholds (see deep_gemm runtime.hpp comments).
Older callers that omit num_groups are interpreted as passing already per-expert em (legacy behaviour preserved for backward compat).
Source code in vllm/utils/deep_gemm.py
is_deep_gemm_e8m0_used() cached ¶
Return True if vLLM is configured to use DeepGEMM " "E8M0 scale on a Hopper or Blackwell-class GPU.
Source code in vllm/utils/deep_gemm.py
is_deep_gemm_supported() cached ¶
Return True if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported.
Source code in vllm/utils/deep_gemm.py
mk_alignment_scope(value) ¶
Temporarily set DeepGEMM's BLOCK_M cap, restoring on exit.
Use around a sequence of grouped-contiguous GEMM calls whose workspace is padded to value (typically the per_call_align returned by compute_aligned_M_and_alignment).
Source code in vllm/utils/deep_gemm.py
pack_ue8m0_to_int(x) ¶
Pack 4 UE8M0 (uint8) scales into one int32.
DeepGEMM's SM100/SM120 FP8/FP4 kernels accept either float32 scales (legacy format, 4 B/scale) or int32 packed UE8M0 scales (1 B/scale after 4:1 packing — 4× smaller than the legacy fp32 representation).
Source code in vllm/utils/deep_gemm.py
set_mk_alignment_for_contiguous_layout(value) ¶
Set DeepGEMM's BLOCK_M cap for grouped contiguous GEMMs.
The DG heuristic constrains BLOCK_M ≤ this value when picking a kernel layout. Use this in concert with compute_aligned_M_and_alignment's per-call alignment so the workspace's per-expert padding matches the kernel's BLOCK_M; a mismatch leads to the scheduler reading the wrong expert_id from m_indices at m_block_idx * BLOCK_M stride and OOB-indexing the B-weights tensor (manifests as IMA under CUDA-graph replay).
Source code in vllm/utils/deep_gemm.py
should_auto_disable_deep_gemm(model_type) ¶
Check if DeepGemm should be auto-disabled for this model on Blackwell.
Returns True if the model is known to have accuracy degradation with DeepGemm's E8M0 scale format on Blackwell GPUs (SM100+).
Source code in vllm/utils/deep_gemm.py
tf32_hc_prenorm_gemm(x, fn, out, sqrsum, num_split) ¶
Perform the following computation
out = x.float() @ fn.T sqrsum = x.float().square().sum(-1)
See the caller function for shape requirement