Skip to content

llmcompressor.observers

Framework for monitoring and analyzing model behavior during compression.

Provides observers for tracking tensor statistics, activation ranges, and model behavior during compression workflows. Includes min-max observers, MSE observers, and helper utilities for quantization and other compression techniques.

Modules:

Classes:

  • MinMaxObserver

    Implements a quantization observer that calculates scale and zero point based on the

  • MovingAverageMSEObserver

    Implements a dynamic quantization observer that sets the scale and

  • Observer

    Base Observer class to be subclassed for specific implementation.

Functions:

MinMaxObserver

MinMaxObserver(
    quantization_args: QuantizationArgs,
    averaging_constant: float = 0.01,
    **kwargs
)

Bases: Observer

Implements a quantization observer that calculates scale and zero point based on the minimum and maximum values of the tensor being observed. If averaging_constant is specified, then the scales are updated using a moving average

Methods:

  • calculate_gparam

    Generate a global scale using the observed min and max.

  • calculate_qparams

    Generate a scale and zero-point using the observed min and max.

  • calculate_updated_min_max

    Updates the observed min and max using a moving average smoothed by the

  • get_qparams_along_dim

    Calculate quantization parameters along the specified dimension

  • reset

    Reset the state of the observer, including min and maximum values

Source code in llmcompressor/observers/min_max.py
def __init__(
    self,
    quantization_args: QuantizationArgs,
    averaging_constant: float = 0.01,
    **kwargs,
):
    super().__init__(quantization_args=quantization_args)

    self.min_val = {}
    self.max_val = {}
    self.averaging_constant = averaging_constant

calculate_gparam

calculate_gparam(observed: Tensor) -> torch.Tensor

Generate a global scale using the observed min and max.

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate quantization parameters for

Returns:

  • Tensor

    updated global scale derived from the observed tensor

Source code in llmcompressor/observers/min_max.py
def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor:
    """
    Generate a global scale using the observed min and max.

    :param observed: observed tensor to calculate quantization parameters for
    :return: updated global scale derived from the observed tensor
    """

    updated_min_val, updated_max_val = self.calculate_updated_min_max(
        observed=observed
    )
    return generate_gparam(
        updated_min_val=updated_min_val, updated_max_val=updated_max_val
    )

calculate_qparams

calculate_qparams(
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[Tensor] = None,
) -> Tuple[torch.FloatTensor, torch.IntTensor]

Generate a scale and zero-point using the observed min and max.

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate quantization parameters for

  • reduce_dims

    (Optional[Tuple[int]], default: None ) –

    optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

  • tensor_id

    (Optional[Any], default: None ) –

    Optional id if different ranges of observed tensors are passed, useful for sharding tensors by group_size

  • global_scale

    (Optional[Tensor], default: None ) –

    optional scale to further scale local quantization scales

Returns:

  • Tuple[FloatTensor, IntTensor]

    tuple of scale and zero point derived from the observed tensor

Source code in llmcompressor/observers/min_max.py
def calculate_qparams(
    self,
    observed: torch.Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
    """
    Generate a scale and zero-point using the observed min and max.

    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :param tensor_id: Optional id if different ranges of observed tensors are
        passed, useful for sharding tensors by group_size
    :param global_scale: optional scale to further scale local quantization scales
    :return: tuple of scale and zero point derived from the observed tensor
    """

    updated_min_val, updated_max_val = self.calculate_updated_min_max(
        observed=observed, tensor_id=tensor_id, reduce_dims=reduce_dims
    )
    return calculate_qparams(
        min_vals=updated_min_val,
        max_vals=updated_max_val,
        quantization_args=self.quantization_args,
        global_scale=global_scale,
    )

calculate_updated_min_max

calculate_updated_min_max(
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
)

Updates the observed min and max using a moving average smoothed by the averaging_constant. Set the averaging_constant to 1.0 to disable averaging.

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate quantization parameters for

  • reduce_dims

    (Optional[Tuple[int]], default: None ) –

    optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

  • tensor_id

    (Optional[Any], default: None ) –

    Optional id if different ranges of observed tensors are passed, useful for sharding tensors by group_size

Returns:

  • updated min and max values

Source code in llmcompressor/observers/min_max.py
def calculate_updated_min_max(
    self,
    observed: torch.Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
):
    """
    Updates the observed min and max using a moving average smoothed by the
    averaging_constant. Set the averaging_constant to 1.0 to disable averaging.

    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :param tensor_id: Optional id if different ranges of observed tensors are
        passed, useful for sharding tensors by group_size
    :return: updated min and max values
    """
    tensor_id = tensor_id or "default"

    if not reduce_dims:
        min_val, max_val = torch.aminmax(observed)
    else:
        min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
        max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

    # early stopping, save some computation and memory
    if self.averaging_constant == 1.0:
        return min_val, max_val

    running_min_val = self.min_val.get(tensor_id, None)
    running_max_val = self.max_val.get(tensor_id, None)

    if running_min_val is None or running_max_val is None:
        updated_min_val = min_val
        updated_max_val = max_val
    else:
        updated_min_val = running_min_val + self.averaging_constant * (
            min_val - running_min_val
        )
        updated_max_val = running_max_val + self.averaging_constant * (
            max_val - running_max_val
        )

    self.min_val[tensor_id] = updated_min_val
    self.max_val[tensor_id] = updated_max_val
    return updated_min_val, updated_max_val

get_qparams_along_dim

get_qparams_along_dim(
    observed: Tensor,
    dim: int,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[Tensor] = None,
)

Calculate quantization parameters along the specified dimension

Source code in llmcompressor/observers/min_max.py
def get_qparams_along_dim(
    self,
    observed: torch.Tensor,
    dim: int,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[torch.Tensor] = None,
):
    """
    Calculate quantization parameters along the specified dimension
    """
    reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
    return self.calculate_qparams(
        observed,
        reduce_dims=reduce_dims,
        tensor_id=tensor_id,
        global_scale=global_scale,
    )

reset

reset()

Reset the state of the observer, including min and maximum values

Source code in llmcompressor/observers/min_max.py
def reset(self):
    """
    Reset the state of the observer, including min and maximum values
    """
    super().reset()
    self.min_val = {}
    self.max_val = {}

MovingAverageMSEObserver

MovingAverageMSEObserver(
    quantization_args: QuantizationArgs,
    maxshrink: float = 0.2,
    patience: int = 5,
    averaging_constant: float = 0.01,
    grid: float = 100.0,
    norm: float = 2.4,
    **kwargs
)

Bases: Observer

Implements a dynamic quantization observer that sets the scale and zero point based on a moving average of the mse-clipped min and max observed values

Methods:

  • calculate_mse_min_max

    Computes the mse-clipped min and max values of the observed tensor by

  • calculate_qparams

    Updates the mse-clipped min and max values of the observed tensor using

  • calculate_updated_min_max

    Updates the mse-clipped min and max values of the observed tensor using

  • reset

    Reset the state of the observer, including min and maximum values

Source code in llmcompressor/observers/mse.py
def __init__(
    self,
    quantization_args: QuantizationArgs,
    maxshrink: float = 0.2,
    patience: int = 5,
    averaging_constant: float = 0.01,
    grid: float = 100.0,
    norm: float = 2.4,
    **kwargs,
):
    super().__init__(quantization_args=quantization_args)

    self.min_val = {}
    self.max_val = {}
    self.maxshrink = maxshrink
    self.patience = patience
    self.averaging_constant = averaging_constant
    self.grid = grid
    self.norm = norm

calculate_mse_min_max

calculate_mse_min_max(
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    global_scale: Optional[Tensor] = None,
)

Computes the mse-clipped min and max values of the observed tensor by optimizing for quantization error

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate quantization parameters for

  • reduce_dims

    (Optional[Tuple[int]], default: None ) –

    optional tuple of dimensions to reduce along, returned values will be shaped (1,) along the reduced dimensions

  • global_scale

    (Optional[Tensor], default: None ) –

    optional scale to further scale local quantization scales

Returns:

  • tuple of min and max values derived from the observed tensor

Source code in llmcompressor/observers/mse.py
def calculate_mse_min_max(
    self,
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    global_scale: Optional[torch.Tensor] = None,
):
    """
    Computes the mse-clipped min and max values of the observed tensor by
    optimizing for quantization error

    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned values will be shaped (1,) along the reduced dimensions
    :param global_scale: optional scale to further scale local quantization scales
    :return: tuple of min and max values derived from the observed tensor
    """
    from compressed_tensors.quantization.lifecycle import fake_quantize

    if not reduce_dims:
        absolute_min_val, absolute_max_val = torch.aminmax(observed)
    else:
        absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
        absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

    best = torch.full_like(
        absolute_min_val, torch.finfo(absolute_min_val.dtype).max
    )
    min_val = torch.ones_like(absolute_min_val)
    max_val = torch.zeros_like(absolute_max_val)

    # Early stopping params
    no_improve_count = 0

    for i in range(int(self.maxshrink * self.grid)):
        p = 1 - i / self.grid
        shrinked_min_val = p * absolute_min_val
        shrinked_max_val = p * absolute_max_val

        candidate_scales, candidate_zero_points = calculate_qparams(
            min_vals=shrinked_min_val,
            max_vals=shrinked_max_val,
            quantization_args=self.quantization_args,
            global_scale=global_scale,
        )
        q = fake_quantize(
            observed,
            candidate_scales,
            candidate_zero_points,
            self.quantization_args,
            global_scale=global_scale,
        )

        q -= observed
        q.abs_()
        q.pow_(self.norm)
        if not reduce_dims:
            err = torch.sum(q)
        else:
            err = torch.sum(q, reduce_dims, keepdims=True)

        tmp = err < best
        if torch.any(tmp):
            best[tmp] = err[tmp]
            min_val[tmp] = shrinked_min_val[tmp]
            max_val[tmp] = shrinked_max_val[tmp]
            no_improve_count = 0
        else:
            no_improve_count += 1
            if no_improve_count >= self.patience:
                break

    return min_val, max_val

calculate_qparams

calculate_qparams(
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]

Updates the mse-clipped min and max values of the observed tensor using a moving average smoothed by the averaging_constant

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate quantization parameters for

  • reduce_dims

    (Optional[Tuple[int]], default: None ) –

    optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

  • tensor_id

    (Optional[Any], default: None ) –

    Optional id if different ranges of observed tensors are passed, useful for sharding tensors by group_size

  • global_scale

    (Optional[Tensor], default: None ) –

    optional scale to further scale local quantization scales

Returns:

  • Tuple[FloatTensor, IntTensor]

    tuple of scale and zero point derived from the observed tensor

Source code in llmcompressor/observers/mse.py
def calculate_qparams(
    self,
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[torch.Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    Updates the mse-clipped min and max values of the observed tensor using
    a moving average smoothed by the averaging_constant

    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :param tensor_id: Optional id if different ranges of observed tensors are
        passed, useful for sharding tensors by group_size
    :param global_scale: optional scale to further scale local quantization scales
    :return: tuple of scale and zero point derived from the observed tensor
    """
    updated_min_val, updated_max_val = self.calculate_updated_min_max(
        observed=observed,
        tensor_id=tensor_id,
        reduce_dims=reduce_dims,
        global_scale=global_scale,
    )
    scale, zero_point = calculate_qparams(
        min_vals=updated_min_val,
        max_vals=updated_max_val,
        quantization_args=self.quantization_args,
        global_scale=global_scale,
    )
    return scale, zero_point

calculate_updated_min_max

calculate_updated_min_max(
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]

Updates the mse-clipped min and max values of the observed tensor using a moving average smoothed by the averaging_constant

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate quantization parameters for

  • reduce_dims

    (Optional[Tuple[int]], default: None ) –

    optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

  • tensor_id

    (Optional[Any], default: None ) –

    Optional id if different ranges of observed tensors are passed, useful for sharding tensors by group_size

  • global_scale

    (Optional[Tensor], default: None ) –

    optional scale to further scale local quantization scales

Returns:

  • Tuple[FloatTensor, IntTensor]

    updated min and max values derived from the observed value

Source code in llmcompressor/observers/mse.py
def calculate_updated_min_max(
    self,
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[torch.Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    Updates the mse-clipped min and max values of the observed tensor using
    a moving average smoothed by the averaging_constant

    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :param tensor_id: Optional id if different ranges of observed tensors are
        passed, useful for sharding tensors by group_size
    :param global_scale: optional scale to further scale local quantization scales
    :return: updated min and max values derived from the observed value
    """
    # TODO: will need to be expanded to support fp4 activations;
    # currently not supported
    min_val, max_val = self.calculate_mse_min_max(
        observed, reduce_dims, global_scale=global_scale
    )

    running_min_val = self.min_val.get(tensor_id, None)
    running_max_val = self.max_val.get(tensor_id, None)

    if running_min_val is None or running_max_val is None:
        updated_min_val = min_val
        updated_max_val = max_val
    else:
        updated_min_val = running_min_val + self.averaging_constant * (
            min_val - running_min_val
        )
        updated_max_val = running_max_val + self.averaging_constant * (
            max_val - running_max_val
        )

    tensor_id = tensor_id or "default"
    self.min_val[tensor_id] = updated_min_val
    self.max_val[tensor_id] = updated_max_val
    return updated_min_val, updated_max_val

reset

reset()

Reset the state of the observer, including min and maximum values

Source code in llmcompressor/observers/mse.py
def reset(self):
    """
    Reset the state of the observer, including min and maximum values
    """
    super().reset()
    self.min_val = {}
    self.max_val = {}

Observer

Observer(quantization_args: QuantizationArgs)

Bases: InternalModule, RegistryMixin

Base Observer class to be subclassed for specific implementation. Subclasses should override calculate_qparams to return a scale, zero_point pair

Methods:

  • calculate_gparam

    :param observed: observed tensor to calculate quantization parameters for

  • calculate_qparams

    :param observed: observed tensor to calculate quantization parameters for

  • forward

    maps directly to get_qparams

  • get_gparam

    Function to derive a global scale parameter

  • get_qparams

    Convenience function to wrap overwritten calculate_qparams

  • post_calculate_qparams

    Run any logic specific to its observers after running calculate_qparams

  • record_observed_tokens

    Counts the number of tokens observed during the

  • reset

    Reset the state of the observer

Source code in llmcompressor/observers/base.py
def __init__(
    self,
    quantization_args: QuantizationArgs,
):
    self.quantization_args: QuantizationArgs = quantization_args
    super().__init__()
    self._scale = None
    self._zero_point = None
    self._num_observed_tokens = None

calculate_gparam

calculate_gparam(observed: Tensor) -> torch.Tensor

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate quantization parameters for

Returns:

  • Tensor

    global scale derived from the observed tensor

Source code in llmcompressor/observers/base.py
def calculate_gparam(
    self,
    observed: Tensor,
) -> torch.Tensor:
    """
    :param observed: observed tensor to calculate quantization parameters for
    :return: global scale derived from the observed tensor
    """
    raise NotImplementedError(f"{self.__class__} must implement calculate_gparam")

calculate_qparams

calculate_qparams(
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate quantization parameters for

  • reduce_dims

    (Optional[Tuple[int]], default: None ) –

    optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

  • tensor_id

    (Optional[Any], default: None ) –

    optional id for tracking separate statistics when different ranges of observed tensors are passed, useful for sharding tensors by group_size or block quantization

  • global_scale

    (Optional[Tensor], default: None ) –

    optional scale to further scale local quantization scales

Returns:

  • Tuple[FloatTensor, IntTensor]

    tuple of scale and zero point derived from the observed tensor

Source code in llmcompressor/observers/base.py
def calculate_qparams(
    self,
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
    global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :param tensor_id: optional id for tracking separate statistics when different
        ranges of observed tensors are passed, useful for sharding tensors by
        group_size or block quantization
    :param global_scale: optional scale to further scale local quantization scales
    :return: tuple of scale and zero point derived from the observed tensor
    """
    raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

forward

forward(
    observed: Tensor,
    g_idx: Optional[Tensor] = None,
    global_scale: Optional[Tensor] = None,
    should_calculate_gparam: bool = False,
) -> Tuple[FloatTensor, IntTensor]

maps directly to get_qparams

Parameters:

  • observed

    (Tensor) –

    optional observed tensor from which to calculate quantization parameters

  • g_idx

    (Optional[Tensor], default: None ) –

    optional mapping from column index to group index

  • global_scale

    (Optional[Tensor], default: None ) –

    optional scale to further scale local quantization scales

Returns:

  • Tuple[FloatTensor, IntTensor]

    tuple of scale and zero point based on last observed value

Source code in llmcompressor/observers/base.py
@torch.no_grad()
def forward(
    self,
    observed: Tensor,
    g_idx: Optional[Tensor] = None,
    global_scale: Optional[Tensor] = None,
    should_calculate_gparam: bool = False,
) -> Tuple[FloatTensor, IntTensor]:
    """
    maps directly to get_qparams
    :param observed: optional observed tensor from which to calculate
        quantization parameters
    :param g_idx: optional mapping from column index to group index
    :param global_scale: optional scale to further scale local quantization scales
    :return: tuple of scale and zero point based on last observed value
    """
    self.record_observed_tokens(observed)
    if should_calculate_gparam:
        return self.get_gparam(observed=observed)
    return self.get_qparams(
        observed=observed,
        g_idx=g_idx,
        global_scale=global_scale,
    )

get_gparam

get_gparam(observed: Tensor)

Function to derive a global scale parameter

Parameters:

  • observed

    (Tensor) –

    observed tensor to calculate global parameters from

Returns:

  • derived global scale

Source code in llmcompressor/observers/base.py
def get_gparam(self, observed: Tensor):
    """
    Function to derive a global scale parameter
    :param observed: observed tensor to calculate global parameters
        from
    :return: derived global scale
    """
    if self.quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
        return self.calculate_gparam(observed)
    raise NotImplementedError(
        "global parameter generation is only supported for TENSOR_GROUP"
    )

get_qparams

get_qparams(
    observed: Optional[Tensor] = None,
    g_idx: Optional[Tensor] = None,
    global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]

Convenience function to wrap overwritten calculate_qparams adds support to make observed tensor optional and support for tracking latest calculated scale and zero point

Parameters:

  • observed

    (Optional[Tensor], default: None ) –

    optional observed tensor to calculate quantization parameters from

  • g_idx

    (Optional[Tensor], default: None ) –

    optional mapping from column index to group index

  • global_scale

    (Optional[Tensor], default: None ) –

    optional scale to further scale local quantization scales

Returns:

  • Tuple[FloatTensor, IntTensor]

    tuple of scale and zero point based on last observed value

Source code in llmcompressor/observers/base.py
def get_qparams(
    self,
    observed: Optional[Tensor] = None,
    g_idx: Optional[Tensor] = None,
    global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    Convenience function to wrap overwritten calculate_qparams
    adds support to make observed tensor optional and support for tracking latest
    calculated scale and zero point

    :param observed: optional observed tensor to calculate quantization parameters
        from
    :param g_idx: optional mapping from column index to group index
    :param global_scale: optional scale to further scale local quantization scales
    :return: tuple of scale and zero point based on last observed value
    """
    if observed is not None:
        group_size = self.quantization_args.group_size

        if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
            # re-calculate scale and zero point, update the stored value
            self._scale, self._zero_point = self.calculate_qparams(observed)

        elif self.quantization_args.strategy in (
            QuantizationStrategy.TENSOR_GROUP,
            QuantizationStrategy.GROUP,
        ):
            rows = observed.shape[0]
            columns = observed.shape[1]
            num_groups = int(ceil(columns / group_size))
            if num_groups * group_size != columns:
                logger.bind(log_once=True).warning(
                    "Attempting to quantize a module weight whose columns "
                    f"({columns}) are not divisible by group_size ({group_size}). "
                    "This scheme is not supported by vLLM, please consider "
                    "adjusting the group_size for modules with this number of "
                    "columns",
                )

            self._scale = torch.empty(
                (rows, num_groups), dtype=observed.dtype, device=observed.device
            )
            if is_fp4(quantization_args=self.quantization_args):
                zp_dtype = FP8_E4M3_DATA.dtype
            else:
                zp_dtype = self.quantization_args.pytorch_dtype()

            self._zero_point = torch.empty(
                (rows, num_groups), dtype=zp_dtype, device=observed.device
            )

            # support column-order (default) quantization as well as other orderings
            # such as activation ordering. Below checks if g_idx has initialized
            is_column_order = g_idx is None or -1 in g_idx
            if is_column_order:
                group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
            else:
                group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
                group_sizes = group_sizes[torch.argsort(group_indices)]

                perm = torch.argsort(g_idx)
                observed = safe_permute(observed, perm, dim=1)

            # TODO: experiment with vectorizing for loop for performance
            end = 0
            for group_index, group_count in enumerate(group_sizes):
                start = end
                end = start + group_count
                scale, zero_point = self.get_qparams_along_dim(
                    observed[:, start:end],
                    0,
                    tensor_id=group_index,
                    global_scale=global_scale,
                )

                self._scale[:, group_index] = scale.squeeze(1)
                self._zero_point[:, group_index] = zero_point.squeeze(1)

        elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
            # assume observed is transposed, because its the output, hence use dim 0
            self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

        elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
            # use dim 1, assume the obsersed.shape = [batch, token, hidden]
            # should be batch, token
            self._scale, self._zero_point = self.get_qparams_along_dim(
                observed,
                dim={0, 1},
            )

        elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
            # Block-wise quantization: one scale/zero_point per block of shape
            # [block_rows, block_cols]
            rows, cols = observed.shape[:2]
            bs = self.quantization_args.block_structure
            if not (
                isinstance(bs, (list, tuple))
                and len(bs) == 2
                and all(isinstance(x, int) for x in bs)
            ):
                raise ValueError(
                    f"Invalid block_structure '{bs}'. "
                    f"Must be a list of two ints [rows, cols]."
                )
            block_rows, block_cols = bs
            num_br = int(ceil(rows / block_rows))
            num_bc = int(ceil(cols / block_cols))

            # allocate per-block scale and zero_point
            self._scale = torch.empty(
                (num_br, num_bc), dtype=observed.dtype, device=observed.device
            )

            # Use same dtype logic as GROUP strategy for zero_point
            if is_fp4(quantization_args=self.quantization_args):
                zp_dtype = FP8_E4M3_DATA.dtype
            else:
                zp_dtype = self.quantization_args.pytorch_dtype()

            self._zero_point = torch.empty(
                (num_br, num_bc), dtype=zp_dtype, device=observed.device
            )

            # compute qparams for each block
            for i in range(num_br):
                r0 = i * block_rows
                r1 = min((i + 1) * block_rows, rows)
                for j in range(num_bc):
                    c0 = j * block_cols
                    c1 = min((j + 1) * block_cols, cols)
                    # reduce across both dims to get one scale and zp per block
                    # Use unique tensor_id for each block to maintain separate stats
                    block_tensor_id = f"block_{i}_{j}"
                    scale_bp, zp_bp = self.calculate_qparams(
                        observed[r0:r1, c0:c1],
                        reduce_dims=(0, 1),
                        tensor_id=block_tensor_id,
                    )
                    self._scale[i, j] = scale_bp
                    self._zero_point[i, j] = zp_bp

    return self._scale, self._zero_point

post_calculate_qparams

post_calculate_qparams() -> None

Run any logic specific to its observers after running calculate_qparams

Source code in llmcompressor/observers/base.py
def post_calculate_qparams(self) -> None:
    """
    Run any logic specific to its observers after running calculate_qparams
    """

record_observed_tokens

record_observed_tokens(batch_tensor: Tensor)

Counts the number of tokens observed during the forward passes. The count is aggregated in the _num_observed_tokens attribute of the class.

Note: The batch_tensor is expected to have two dimensions (batch_size * sequence_length, num_features). This is the general shape expected by the forward pass of the expert layers in a MOE model. If the input tensor does not have two dimensions, the _num_observed_tokens attribute will be set to None.

Source code in llmcompressor/observers/base.py
def record_observed_tokens(self, batch_tensor: Tensor):
    """
    Counts the number of tokens observed during the
    forward passes. The count is aggregated in the
    _num_observed_tokens attribute of the class.

    Note: The batch_tensor is expected to have two dimensions
        (batch_size * sequence_length, num_features). This is the
        general shape expected by the forward pass of the expert
        layers in a MOE model. If the input tensor does not have
        two dimensions, the _num_observed_tokens attribute will be set
        to None.
    """
    if not isinstance(batch_tensor, Tensor):
        raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")

    if batch_tensor.ndim != 2:
        logger.debug(
            "The input tensor is expected to have two dimensions "
            "(batch_size * sequence_length, num_features). "
            f"The input tensor has {batch_tensor.ndim} dimensions."
        )
        return

    if self._num_observed_tokens is None:
        # initialize the count
        self._num_observed_tokens = 0

    # batch_tensor (batch_size * sequence_length, num_features)
    # observed_tokens (batch_size * sequence_length)
    observed_tokens, _ = batch_tensor.shape
    self._num_observed_tokens += observed_tokens

reset

reset()

Reset the state of the observer

Source code in llmcompressor/observers/base.py
def reset(self):
    """
    Reset the state of the observer
    """
    self._num_observed_tokens = None
    self._scale = None
    self._zero_point = None

get_observer_token_count

get_observer_token_count(module: Module) -> Counter

Parse the module and return the number of tokens observed by each module's observer.

Parameters:

  • module

    (Module) –

    module to parse

Returns:

  • Counter

    counter with the number of tokens observed by each observer

Source code in llmcompressor/observers/helpers.py
def get_observer_token_count(module: torch.nn.Module) -> Counter:
    """
    Parse the module and return the number of tokens observed by
    each module's observer.

    :param module: module to parse
    :return: counter with the number of tokens observed by each observer
    """
    token_counts = Counter()
    for name, module in module.named_modules():
        if name.endswith(".input_observer"):
            token_counts[name.replace(".input_observer", "")] = (
                module._num_observed_tokens
            )
    return token_counts