Skip to content

vllm.model_executor.layers.minimax_rms_norm.rms_norm_tp

_all_reduce_variance(var)

All-reduce a per-token variance tensor across the TP group.

Variance is accumulated in fp32 for numerical stability. The FlashInfer fused all-reduce caches a single global workspace keyed to the model's 16-bit activation dtype (use_fp32_lamport=False); routing an fp32 reduction through it would read against a mismatched workspace and corrupt the result. FlashInfer's fast-path only triggers for 2D inputs, so reducing a flattened (1D) view keeps these fp32 reductions on custom all-reduce / pynccl, both of which handle fp32 correctly.

Source code in vllm/model_executor/layers/minimax_rms_norm/rms_norm_tp.py
def _all_reduce_variance(var: torch.Tensor) -> torch.Tensor:
    """All-reduce a per-token variance tensor across the TP group.

    Variance is accumulated in fp32 for numerical stability. The FlashInfer
    fused all-reduce caches a single global workspace keyed to the model's
    16-bit activation dtype (``use_fp32_lamport=False``); routing an fp32
    reduction through it would read against a mismatched workspace and corrupt
    the result. FlashInfer's fast-path only triggers for 2D inputs, so reducing
    a flattened (1D) view keeps these fp32 reductions on custom all-reduce /
    pynccl, both of which handle fp32 correctly.
    """
    return tensor_model_parallel_all_reduce(var.flatten()).view_as(var)