Skip to content

llmcompressor.modifiers.quantization.cache

Quantized key-value cache implementation for efficient inference.

Provides quantized KV cache classes extending HuggingFace's DynamicCache with quantization support. Enables memory-efficient attention mechanisms by quantizing cached key and value tensors during model inference with configurable quantization strategies.

Classes:

QuantizedKVParameterCache

QuantizedKVParameterCache(
    quantization_args: QuantizationArgs,
)

Bases: DynamicCache

Quantized KV cache used in the forward call based on HF's dynamic cache. Quantization strategy (tensor, group, channel) set from Quantization arg's strategy Singleton, so that the same cache gets reused in all forward call of self_attn. Each time forward is called, .update() is called, and ._quantize(), ._dequantize() gets called appropriately. The size of tensor is [batch_size, num_heads, seq_len - residual_length, head_dim].

Triggered by adding kv_cache_scheme in the recipe.

Example:

```python3 recipe = ''' quant_stage: quant_modifiers: QuantizationModifier: kv_cache_scheme: num_bits: 8 type: float strategy: tensor dynamic: false symmetric: true '''

Methods:

  • get_seq_length

    Returns the sequence length of the cached states.

  • reset

    Reset the instantiation, create new instance on init

  • reset_states

    reset the kv states (used in calibration)

  • update

    Get the k_scale and v_scale and output the

Source code in llmcompressor/modifiers/quantization/cache.py
def __init__(self, quantization_args: QuantizationArgs):
    if not self._initialized:
        super().__init__()

        self.quantization_args = quantization_args

        self.k_observers: List[Observer] = []
        self.v_observers: List[Observer] = []

        # each index corresponds to layer_idx of the attention layer
        self.k_scales: List[Tensor] = []
        self.v_scales: List[Tensor] = []

        self.k_zps: List[Tensor] = []
        self.v_zps: List[Tensor] = []

        self._initialized = True

get_seq_length

get_seq_length(layer_idx: Optional[int] = 0) -> int

Returns the sequence length of the cached states. A layer index can be optionally passed.

Source code in llmcompressor/modifiers/quantization/cache.py
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
    """
    Returns the sequence length of the cached states.
    A layer index can be optionally passed.
    """
    if len(self.key_cache) <= layer_idx:
        return 0
    # since we cannot get the seq_length of each layer directly and
    # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
    # this is a hack to get the actual seq_length for the given layer_idx
    # this part of code otherwise fails when used to
    # verify attn_weight shape in some models
    return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

reset

reset()

Reset the instantiation, create new instance on init

Source code in llmcompressor/modifiers/quantization/cache.py
def reset(self):
    """
    Reset the instantiation, create new instance on init
    """
    QuantizedKVParameterCache._instance = None
    QuantizedKVParameterCache._initialized = False

reset_states

reset_states()

reset the kv states (used in calibration)

Source code in llmcompressor/modifiers/quantization/cache.py
def reset_states(self):
    """reset the kv states (used in calibration)"""
    self.key_cache: List[Tensor] = []
    self.value_cache: List[Tensor] = []
    # Used in `generate` to keep tally of how many tokens the cache has seen
    self._seen_tokens = 0
    self._quantized_key_cache: List[Tensor] = []
    self._quantized_value_cache: List[Tensor] = []

update

update(
    key_states: Tensor,
    value_states: Tensor,
    layer_idx: int,
    cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Tensor, Tensor]

Get the k_scale and v_scale and output the fakequant-ed key_states and value_states

Source code in llmcompressor/modifiers/quantization/cache.py
def update(
    self,
    key_states: Tensor,
    value_states: Tensor,
    layer_idx: int,
    cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Get the k_scale and v_scale and output the
     fakequant-ed key_states and value_states
    """

    if len(self.k_observers) <= layer_idx:
        k_observer_name = self.quantization_args.observer
        k_observer = Observer.load_from_registry(
            k_observer_name, quantization_args=self.quantization_args
        )
        v_observer_name = self.quantization_args.observer
        v_observer = Observer.load_from_registry(
            v_observer_name, quantization_args=self.quantization_args
        )

        # NOTE: User may ignore some layers in configuration,
        # meaning len(self.k_observers) <= layer_idx-1
        # Must account for that case by padding list so that
        # index of lists corresponds to layer_idx
        _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
        _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)

    q_key_states = self._quantize(
        key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
    )
    q_value_states = self._quantize(
        value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
    )

    qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
    qdq_value_states = self._dequantize(
        q_value_states, KVCacheScaleType.VALUE, layer_idx
    )

    keys_to_return, values_to_return = qdq_key_states, qdq_value_states

    return keys_to_return, values_to_return