Skip to content

llmcompressor.modifiers.transform.spinquant

Modules:

Classes:

  • Event

    A class for defining an event that can be triggered during sparsification.

  • EventType

    An Enum for defining the different types of events that can be triggered

  • Modifier

    A base class for all modifiers to inherit from.

  • NormMapping

    SpinQuant needs to know where every norm layer exists in the model,

  • SpinQuantMapping

    SpinQuant needs to know the entire architecture of the model,

  • SpinQuantModifier

    Implements the transforms according to "SpinQuant: LLM quantization

  • State

    State class holds information about the current compression state.

Functions:

Event dataclass

Event(
    type_: Optional[EventType] = None,
    steps_per_epoch: Optional[int] = None,
    batches_per_step: Optional[int] = None,
    invocations_per_step: int = 1,
    global_step: int = 0,
    global_batch: int = 0,
)

A class for defining an event that can be triggered during sparsification.

Parameters:

  • type_

    (Optional[EventType], default: None ) –

    The type of event.

  • steps_per_epoch

    (Optional[int], default: None ) –

    The number of steps per epoch.

  • batches_per_step

    (Optional[int], default: None ) –

    The number of batches per step where step is an optimizer step invocation. For most pathways, these are the same. See the invocations_per_step parameter for more details when they are not.

  • invocations_per_step

    (int, default: 1 ) –

    The number of invocations of the step wrapper before optimizer.step was called. Generally can be left as 1 (default). For older amp pathways, this is the number of times the scaler wrapper was invoked before the wrapped optimizer step function was called to handle accumulation in fp16.

  • global_step

    (int, default: 0 ) –

    The current global step.

  • global_batch

    (int, default: 0 ) –

    The current global batch.

Methods:

  • new_instance

    Creates a new instance of the event with the provided keyword arguments.

  • should_update

    Determines if the event should trigger an update.

Attributes:

  • current_index (float) –

    Calculates the current index of the event.

  • epoch (int) –

    Calculates the current epoch.

  • epoch_based (bool) –

    Determines if the event is based on epochs.

  • epoch_batch (int) –

    Calculates the current batch within the current epoch.

  • epoch_full (float) –

    Calculates the current epoch with the fraction of the current step.

  • epoch_step (int) –

    Calculates the current step within the current epoch.

current_index property writable

current_index: float

Calculates the current index of the event.

Returns:

  • float

    The current index of the event, which is either the global step or the epoch with the fraction of the current step.

Raises:

  • ValueError

    if the event is not epoch based or if the steps per epoch are too many.

epoch property

epoch: int

Calculates the current epoch.

Returns:

  • int

    The current epoch.

Raises:

  • ValueError

    if the event is not epoch based.

epoch_based property

epoch_based: bool

Determines if the event is based on epochs.

Returns:

  • bool

    True if the event is based on epochs, False otherwise.

epoch_batch property

epoch_batch: int

Calculates the current batch within the current epoch.

Returns:

  • int

    The current batch within the current epoch.

Raises:

  • ValueError

    if the event is not epoch based.

epoch_full property

epoch_full: float

Calculates the current epoch with the fraction of the current step.

Returns:

  • float

    The current epoch with the fraction of the current step.

Raises:

  • ValueError

    if the event is not epoch based.

epoch_step property

epoch_step: int

Calculates the current step within the current epoch.

Returns:

  • int

    The current step within the current epoch.

Raises:

  • ValueError

    if the event is not epoch based.

new_instance

new_instance(**kwargs) -> Event

Creates a new instance of the event with the provided keyword arguments.

Parameters:

  • kwargs

    Keyword arguments to set in the new instance.

Returns:

  • Event

    A new instance of the event with the provided kwargs.

Source code in llmcompressor/core/events/event.py
def new_instance(self, **kwargs) -> "Event":
    """
    Creates a new instance of the event with the provided keyword arguments.

    :param kwargs: Keyword arguments to set in the new instance.
    :return: A new instance of the event with the provided kwargs.
    :rtype: Event
    """
    logger.debug("Creating new instance of event with kwargs: {}", kwargs)
    instance = deepcopy(self)
    for key, value in kwargs.items():
        setattr(instance, key, value)
    return instance

should_update

should_update(
    start: Optional[float],
    end: Optional[float],
    update: Optional[float],
) -> bool

Determines if the event should trigger an update.

Parameters:

  • start

    (Optional[float]) –

    The start index to check against, set to None to ignore start.

  • end

    (Optional[float]) –

    The end index to check against, set to None to ignore end.

  • update

    (Optional[float]) –

    The update interval, set to None or 0.0 to always update, otherwise must be greater than 0.0, defaults to None.

Returns:

  • bool

    True if the event should trigger an update, False otherwise.

Source code in llmcompressor/core/events/event.py
def should_update(
    self, start: Optional[float], end: Optional[float], update: Optional[float]
) -> bool:
    """
    Determines if the event should trigger an update.

    :param start: The start index to check against, set to None to ignore start.
    :type start: Optional[float]
    :param end: The end index to check against, set to None to ignore end.
    :type end: Optional[float]
    :param update: The update interval, set to None or 0.0 to always update,
        otherwise must be greater than 0.0, defaults to None.
    :type update: Optional[float]
    :return: True if the event should trigger an update, False otherwise.
    :rtype: bool
    """
    current = self.current_index
    logger.debug(
        "Checking if event should update: "
        "current_index={}, start={}, end={}, update={}",
        current,
        start,
        end,
        update,
    )
    if start is not None and current < start:
        return False
    if end is not None and current > end:
        return False
    return update is None or update <= 0.0 or current % update < 1e-10

EventType

Bases: Enum

An Enum for defining the different types of events that can be triggered during model compression lifecycles. The purpose of each EventType is to trigger the corresponding modifier callback during training or post training pipelines.

Parameters:

  • INITIALIZE

    Event type for initialization.

  • FINALIZE

    Event type for finalization.

  • BATCH_START

    Event type for the start of a batch.

  • LOSS_CALCULATED

    Event type for when loss is calculated.

  • BATCH_END

    Event type for the end of a batch.

  • CALIBRATION_EPOCH_START

    Event type for the start of a calibration epoch.

  • SEQUENTIAL_EPOCH_END

    Event type for the end of a layer calibration epoch, specifically used by src/llmcompressor/pipelines/sequential/pipeline.py

  • CALIBRATION_EPOCH_END

    Event type for the end of a calibration epoch.

  • OPTIM_PRE_STEP

    Event type for pre-optimization step.

  • OPTIM_POST_STEP

    Event type for post-optimization step.

Modifier

Bases: ModifierInterface, HooksMixin

A base class for all modifiers to inherit from. Modifiers are used to modify the training process for a model. Defines base attributes and methods available to all modifiers

Lifecycle: 1. initialize 2. on_event -> * on_start if self.start <= event.current_index * on_end if self.end >= event.current_index 5. finalize

Parameters:

  • index

    The index of the modifier in the list of modifiers for the model

  • group

    The group name for the modifier

  • start

    The start step for the modifier

  • end

    The end step for the modifier

  • update

    The update step for the modifier

Methods:

  • finalize

    Finalize the modifier for the given model and state.

  • initialize

    Initialize the modifier for the given model and state.

  • on_end

    on_end is called when the modifier ends and must be implemented

  • on_event

    on_event is called whenever an event is triggered

  • on_finalize

    on_finalize is called on modifier finalization and

  • on_initialize

    on_initialize is called on modifier initialization and

  • on_start

    on_start is called when the modifier starts and

  • on_update

    on_update is called when the model in question must be

  • should_end

    :param event: The event to check if the modifier should end

  • should_start

    :param event: The event to check if the modifier should start

  • update_event

    Update modifier based on the given event. In turn calls

Attributes:

  • finalized (bool) –

    :return: True if the modifier has been finalized

  • initialized (bool) –

    :return: True if the modifier has been initialized

finalized property

finalized: bool

Returns:

  • bool

    True if the modifier has been finalized

initialized property

initialized: bool

Returns:

  • bool

    True if the modifier has been initialized

finalize

finalize(state: State, **kwargs)

Finalize the modifier for the given model and state.

Parameters:

  • state

    (State) –

    The current state of the model

  • kwargs

    Additional arguments for finalizing the modifier

Raises:

  • RuntimeError

    if the modifier has not been initialized

Source code in llmcompressor/modifiers/modifier.py
def finalize(self, state: State, **kwargs):
    """
    Finalize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has not been initialized
    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    """
    if self.finalized_:
        raise RuntimeError("cannot finalize a modifier twice")

    if not self.initialized_:
        raise RuntimeError("cannot finalize an uninitialized modifier")

    # TODO: all finalization should succeed
    self.finalized_ = self.on_finalize(state=state, **kwargs)

initialize

initialize(state: State, **kwargs)

Initialize the modifier for the given model and state.

Parameters:

  • state

    (State) –

    The current state of the model

  • kwargs

    Additional arguments for initializing the modifier

Raises:

  • RuntimeError

    if the modifier has already been finalized

Source code in llmcompressor/modifiers/modifier.py
def initialize(self, state: State, **kwargs):
    """
    Initialize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has already been finalized
    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    """
    if self.initialized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been initialized"
        )

    if self.finalized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been finalized"
        )

    self.initialized_ = self.on_initialize(state=state, **kwargs)

    # trigger starts
    fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
    if self.should_start(fake_start_event):
        self.on_start(state, fake_start_event, **kwargs)
        self.started_ = True

on_end

on_end(state: State, event: Event, **kwargs)

on_end is called when the modifier ends and must be implemented by the inheriting modifier.

Parameters:

  • state

    (State) –

    The current state of the model

  • event

    (Event) –

    The event that triggered the end

  • kwargs

    Additional arguments for ending the modifier

Source code in llmcompressor/modifiers/modifier.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    on_end is called when the modifier ends and must be implemented
    by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the end
    :param kwargs: Additional arguments for ending the modifier
    """
    pass

on_event

on_event(state: State, event: Event, **kwargs)

on_event is called whenever an event is triggered

Parameters:

  • state

    (State) –

    The current state of the model

  • event

    (Event) –

    The event that triggered the update

  • kwargs

    Additional arguments for updating the model

Source code in llmcompressor/modifiers/modifier.py
def on_event(self, state: State, event: Event, **kwargs):
    """
    on_event is called whenever an event is triggered

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

on_finalize

on_finalize(state: State, **kwargs) -> bool

on_finalize is called on modifier finalization and must be implemented by the inheriting modifier.

Parameters:

  • state

    (State) –

    The current state of the model

  • kwargs

    Additional arguments for finalizing the modifier

Returns:

  • bool

    True if the modifier was finalized successfully, False otherwise

Source code in llmcompressor/modifiers/modifier.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    on_finalize is called on modifier finalization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    :return: True if the modifier was finalized successfully,
        False otherwise
    """
    return True

on_initialize abstractmethod

on_initialize(state: State, **kwargs) -> bool

on_initialize is called on modifier initialization and must be implemented by the inheriting modifier.

Parameters:

  • state

    (State) –

    The current state of the model

  • kwargs

    Additional arguments for initializing the modifier

Returns:

  • bool

    True if the modifier was initialized successfully, False otherwise

Source code in llmcompressor/modifiers/modifier.py
@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    on_initialize is called on modifier initialization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    :return: True if the modifier was initialized successfully,
        False otherwise
    """
    raise NotImplementedError()

on_start

on_start(state: State, event: Event, **kwargs)

on_start is called when the modifier starts and must be implemented by the inheriting modifier.

Parameters:

  • state

    (State) –

    The current state of the model

  • event

    (Event) –

    The event that triggered the start

  • kwargs

    Additional arguments for starting the modifier

Source code in llmcompressor/modifiers/modifier.py
def on_start(self, state: State, event: Event, **kwargs):
    """
    on_start is called when the modifier starts and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the start
    :param kwargs: Additional arguments for starting the modifier
    """
    pass

on_update

on_update(state: State, event: Event, **kwargs)

on_update is called when the model in question must be updated based on passed in event. Must be implemented by the inheriting modifier.

Parameters:

  • state

    (State) –

    The current state of the model

  • event

    (Event) –

    The event that triggered the update

  • kwargs

    Additional arguments for updating the model

Source code in llmcompressor/modifiers/modifier.py
def on_update(self, state: State, event: Event, **kwargs):
    """
    on_update is called when the model in question must be
    updated based on passed in event. Must be implemented by the
    inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

should_end

should_end(event: Event)

Parameters:

  • event

    (Event) –

    The event to check if the modifier should end

Returns:

  • True if the modifier should end based on the given event

Source code in llmcompressor/modifiers/modifier.py
def should_end(self, event: Event):
    """
    :param event: The event to check if the modifier should end
    :return: True if the modifier should end based on the given event
    """
    current = event.current_index

    return self.end is not None and current >= self.end

should_start

should_start(event: Event) -> bool

Parameters:

  • event

    (Event) –

    The event to check if the modifier should start

Returns:

  • bool

    True if the modifier should start based on the given event

Source code in llmcompressor/modifiers/modifier.py
def should_start(self, event: Event) -> bool:
    """
    :param event: The event to check if the modifier should start
    :return: True if the modifier should start based on the given event
    """
    if self.start is None:
        return False

    current = event.current_index

    return self.start <= current and (self.end is None or current < self.end)

update_event

update_event(state: State, event: Event, **kwargs)

Update modifier based on the given event. In turn calls on_start, on_update, and on_end based on the event and modifier settings. Returns immediately if the modifier is not initialized

Parameters:

  • state

    (State) –

    The current state of sparsification

  • event

    (Event) –

    The event to update the modifier with

  • kwargs

    Additional arguments for updating the modifier

Raises:

  • RuntimeError

    if the modifier has been finalized

Source code in llmcompressor/modifiers/modifier.py
def update_event(self, state: State, event: Event, **kwargs):
    """
    Update modifier based on the given event. In turn calls
    on_start, on_update, and on_end based on the event and
    modifier settings. Returns immediately if the modifier is
    not initialized

    :raises RuntimeError: if the modifier has been finalized
    :param state: The current state of sparsification
    :param event: The event to update the modifier with
    :param kwargs: Additional arguments for updating the modifier
    """
    if not self.initialized_:
        raise RuntimeError("Cannot update an uninitialized modifier")

    if self.finalized_:
        raise RuntimeError("Cannot update a finalized modifier")

    self.on_event(state, event, **kwargs)

    # handle starting the modifier if needed
    if (
        event.type_ == EventType.BATCH_START
        and not self.started_
        and self.should_start(event)
    ):
        self.on_start(state, event, **kwargs)
        self.started_ = True
        self.on_update(state, event, **kwargs)

        return

    # handle ending the modifier if needed
    if (
        event.type_ == EventType.BATCH_END
        and not self.ended_
        and self.should_end(event)
    ):
        self.on_end(state, event, **kwargs)
        self.ended_ = True
        self.on_update(state, event, **kwargs)

        return

    if self.started_ and not self.ended_:
        self.on_update(state, event, **kwargs)

NormMapping

Bases: BaseModel

SpinQuant needs to know where every norm layer exists in the model, as well as all the subsequent Linear layers the norm passes into. This is because the norm layer weights need to normalized before transforms can be fused into Linear layers.

Parameters:

  • norm

    name or regex that matches norm layer in model

  • linears

    list of names or regexes of Linear layers that receive input from norm.

SpinQuantMapping

Bases: BaseModel

SpinQuant needs to know the entire architecture of the model, as R1, R2, R3, and R4 rotations need to be applied to specific layers (https://arxiv.org/pdf/2405.16406 Fig. 1).

Parameters:

  • embedding

    name or regex of embedding layer

  • attn_q

    name or regex of q_proj layer in attention block

  • attn_k

    name or regex of k_proj layer in attention block

  • attn_v

    name or regex of v_proj layer in attention block

  • attn_o

    name or regex of o_proj layer in attention block

  • attn_head_dim

    head_dim of the attention module, needed because R2 needs to be applied "head-wisely" to v_proj and o_proj

  • mlp_in

    list of names or regexes for the mlp blocks that receive the input to the MLP block, usually up_proj and gate_proj

  • mlp_out

    list of names or regexes for the mlp blocks that consitute the output of the MLP block, usually down_proj

SpinQuantModifier

Bases: Modifier

Implements the transforms according to "SpinQuant: LLM quantization with learned rotations" (https://arxiv.org/abs/2405.16406)

Transforms (rotations) are extra layers added to a model which reduce the accuracy loss induced by quantization. This is achived through "rotating" weights and activations into a space with a smaller dynamic range of values, thus decreasing the range of scales required for quantization.

The SpinQuant authors describe four different rotations which can be applied to a model. R1 and R2 are "offline" rotations, meaning that they can be fused into existing weights and therefore do not induce runtime cost. R3 and R4 are "online" rotations, meaning that they require additional computation at runtime.

Lifecycle: - on_initialize - infer SpinQuantMappings & NormMappings - as needed, create transform schemes for R1, R2, R3, & R4 - on_start - normalize embeddings - fuse norm layers into subsequent Linear layers - apply TransformConfig - fuse transforms into weights for mergeable transforms - add hooks for online transforms - on sequential epoch end - on_end - on_finalize

Parameters:

  • rotations

    A list containing the names of rotations to apply to the model. Possible rotations include R1, R2, R3, and R4

  • transform_type

    The type of transform to apply to the model. "hadamard" has the least performance cost but only supports sizes which are powers of power of two. "random-matrix" has more performance cost, but supports a much larger set of sizes. "random-matrix" has the greatest performance cost, but supports any size

  • randomize

    if True, create distinct transforms for each application

  • learnable

    if True, attach gradients to transform weights for training

  • precision

    Precision at which all transforms should be applied. This applies to both weight fusing and online rotations

  • mappings

    Specifies layers within a model to target for transforms. A mapping will be inferred if None is provided

  • norm_mappings

    Specifies layers within a model to target for norm fusing. A mapping will be inferred if None is provided

  • transform_config

    Optional transform config for overriding provided arguments

State dataclass

State(
    model: Any = None,
    teacher_model: Any = None,
    optimizer: Any = None,
    optim_wrapped: bool = None,
    loss: Any = None,
    batch_data: Any = None,
    data: Data = Data(),
    hardware: Hardware = Hardware(),
    loggers: Optional[LoggerManager] = None,
    model_log_cadence: Optional[float] = None,
    _last_log_step: Union[float, int, None] = None,
)

State class holds information about the current compression state.

Parameters:

  • model

    (Any, default: None ) –

    The model being used for compression

  • teacher_model

    (Any, default: None ) –

    The teacher model being used for compression

  • optimizer

    (Any, default: None ) –

    The optimizer being used for training

  • optim_wrapped

    (bool, default: None ) –

    Whether or not the optimizer has been wrapped

  • loss

    (Any, default: None ) –

    The loss function being used for training

  • batch_data

    (Any, default: None ) –

    The current batch of data being used for compression

  • data

    (Data, default: Data() ) –

    The data sets being used for training, validation, testing, and/or calibration, wrapped in a Data instance

  • hardware

    (Hardware, default: Hardware() ) –

    Hardware instance holding info about the target hardware being used

  • loggers

    (Optional[LoggerManager], default: None ) –

    LoggerManager instance holding all the loggers to log

  • model_log_cadence

    (Optional[float], default: None ) –

    The cadence to log model information w.r.t epochs. If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.

Methods:

  • update

    Update the state with the given parameters.

Attributes:

  • compression_ready (bool) –

    Check if the model and optimizer are set for compression.

compression_ready property

compression_ready: bool

Check if the model and optimizer are set for compression.

Returns:

  • bool

    True if model and optimizer are set, False otherwise

update

update(
    model: Any = None,
    teacher_model: Any = None,
    optimizer: Any = None,
    attach_optim_callbacks: bool = True,
    train_data: Any = None,
    val_data: Any = None,
    test_data: Any = None,
    calib_data: Any = None,
    copy_data: bool = True,
    start: float = None,
    steps_per_epoch: int = None,
    batches_per_step: int = None,
    loggers: Union[
        None, LoggerManager, List[BaseLogger]
    ] = None,
    model_log_cadence: Optional[float] = None,
    **kwargs
) -> Dict

Update the state with the given parameters.

Parameters:

  • model

    (Any, default: None ) –

    The model to update the state with

  • teacher_model

    (Any, default: None ) –

    The teacher model to update the state with

  • optimizer

    (Any, default: None ) –

    The optimizer to update the state with

  • attach_optim_callbacks

    (bool, default: True ) –

    Whether or not to attach optimizer callbacks

  • train_data

    (Any, default: None ) –

    The training data to update the state with

  • val_data

    (Any, default: None ) –

    The validation data to update the state with

  • test_data

    (Any, default: None ) –

    The testing data to update the state with

  • calib_data

    (Any, default: None ) –

    The calibration data to update the state with

  • copy_data

    (bool, default: True ) –

    Whether or not to copy the data

  • start

    (float, default: None ) –

    The start index to update the state with

  • steps_per_epoch

    (int, default: None ) –

    The steps per epoch to update the state with

  • batches_per_step

    (int, default: None ) –

    The batches per step to update the state with

  • loggers

    (Union[None, LoggerManager, List[BaseLogger]], default: None ) –

    The metrics manager to setup logging important info and milestones to, also accepts a list of BaseLogger(s)

  • model_log_cadence

    (Optional[float], default: None ) –

    The cadence to log model information w.r.t epochs. If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.

  • kwargs

    Additional keyword arguments to update the state with

Returns:

  • Dict

    The updated state as a dictionary

Source code in llmcompressor/core/state.py
def update(
    self,
    model: Any = None,
    teacher_model: Any = None,
    optimizer: Any = None,
    attach_optim_callbacks: bool = True,
    train_data: Any = None,
    val_data: Any = None,
    test_data: Any = None,
    calib_data: Any = None,
    copy_data: bool = True,
    start: float = None,
    steps_per_epoch: int = None,
    batches_per_step: int = None,
    loggers: Union[None, LoggerManager, List[BaseLogger]] = None,
    model_log_cadence: Optional[float] = None,
    **kwargs,
) -> Dict:
    """
    Update the state with the given parameters.

    :param model: The model to update the state with
    :type model: Any
    :param teacher_model: The teacher model to update the state with
    :type teacher_model: Any
    :param optimizer: The optimizer to update the state with
    :type optimizer: Any
    :param attach_optim_callbacks: Whether or not to attach optimizer callbacks
    :type attach_optim_callbacks: bool
    :param train_data: The training data to update the state with
    :type train_data: Any
    :param val_data: The validation data to update the state with
    :type val_data: Any
    :param test_data: The testing data to update the state with
    :type test_data: Any
    :param calib_data: The calibration data to update the state with
    :type calib_data: Any
    :param copy_data: Whether or not to copy the data
    :type copy_data: bool
    :param start: The start index to update the state with
    :type start: float
    :param steps_per_epoch: The steps per epoch to update the state with
    :type steps_per_epoch: int
    :param batches_per_step: The batches per step to update the state with
    :type batches_per_step: int
    :param loggers: The metrics manager to setup logging important info and
        milestones to, also accepts a list of BaseLogger(s)
    :type loggers: Union[None, LoggerManager, List[BaseLogger]]
    :param model_log_cadence: The cadence to log model information w.r.t epochs.
        If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
    :type model_log_cadence: Optional[float]
    :param kwargs: Additional keyword arguments to update the state with
    :return: The updated state as a dictionary
    :rtype: Dict
    """
    logger.debug(
        "Updating state with provided parameters: {}",
        {
            "model": model,
            "teacher_model": teacher_model,
            "optimizer": optimizer,
            "attach_optim_callbacks": attach_optim_callbacks,
            "train_data": train_data,
            "val_data": val_data,
            "test_data": test_data,
            "calib_data": calib_data,
            "copy_data": copy_data,
            "start": start,
            "steps_per_epoch": steps_per_epoch,
            "batches_per_step": batches_per_step,
            "loggers": loggers,
            "model_log_cadence": model_log_cadence,
            "kwargs": kwargs,
        },
    )

    if model is not None:
        self.model = model
    if teacher_model is not None:
        self.teacher_model = teacher_model
    if optimizer is not None:
        self.optim_wrapped = attach_optim_callbacks
        self.optimizer = optimizer
    if train_data is not None:
        self.data.train = train_data if not copy_data else deepcopy(train_data)
    if val_data is not None:
        self.data.val = val_data if not copy_data else deepcopy(val_data)
    if test_data is not None:
        self.data.test = test_data if not copy_data else deepcopy(test_data)
    if calib_data is not None:
        self.data.calib = calib_data if not copy_data else deepcopy(calib_data)

    if "device" in kwargs:
        self.hardware.device = kwargs["device"]

    loggers = loggers or []
    if isinstance(loggers, list):
        loggers = LoggerManager(loggers)
    self.loggers = loggers

    if model_log_cadence is not None:
        self.model_log_cadence = model_log_cadence

    return kwargs

center_embeddings

center_embeddings(embedding: Module)

Shift each embedding to have a mean of zero

Parameters:

  • embedding

    (Module) –

    embedding module containing embeddings to center

Source code in llmcompressor/modeling/fuse.py
def center_embeddings(embedding: torch.nn.Module):
    """
    Shift each embedding to have a mean of zero

    :param embedding: embedding module containing embeddings to center
    """
    if not hasattr(embedding, "weight"):
        raise ValueError(f"Cannot fuse norm of type {type(embedding)}")

    with align_module_device(embedding):
        weight_dtype = embedding.weight.dtype
        weight = embedding.weight.to(PRECISION)
        new_weight = weight - weight.mean(dim=-1, keepdim=True)
        new_weight = new_weight.to(weight_dtype)

    update_offload_parameter(embedding, "weight", new_weight)

fuse_norm_linears

fuse_norm_linears(norm: Module, linears: Iterable[Linear])

Fuse the scaling operation of norm layer into subsequent linear layers. This useful for ensuring transform invariance between norm and linear layers.

Note that unitary transforms (rotation) commute with normalization, but not scaling

Parameters:

  • norm

    (Module) –

    norm layer whose weight will be fused into subsequent linears

  • linears

    (Iterable[Linear]) –

    linear layers which directly follow the norm layer

Source code in llmcompressor/modeling/fuse.py
def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
    """
    Fuse the scaling operation of norm layer into subsequent linear layers.
    This useful for ensuring transform invariance between norm and linear layers.

    Note that unitary transforms (rotation) commute with normalization, but not scaling

    :param norm: norm layer whose weight will be fused into subsequent linears
    :param linears: linear layers which directly follow the norm layer
    """
    if not hasattr(norm, "weight"):
        raise ValueError(f"Cannot fuse norm of type {type(norm)}")

    for linear in linears:
        # NOTE: spinquant does this op in float64
        exec_device = get_execution_device(norm)
        with align_module_device(norm, exec_device), align_module_device(
            linear, exec_device
        ):
            weight_dtype = linear.weight.dtype
            new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
            new_weight = new_weight.to(weight_dtype)

        update_offload_parameter(linear, "weight", new_weight)

    new_norm_weight = torch.ones_like(norm.weight, device="cpu")
    update_offload_parameter(norm, "weight", new_norm_weight)