Skip to content

llmcompressor.modifiers.quantization.gptq

Modules:

Classes:

GPTQModifier

Bases: Modifier, QuantizationMixin

Implements the GPTQ algorithm from https://arxiv.org/abs/2210.17323. This modifier uses activations to calibrate a hessian matrix, which is then used to determine optimal quantizion values and orderings for the model weights.

| Sample yaml: | test_stage: | obcq_modifiers: | GPTQModifier: | block_size: 128 | dampening_frac: 0.001 | offload_hessians: False | actorder: static | config_groups: | group_0: | targets: | - "Linear" | input_activations: null | output_activations: null | weights: | num_bits: 8 | type: "int" | symmetric: true | strategy: group | group_size: 128

Lifecycle: - on_initialize - apply config to model - on_start - add activation calibration hooks - add gptq weight calibration hooks - on_sequential_epoch_end - quantize_weight - on_finalize - remove_hooks() - model.apply(freeze_module_quantization)

Parameters:

  • sequential_targets

    list of layer names to compress during GPTQ, or 'ALL' to compress every layer in the model

  • block_size

    Used to determine number of columns to compress in one pass

  • dampening_frac

    Amount of dampening to apply to H, as a fraction of the diagonal norm

  • actorder

    order in which weight columns are quantized. For more information, on actorder options, see https://github.com/vllm-project/vllm/pull/8135

  • offload_hessians

    Set to True for decreased memory usage but increased runtime.

  • config_groups

    dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized.

  • targets

    list of layer names to quantize if a scheme is provided. Defaults to Linear layers

  • ignore

    optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list.

  • scheme

    a single quantization scheme to apply to the model. This is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level. Can also be set to a dictionary of the format preset_scheme_name: targets for example: W8A8: ['Linear'] for weight and activation 8-bit.

  • kv_cache_scheme

    optional QuantizationArgs, that specify the quantization of the kv cache. If None, kv cache is not quantized. When applying kv cache quantization to transformer AutoModelForCausalLM, the kv_cache_scheme gets converted into a QuantizationScheme that: - targets the q_proj and k_proj modules of the model. The outputs of those modules are the keys and values that might be cached - quantizes the outputs of the aformentioned layers, so that keys and values are compressed before storing them in the cache There is an explicit assumption that the model contains modules with k_proj and v_proj in their names. If this is not the case and kv_cache_scheme != None, the quantization of kv cache will fail

Methods:

  • calibrate_module

    Calibration hook used to accumulate the hessian of the input to the module

  • compress_modules

    Quantize modules which have been calibrated

  • on_end

    Finish calibrating by removing observers and calibration hooks

  • on_finalize

    disable the quantization observers used by the OBCQ algorithm

  • on_initialize

    Initialize and run the GPTQ algorithm on the current state

calibrate_module

calibrate_module(
    module: Module,
    args: Tuple[Tensor, ...],
    _output: Tensor,
)

Calibration hook used to accumulate the hessian of the input to the module

Parameters:

  • module

    (Module) –

    module being calibrated

  • args

    (Tuple[Tensor, ...]) –

    inputs to the module, the first element of which is the cannonical input

  • _output

    (Tensor) –

    uncompressed module output, unused

Source code in llmcompressor/modifiers/quantization/gptq/base.py
def calibrate_module(
    self,
    module: torch.nn.Module,
    args: Tuple[torch.Tensor, ...],
    _output: torch.Tensor,
):
    """
    Calibration hook used to accumulate the hessian of the input to the module

    :param module: module being calibrated
    :param args: inputs to the module, the first element of which is the
        cannonical input
    :param _output: uncompressed module output, unused
    """
    # Assume that first argument is the input
    inp = args[0]

    # Initialize hessian if not present
    if module not in self._num_samples:
        init_device = (
            "cpu" if self.offload_hessians else get_execution_device(module)
        )
        self._hessians[module] = make_empty_hessian(module, device=init_device)
        self._num_samples[module] = 0

    # Accumulate hessian with input with optional offloading
    with self._maybe_onload_hessian(module):
        self._hessians[module], self._num_samples[module] = accumulate_hessian(
            inp,
            module,
            self._hessians[module],
            self._num_samples[module],
        )

compress_modules

compress_modules()

Quantize modules which have been calibrated

Source code in llmcompressor/modifiers/quantization/gptq/base.py
def compress_modules(self):
    """
    Quantize modules which have been calibrated
    """
    for module in list(self._num_samples.keys()):
        name = self._module_names[module]
        num_samples = self._num_samples[module]
        quant_args = getattr_chain(module, "quantization_scheme.weights")

        logger.info(f"Quantizing {name} using {num_samples} samples")
        with torch.no_grad(), align_module_device(
            module
        ), self._maybe_onload_hessian(module), CompressionLogger(
            module
        ) as comp_logger:
            loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
                module=module,
                quant_args=quant_args,
                hessians_dict=self._hessians,
                blocksize=self.block_size,
                percdamp=self.dampening_frac,
            )
            comp_logger.set_loss(loss)

        update_offload_parameter(module, "weight", quantized_weight)
        update_offload_parameter(module, "weight_scale", scale)
        update_offload_parameter(module, "weight_zero_point", zero_point)
        if g_idx is not None:
            update_offload_parameter(module, "weight_g_idx", g_idx)

        # self._hessians[module] already deleted by quantize_weight
        del self._num_samples[module]

on_end

on_end(state: State, event: Event, **kwargs)

Finish calibrating by removing observers and calibration hooks

Source code in llmcompressor/modifiers/quantization/gptq/base.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    Finish calibrating by removing observers and calibration hooks
    """
    self.ended_ = True
    QuantizationMixin.end_calibration(self, state.model)
    self.remove_hooks()  # remove gptq hooks

on_finalize

on_finalize(state: State, **kwargs) -> bool

disable the quantization observers used by the OBCQ algorithm

Parameters:

  • state

    (State) –

    session state storing input model and calibration data

Source code in llmcompressor/modifiers/quantization/gptq/base.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    disable the quantization observers used by the OBCQ algorithm

    :param state: session state storing input model and calibration data
    """
    if not self.ended_:
        self.on_end(state, None)

    if len(self._num_samples) > 0:
        raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

    self._hessians = dict()
    self._num_samples = dict()

    return True

on_initialize

on_initialize(state: State, **kwargs) -> bool

Initialize and run the GPTQ algorithm on the current state

Parameters:

  • state

    (State) –

    session state storing input model and calibration data

Source code in llmcompressor/modifiers/quantization/gptq/base.py
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    Initialize and run the GPTQ algorithm on the current state

    :param state: session state storing input model and calibration data
    """
    # apply config to model and prepare calibration hooks
    if QuantizationMixin.has_config(self):
        QuantizationMixin.initialize_quantization(self, state.model)

    # prepare module names
    self._module_names = {m: name for name, m in state.model.named_modules()}

    return True