Skip to content

llmcompressor.modifiers.pruning.reap.utils

Utilities for REAP: MoE detection, saliency tracking, and expert pruning.

Classes:

  • REAPSaliencyTracker

    Accumulates the REAP saliency S_j = mean(g_j * ||f_j||_2) per expert,

Functions:

  • prune_moe_layer

    Structurally prune a MoE block to keep only retained experts: slice the

REAPSaliencyTracker

REAPSaliencyTracker(num_experts: int)

Accumulates the REAP saliency S_j = mean(g_j * ||f_j||_2) per expert, averaged over the tokens routed to expert j, where g_j is the router gate weight and f_j is the expert output.

Accumulators live on the device of the incoming data (allocated lazily) to avoid a host sync on every update; they are moved to the host only when mean_saliency is read.

Methods:

Source code in src/llmcompressor/modifiers/pruning/reap/utils.py
def __init__(self, num_experts: int):
    self.num_experts = num_experts
    self.sum_saliency: torch.Tensor | None = None
    self.count: torch.Tensor | None = None

compute_retained_experts

compute_retained_experts(
    n_experts_to_drop: int,
    n_experts_to_drop_per_group: int | None,
    moe_attrs: MoeModelAttrs,
) -> list[int]

Select which experts to keep, dropping the lowest-saliency ones.

Source code in src/llmcompressor/modifiers/pruning/reap/utils.py
def compute_retained_experts(
    self,
    n_experts_to_drop: int,
    n_experts_to_drop_per_group: int | None,
    moe_attrs: MoeModelAttrs,
) -> list[int]:
    """Select which experts to keep, dropping the lowest-saliency ones."""
    saliency = self.mean_saliency

    if n_experts_to_drop_per_group is None:
        _, drop_indices = torch.topk(saliency, n_experts_to_drop, largest=False)
        drop_set = set(int(i) for i in drop_indices.tolist())
        retained = [i for i in range(self.num_experts) if i not in drop_set]
    else:
        retained: list[int] = []
        for g in range(moe_attrs.n_group):
            lo = g * moe_attrs.group_size
            grp = saliency[lo : lo + moe_attrs.group_size]
            _, drop_local = torch.topk(
                grp, n_experts_to_drop_per_group, largest=False
            )
            drop_set = {lo + int(i) for i in drop_local.tolist()}
            retained.extend(
                i for i in range(lo, lo + moe_attrs.group_size) if i not in drop_set
            )

    return retained

update

update(
    topk_indices: Tensor,
    topk_weights: Tensor,
    expert_norms_dict: dict[int, Tensor],
)

Vectorized accumulation over one batch.

Parameters:

  • topk_indices (Tensor) –

    [num_tokens, top_k] selected expert ids

  • topk_weights (Tensor) –

    [num_tokens, top_k] gate weight per selection

  • expert_norms_dict (dict[int, Tensor]) –

    dict mapping expert_idx to output norms [num_routed_tokens] for tokens routed to that expert (sparse routing: experts only see tokens the router sent to them)

Source code in src/llmcompressor/modifiers/pruning/reap/utils.py
@torch.no_grad()
def update(
    self,
    topk_indices: torch.Tensor,
    topk_weights: torch.Tensor,
    expert_norms_dict: dict[int, torch.Tensor],
):
    """
    Vectorized accumulation over one batch.

    :param topk_indices: ``[num_tokens, top_k]`` selected expert ids
    :param topk_weights: ``[num_tokens, top_k]`` gate weight per selection
    :param expert_norms_dict: dict mapping expert_idx to output norms
        ``[num_routed_tokens]`` for tokens routed to that expert (sparse
        routing: experts only see tokens the router sent to them)
    """
    if not expert_norms_dict:
        return

    self._ensure(next(iter(expert_norms_dict.values())).device)

    if get_calibrate_all_experts_flag():
        stacked_norms = torch.stack(
            [expert_norms_dict[i] for i in range(self.num_experts)], dim=1
        )
        flat_idx = topk_indices.reshape(-1).to(torch.long)
        gathered_norms = stacked_norms.gather(1, topk_indices.to(torch.long))
        contrib = topk_weights.to(torch.float64) * gathered_norms.to(torch.float64)
        self.sum_saliency.index_add_(0, flat_idx, contrib.reshape(-1))
        self.count.index_add_(
            0, flat_idx, torch.ones_like(flat_idx, dtype=torch.float64)
        )
    else:
        # Flatten in (slot, token) order to match torch.where order
        # in LinearExperts2D torch.where scans row-by-row (slot 0 all
        # tokens, then slot 1 all tokens, etc.)
        flat_idx = topk_indices.T.reshape(-1).to(torch.long)
        flat_weights = topk_weights.T.reshape(-1).to(torch.float64)

        # Build flat norms tensor aligned with flat_idx
        flat_norms = torch.zeros_like(flat_weights)
        for expert_idx, expert_norms in expert_norms_dict.items():
            mask = flat_idx == expert_idx

            # Assertion check: number of norms must match number of routed tokens
            assert len(expert_norms) == mask.sum().item(), (
                f"REAP saliency tracker: expert {expert_idx} has "
                f"{len(expert_norms)} norms but router sent "
                f"{mask.sum().item()} tokens to it. This indicates a bug in "
                f"the expert hook or routing extraction logic."
            )

            flat_norms[mask] = expert_norms.to(torch.float64)

        # Vectorized computation using index_add_
        contrib = flat_weights * flat_norms
        self.sum_saliency.index_add_(0, flat_idx, contrib)
        self.count.index_add_(
            0, flat_idx, torch.ones_like(flat_idx, dtype=torch.float64)
        )

prune_moe_layer

prune_moe_layer(
    model: Module,
    layer_name: str,
    retained: list[int],
    moe_attrs: MoeModelAttrs,
) -> list[int]

Structurally prune a MoE block to keep only retained experts: slice the expert ModuleList, shrink the router, and update expert-count attributes. Offload-safe: experts are kept as existing module objects (offload state travels with them) and the small router is resized under align_module_device.

Source code in src/llmcompressor/modifiers/pruning/reap/utils.py
def prune_moe_layer(
    model: nn.Module,
    layer_name: str,
    retained: list[int],
    moe_attrs: MoeModelAttrs,
) -> list[int]:
    """
    Structurally prune a MoE block to keep only ``retained`` experts: slice the
    expert ``ModuleList``, shrink the router, and update expert-count attributes.
    Offload-safe: experts are kept as existing module objects (offload state
    travels with them) and the small router is resized under
    ``align_module_device``.
    """
    moe_block = model.get_submodule(layer_name)
    router = getattr(moe_block, moe_attrs.router_attr)
    experts = getattr(moe_block, moe_attrs.experts_attr)

    # Preserve non-expert modules (e.g., act_fn in LinearExperts2D)
    # These are modules that are not instances of ExpertMLP subclasses
    non_expert_modules = {}
    for key, module in experts._modules.items():
        if not isinstance(module, ExpertMLP):
            non_expert_modules[key] = module

    # Rebuild with retained experts
    new_modules = OrderedDict(
        ((str(i), experts[pos]) for i, pos in enumerate(retained))
    )

    # Re-add non-expert modules
    new_modules.update(non_expert_modules)

    experts._modules = new_modules
    experts.num_experts = len(retained)

    _prune_router(router, retained)

    # Update num_experts for any other modules in the layer that may track it
    for holder in (moe_block, router):
        for key in NUM_EXPERTS_MODULE_KEYS:
            if isinstance(getattr(holder, key, None), int):
                setattr(holder, key, len(retained))

    return retained