Skip to content

vllm.distributed.eplb.eplb_utils

Utility functions for EPLB (Expert Parallel Load Balancing).

Classes:

  • CpuGpuEvent

    Combines a CUDA event with a CPU threading event to enforce record->wait

Functions:

CpuGpuEvent

Combines a CUDA event with a CPU threading event to enforce record->wait ordering across two threads.

This class is designed for exactly two threads: one producer that calls record() and one consumer that calls wait(). Using it with more than two threads is not supported and will produce undefined behavior.

CUDA events alone are insufficient for cross-thread synchronization because waiting on an unrecorded CUDA event is a no-op. The wait will return immediately instead of blocking. This class adds a threading.Event so that the waiting thread blocks on the CPU side until record() is called, at which point the CUDA event is guaranteed to be in-flight and event.wait() will correctly synchronize the GPU stream.

Methods:

  • record

    Unblocks the waiting thread after calling event.record().

  • wait

    Blocks the calling thread until record finishes. Used to guarantee that the

Source code in vllm/distributed/eplb/eplb_utils.py
class CpuGpuEvent:
    """
    Combines a CUDA event with a CPU threading event to enforce record->wait
    ordering across two threads.

    This class is designed for exactly two threads: one producer that calls
    record() and one consumer that calls wait(). Using it with more than two
    threads is not supported and will produce undefined behavior.

    CUDA events alone are insufficient for cross-thread synchronization because
    waiting on an unrecorded CUDA event is a no-op. The wait will return
    immediately instead of blocking. This class adds a threading.Event so
    that the waiting thread blocks on the CPU side until record() is called, at
    which point the CUDA event is guaranteed to be in-flight and event.wait() will
    correctly synchronize the GPU stream.
    """

    def __init__(self):
        self._event = torch.cuda.Event()
        self._recorded = threading.Event()

    def wait(self, stream: torch.cuda.Stream | None = None):
        """
        Blocks the calling thread until record finishes. Used to guarantee that the
        record kernel is called before wait.

        Should only be called by the Async Eplb thread.
        """
        self._recorded.wait()
        self._event.wait(stream)
        self._recorded.clear()

    def record(self, stream: torch.cuda.Stream | None = None):
        """
        Unblocks the waiting thread after calling event.record().

        Should only be called by the main thread.
        """
        if self._recorded.is_set():
            raise RuntimeError(
                "CpuGpuEvent.record() called before the previous event was "
                "consumed by wait()"
            )
        self._event = torch.cuda.Event()
        self._event.record(stream)
        self._recorded.set()

record(stream=None)

Unblocks the waiting thread after calling event.record().

Should only be called by the main thread.

Source code in vllm/distributed/eplb/eplb_utils.py
def record(self, stream: torch.cuda.Stream | None = None):
    """
    Unblocks the waiting thread after calling event.record().

    Should only be called by the main thread.
    """
    if self._recorded.is_set():
        raise RuntimeError(
            "CpuGpuEvent.record() called before the previous event was "
            "consumed by wait()"
        )
    self._event = torch.cuda.Event()
    self._event.record(stream)
    self._recorded.set()

wait(stream=None)

Blocks the calling thread until record finishes. Used to guarantee that the record kernel is called before wait.

Should only be called by the Async Eplb thread.

Source code in vllm/distributed/eplb/eplb_utils.py
def wait(self, stream: torch.cuda.Stream | None = None):
    """
    Blocks the calling thread until record finishes. Used to guarantee that the
    record kernel is called before wait.

    Should only be called by the Async Eplb thread.
    """
    self._recorded.wait()
    self._event.wait(stream)
    self._recorded.clear()

override_envs_for_eplb(parallel_config, moe_backend=None)

Override environment variables for EPLB when specific conditions are met.

Parameters:

  • parallel_config

    (ParallelConfig) –

    The parallel configuration object.

  • moe_backend

    (str | None, default: None ) –

    The configured MoE backend (e.g. deep_gemm_mega_moe).

Source code in vllm/distributed/eplb/eplb_utils.py
def override_envs_for_eplb(
    parallel_config: ParallelConfig,
    moe_backend: str | None = None,
) -> None:
    """
    Override environment variables for EPLB when specific conditions are met.

    Args:
        parallel_config: The parallel configuration object.
        moe_backend: The configured MoE backend (e.g. ``deep_gemm_mega_moe``).
    """
    is_data_parallel = parallel_config.data_parallel_size > 1
    is_eplb_enabled = parallel_config.enable_eplb
    is_mega_moe = moe_backend == "deep_gemm_mega_moe"
    is_nccl_based_eplb_communicator = parallel_config.eplb_config.communicator in (
        "torch_nccl",
        "pynccl",
    )

    # Override NCCL_MAX_CTAS to avoid hangs when EPLB's NCCL weight exchange
    # contends with MoE backend's cooperative-launch on GPU SMs.
    #
    # DeepGEMM Mega MoE uses cooperative launch, which tries to reserve a
    # large fraction of the GPU's SMs. If those SMs are occupied by NCCL,
    # the cooperative launch blocks until enough SMs are freed, causing a
    # deadlock. Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for
    # the cooperative kernel to launch and complete.
    if (
        is_data_parallel
        and is_eplb_enabled
        and is_nccl_based_eplb_communicator
        and is_mega_moe
    ):
        current_value_str = os.getenv("NCCL_MAX_CTAS")

        if current_value_str and current_value_str.isdigit():
            return

        override_value = 8
        os.environ["NCCL_MAX_CTAS"] = str(override_value)
        logger.info_once(
            f"EPLB: Setting NCCL_MAX_CTAS={override_value} "
            f"for expert parallel with NCCL-based EPLB communicator and "
            f"cooperative MoE backend (deep_gemm_mega_moe)",
            scope="global",
        )