Skip to content

llmcompressor.modifiers.transform.spinquant

Event dataclass

A class for defining an event that can be triggered during sparsification.

Parameters:

Name Type Description Default
type_ Optional[EventType]

The type of event.

None
steps_per_epoch Optional[int]

The number of steps per epoch.

None
batches_per_step Optional[int]

The number of batches per step where step is an optimizer step invocation. For most pathways, these are the same. See the invocations_per_step parameter for more details when they are not.

None
invocations_per_step int

The number of invocations of the step wrapper before optimizer.step was called. Generally can be left as 1 (default). For older amp pathways, this is the number of times the scaler wrapper was invoked before the wrapped optimizer step function was called to handle accumulation in fp16.

1
global_step int

The current global step.

0
global_batch int

The current global batch.

0
Source code in llmcompressor/core/events/event.py
@dataclass
class Event:
    """
    A class for defining an event that can be triggered during sparsification.

    :param type_: The type of event.
    :type type_: Optional[EventType]
    :param steps_per_epoch: The number of steps per epoch.
    :type steps_per_epoch: Optional[int]
    :param batches_per_step: The number of batches per step where step is an
        optimizer step invocation. For most pathways, these are the same.
        See the invocations_per_step parameter for more details when they are not.
    :type batches_per_step: Optional[int]
    :param invocations_per_step: The number of invocations of the step wrapper
        before optimizer.step was called. Generally can be left as 1 (default).
        For older amp pathways, this is the number of times the scaler wrapper
        was invoked before the wrapped optimizer step function was called to
        handle accumulation in fp16.
    :type invocations_per_step: Optional[int]
    :param global_step: The current global step.
    :type global_step: int
    :param global_batch: The current global batch.
    :type global_batch: int
    """

    type_: Optional[EventType] = None
    steps_per_epoch: Optional[int] = None
    batches_per_step: Optional[int] = None
    invocations_per_step: int = 1
    global_step: int = 0
    global_batch: int = 0

    @property
    def epoch_based(self) -> bool:
        """
        Determines if the event is based on epochs.

        :return: True if the event is based on epochs, False otherwise.
        :rtype: bool
        """
        return self.steps_per_epoch is not None

    @property
    def epoch(self) -> int:
        """
        Calculates the current epoch.

        :raises ValueError: if the event is not epoch based.
        :return: The current epoch.
        :rtype: int
        """
        if not self.epoch_based:
            logger.error("Attempt to access epoch for a non-epoch based event")
            raise ValueError("Event is not epoch based")
        return self.global_step // self.steps_per_epoch

    @property
    def epoch_full(self) -> float:
        """
        Calculates the current epoch with the fraction of the current step.

        :raises ValueError: if the event is not epoch based.
        :return: The current epoch with the fraction of the current step.
        :rtype: float
        """
        if not self.epoch_based:
            logger.error("Attempt to access epoch_full for a non-epoch based event")
            raise ValueError("Event is not epoch based")
        return self.global_step / float(self.steps_per_epoch)

    @property
    def epoch_step(self) -> int:
        """
        Calculates the current step within the current epoch.

        :raises ValueError: if the event is not epoch based.
        :return: The current step within the current epoch.
        :rtype: int
        """
        if not self.epoch_based:
            logger.error("Attempt to access epoch_step for a non-epoch based event")
            raise ValueError("Event is not epoch based")
        return self.global_step % self.steps_per_epoch

    @property
    def epoch_batch(self) -> int:
        """
        Calculates the current batch within the current epoch.

        :raises ValueError: if the event is not epoch based.
        :return: The current batch within the current epoch.
        :rtype: int
        """
        if not self.epoch_based:
            logger.error("Attempt to access epoch_batch for a non-epoch based event")
            raise ValueError("Event is not epoch based")
        batches_per_epoch = (
            self.steps_per_epoch * self.batches_per_step
            if self.batches_per_step
            else self.steps_per_epoch
        )
        return self.global_batch % batches_per_epoch

    @property
    def current_index(self) -> float:
        """
        Calculates the current index of the event.

        :raises ValueError: if the event is not epoch based or
            if the steps per epoch are too many.
        :return: The current index of the event, which is either the global step
            or the epoch with the fraction of the current step.
        :rtype: float
        """
        if not self.epoch_based:
            return self.global_step
        epoch_full = self.epoch_full
        if epoch_full - self.epoch > 1.0:
            logger.error("Too many steps per epoch for epoch based event")
            raise ValueError("Too many steps per epoch for epoch based event")
        return epoch_full

    @current_index.setter
    def current_index(self, value: float):
        """
        Sets the current index of the event.

        :param value: The current index value.
        :type value: float
        """
        logger.debug("Setting current index: {}", value)
        if not self.epoch_based:
            self.global_step = int(value)
            self.global_batch = (
                self.global_step
                if self.batches_per_step is None or self.batches_per_step < 2
                else self.global_step * self.batches_per_step
            )
        else:
            self.global_step = int(value * self.steps_per_epoch)
            self.global_batch = (
                self.global_step
                if self.batches_per_step is None or self.batches_per_step < 2
                else self.global_step * self.batches_per_step
            )

    def should_update(
        self, start: Optional[float], end: Optional[float], update: Optional[float]
    ) -> bool:
        """
        Determines if the event should trigger an update.

        :param start: The start index to check against, set to None to ignore start.
        :type start: Optional[float]
        :param end: The end index to check against, set to None to ignore end.
        :type end: Optional[float]
        :param update: The update interval, set to None or 0.0 to always update,
            otherwise must be greater than 0.0, defaults to None.
        :type update: Optional[float]
        :return: True if the event should trigger an update, False otherwise.
        :rtype: bool
        """
        current = self.current_index
        logger.debug(
            "Checking if event should update: "
            "current_index={}, start={}, end={}, update={}",
            current,
            start,
            end,
            update,
        )
        if start is not None and current < start:
            return False
        if end is not None and current > end:
            return False
        return update is None or update <= 0.0 or current % update < 1e-10

    def new_instance(self, **kwargs) -> "Event":
        """
        Creates a new instance of the event with the provided keyword arguments.

        :param kwargs: Keyword arguments to set in the new instance.
        :return: A new instance of the event with the provided kwargs.
        :rtype: Event
        """
        logger.debug("Creating new instance of event with kwargs: {}", kwargs)
        instance = deepcopy(self)
        for key, value in kwargs.items():
            setattr(instance, key, value)
        return instance

current_index property writable

Calculates the current index of the event.

Returns:

Type Description
float

The current index of the event, which is either the global step or the epoch with the fraction of the current step.

Raises:

Type Description
ValueError

if the event is not epoch based or if the steps per epoch are too many.

epoch property

Calculates the current epoch.

Returns:

Type Description
int

The current epoch.

Raises:

Type Description
ValueError

if the event is not epoch based.

epoch_based property

Determines if the event is based on epochs.

Returns:

Type Description
bool

True if the event is based on epochs, False otherwise.

epoch_batch property

Calculates the current batch within the current epoch.

Returns:

Type Description
int

The current batch within the current epoch.

Raises:

Type Description
ValueError

if the event is not epoch based.

epoch_full property

Calculates the current epoch with the fraction of the current step.

Returns:

Type Description
float

The current epoch with the fraction of the current step.

Raises:

Type Description
ValueError

if the event is not epoch based.

epoch_step property

Calculates the current step within the current epoch.

Returns:

Type Description
int

The current step within the current epoch.

Raises:

Type Description
ValueError

if the event is not epoch based.

new_instance(**kwargs)

Creates a new instance of the event with the provided keyword arguments.

Parameters:

Name Type Description Default
kwargs

Keyword arguments to set in the new instance.

{}

Returns:

Type Description
Event

A new instance of the event with the provided kwargs.

Source code in llmcompressor/core/events/event.py
def new_instance(self, **kwargs) -> "Event":
    """
    Creates a new instance of the event with the provided keyword arguments.

    :param kwargs: Keyword arguments to set in the new instance.
    :return: A new instance of the event with the provided kwargs.
    :rtype: Event
    """
    logger.debug("Creating new instance of event with kwargs: {}", kwargs)
    instance = deepcopy(self)
    for key, value in kwargs.items():
        setattr(instance, key, value)
    return instance

should_update(start, end, update)

Determines if the event should trigger an update.

Parameters:

Name Type Description Default
start Optional[float]

The start index to check against, set to None to ignore start.

required
end Optional[float]

The end index to check against, set to None to ignore end.

required
update Optional[float]

The update interval, set to None or 0.0 to always update, otherwise must be greater than 0.0, defaults to None.

required

Returns:

Type Description
bool

True if the event should trigger an update, False otherwise.

Source code in llmcompressor/core/events/event.py
def should_update(
    self, start: Optional[float], end: Optional[float], update: Optional[float]
) -> bool:
    """
    Determines if the event should trigger an update.

    :param start: The start index to check against, set to None to ignore start.
    :type start: Optional[float]
    :param end: The end index to check against, set to None to ignore end.
    :type end: Optional[float]
    :param update: The update interval, set to None or 0.0 to always update,
        otherwise must be greater than 0.0, defaults to None.
    :type update: Optional[float]
    :return: True if the event should trigger an update, False otherwise.
    :rtype: bool
    """
    current = self.current_index
    logger.debug(
        "Checking if event should update: "
        "current_index={}, start={}, end={}, update={}",
        current,
        start,
        end,
        update,
    )
    if start is not None and current < start:
        return False
    if end is not None and current > end:
        return False
    return update is None or update <= 0.0 or current % update < 1e-10

EventType

Bases: Enum

An Enum for defining the different types of events that can be triggered during model compression lifecycles. The purpose of each EventType is to trigger the corresponding modifier callback during training or post training pipelines.

Parameters:

Name Type Description Default
INITIALIZE

Event type for initialization.

required
FINALIZE

Event type for finalization.

required
BATCH_START

Event type for the start of a batch.

required
LOSS_CALCULATED

Event type for when loss is calculated.

required
BATCH_END

Event type for the end of a batch.

required
CALIBRATION_EPOCH_START

Event type for the start of a calibration epoch.

required
SEQUENTIAL_EPOCH_END

Event type for the end of a layer calibration epoch, specifically used by src/llmcompressor/pipelines/sequential/pipeline.py

required
CALIBRATION_EPOCH_END

Event type for the end of a calibration epoch.

required
OPTIM_PRE_STEP

Event type for pre-optimization step.

required
OPTIM_POST_STEP

Event type for post-optimization step.

required
Source code in llmcompressor/core/events/event.py
@unique
class EventType(Enum):
    """
    An Enum for defining the different types of events that can be triggered
    during model compression lifecycles.
    The purpose of each EventType is to trigger the corresponding
    modifier callback during training or post training pipelines.

    :param INITIALIZE: Event type for initialization.
    :param FINALIZE: Event type for finalization.
    :param BATCH_START: Event type for the start of a batch.
    :param LOSS_CALCULATED: Event type for when loss is calculated.
    :param BATCH_END: Event type for the end of a batch.
    :param CALIBRATION_EPOCH_START: Event type for the start of a calibration epoch.
    :param SEQUENTIAL_EPOCH_END: Event type for the end of a layer calibration epoch,
        specifically used by `src/llmcompressor/pipelines/sequential/pipeline.py`
    :param CALIBRATION_EPOCH_END: Event type for the end of a calibration epoch.
    :param OPTIM_PRE_STEP: Event type for pre-optimization step.
    :param OPTIM_POST_STEP: Event type for post-optimization step.
    """

    # training lifecycle
    INITIALIZE = "initialize"
    FINALIZE = "finalize"

    # batch lifecycle
    BATCH_START = "batch_start"
    LOSS_CALCULATED = "loss_calculated"
    BATCH_END = "batch_end"

    # calibration lifecycle
    CALIBRATION_EPOCH_START = "calibration_epoch_start"
    SEQUENTIAL_EPOCH_END = "sequential_epoch_end"
    CALIBRATION_EPOCH_END = "calibration_epoch_end"

    # step lifecycle
    OPTIM_PRE_STEP = "optim_pre_step"
    OPTIM_POST_STEP = "optim_post_step"

Modifier

Bases: ModifierInterface, HooksMixin

A base class for all modifiers to inherit from. Modifiers are used to modify the training process for a model. Defines base attributes and methods available to all modifiers

Lifecycle: 1. initialize 2. on_event -> * on_start if self.start <= event.current_index * on_end if self.end >= event.current_index 5. finalize

Parameters:

Name Type Description Default
index

The index of the modifier in the list of modifiers for the model

required
group

The group name for the modifier

required
start

The start step for the modifier

required
end

The end step for the modifier

required
update

The update step for the modifier

required
Source code in llmcompressor/modifiers/modifier.py
class Modifier(ModifierInterface, HooksMixin):
    """
    A base class for all modifiers to inherit from.
    Modifiers are used to modify the training process for a model.
    Defines base attributes and methods available to all modifiers

    Lifecycle:
    1. initialize
    2. on_event ->
        * on_start if self.start <= event.current_index
        * on_end if self.end >= event.current_index
    5. finalize

    :param index: The index of the modifier in the list of modifiers
        for the model
    :param group: The group name for the modifier
    :param start: The start step for the modifier
    :param end: The end step for the modifier
    :param update: The update step for the modifier
    """

    model_config = ConfigDict(extra="forbid")

    index: Optional[int] = None
    group: Optional[str] = None
    start: Optional[float] = None
    end: Optional[float] = None
    update: Optional[float] = None

    initialized_: bool = False
    finalized_: bool = False
    started_: bool = False
    ended_: bool = False

    @property
    def initialized(self) -> bool:
        """
        :return: True if the modifier has been initialized
        """
        return self.initialized_

    @property
    def finalized(self) -> bool:
        """
        :return: True if the modifier has been finalized
        """
        return self.finalized_

    def initialize(self, state: State, **kwargs):
        """
        Initialize the modifier for the given model and state.

        :raises RuntimeError: if the modifier has already been finalized
        :param state: The current state of the model
        :param kwargs: Additional arguments for initializing the modifier
        """
        if self.initialized_:
            raise RuntimeError(
                "Cannot initialize a modifier that has already been initialized"
            )

        if self.finalized_:
            raise RuntimeError(
                "Cannot initialize a modifier that has already been finalized"
            )

        self.initialized_ = self.on_initialize(state=state, **kwargs)

        # trigger starts
        fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
        if self.should_start(fake_start_event):
            self.on_start(state, fake_start_event, **kwargs)
            self.started_ = True

    def finalize(self, state: State, **kwargs):
        """
        Finalize the modifier for the given model and state.

        :raises RuntimeError: if the modifier has not been initialized
        :param state: The current state of the model
        :param kwargs: Additional arguments for finalizing the modifier
        """
        if self.finalized_:
            raise RuntimeError("cannot finalize a modifier twice")

        if not self.initialized_:
            raise RuntimeError("cannot finalize an uninitialized modifier")

        # TODO: all finalization should succeed
        self.finalized_ = self.on_finalize(state=state, **kwargs)

    def update_event(self, state: State, event: Event, **kwargs):
        """
        Update modifier based on the given event. In turn calls
        on_start, on_update, and on_end based on the event and
        modifier settings. Returns immediately if the modifier is
        not initialized

        :raises RuntimeError: if the modifier has been finalized
        :param state: The current state of sparsification
        :param event: The event to update the modifier with
        :param kwargs: Additional arguments for updating the modifier
        """
        if not self.initialized_:
            raise RuntimeError("Cannot update an uninitialized modifier")

        if self.finalized_:
            raise RuntimeError("Cannot update a finalized modifier")

        self.on_event(state, event, **kwargs)

        # handle starting the modifier if needed
        if (
            event.type_ == EventType.BATCH_START
            and not self.started_
            and self.should_start(event)
        ):
            self.on_start(state, event, **kwargs)
            self.started_ = True
            self.on_update(state, event, **kwargs)

            return

        # handle ending the modifier if needed
        if (
            event.type_ == EventType.BATCH_END
            and not self.ended_
            and self.should_end(event)
        ):
            self.on_end(state, event, **kwargs)
            self.ended_ = True
            self.on_update(state, event, **kwargs)

            return

        if self.started_ and not self.ended_:
            self.on_update(state, event, **kwargs)

    def should_start(self, event: Event) -> bool:
        """
        :param event: The event to check if the modifier should start
        :return: True if the modifier should start based on the given event
        """
        if self.start is None:
            return False

        current = event.current_index

        return self.start <= current and (self.end is None or current < self.end)

    def should_end(self, event: Event):
        """
        :param event: The event to check if the modifier should end
        :return: True if the modifier should end based on the given event
        """
        current = event.current_index

        return self.end is not None and current >= self.end

    @abstractmethod
    def on_initialize(self, state: State, **kwargs) -> bool:
        """
        on_initialize is called on modifier initialization and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param kwargs: Additional arguments for initializing the modifier
        :return: True if the modifier was initialized successfully,
            False otherwise
        """
        raise NotImplementedError()

    def on_finalize(self, state: State, **kwargs) -> bool:
        """
        on_finalize is called on modifier finalization and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param kwargs: Additional arguments for finalizing the modifier
        :return: True if the modifier was finalized successfully,
            False otherwise
        """
        return True

    def on_start(self, state: State, event: Event, **kwargs):
        """
        on_start is called when the modifier starts and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the start
        :param kwargs: Additional arguments for starting the modifier
        """
        pass

    def on_update(self, state: State, event: Event, **kwargs):
        """
        on_update is called when the model in question must be
        updated based on passed in event. Must be implemented by the
        inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the update
        :param kwargs: Additional arguments for updating the model
        """
        pass

    def on_end(self, state: State, event: Event, **kwargs):
        """
        on_end is called when the modifier ends and must be implemented
        by the inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the end
        :param kwargs: Additional arguments for ending the modifier
        """
        pass

    def on_event(self, state: State, event: Event, **kwargs):
        """
        on_event is called whenever an event is triggered

        :param state: The current state of the model
        :param event: The event that triggered the update
        :param kwargs: Additional arguments for updating the model
        """
        pass

finalized property

Returns:

Type Description
bool

True if the modifier has been finalized

initialized property

Returns:

Type Description
bool

True if the modifier has been initialized

finalize(state, **kwargs)

Finalize the modifier for the given model and state.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for finalizing the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has not been initialized

Source code in llmcompressor/modifiers/modifier.py
def finalize(self, state: State, **kwargs):
    """
    Finalize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has not been initialized
    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    """
    if self.finalized_:
        raise RuntimeError("cannot finalize a modifier twice")

    if not self.initialized_:
        raise RuntimeError("cannot finalize an uninitialized modifier")

    # TODO: all finalization should succeed
    self.finalized_ = self.on_finalize(state=state, **kwargs)

initialize(state, **kwargs)

Initialize the modifier for the given model and state.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for initializing the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has already been finalized

Source code in llmcompressor/modifiers/modifier.py
def initialize(self, state: State, **kwargs):
    """
    Initialize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has already been finalized
    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    """
    if self.initialized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been initialized"
        )

    if self.finalized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been finalized"
        )

    self.initialized_ = self.on_initialize(state=state, **kwargs)

    # trigger starts
    fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
    if self.should_start(fake_start_event):
        self.on_start(state, fake_start_event, **kwargs)
        self.started_ = True

on_end(state, event, **kwargs)

on_end is called when the modifier ends and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the end

required
kwargs

Additional arguments for ending the modifier

{}
Source code in llmcompressor/modifiers/modifier.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    on_end is called when the modifier ends and must be implemented
    by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the end
    :param kwargs: Additional arguments for ending the modifier
    """
    pass

on_event(state, event, **kwargs)

on_event is called whenever an event is triggered

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the update

required
kwargs

Additional arguments for updating the model

{}
Source code in llmcompressor/modifiers/modifier.py
def on_event(self, state: State, event: Event, **kwargs):
    """
    on_event is called whenever an event is triggered

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

on_finalize(state, **kwargs)

on_finalize is called on modifier finalization and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for finalizing the modifier

{}

Returns:

Type Description
bool

True if the modifier was finalized successfully, False otherwise

Source code in llmcompressor/modifiers/modifier.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    on_finalize is called on modifier finalization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    :return: True if the modifier was finalized successfully,
        False otherwise
    """
    return True

on_initialize(state, **kwargs) abstractmethod

on_initialize is called on modifier initialization and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for initializing the modifier

{}

Returns:

Type Description
bool

True if the modifier was initialized successfully, False otherwise

Source code in llmcompressor/modifiers/modifier.py
@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    on_initialize is called on modifier initialization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    :return: True if the modifier was initialized successfully,
        False otherwise
    """
    raise NotImplementedError()

on_start(state, event, **kwargs)

on_start is called when the modifier starts and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the start

required
kwargs

Additional arguments for starting the modifier

{}
Source code in llmcompressor/modifiers/modifier.py
def on_start(self, state: State, event: Event, **kwargs):
    """
    on_start is called when the modifier starts and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the start
    :param kwargs: Additional arguments for starting the modifier
    """
    pass

on_update(state, event, **kwargs)

on_update is called when the model in question must be updated based on passed in event. Must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the update

required
kwargs

Additional arguments for updating the model

{}
Source code in llmcompressor/modifiers/modifier.py
def on_update(self, state: State, event: Event, **kwargs):
    """
    on_update is called when the model in question must be
    updated based on passed in event. Must be implemented by the
    inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

should_end(event)

Parameters:

Name Type Description Default
event Event

The event to check if the modifier should end

required

Returns:

Type Description

True if the modifier should end based on the given event

Source code in llmcompressor/modifiers/modifier.py
def should_end(self, event: Event):
    """
    :param event: The event to check if the modifier should end
    :return: True if the modifier should end based on the given event
    """
    current = event.current_index

    return self.end is not None and current >= self.end

should_start(event)

Parameters:

Name Type Description Default
event Event

The event to check if the modifier should start

required

Returns:

Type Description
bool

True if the modifier should start based on the given event

Source code in llmcompressor/modifiers/modifier.py
def should_start(self, event: Event) -> bool:
    """
    :param event: The event to check if the modifier should start
    :return: True if the modifier should start based on the given event
    """
    if self.start is None:
        return False

    current = event.current_index

    return self.start <= current and (self.end is None or current < self.end)

update_event(state, event, **kwargs)

Update modifier based on the given event. In turn calls on_start, on_update, and on_end based on the event and modifier settings. Returns immediately if the modifier is not initialized

Parameters:

Name Type Description Default
state State

The current state of sparsification

required
event Event

The event to update the modifier with

required
kwargs

Additional arguments for updating the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has been finalized

Source code in llmcompressor/modifiers/modifier.py
def update_event(self, state: State, event: Event, **kwargs):
    """
    Update modifier based on the given event. In turn calls
    on_start, on_update, and on_end based on the event and
    modifier settings. Returns immediately if the modifier is
    not initialized

    :raises RuntimeError: if the modifier has been finalized
    :param state: The current state of sparsification
    :param event: The event to update the modifier with
    :param kwargs: Additional arguments for updating the modifier
    """
    if not self.initialized_:
        raise RuntimeError("Cannot update an uninitialized modifier")

    if self.finalized_:
        raise RuntimeError("Cannot update a finalized modifier")

    self.on_event(state, event, **kwargs)

    # handle starting the modifier if needed
    if (
        event.type_ == EventType.BATCH_START
        and not self.started_
        and self.should_start(event)
    ):
        self.on_start(state, event, **kwargs)
        self.started_ = True
        self.on_update(state, event, **kwargs)

        return

    # handle ending the modifier if needed
    if (
        event.type_ == EventType.BATCH_END
        and not self.ended_
        and self.should_end(event)
    ):
        self.on_end(state, event, **kwargs)
        self.ended_ = True
        self.on_update(state, event, **kwargs)

        return

    if self.started_ and not self.ended_:
        self.on_update(state, event, **kwargs)

NormMapping

Bases: BaseModel

SpinQuant needs to know where every norm layer exists in the model, as well as all the subsequent Linear layers the norm passes into. This is because the norm layer weights need to normalized before transforms can be fused into Linear layers.

Parameters:

Name Type Description Default
norm

name or regex that matches norm layer in model

required
linears

list of names or regexes of Linear layers that receive input from norm.

required
Source code in llmcompressor/modifiers/transform/spinquant/norm_mappings.py
class NormMapping(BaseModel):
    """
    SpinQuant needs to know where every norm layer exists in the model,
    as well as all the subsequent Linear layers the norm passes into.
    This is because the norm layer weights need to normalized before
    transforms can be fused into Linear layers.

    :param norm: name or regex that matches norm layer in model
    :param linears: list of names or regexes of Linear layers that
    receive input from norm.
    """

    norm: str
    linears: List[str]

    @field_validator("linears", mode="before")
    def cast_to_list(cls, value):
        if isinstance(value, str):
            return [value]

        return value

SpinQuantMapping

Bases: BaseModel

SpinQuant needs to know the entire architecture of the model, as R1, R2, R3, and R4 rotations need to be applied to specific layers (https://arxiv.org/pdf/2405.16406 Fig. 1).

Parameters:

Name Type Description Default
embedding

name or regex of embedding layer

required
attn_q

name or regex of q_proj layer in attention block

required
attn_k

name or regex of k_proj layer in attention block

required
attn_v

name or regex of v_proj layer in attention block

required
attn_o

name or regex of o_proj layer in attention block

required
attn_head_dim

head_dim of the attention module, needed because R2 needs to be applied "head-wisely" to v_proj and o_proj

required
mlp_in

list of names or regexes for the mlp blocks that receive the input to the MLP block, usually up_proj and gate_proj

required
mlp_out

list of names or regexes for the mlp blocks that consitute the output of the MLP block, usually down_proj

required
Source code in llmcompressor/modifiers/transform/spinquant/mappings.py
class SpinQuantMapping(BaseModel):
    """
    SpinQuant needs to know the entire architecture of the model,
    as R1, R2, R3, and R4 rotations need to be applied to specific
    layers (https://arxiv.org/pdf/2405.16406 Fig. 1).

    :param embedding: name or regex of embedding layer
    :param attn_q: name or regex of q_proj layer in attention block
    :param attn_k: name or regex of k_proj layer in attention block
    :param attn_v: name or regex of v_proj layer in attention block
    :param attn_o: name or regex of o_proj layer in attention block
    :param attn_head_dim: head_dim of the attention module, needed
        because R2 needs to be applied "head-wisely" to v_proj and
        o_proj
    :param mlp_in: list of names or regexes for the mlp blocks that
        receive the input to the MLP block, usually up_proj and gate_proj
    :param mlp_out: list of names or regexes for the mlp blocks that
        consitute the output of the MLP block, usually down_proj
    """

    embedding: str

    attn_q: str
    attn_k: str
    attn_v: str
    attn_o: str
    attn_head_dim: Optional[int] = Field(default=None)

    mlp_in: List[str]  # up_proj, gate_proj
    mlp_out: List[str]  # down_proj

    lm_head: str

    @field_validator("mlp_in", "mlp_out", mode="before")
    def cast_to_list(cls, value):
        if isinstance(value, str):
            return [value]

        return value

SpinQuantModifier

Bases: Modifier

Implements the transforms according to "SpinQuant: LLM quantization with learned rotations" (https://arxiv.org/abs/2405.16406)

Transforms (rotations) are extra layers added to a model which reduce the accuracy loss induced by quantization. This is achived through "rotating" weights and activations into a space with a smaller dynamic range of values, thus decreasing the range of scales required for quantization.

The SpinQuant authors describe four different rotations which can be applied to a model. R1 and R2 are "offline" rotations, meaning that they can be fused into existing weights and therefore do not induce runtime cost. R3 and R4 are "online" rotations, meaning that they require additional computation at runtime.

Lifecycle: - on_initialize - infer SpinQuantMappings & NormMappings - as needed, create transform schemes for R1, R2, R3, & R4 - on_start - normalize embeddings - fuse norm layers into subsequent Linear layers - apply TransformConfig - fuse transforms into weights for mergeable transforms - add hooks for online transforms - on sequential epoch end - on_end - on_finalize

Parameters:

Name Type Description Default
rotations

A list containing the names of rotations to apply to the model. Possible rotations include R1, R2, R3, and R4

required
transform_type

The type of transform to apply to the model. "hadamard" has the least performance cost but only supports sizes which are powers of power of two. "random-matrix" has more performance cost, but supports a much larger set of sizes. "random-matrix" has the greatest performance cost, but supports any size

required
randomize

if True, create distinct transforms for each application

required
learnable

if True, attach gradients to transform weights for training

required
precision

Precision at which all transforms should be applied. This applies to both weight fusing and online rotations

required
mappings

Specifies layers within a model to target for transforms. A mapping will be inferred if None is provided

required
norm_mappings

Specifies layers within a model to target for norm fusing. A mapping will be inferred if None is provided

required
transform_config

Optional transform config for overriding provided arguments

required
Source code in llmcompressor/modifiers/transform/spinquant/base.py
class SpinQuantModifier(Modifier, use_enum_values=True):
    """
    Implements the transforms according to "SpinQuant: LLM quantization
    with learned rotations" (https://arxiv.org/abs/2405.16406)

    Transforms (rotations) are extra layers added to a model which reduce the accuracy
    loss induced by quantization. This is achived through "rotating" weights and
    activations into a space with a smaller dynamic range of values, thus decreasing
    the range of scales required for quantization.

    The SpinQuant authors describe four different rotations which can be applied to a
    model. R1 and R2 are "offline" rotations, meaning that they can be fused into
    existing weights and therefore do not induce runtime cost. R3 and R4 are "online"
    rotations, meaning that they require additional computation at runtime.

    Lifecycle:
        - on_initialize
            - infer SpinQuantMappings & NormMappings
            - as needed, create transform schemes for R1, R2, R3, & R4
        - on_start
            - normalize embeddings
            - fuse norm layers into subsequent Linear layers
            - apply TransformConfig
                - fuse transforms into weights for mergeable transforms
                - add hooks for online transforms
        - on sequential epoch end
        - on_end
        - on_finalize

    :param rotations: A list containing the names of rotations to apply to the model.
        Possible rotations include R1, R2, R3, and R4
    :param transform_type: The type of transform to apply to the model.
        `"hadamard"` has the least performance cost but only supports sizes which are
        powers of power of two.
        `"random-matrix"` has more performance cost, but supports a much larger set of
            sizes.
        `"random-matrix"` has the greatest performance cost, but supports any size
    :param randomize: if True, create distinct transforms for each application
    :param learnable: if True, attach gradients to transform weights for training
    :param precision: Precision at which all transforms should be applied. This applies
        to both weight fusing and online rotations
    :param mappings: Specifies layers within a model to target for transforms.
        A mapping will be inferred if None is provided
    :param norm_mappings: Specifies layers within a model to target for norm fusing.
        A mapping will be inferred if None is provided
    :param transform_config: Optional transform config for overriding provided arguments
    """

    rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"])
    transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
        default="hadamard"
    )
    randomize: bool = Field(default=False)
    learnable: bool = Field(default=False)
    precision: TorchDtype = Field(default=torch.float64)

    # norm mappings separate from spinquant mappings to allow users to
    # override spinquant mappings with transform_config without overriding norms
    mappings: Optional[SpinQuantMapping] = Field(
        default=None,
        repr=False,
        exclude=True,
    )
    norm_mappings: Optional[List[NormMapping]] = Field(
        default=None,
        repr=False,
        exclude=True,
    )

    # optional override for more fine-grained control
    # also included in recipe serialization
    transform_config: Optional[TransformConfig] = Field(default=None, repr=False)

    @field_validator("randomize", "learnable", mode="before")
    def validate_not_implemented(cls, value, info: ValidationInfo):
        if value:
            raise NotImplementedError(f"{info.field_name} is not supported right now")
        return value

    @field_validator("rotations", mode="before")
    def validate_rotations(cls, value):
        if isinstance(value, Iterable):
            return tuple(v.upper() for v in value)
        return value

    def on_initialize(self, state: State, **kwargs) -> bool:
        if self.transform_config is not None:
            return True

        self.mappings = infer_mapping_from_model(state.model)
        self.norm_mappings = infer_norm_mapping_from_model(state.model)

        config_groups = {}
        if SpinquantRotation.R1 in self.rotations:
            config_groups["R1"] = self._create_r1_scheme()

        if SpinquantRotation.R2 in self.rotations:
            config_groups["R2"] = self._create_r2_scheme(state.model)

        if SpinquantRotation.R3 in self.rotations:
            config_groups["R3"] = self._create_r3_scheme()

        if SpinquantRotation.R4 in self.rotations:
            config_groups["R4"] = self._create_r4_scheme()

        self.transform_config = TransformConfig(config_groups=config_groups)

        return True

    def on_start(self, state: State, event: Event, **kwargs):
        self.started_ = True

        # needs to happen after the model has been hooked to execute on the GPU
        # otherwise we're applying weight transforms on CPU
        self._center_embeddings(state.model)
        self._fuse_norms(state.model)
        apply_transform_config(state.model, self.transform_config)

    def on_event(self, state: State, event: Event, **kwargs):
        if event.type_ == EventType.CALIBRATION_EPOCH_START:
            if not self.started_:
                self.on_start(state, None)

        elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
            pass

        elif event.type_ == EventType.CALIBRATION_EPOCH_END:
            if not self.ended_:
                self.on_end(state, None)

    def on_end(self, state: State, event: Event, **kwargs):
        self.ended_ = True

    def on_finalize(self, state: State, **kwargs) -> bool:
        if not self.ended_:
            self.on_end(state, None)

        return True

    def _center_embeddings(self, model: PreTrainedModel):
        for _, embedding in match_named_modules(
            model, [self.mappings.embedding], warn_on_fail=True
        ):
            center_embeddings(embedding)

    def _fuse_norms(self, model: PreTrainedModel):
        for mapping in self.norm_mappings:
            for norm, *linears in match_modules_set(
                model, (mapping.norm, *mapping.linears)
            ):
                fuse_norm_linears(norm, linears)

    def _create_r1_scheme(self) -> TransformScheme:
        return TransformScheme(
            type=self.transform_type,
            randomize=self.randomize,
            requires_grad=self.learnable,
            precision=self.precision,
            apply=[
                TransformArgs(
                    targets=[
                        self.mappings.embedding,
                        self.mappings.attn_o,
                        *self.mappings.mlp_out,
                    ],
                    location="weight_output",
                ),
                TransformArgs(
                    targets=[
                        self.mappings.attn_q,
                        self.mappings.attn_k,
                        self.mappings.attn_v,
                        *self.mappings.mlp_in,
                        self.mappings.lm_head,
                    ],
                    location="weight_input",
                    inverse=True,
                ),
            ],
        )

    def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
        config = model.config

        if hasattr(config, "head_dim"):
            head_dim = config.head_dim
        elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
            head_dim = config.hidden_size // config.num_attention_heads
        else:
            raise NotImplementedError()

        return TransformScheme(
            type=self.transform_type,
            randomize=self.randomize,
            requires_grad=self.learnable,
            precision=self.precision,
            head_dim=head_dim,
            apply=[
                TransformArgs(targets=[self.mappings.attn_v], location="weight_output"),
                TransformArgs(
                    targets=[self.mappings.attn_o],
                    location="weight_input",
                    inverse=True,
                ),
            ],
        )

    def _create_r3_scheme(self) -> TransformScheme:
        raise NotImplementedError(
            "SpinQuant R3 and R4 rotations will be added in a future release"
        )

    def _create_r4_scheme(self) -> TransformScheme:
        raise NotImplementedError(
            "SpinQuant R3 and R4 rotations will be added in a future release"
        )

State dataclass

State class holds information about the current compression state.

Parameters:

Name Type Description Default
model Any

The model being used for compression

None
teacher_model Any

The teacher model being used for compression

None
optimizer Any

The optimizer being used for training

None
optim_wrapped bool

Whether or not the optimizer has been wrapped

None
loss Any

The loss function being used for training

None
batch_data Any

The current batch of data being used for compression

None
data Data

The data sets being used for training, validation, testing, and/or calibration, wrapped in a Data instance

Data()
hardware Hardware

Hardware instance holding info about the target hardware being used

Hardware()
loggers Optional[LoggerManager]

LoggerManager instance holding all the loggers to log

None
model_log_cadence Optional[float]

The cadence to log model information w.r.t epochs. If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.

None
Source code in llmcompressor/core/state.py
@dataclass
class State:
    """
    State class holds information about the current compression state.

    :param model: The model being used for compression
    :type model: Any
    :param teacher_model: The teacher model being used for compression
    :type teacher_model: Any
    :param optimizer: The optimizer being used for training
    :type optimizer: Any
    :param optim_wrapped: Whether or not the optimizer has been wrapped
    :type optim_wrapped: bool
    :param loss: The loss function being used for training
    :type loss: Any
    :param batch_data: The current batch of data being used for compression
    :type batch_data: Any
    :param data: The data sets being used for training, validation, testing,
        and/or calibration, wrapped in a Data instance
    :type data: Data
    :param hardware: Hardware instance holding info about the target hardware being used
    :type hardware: Hardware
    :param loggers: LoggerManager instance holding all the loggers to log
    :type loggers: Optional[LoggerManager]
    :param model_log_cadence: The cadence to log model information w.r.t epochs.
        If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
    :type model_log_cadence: Optional[float]
    """

    model: Any = None
    teacher_model: Any = None
    optimizer: Any = None
    optim_wrapped: bool = None
    loss: Any = None
    batch_data: Any = None
    data: Data = field(default_factory=Data)
    hardware: Hardware = field(default_factory=Hardware)
    loggers: Optional[LoggerManager] = None
    model_log_cadence: Optional[float] = None
    _last_log_step: Union[float, int, None] = None

    @property
    def compression_ready(self) -> bool:
        """
        Check if the model and optimizer are set for compression.

        :return: True if model and optimizer are set, False otherwise
        :rtype: bool
        """
        ready = self.model is not None and self.optimizer is not None
        logger.debug("Compression ready: {}", ready)
        return ready

    def update(
        self,
        model: Any = None,
        teacher_model: Any = None,
        optimizer: Any = None,
        attach_optim_callbacks: bool = True,
        train_data: Any = None,
        val_data: Any = None,
        test_data: Any = None,
        calib_data: Any = None,
        copy_data: bool = True,
        start: float = None,
        steps_per_epoch: int = None,
        batches_per_step: int = None,
        loggers: Union[None, LoggerManager, List[BaseLogger]] = None,
        model_log_cadence: Optional[float] = None,
        **kwargs,
    ) -> Dict:
        """
        Update the state with the given parameters.

        :param model: The model to update the state with
        :type model: Any
        :param teacher_model: The teacher model to update the state with
        :type teacher_model: Any
        :param optimizer: The optimizer to update the state with
        :type optimizer: Any
        :param attach_optim_callbacks: Whether or not to attach optimizer callbacks
        :type attach_optim_callbacks: bool
        :param train_data: The training data to update the state with
        :type train_data: Any
        :param val_data: The validation data to update the state with
        :type val_data: Any
        :param test_data: The testing data to update the state with
        :type test_data: Any
        :param calib_data: The calibration data to update the state with
        :type calib_data: Any
        :param copy_data: Whether or not to copy the data
        :type copy_data: bool
        :param start: The start index to update the state with
        :type start: float
        :param steps_per_epoch: The steps per epoch to update the state with
        :type steps_per_epoch: int
        :param batches_per_step: The batches per step to update the state with
        :type batches_per_step: int
        :param loggers: The metrics manager to setup logging important info and
            milestones to, also accepts a list of BaseLogger(s)
        :type loggers: Union[None, LoggerManager, List[BaseLogger]]
        :param model_log_cadence: The cadence to log model information w.r.t epochs.
            If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
        :type model_log_cadence: Optional[float]
        :param kwargs: Additional keyword arguments to update the state with
        :return: The updated state as a dictionary
        :rtype: Dict
        """
        logger.debug(
            "Updating state with provided parameters: {}",
            {
                "model": model,
                "teacher_model": teacher_model,
                "optimizer": optimizer,
                "attach_optim_callbacks": attach_optim_callbacks,
                "train_data": train_data,
                "val_data": val_data,
                "test_data": test_data,
                "calib_data": calib_data,
                "copy_data": copy_data,
                "start": start,
                "steps_per_epoch": steps_per_epoch,
                "batches_per_step": batches_per_step,
                "loggers": loggers,
                "model_log_cadence": model_log_cadence,
                "kwargs": kwargs,
            },
        )

        if model is not None:
            self.model = model
        if teacher_model is not None:
            self.teacher_model = teacher_model
        if optimizer is not None:
            self.optim_wrapped = attach_optim_callbacks
            self.optimizer = optimizer
        if train_data is not None:
            self.data.train = train_data if not copy_data else deepcopy(train_data)
        if val_data is not None:
            self.data.val = val_data if not copy_data else deepcopy(val_data)
        if test_data is not None:
            self.data.test = test_data if not copy_data else deepcopy(test_data)
        if calib_data is not None:
            self.data.calib = calib_data if not copy_data else deepcopy(calib_data)

        if "device" in kwargs:
            self.hardware.device = kwargs["device"]

        loggers = loggers or []
        if isinstance(loggers, list):
            loggers = LoggerManager(loggers)
        self.loggers = loggers

        if model_log_cadence is not None:
            self.model_log_cadence = model_log_cadence

        return kwargs

compression_ready property

Check if the model and optimizer are set for compression.

Returns:

Type Description
bool

True if model and optimizer are set, False otherwise

update(model=None, teacher_model=None, optimizer=None, attach_optim_callbacks=True, train_data=None, val_data=None, test_data=None, calib_data=None, copy_data=True, start=None, steps_per_epoch=None, batches_per_step=None, loggers=None, model_log_cadence=None, **kwargs)

Update the state with the given parameters.

Parameters:

Name Type Description Default
model Any

The model to update the state with

None
teacher_model Any

The teacher model to update the state with

None
optimizer Any

The optimizer to update the state with

None
attach_optim_callbacks bool

Whether or not to attach optimizer callbacks

True
train_data Any

The training data to update the state with

None
val_data Any

The validation data to update the state with

None
test_data Any

The testing data to update the state with

None
calib_data Any

The calibration data to update the state with

None
copy_data bool

Whether or not to copy the data

True
start float

The start index to update the state with

None
steps_per_epoch int

The steps per epoch to update the state with

None
batches_per_step int

The batches per step to update the state with

None
loggers Union[None, LoggerManager, List[BaseLogger]]

The metrics manager to setup logging important info and milestones to, also accepts a list of BaseLogger(s)

None
model_log_cadence Optional[float]

The cadence to log model information w.r.t epochs. If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.

None
kwargs

Additional keyword arguments to update the state with

{}

Returns:

Type Description
Dict

The updated state as a dictionary

Source code in llmcompressor/core/state.py
def update(
    self,
    model: Any = None,
    teacher_model: Any = None,
    optimizer: Any = None,
    attach_optim_callbacks: bool = True,
    train_data: Any = None,
    val_data: Any = None,
    test_data: Any = None,
    calib_data: Any = None,
    copy_data: bool = True,
    start: float = None,
    steps_per_epoch: int = None,
    batches_per_step: int = None,
    loggers: Union[None, LoggerManager, List[BaseLogger]] = None,
    model_log_cadence: Optional[float] = None,
    **kwargs,
) -> Dict:
    """
    Update the state with the given parameters.

    :param model: The model to update the state with
    :type model: Any
    :param teacher_model: The teacher model to update the state with
    :type teacher_model: Any
    :param optimizer: The optimizer to update the state with
    :type optimizer: Any
    :param attach_optim_callbacks: Whether or not to attach optimizer callbacks
    :type attach_optim_callbacks: bool
    :param train_data: The training data to update the state with
    :type train_data: Any
    :param val_data: The validation data to update the state with
    :type val_data: Any
    :param test_data: The testing data to update the state with
    :type test_data: Any
    :param calib_data: The calibration data to update the state with
    :type calib_data: Any
    :param copy_data: Whether or not to copy the data
    :type copy_data: bool
    :param start: The start index to update the state with
    :type start: float
    :param steps_per_epoch: The steps per epoch to update the state with
    :type steps_per_epoch: int
    :param batches_per_step: The batches per step to update the state with
    :type batches_per_step: int
    :param loggers: The metrics manager to setup logging important info and
        milestones to, also accepts a list of BaseLogger(s)
    :type loggers: Union[None, LoggerManager, List[BaseLogger]]
    :param model_log_cadence: The cadence to log model information w.r.t epochs.
        If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
    :type model_log_cadence: Optional[float]
    :param kwargs: Additional keyword arguments to update the state with
    :return: The updated state as a dictionary
    :rtype: Dict
    """
    logger.debug(
        "Updating state with provided parameters: {}",
        {
            "model": model,
            "teacher_model": teacher_model,
            "optimizer": optimizer,
            "attach_optim_callbacks": attach_optim_callbacks,
            "train_data": train_data,
            "val_data": val_data,
            "test_data": test_data,
            "calib_data": calib_data,
            "copy_data": copy_data,
            "start": start,
            "steps_per_epoch": steps_per_epoch,
            "batches_per_step": batches_per_step,
            "loggers": loggers,
            "model_log_cadence": model_log_cadence,
            "kwargs": kwargs,
        },
    )

    if model is not None:
        self.model = model
    if teacher_model is not None:
        self.teacher_model = teacher_model
    if optimizer is not None:
        self.optim_wrapped = attach_optim_callbacks
        self.optimizer = optimizer
    if train_data is not None:
        self.data.train = train_data if not copy_data else deepcopy(train_data)
    if val_data is not None:
        self.data.val = val_data if not copy_data else deepcopy(val_data)
    if test_data is not None:
        self.data.test = test_data if not copy_data else deepcopy(test_data)
    if calib_data is not None:
        self.data.calib = calib_data if not copy_data else deepcopy(calib_data)

    if "device" in kwargs:
        self.hardware.device = kwargs["device"]

    loggers = loggers or []
    if isinstance(loggers, list):
        loggers = LoggerManager(loggers)
    self.loggers = loggers

    if model_log_cadence is not None:
        self.model_log_cadence = model_log_cadence

    return kwargs

center_embeddings(embedding)

Shift each embedding to have a mean of zero

Parameters:

Name Type Description Default
embedding Module

embedding module containing embeddings to center

required
Source code in llmcompressor/modeling/fuse.py
def center_embeddings(embedding: torch.nn.Module):
    """
    Shift each embedding to have a mean of zero

    :param embedding: embedding module containing embeddings to center
    """
    if not hasattr(embedding, "weight"):
        raise ValueError(f"Cannot fuse norm of type {type(embedding)}")

    with align_module_device(embedding):
        weight_dtype = embedding.weight.dtype
        weight = embedding.weight.to(PRECISION)
        new_weight = weight - weight.mean(dim=-1, keepdim=True)
        new_weight = new_weight.to(weight_dtype)

    update_offload_parameter(embedding, "weight", new_weight)

fuse_norm_linears(norm, linears)

Fuse the scaling operation of norm layer into subsequent linear layers. This useful for ensuring transform invariance between norm and linear layers.

Note that unitary transforms (rotation) commute with normalization, but not scaling

Parameters:

Name Type Description Default
norm Module

norm layer whose weight will be fused into subsequent linears

required
linears Iterable[Linear]

linear layers which directly follow the norm layer

required
Source code in llmcompressor/modeling/fuse.py
def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
    """
    Fuse the scaling operation of norm layer into subsequent linear layers.
    This useful for ensuring transform invariance between norm and linear layers.

    Note that unitary transforms (rotation) commute with normalization, but not scaling

    :param norm: norm layer whose weight will be fused into subsequent linears
    :param linears: linear layers which directly follow the norm layer
    """
    if not hasattr(norm, "weight"):
        raise ValueError(f"Cannot fuse norm of type {type(norm)}")

    for linear in linears:
        # NOTE: spinquant does this op in float64
        exec_device = get_execution_device(norm)
        with align_module_device(norm, exec_device), align_module_device(
            linear, exec_device
        ):
            weight_dtype = linear.weight.dtype
            new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
            new_weight = new_weight.to(weight_dtype)

        update_offload_parameter(linear, "weight", new_weight)

    new_norm_weight = torch.ones_like(norm.weight, device="cpu")
    update_offload_parameter(norm, "weight", new_norm_weight)