Skip to content

vllm.models.deepseek_v32.nvidia.fused_ops

Fused ops for deepseek_v32 (eager / breakable-cudagraph path).

These recover fusions that vLLM's torch.compile passes would normally do but that don't fire when running eager under the breakable CUDA graph.

Functions:

fused_allreduce_rms_norm(hidden_states, residual, norm)

All-reduce + add residual + (standard) RMSNorm, fused via flashinfer.

hidden_states is the per-rank partial output of a row-parallel linear run with reduce_results=False; norm is the RMSNorm applied right after. Returns (normed_output, new_residual), equivalent to norm(all_reduce(hidden_states), residual). Falls back to an explicit all-reduce + RMSNorm when the flashinfer fast path is unavailable.

Source code in vllm/models/deepseek_v32/nvidia/fused_ops.py
def fused_allreduce_rms_norm(
    hidden_states: torch.Tensor,
    residual: torch.Tensor,
    norm: RMSNorm,
) -> tuple[torch.Tensor, torch.Tensor]:
    """All-reduce + add residual + (standard) RMSNorm, fused via flashinfer.

    ``hidden_states`` is the per-rank *partial* output of a row-parallel linear
    run with ``reduce_results=False``; ``norm`` is the RMSNorm applied right
    after. Returns ``(normed_output, new_residual)``, equivalent to
    ``norm(all_reduce(hidden_states), residual)``. Falls back to an explicit
    all-reduce + RMSNorm when the flashinfer fast path is unavailable.
    """
    tp_size = get_tensor_model_parallel_world_size()
    if tp_size == 1:
        return norm(hidden_states, residual)

    if flashinfer_trtllm_fused_allreduce_norm is not None:
        ok, max_token_num = _can_use_flashinfer(hidden_states, tp_size)
        if ok:
            norm_out = torch.empty_like(hidden_states)
            # With norm_out provided, the kernel writes the new residual
            # (all_reduce(hidden_states) + residual) into the hidden_states
            # buffer and the normalized result into norm_out.
            flashinfer_trtllm_fused_allreduce_norm(
                allreduce_in=hidden_states,
                residual=residual,
                rms_gamma=norm.weight,
                rms_eps=norm.variance_epsilon,
                world_size=tp_size,
                weight_bias=0.0,  # standard RMSNorm (Gemma would use 1.0)
                launch_with_pdl=True,
                fp32_acc=True,
                max_token_num=max_token_num,
                pattern_code=_AR_RESIDUAL_RMS_NORM,
                norm_out=norm_out,
            )
            return norm_out, hidden_states

    reduced = tensor_model_parallel_all_reduce(hidden_states)
    return norm(reduced, residual)