Manual fusion of tensor-parallel all-reduce with the following GemmaRMSNorm.
Under tensor parallelism a RowParallelLinear (e.g. attention o_proj) produces a per-rank partial sum that is all-reduced, and the result is then fed into a GemmaRMSNorm that adds the residual and normalizes. flashinfer ships a kernel that fuses all-reduce + residual-add + RMSNorm into a single launch; this helper drives it directly (no torch.compile pass) for models that run eager.
Scope: attention output only, no quantization. When the flashinfer fast path is not applicable (TP==1, flashinfer/NVSwitch unavailable, unsupported dtype, or an oversize batch) it falls back to all_reduce + GemmaRMSNorm, which is numerically identical to the unfused model path.
Functions:
_can_use_flashinfer(hidden_states, tp_size)
Whether the flashinfer fused path applies; returns (ok, max_token_num).
Source code in vllm/model_executor/layers/fused_allreduce_gemma_rms_norm.py
| def _can_use_flashinfer(hidden_states: torch.Tensor, tp_size: int) -> tuple[bool, int]:
"""Whether the flashinfer fused path applies; returns (ok, max_token_num)."""
if (
flashinfer_trtllm_fused_allreduce_norm is None
or get_fi_ar_workspace is None
or _AR_RESIDUAL_RMS_NORM is None
):
return False, 0
if (
not hidden_states.is_cuda
or hidden_states.dim() != 2
or not hidden_states.is_contiguous()
or hidden_states.dtype not in _FI_SUPPORTED_DTYPES
):
return False, 0
num_tokens, hidden_size = hidden_states.shape
max_token_num = _max_token_num(tp_size, hidden_size, hidden_states.dtype)
if max_token_num is None or num_tokens > max_token_num:
return False, 0
# Lazily create / fetch the (globally cached) workspace; returns None on
# GPUs without NVSwitch, in which case we fall back gracefully.
workspace = get_fi_ar_workspace(
world_size=tp_size,
rank=get_tensor_model_parallel_rank(),
max_token_num=max_token_num,
hidden_dim=hidden_size,
dtype=hidden_states.dtype,
group=get_tp_group().device_group,
)
if workspace is None:
return False, 0
return True, max_token_num
|
_max_token_num(tp_size, hidden_size, dtype)
Workspace token budget for flashinfer fused all-reduce, or None if the current world size / device is unsupported. Mirrors FlashInferAllReduce.
Source code in vllm/model_executor/layers/fused_allreduce_gemma_rms_norm.py
| def _max_token_num(tp_size: int, hidden_size: int, dtype: torch.dtype) -> int | None:
"""Workspace token budget for flashinfer fused all-reduce, or None if the
current world size / device is unsupported. Mirrors ``FlashInferAllReduce``."""
from vllm.config.compilation import PassConfig
max_size_mb = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(tp_size)
if not max_size_mb:
return None
element_size = torch.tensor([], dtype=dtype).element_size()
return int(max_size_mb * MiB) // (hidden_size * element_size)
|
fused_allreduce_gemma_rms_norm(hidden_states, residual, norm)
All-reduce hidden_states + add residual + GemmaRMSNorm, fused.
hidden_states is the per-rank partial (un-reduced) output of a row-parallel linear; norm is the GemmaRMSNorm applied right after. Returns (normed_output, new_residual), equivalent to norm(all_reduce(hidden_states), residual).
Source code in vllm/model_executor/layers/fused_allreduce_gemma_rms_norm.py
| def fused_allreduce_gemma_rms_norm(
hidden_states: torch.Tensor,
residual: torch.Tensor,
norm: GemmaRMSNorm,
) -> tuple[torch.Tensor, torch.Tensor]:
"""All-reduce ``hidden_states`` + add ``residual`` + GemmaRMSNorm, fused.
``hidden_states`` is the per-rank *partial* (un-reduced) output of a
row-parallel linear; ``norm`` is the GemmaRMSNorm applied right after.
Returns ``(normed_output, new_residual)``, equivalent to
``norm(all_reduce(hidden_states), residual)``.
"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No all-reduce needed; identical to the unfused path.
return norm(hidden_states, residual)
ok, max_token_num = _can_use_flashinfer(hidden_states, tp_size)
if ok:
norm_out = torch.empty_like(hidden_states)
# With norm_out provided, the kernel writes the new residual
# (all_reduce(hidden_states) + residual) into the hidden_states buffer
# and the normalized result into norm_out, leaving `residual` untouched.
flashinfer_trtllm_fused_allreduce_norm(
allreduce_in=hidden_states,
residual=residual,
rms_gamma=norm.weight,
rms_eps=norm.variance_epsilon,
world_size=tp_size,
weight_bias=1.0, # GemmaRMSNorm-style
launch_with_pdl=True,
fp32_acc=True,
max_token_num=max_token_num,
pattern_code=_AR_RESIDUAL_RMS_NORM,
norm_out=norm_out,
)
return norm_out, hidden_states
# Fallback: explicit all-reduce + GemmaRMSNorm (matches the unfused model).
reduced = tensor_model_parallel_all_reduce(hidden_states)
return norm(reduced, residual)
|