All-reduce + add residual + (standard) RMSNorm, fused via flashinfer.
hidden_states is the per-rank partial output of a row-parallel linear run with reduce_results=False; norm is the RMSNorm applied right after. Returns (normed_output, new_residual), equivalent to norm(all_reduce(hidden_states), residual). Falls back to an explicit all-reduce + RMSNorm when the flashinfer fast path is unavailable.
Source code in vllm/models/deepseek_v32/nvidia/fused_ops.py
| def fused_allreduce_rms_norm(
hidden_states: torch.Tensor,
residual: torch.Tensor,
norm: RMSNorm,
) -> tuple[torch.Tensor, torch.Tensor]:
"""All-reduce + add residual + (standard) RMSNorm, fused via flashinfer.
``hidden_states`` is the per-rank *partial* output of a row-parallel linear
run with ``reduce_results=False``; ``norm`` is the RMSNorm applied right
after. Returns ``(normed_output, new_residual)``, equivalent to
``norm(all_reduce(hidden_states), residual)``. Falls back to an explicit
all-reduce + RMSNorm when the flashinfer fast path is unavailable.
"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
return norm(hidden_states, residual)
if flashinfer_trtllm_fused_allreduce_norm is not None:
ok, max_token_num = _can_use_flashinfer(hidden_states, tp_size)
if ok:
norm_out = torch.empty_like(hidden_states)
# With norm_out provided, the kernel writes the new residual
# (all_reduce(hidden_states) + residual) into the hidden_states
# buffer and the normalized result into norm_out.
flashinfer_trtllm_fused_allreduce_norm(
allreduce_in=hidden_states,
residual=residual,
rms_gamma=norm.weight,
rms_eps=norm.variance_epsilon,
world_size=tp_size,
weight_bias=0.0, # standard RMSNorm (Gemma would use 1.0)
launch_with_pdl=True,
fp32_acc=True,
max_token_num=max_token_num,
pattern_code=_AR_RESIDUAL_RMS_NORM,
norm_out=norm_out,
)
return norm_out, hidden_states
reduced = tensor_model_parallel_all_reduce(hidden_states)
return norm(reduced, residual)
|