Skip to content

llmcompressor.pipelines

Compression pipelines for orchestrating different compression strategies.

Provides various compression pipelines including basic, sequential, independent, layer-sequential, and data-free approaches. Each pipeline coordinates different compression techniques and workflows for optimal model optimization based on specific requirements and constraints.

Modules:

Classes:

Functions:

  • handle_sequential_oom

    Catch ooms and suggest changing sequential targets

  • trace_subgraphs

    Trace a model to produce subgraphs, where each sequential target belongs to exactly

BasicPipeline

CalibrationPipeline

Bases: ABC, RegistryMixin

Methods:

  • from_modifiers

    Infer which calibration pipeline to use based on the available modifiers and

from_modifiers classmethod

from_modifiers(
    modifiers: list[Modifier], user: str | None = None
) -> CalibrationPipeline

Infer which calibration pipeline to use based on the available modifiers and any user specifications

Parameters:

  • modifiers (list[Modifier]) –

    modifiers to apply to model

  • user (str | None, default: None ) –

    pipeline name passed by user

Returns:

  • CalibrationPipeline

    CalibrationPipeline instance to be called with data (if not datafree)

Source code in src/llmcompressor/pipelines/registry.py
@classmethod
def from_modifiers(
    cls, modifiers: list[Modifier], user: str | None = None
) -> "CalibrationPipeline":
    """
    Infer which calibration pipeline to use based on the available modifiers and
    any user specifications

    :param modifiers: modifiers to apply to model
    :param user: pipeline name passed by user
    :return: CalibrationPipeline instance to be called with data (if not datafree)
    """
    user = standardize_lookup_name(user) if user else None
    inferred = standardize_lookup_name(cls._infer_pipeline(modifiers))
    independent = standardize_lookup_name("independent")

    if user == independent:
        inferred = independent

    if user is not None and user != inferred:
        logger.warning(
            f"Calibration pipeline is set to `{user}`, but it is recommended to "
            f"use `{inferred}`"
        )

    pipeline = user or inferred
    return cls.load_from_registry(pipeline)

DataFreePipeline

IndependentPipeline

SequentialPipeline

Subgraph dataclass

Subgraph(
    graph: Graph,
    input_names: set[str],
    consumed_names: set[str],
    _code: PythonCode | None = None,
)

Dataclass specifying an executable subgraph of a model graph

Parameters:

  • graph (Graph) –

    subgraph of model graph

  • input_names (set[str]) –

    argument names of the compiled forward function

  • consumed_names (set[str]) –

    argument names which are not used by any subsequent subgraphs and can therefore be deleted from the intermediates cache

Methods:

  • forward

    Execute the operations within the subgraph

forward

forward(*args, **kwargs) -> dict[str, Any]

Execute the operations within the subgraph

Parameters:

  • \*args

    argument inputs to subgraph forward function

  • \**kwargs

    keyword inputs to subgraph forward function

Returns:

  • dict[str, Any]
Source code in src/llmcompressor/pipelines/sequential/helpers.py
def forward(self, *args, **kwargs) -> dict[str, Any]:
    """
    Execute the operations within the subgraph

    :param \\*args: argument inputs to subgraph forward function
    :param \\**kwargs: keyword inputs to subgraph forward function
    :return keyword outputs of subgraph forward function (non-consumed variables):
    """
    if self._code is None:
        self._code = self.graph.python_code("self")
        exec(self._code.src, self._code.globals)

    forward_fn = self._code.globals.get("forward")

    with append_autowrap_source_on_fail():
        return forward_fn(*args, **kwargs)

handle_sequential_oom

handle_sequential_oom(func)

Catch ooms and suggest changing sequential targets

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def handle_sequential_oom(func):
    """Catch ooms and suggest changing sequential targets"""

    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except torch.OutOfMemoryError as e:
            raise torch.OutOfMemoryError(
                "Sequential pipeline ran out of memory. "
                "Please consider choosing a smaller module "
                "for `sequential_targets` argument, ex. 'Linear'"
            ) from e

    return wrapper

trace_subgraphs

trace_subgraphs(
    model: PreTrainedModel,
    sample_input: dict[str, Any],
    sequential_targets: list[str],
    ignore: list[str],
    targets_per_subgraph: int = 1,
) -> list[Subgraph]

Trace a model to produce subgraphs, where each sequential target belongs to exactly one subgraph and where executing each subgraph in order is equivalent to executing the original model

Parameters:

  • model (PreTrainedModel) –

    model being traced

  • sample_input (dict[str, Any]) –

    inputs whose values will change during execution but whose len, bool, and contains values are assumed constant across batches

  • sequential_targets (list[str]) –

    list of patterns matching sequential targets

  • ignore (list[str]) –

    function and method names to skip during tracing

  • targets_per_subgraph (int, default: 1 ) –

    number of targets to include per subgraph

Returns:

  • list[Subgraph]

    a list of Subgraphs in order of execution

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def trace_subgraphs(
    model: PreTrainedModel,
    sample_input: dict[str, Any],
    sequential_targets: list[str],
    ignore: list[str],
    targets_per_subgraph: int = 1,
) -> list[Subgraph]:
    """
    Trace a model to produce subgraphs, where each sequential target belongs to exactly
    one subgraph and where executing each subgraph in order is equivalent to executing
    the original model

    :param model: model being traced
    :param sample_input: inputs whose values will change during execution but whose
        __len__, __bool__, and __contains__ values are assumed constant across batches
    :param sequential_targets: list of patterns matching sequential targets
    :param ignore: function and method names to skip during tracing
    :param targets_per_subgraph: number of targets to include per subgraph
    :return: a list of Subgraphs in order of execution
    """
    # find modules
    targets = set(
        module for _, module in match_named_modules(model, sequential_targets)
    )
    ancestors = get_sequential_ancestors(model, targets)

    # initialize arguments
    tracer = SequentialTracer(ancestors)
    concrete_args = populate_concrete_args(model, sample_input)

    with contextlib.ExitStack() as stack:
        # calibration context
        stack.enter_context(calibration_forward_context(model))
        stack.enter_context(HooksMixin.disable_hooks())

        # flags useful for tracing
        stack.enter_context(patch_attr(model.config, "_attn_implementation", "eager"))
        stack.enter_context(patch_attr(torch.compiler, "_is_compiling_flag", True))

        # autowrap forwards
        stack.enter_context(autowrap_forwards(ancestors, ignore))

        # avoid bug where pytorch cannot handle wrapped root functions
        unwrapped = inspect.unwrap(model.forward).__get__(model)
        stack.enter_context(patch_attr(model, "forward", unwrapped))
        stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__))
        assert isinstance(model.forward, MethodType)
        assert isinstance(type(model).forward, FunctionType)

        # avoid device movement during tracing
        stack.enter_context(disable_onloading())

        with append_autowrap_source_on_fail():
            graph = GraphModule(
                model,
                tracer.trace(
                    model,
                    dummy_inputs=sample_input,
                    concrete_args=concrete_args,
                    complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
                    # bug in trace throws an error for variadic
                    # args and kwargs in function signature
                ),
            )

    # copy metadata
    graph.config = model.config
    graph.class_for_deserialization = model.__class__
    graph.device = model.device

    # perform subgraph partition
    partitions = topological_partition(graph, targets, targets_per_subgraph)
    subgraphs = partition_graph(model, partitions)
    trace_consumed_names(subgraphs)

    # As currently implemented, `topological_partition` generates an extra subgraph at
    # the beginning which does not contain a target. This adds a little more runtime,
    # and could be folded into the first subgraph in the future
    if len(subgraphs) != len(targets) + 1:
        logger.warning(
            f"Expected {len(targets)} subgraphs, but only traced {len(subgraphs)}. "
            "This is likely due to having wrapped code which calls sequential targets"
        )

    return subgraphs