Skip to content

llmcompressor.modifiers.pruning.wanda

Modules:

Classes:

WandaPruningModifier

Bases: SparsityModifierBase

Modifier for applying the one-shot WANDA algorithm to a model from the paper: https://arxiv.org/abs/2306.11695

| Sample yaml: | test_stage: | sparsity_modifiers: | WandaPruningModifier: | sparsity: 0.5 | mask_structure: "2:4"

Lifecycle: - on_initialize - register_hook(module, calibrate_module, "forward") - run_sequential / run_layer_sequential / run_basic - make_empty_row_scalars - accumulate_row_scalars - on_sequential_batch_end - sparsify_weight - on_finalize - remove_hooks()

Parameters:

  • sparsity

    Sparsity to compress model to

  • sparsity_profile

    Can be set to 'owl' to use Outlier Weighed Layerwise Sparsity (OWL), more information can be found in the paper https://arxiv.org/pdf/2310.05175

  • mask_structure

    String to define the structure of the mask to apply. Must be of the form N:M where N, M are integers that define a custom block shape. Defaults to 0:0 which represents an unstructured mask.

  • owl_m

    Number of outliers to use for OWL

  • owl_lmbda

    Lambda value to use for OWL

  • sequential_targets

    list of layer names to compress during OBCQ, or 'ALL' to compress every layer in the model. Alias for targets

  • targets

    list of layer names to compress during OBCQ, or 'ALL' to compress every layer in the model. Alias for sequential_targets

  • ignore

    optional list of module class names or submodule names to not quantize even if they match a target. Defaults to empty list.

Methods:

  • calibrate_module

    Calibration hook used to accumulate the row scalars of the input to the module

  • compress_modules

    Sparsify modules which have been calibrated

calibrate_module

calibrate_module(
    module: Module,
    args: Tuple[Tensor, ...],
    _output: Tensor,
)

Calibration hook used to accumulate the row scalars of the input to the module

Parameters:

  • module

    (Module) –

    module being calibrated

  • args

    (Tuple[Tensor, ...]) –

    inputs to the module, the first element of which is the cannonical input

  • _output

    (Tensor) –

    uncompressed module output, unused

Source code in llmcompressor/modifiers/pruning/wanda/base.py
def calibrate_module(
    self,
    module: torch.nn.Module,
    args: Tuple[torch.Tensor, ...],
    _output: torch.Tensor,
):
    """
    Calibration hook used to accumulate the row scalars of the input to the module

    :param module: module being calibrated
    :param args: inputs to the module, the first element of which is the
        cannonical input
    :param _output: uncompressed module output, unused
    """
    # Assume that the first argument is the input
    inp = args[0]

    # Initialize row scalars if not present
    if module not in self._num_samples:
        device = get_execution_device(module)
        self._row_scalars[module] = make_empty_row_scalars(module, device=device)
        self._num_samples[module] = 0

    # Accumulate scalars using data
    self._row_scalars[module], self._num_samples[module] = accumulate_row_scalars(
        inp,
        module,
        self._row_scalars[module],
        self._num_samples[module],
    )

compress_modules

compress_modules()

Sparsify modules which have been calibrated

Source code in llmcompressor/modifiers/pruning/wanda/base.py
def compress_modules(self):
    """
    Sparsify modules which have been calibrated
    """
    for module in list(self._num_samples.keys()):
        name = self._module_names[module]
        sparsity = self._module_sparsities[module]
        num_samples = self._num_samples[module]

        logger.info(f"Sparsifying {name} using {num_samples} samples")
        with torch.no_grad(), align_module_device(module), CompressionLogger(
            module
        ):
            sparsified_weight = sparsify_weight(
                module=module,
                row_scalars_dict=self._row_scalars,
                sparsity=sparsity,
                prune_n=self._prune_n,
                prune_m=self._prune_m,
            )

        update_offload_parameter(module, "weight", sparsified_weight)

        # self._row_scalars[module] already deleted by sparsify_weight
        del self._num_samples[module]