Skip to content

llmcompressor.pipelines.sequential.helpers

Classes:

  • Subgraph

    Dataclass specifying an executable subgraph of a model graph

Functions:

  • handle_sequential_oom

    Catch ooms and suggest changing sequential targets

  • trace_subgraphs

    Trace a model to produce subgraphs, where each sequential target belongs to exactly

SequentialTracer

SequentialTracer(ancestors: set[Module])

Bases: HFTracer

Get a tracer specialized for the given model. The resulting tracer will not trace inside of sequential targets, nor any modules which are not call graph ancestors of sequential targets

Parameters:

  • ancestors (set[Module]) –

    modules which are ancestors of sequential targets

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def __init__(self, ancestors: set[Module]):
    self.ancestors = ancestors

    # skip any mask creation functions not already caught by the autowrapper
    super().__init__(autowrap_functions=_get_autowrap_functions())

Subgraph dataclass

Subgraph(
    graph: Graph,
    input_names: set[str],
    consumed_names: set[str],
    _code: PythonCode | None = None,
)

Dataclass specifying an executable subgraph of a model graph

Parameters:

  • graph (Graph) –

    subgraph of model graph

  • input_names (set[str]) –

    argument names of the compiled forward function

  • consumed_names (set[str]) –

    argument names which are not used by any subsequent subgraphs and can therefore be deleted from the intermediates cache

Methods:

  • forward

    Execute the operations within the subgraph

forward

forward(*args, **kwargs) -> dict[str, Any]

Execute the operations within the subgraph

Parameters:

  • \*args

    argument inputs to subgraph forward function

  • \**kwargs

    keyword inputs to subgraph forward function

Returns:

  • dict[str, Any]
Source code in src/llmcompressor/pipelines/sequential/helpers.py
def forward(self, *args, **kwargs) -> dict[str, Any]:
    """
    Execute the operations within the subgraph

    :param \\*args: argument inputs to subgraph forward function
    :param \\**kwargs: keyword inputs to subgraph forward function
    :return keyword outputs of subgraph forward function (non-consumed variables):
    """
    if self._code is None:
        self._code = self.graph.python_code("self")
        exec(self._code.src, self._code.globals)

    forward_fn = self._code.globals.get("forward")

    with append_autowrap_source_on_fail():
        return forward_fn(*args, **kwargs)

find_target_nodes

find_target_nodes(
    graph: GraphModule, targets: set[Module]
) -> set[Node]

Find all nodes whose execution is equivalent to executing the target modules. Note that these nodes are guaranteed to be treated as leaf nodes by SequentialTracer

Parameters:

  • graph (GraphModule) –

    graph containing target nodes

  • targets (set[Module]) –

    modules whose nodes are being searched for

Returns:

  • set[Node]

    set of all nodes which call the target modules

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def find_target_nodes(graph: GraphModule, targets: set[Module]) -> set[Node]:
    """
    Find all nodes whose execution is equivalent to executing the target modules.
    Note that these nodes are guaranteed to be treated as leaf nodes by SequentialTracer

    :param graph: graph containing target nodes
    :param targets: modules whose nodes are being searched for
    :return: set of all nodes which call the target modules
    """
    return set(
        node
        for node in graph.graph.nodes
        if node.op == "call_module" and graph.get_submodule(node.target) in targets
    )

get_sequential_ancestors

get_sequential_ancestors(
    model: Module, targets: set[Module]
) -> set[Module]

Find modules which are call graph ancestors of the given sequential targets

Parameters:

  • model (Module) –

    model containing sequential targets

  • targets (set[Module]) –

    sequential targets to find ancestors of

Returns:

  • set[Module]

    call graph ancestors of sequential targets

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def get_sequential_ancestors(model: Module, targets: set[Module]) -> set[Module]:
    """
    Find modules which are call graph ancestors of the given sequential targets

    :param model: model containing sequential targets
    :param targets: sequential targets to find ancestors of
    :return: call graph ancestors of sequential targets
    """
    ancestors = set()

    def is_ancestor(module: Module) -> bool:
        if module in ancestors or module in targets:
            return True

        # eagerly compute list in order to avoid early stopping and :. missing ancestors
        _is_ancestor = any([is_ancestor(child) for child in module.children()])
        if _is_ancestor:
            ancestors.add(module)

        return _is_ancestor

    is_ancestor(model)
    return ancestors

graph_is_well_formed

graph_is_well_formed(graph: Graph) -> bool

A graph is well formed if and only if nodeA in NodeB.users <=> nodeB in Node.A.all_input_nodes

Parameters:

  • graph (Graph) –

    graph being checked

Returns:

  • bool

    True if the graph is well formed, False otherwise

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def graph_is_well_formed(graph: Graph) -> bool:
    """
    A graph is well formed if and only if
    `nodeA in NodeB.users <=> nodeB in Node.A.all_input_nodes`

    :param graph: graph being checked
    :return: True if the graph is well formed, False otherwise
    """
    for node in graph.nodes:
        for user in node.users:
            if node not in user.all_input_nodes:
                return False

        for input_node in node.all_input_nodes:
            if node not in input_node.users:
                return False

        if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len(
            set(node.all_input_nodes)
        ):
            return False

    return True

handle_sequential_oom

handle_sequential_oom(func)

Catch ooms and suggest changing sequential targets

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def handle_sequential_oom(func):
    """Catch ooms and suggest changing sequential targets"""

    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except torch.OutOfMemoryError as e:
            raise torch.OutOfMemoryError(
                "Sequential pipeline ran out of memory. "
                "Please consider choosing a smaller module "
                "for `sequential_targets` argument, ex. 'Linear'"
            ) from e

    return wrapper

partition_graph

partition_graph(
    model: Module, partitions: list[list[Node]]
) -> list[Subgraph]

Convert each partition into a Subgraph. Each Subgraph returns a dictionary mapping of output node names to their computed values. Note that the consumed_names attribute of each Subgraph remains empty, to be later populated by trace_consumed_names

Parameters:

  • model (Module) –

    model which owns the produced Subgraphs

  • partitions (list[list[Node]]) –

    list of partitions, where each partition is a list of nodes belonging to that partition

Returns:

  • list[Subgraph]

    list of subgraphs in order of execution

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def partition_graph(model: Module, partitions: list[list[Node]]) -> list[Subgraph]:
    """
    Convert each partition into a Subgraph. Each Subgraph returns a dictionary mapping
    of output node names to their computed values. Note that the `consumed_names`
    attribute of each Subgraph remains empty, to be later populated by
    `trace_consumed_names`

    :param model: model which owns the produced Subgraphs
    :param partitions: list of partitions, where each partition is a list of nodes
        belonging to that partition
    :return: list of subgraphs in order of execution
    """
    subgraphs = []

    # create subgraphs
    for partition_nodes in partitions:
        # create a new graph for the partition
        graph = Graph(model)
        node_map = {}

        # add placeholders for inputs not in this subgraph. use set to deduplicate
        new_input_nodes = {
            input_node
            for node in partition_nodes
            for input_node in node.all_input_nodes
            if input_node not in partition_nodes and input_node.op
        }
        for input_node in new_input_nodes:
            node_map[input_node] = graph.placeholder(input_node.name)

        # add the nodes to subgraph
        for node in partition_nodes:
            node_map[node] = graph.node_copy(node, lambda n: node_map[n])

        # add an output node to collect all subgraph outputs into a dictionary
        if len(graph.find_nodes(op="output")) <= 0:
            output_dict = {
                node.name: node_map[node]
                for node in partition_nodes
                if any(user not in partition_nodes for user in node.users.keys())
            }
            graph.output(output_dict)

        # save the subgraph for this partition
        graph.lint()
        input_names = set(node.name for node in graph.nodes if node.op == "placeholder")
        subgraphs.append(
            Subgraph(
                graph=graph,
                input_names=input_names,
                consumed_names=set(),  # populated later
            )
        )

        assert graph_is_well_formed(graph)

    return subgraphs

populate_concrete_args

populate_concrete_args(
    model: Module, sample_input: dict
) -> dict

Creates concrete args which, unlike the equivalent function provided by transformers.utils.fx, creates default values for variadic arguments, which are needed by some models.

Parameters:

  • model (Module) –

    model being traced

  • sample_input (dict) –

    values used to symbolically trace the model. All arguments to the model.forward function which are not in the sample_input are considered concrete args

Returns:

  • dict

    dictionary mapping concrete argument names to their default values

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def populate_concrete_args(model: Module, sample_input: dict) -> dict:
    """
    Creates concrete args which, unlike the equivalent function provided by
    transformers.utils.fx, creates default values for variadic arguments, which are
    needed by some models.

    :param model: model being traced
    :param sample_input: values used to symbolically trace the model. All arguments
        to the model.forward function which are not in the sample_input are considered
        concrete args
    :return: dictionary mapping concrete argument names to their default values
    """
    sig = inspect.signature(model.forward)

    concrete_args = {}
    for parameter in sig.parameters.values():
        if parameter.name in sample_input:
            continue
        if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL:
            value = list()
        elif parameter.kind == inspect._ParameterKind.VAR_KEYWORD:
            value = dict()
        elif parameter.name == "use_cache":
            value = False
        else:
            value = parameter.default

        concrete_args[parameter.name] = value

    return concrete_args

topological_partition

topological_partition(
    graph: GraphModule,
    targets: set[Module],
    targets_per_subgraph: int = 1,
) -> list[list[Node]]

Partition the graph into partitions such that each target belongs to exactly one partition and executing each partition depends only on intermediate values produced by executing the partitions before it.

Parameters:

  • graph (GraphModule) –

    graph being partitioned

  • targets (set[Module]) –

    target modules which will be assigned to disjoint partitions

  • targets_per_subgraph (int, default: 1 ) –

    number of targets to include per subgraph

Returns:

  • list[list[Node]]

    list of partitions, where each partition is a list of nodes belonging to that partition

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def topological_partition(
    graph: GraphModule, targets: set[Module], targets_per_subgraph: int = 1
) -> list[list[Node]]:
    """
    Partition the graph into partitions such that each `target` belongs to exactly one
    partition and executing each partition depends only on intermediate values produced
    by executing the partitions before it.

    :param graph: graph being partitioned
    :param targets: target modules which will be assigned to disjoint partitions
    :param targets_per_subgraph: number of targets to include per subgraph
    :return: list of partitions, where each partition is a list of nodes belonging to
        that partition
    """
    assert graph_is_well_formed(graph.graph)
    target_nodes = find_target_nodes(graph, targets)

    if targets_per_subgraph <= 0:
        raise ValueError(
            "targets_per_subgraph is required to be greater than or equal to one"
        )

    partitions: list[list[Node]] = [[]]
    remaining_indegrees = {
        node: len([node for node in node.all_input_nodes if node.op != "get_attr"])
        for node in graph.graph.nodes
    }
    partition_index = 0  # global counter
    targets_seen = 0  # number of targets encountered so far

    # start with graph input nodes,
    # but delay the `get_attr` nodes as long as possible
    queue = deque(
        node
        for node in graph.graph.nodes
        if remaining_indegrees[node] == 0 and node.op != "get_attr"
    )
    while len(queue) > 0:
        node = queue.popleft()

        is_target = node in target_nodes
        if is_target:
            # put all nodes prior to first target into separate subgraph
            is_head = partition_index == 0 and len(partitions[partition_index]) > 0

            # finish creating subgraph when number of targets has been seen
            is_complete = targets_seen >= targets_per_subgraph

            if is_head or is_complete:
                partition_index += 1
                partitions.append([])
                targets_seen = 0

        # assign to partition
        partitions[partition_index].append(node)

        # increment after assignment so is_complete fires after the target is placed
        if is_target:
            targets_seen += 1

        # recurse on last indegree only in order to guarantee that
        # the node is assigned to maximal partition
        for user in node.users:
            remaining_indegrees[user] -= 1
            if remaining_indegrees[user] == 0:
                queue.append(user)

    # an ideal implementation would involve implicitly consolidating partition indices
    # so that each node is assigned to the maximum partition possible (in order to delay
    # execution as long as possible), but saving these nodes for last covers the most
    # common and costly case (get_attr)
    for node in graph.graph.find_nodes(op="get_attr"):
        user_partitions = []
        for user in node.users:
            for index in range(len(partitions)):
                if user in partitions[index]:
                    user_partitions.append(index)
                    break

        # workaround
        if len(user_partitions):
            partition_index = min(user_partitions)
            partitions[partition_index].insert(0, node)

    return partitions

trace_consumed_names

trace_consumed_names(subgraphs: list[Subgraph])

Populate the consumed_names attribute of each Subgraph according to when inputs are last used in order to vacate the intermediates cache and save memory

Parameters:

  • subgraphs (list[Subgraph]) –

    list of subgraphs with empty consumed_names attributes

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def trace_consumed_names(subgraphs: list[Subgraph]):
    """
    Populate the `consumed_names` attribute of each Subgraph according to when inputs
    are last used in order to vacate the `intermediates` cache and save memory

    :param subgraphs: list of subgraphs with empty `consumed_names` attributes
    """
    # populate consumed_names according to when inputs are last used
    # in order to vacate the `intermediates` cache and save memory
    all_input_names = set().union(*(subgraph.input_names for subgraph in subgraphs))
    for input_name in all_input_names:
        for subgraph in reversed(subgraphs):
            if input_name in subgraph.input_names:
                subgraph.consumed_names.add(input_name)
                break
        else:
            raise ValueError(f"Could not find input name {input_name} in subgraphs")

trace_subgraphs

trace_subgraphs(
    model: PreTrainedModel,
    sample_input: dict[str, Any],
    sequential_targets: list[str],
    ignore: list[str],
    targets_per_subgraph: int = 1,
) -> list[Subgraph]

Trace a model to produce subgraphs, where each sequential target belongs to exactly one subgraph and where executing each subgraph in order is equivalent to executing the original model

Parameters:

  • model (PreTrainedModel) –

    model being traced

  • sample_input (dict[str, Any]) –

    inputs whose values will change during execution but whose len, bool, and contains values are assumed constant across batches

  • sequential_targets (list[str]) –

    list of patterns matching sequential targets

  • ignore (list[str]) –

    function and method names to skip during tracing

  • targets_per_subgraph (int, default: 1 ) –

    number of targets to include per subgraph

Returns:

  • list[Subgraph]

    a list of Subgraphs in order of execution

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def trace_subgraphs(
    model: PreTrainedModel,
    sample_input: dict[str, Any],
    sequential_targets: list[str],
    ignore: list[str],
    targets_per_subgraph: int = 1,
) -> list[Subgraph]:
    """
    Trace a model to produce subgraphs, where each sequential target belongs to exactly
    one subgraph and where executing each subgraph in order is equivalent to executing
    the original model

    :param model: model being traced
    :param sample_input: inputs whose values will change during execution but whose
        __len__, __bool__, and __contains__ values are assumed constant across batches
    :param sequential_targets: list of patterns matching sequential targets
    :param ignore: function and method names to skip during tracing
    :param targets_per_subgraph: number of targets to include per subgraph
    :return: a list of Subgraphs in order of execution
    """
    # find modules
    targets = set(
        module for _, module in match_named_modules(model, sequential_targets)
    )
    ancestors = get_sequential_ancestors(model, targets)

    # initialize arguments
    tracer = SequentialTracer(ancestors)
    concrete_args = populate_concrete_args(model, sample_input)

    with contextlib.ExitStack() as stack:
        # calibration context
        stack.enter_context(calibration_forward_context(model))
        stack.enter_context(HooksMixin.disable_hooks())

        # flags useful for tracing
        stack.enter_context(patch_attr(model.config, "_attn_implementation", "eager"))
        stack.enter_context(patch_attr(torch.compiler, "_is_compiling_flag", True))

        # autowrap forwards
        stack.enter_context(autowrap_forwards(ancestors, ignore))

        # avoid bug where pytorch cannot handle wrapped root functions
        unwrapped = inspect.unwrap(model.forward).__get__(model)
        stack.enter_context(patch_attr(model, "forward", unwrapped))
        stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__))
        assert isinstance(model.forward, MethodType)
        assert isinstance(type(model).forward, FunctionType)

        # avoid device movement during tracing
        stack.enter_context(disable_onloading())

        with append_autowrap_source_on_fail():
            graph = GraphModule(
                model,
                tracer.trace(
                    model,
                    dummy_inputs=sample_input,
                    concrete_args=concrete_args,
                    complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
                    # bug in trace throws an error for variadic
                    # args and kwargs in function signature
                ),
            )

    # copy metadata
    graph.config = model.config
    graph.class_for_deserialization = model.__class__
    graph.device = model.device

    # perform subgraph partition
    partitions = topological_partition(graph, targets, targets_per_subgraph)
    subgraphs = partition_graph(model, partitions)
    trace_consumed_names(subgraphs)

    # As currently implemented, `topological_partition` generates an extra subgraph at
    # the beginning which does not contain a target. This adds a little more runtime,
    # and could be folded into the first subgraph in the future
    if len(subgraphs) != len(targets) + 1:
        logger.warning(
            f"Expected {len(targets)} subgraphs, but only traced {len(subgraphs)}. "
            "This is likely due to having wrapped code which calls sequential targets"
        )

    return subgraphs