llmcompressor.modifiers.pruning.reap.utils
Utilities for REAP: MoE detection, saliency tracking, and expert pruning.
Classes:
-
REAPSaliencyTracker–Accumulates the REAP saliency
S_j = mean(g_j * ||f_j||_2)per expert,
Functions:
-
prune_moe_layer–Structurally prune a MoE block to keep only
retainedexperts: slice the
REAPSaliencyTracker
Accumulates the REAP saliency S_j = mean(g_j * ||f_j||_2) per expert,
averaged over the tokens routed to expert j, where g_j is the router
gate weight and f_j is the expert output.
Accumulators live on the device of the incoming data (allocated lazily) to
avoid a host sync on every update; they are moved to the host only when
mean_saliency is read.
Methods:
-
compute_retained_experts–Select which experts to keep, dropping the lowest-saliency ones.
-
update–Vectorized accumulation over one batch.
Source code in src/llmcompressor/modifiers/pruning/reap/utils.py
compute_retained_experts
compute_retained_experts(
n_experts_to_drop: int,
n_experts_to_drop_per_group: int | None,
moe_attrs: MoeModelAttrs,
) -> list[int]
Select which experts to keep, dropping the lowest-saliency ones.
Source code in src/llmcompressor/modifiers/pruning/reap/utils.py
update
Vectorized accumulation over one batch.
Parameters:
-
topk_indices(Tensor) –[num_tokens, top_k]selected expert ids -
topk_weights(Tensor) –[num_tokens, top_k]gate weight per selection -
expert_norms_dict(dict[int, Tensor]) –dict mapping expert_idx to output norms
[num_routed_tokens]for tokens routed to that expert (sparse routing: experts only see tokens the router sent to them)
Source code in src/llmcompressor/modifiers/pruning/reap/utils.py
prune_moe_layer
prune_moe_layer(
model: Module,
layer_name: str,
retained: list[int],
moe_attrs: MoeModelAttrs,
) -> list[int]
Structurally prune a MoE block to keep only retained experts: slice the
expert ModuleList, shrink the router, and update expert-count attributes.
Offload-safe: experts are kept as existing module objects (offload state
travels with them) and the small router is resized under
align_module_device.