Skip to content

llmcompressor.transformers

Tools for integrating LLM Compressor with transformers training flows.

Modules:

  • compression
  • finetune
  • sparsification

    Objects, classes, and methods for applying sparsification algorithms to

  • tracing
  • utils

    Utilities for applying sparsification algorithms to Hugging Face transformers flows

Classes:

  • SessionManagerMixIn

    Mix-In class to extend the Hugging Face Trainer class to support LLM Compressor

  • TextGenerationDataset

    Base class for text datasets. Applies the following transformations to a dataset

Functions:

SessionManagerMixIn

SessionManagerMixIn(
    recipe: str,
    model_args: ModelArguments,
    dataset_args: Optional[DatasetArguments] = None,
    teacher: Optional[Union[Module, str]] = None,
    recipe_args: Optional[
        Union[Dict[str, Any], str]
    ] = None,
    **kwargs
)

Mix-In class to extend the Hugging Face Trainer class to support LLM Compressor recipes for one-shot and finetuning flows.

Parameters:

  • recipe

    (str) –

    path to recipe file to apply during training

  • recipe_args

    (Optional[Union[Dict[str, Any], str]], default: None ) –

    additional kwargs to use for evaluating recipe

  • dataset_args

    (Optional[DatasetArguments], default: None ) –

    kwargs for configuring dataset loading

  • teacher

    (Optional[Union[Module, str]], default: None ) –

    optional teacher model to use for distillation

Methods:

  • compute_loss

    Override for the compute_loss to factor trigger callbacks and filter columns

  • create_optimizer

    Override the optimizer to apply and update the recipe while training.

  • create_scheduler

    Create an LR scheduler to work with the applied recipes. This is a placeholder

  • finalize_session

    Wrap up training by finalizing all modifiers initialized in the current session

  • initialize_session

    Initialize the CompressionSession from the specified epoch, evaluates the recipe

  • log_model_sparsification

    Log the current model sparsification info including pruned and quantized states

  • maybe_log_model_sparsification

    Log info on model sparsity and quantization if possible. Only print logs on the

  • save_model

    Override of the save_model function and expects it to exist in the parent.

  • train

    Run a sparsification training cycle. Runs initialization for the sparse session

  • training_step

    Overrides the Trainer's training step to trigger the batch_start callback to

Source code in llmcompressor/transformers/finetune/session_mixin.py
def __init__(
    self,
    recipe: str,
    model_args: "ModelArguments",
    dataset_args: Optional["DatasetArguments"] = None,
    teacher: Optional[Union[Module, str]] = None,
    recipe_args: Optional[Union[Dict[str, Any], str]] = None,
    **kwargs,
):
    self.recipe = recipe
    self.recipe_args = recipe_args
    self.model_args = model_args
    self.teacher = teacher

    # parse training and metadata args
    training_args = kwargs.get("args")

    self.metadata = None
    if training_args is not None:
        # trl_sft_trainer pathway. Both training_args and dataset_args
        # have `max_seq_length` which causes collision error. This is the
        # only shared parameter, where training arg is `TRLSFTConfig` that
        # inherits HuggingFace's `TrainingArguments`
        training_args_dict = training_args.to_dict()
        if "max_seq_length" in training_args_dict:
            training_args_dict["training_args_max_seq_length"] = (
                training_args_dict.pop("max_seq_length")
            )
            logger.warning(
                "Detected `max_seq_length` in both dataset_args ",
                "and training_args. This is expected for TRL in distillation. ",
                "Updating metadata to `training_args_max_seq_length`",
            )

        self.metadata = self._extract_metadata(
            metadata_args=METADATA_ARGS,
            training_args_dict=training_args_dict,
            dataset_args_dict=asdict(dataset_args) if dataset_args else {},
        )

    # setup metrics and session
    self.logger_manager = LoggerManager(log_python=False)
    create_session()

    # call Trainer initialization
    super().__init__(**kwargs)
    self.accelerator.wait_for_everyone()

    # setup callbacks and loss
    self.optim_callbacks = TrainingLoopCallbacks(self)
    self.callback_handler.add_callback(self.optim_callbacks)
    self.callback_disable_fp16 = DisableHalfPrecisionCallback(self)
    self.callback_handler.add_callback(self.callback_disable_fp16)
    self.criterion = torch.nn.CrossEntropyLoss()

    model_signature = inspect.signature(self.model.forward)
    self._signature_columns = list(model_signature.parameters.keys())

    if self.teacher is not None and teacher not in ("disable", "self"):
        teacher_signature = inspect.signature(self.teacher.forward)
        self._teacher_signature_columns = list(teacher_signature.parameters.keys())
    else:
        self._teacher_signature_columns = None

    if self.is_fsdp_enabled:
        self._prepare_model_for_fsdp()

    if dataset_args is not None:
        self.min_tokens_per_module = dataset_args.min_tokens_per_module

compute_loss

compute_loss(
    model: Module,
    inputs: Dict[str, Any],
    return_outputs: bool = False,
    num_items_in_batch: Optional[Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]

Override for the compute_loss to factor trigger callbacks and filter columns

Parameters:

  • model

    (Module) –

    the model to compute the loss for

  • inputs

    (Dict[str, Any]) –

    the inputs to pass through the model for calculating the loss

  • return_outputs

    (bool, default: False ) –

    True to return the outputs with the loss, False otherwise

  • num_items_in_batch

    (Optional[Tensor], default: None ) –

    the number of items which contribute to loss

Returns:

  • Union[Tensor, Tuple[Tensor, Any]]

    the resulting loss if not return_outputs, otherwise a tuple containing the loss and the model's outputs

Source code in llmcompressor/transformers/finetune/session_mixin.py
def compute_loss(
    self,
    model: Module,
    inputs: Dict[str, Any],
    return_outputs: bool = False,
    num_items_in_batch: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
    """
    Override for the compute_loss to factor trigger callbacks and filter columns

    :param model: the model to compute the loss for
    :param inputs: the inputs to pass through the model for calculating the loss
    :param return_outputs: True to return the outputs with the loss,
        False otherwise
    :param num_items_in_batch: the number of items which contribute to loss
    :return: the resulting loss if not return_outputs, otherwise a tuple
        containing the loss and the model's outputs
    """
    self._check_super_defined("compute_loss")

    # TODO: do we need these model signature columns?
    inputs = {k: inputs[k] for k in inputs if k in self._signature_columns}
    loss = super().compute_loss(
        model=model,
        inputs=inputs,
        return_outputs=return_outputs,
        num_items_in_batch=num_items_in_batch,
    )

    # take the mean across multiple GPUs
    # this is done outside the compute_loss function in the parent, replicating it
    # here for LLM Compressor logging and distillation
    loss = loss.mean()

    # Log step-wise loss and perplexity, for llama-recipes comparison
    # we want this before distillation loss so perplexity isn't thrown off
    do_log = self.state.global_step % self.args.logging_steps == 0
    if do_log:
        log = {}
        log["step_loss"] = loss.item()
        log["perplexity"] = torch.exp(loss).item()

    if active_session().lifecycle.initialized_:
        state = callbacks.loss_calculated(loss=loss)
        if state and state.loss is not None:
            loss = state.loss
            if do_log:
                log["distill_step_loss"] = loss.item() - log["step_loss"]
        callbacks.optim_pre_step()

    if do_log:
        self.log(log)

    return loss

create_optimizer

create_optimizer()

Override the optimizer to apply and update the recipe while training. create_optimizer must exist in the parent class and should set self.optimizer to the optimizer state and optionally set self.scaler if using amp.

Source code in llmcompressor/transformers/finetune/session_mixin.py
def create_optimizer(self):
    """
    Override the optimizer to apply and update the recipe while training.
    create_optimizer must exist in the parent class and should set
    self.optimizer to the optimizer state and optionally set self.scaler
    if using amp.
    """

    self._check_super_defined("create_optimizer")
    super().create_optimizer()

    # n_gpu handled internally by dataloader
    total_batch_size = (
        self.args.per_device_train_batch_size
        * self.args.gradient_accumulation_steps
    )

    if isinstance(self.train_dataset, IterableDataset):
        logger.warning(
            "Training is being run with a streamed dataset, "
            "steps_per_epoch cannot be determined and will default to "
            "1. LLM Compressor modifiers utilizing this statistic may not "
            "behave as expected. "
        )
        self.total_steps_per_epoch = 1
    else:
        self.total_steps_per_epoch = math.ceil(
            len(self.train_dataset) / total_batch_size
        )

    active_session().initialize(
        optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch
    )

    return self.optimizer

create_scheduler

create_scheduler(
    num_training_steps: int, optimizer: Optimizer = None
)

Create an LR scheduler to work with the applied recipes. This is a placeholder that just calls the super method, but would be expanded upon if we ever implement a LearningRateModifier.

Parameters:

  • num_training_steps

    (int) –

    the total number of training steps

  • optimizer

    (Optimizer, default: None ) –

    pre-initialized optimizer

Source code in llmcompressor/transformers/finetune/session_mixin.py
def create_scheduler(
    self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
    """
    Create an LR scheduler to work with the applied recipes. This is a placeholder
    that just calls the super method, but would be expanded upon if we ever
    implement a LearningRateModifier.

    :param num_training_steps: the total number of training steps
    :param optimizer: pre-initialized optimizer
    """

    # TODO: we don't currently have a LR scheduler in the new modifier framework
    self._check_super_defined("create_scheduler")
    return super().create_scheduler(
        num_training_steps=num_training_steps, optimizer=optimizer
    )

finalize_session

finalize_session()

Wrap up training by finalizing all modifiers initialized in the current session

Source code in llmcompressor/transformers/finetune/session_mixin.py
def finalize_session(self):
    """
    Wrap up training by finalizing all modifiers initialized in the current session
    """
    session = active_session()
    if not session.lifecycle.initialized_ or session.lifecycle.finalized:
        return False

    with summon_full_params_context(self.model, offload_to_cpu=True):
        # in order to update each layer we need to gathers all its parameters
        active_session().finalize()
    logger.info("Finalized LLM Compressor session")
    model = get_session_model()
    self.model = model
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        torch.xpu.empty_cache()
    else:
        torch.cuda.empty_cache()

initialize_session

initialize_session(
    epoch: float,
    checkpoint: Optional[str] = None,
    stage: Optional[str] = None,
)

Initialize the CompressionSession from the specified epoch, evaluates the recipe and initialized the modifiers for the training session

Parameters:

  • epoch

    (float) –

    Epoch to initialize session from, usually 0 unless loading from a checkpoint

  • checkpoint

    (Optional[str], default: None ) –

    Optional checkpoint to initialize from to continue training

  • stage

    (Optional[str], default: None ) –

    Optional stage of recipe to run, or None to run all stages

Source code in llmcompressor/transformers/finetune/session_mixin.py
def initialize_session(
    self,
    epoch: float,
    checkpoint: Optional[str] = None,
    stage: Optional[str] = None,
):
    """
    Initialize the CompressionSession from the specified epoch, evaluates the recipe
    and initialized the modifiers for the training session

    :param epoch: Epoch to initialize session from, usually 0 unless loading
    from a checkpoint
    :param checkpoint: Optional checkpoint to initialize from to continue training
    :param stage: Optional stage of recipe to run, or None to run all stages
    """
    session = active_session()
    if session.lifecycle.initialized_ or session.lifecycle.finalized:
        return False

    train_data = self.get_train_dataloader()

    self.accelerator.wait_for_everyone()
    with summon_full_params_context(self.model, offload_to_cpu=True):
        active_session().initialize(
            recipe=self.recipe,
            recipe_stage=stage,
            recipe_args=self.recipe_args,
            model=self.model,
            teacher_model=self.teacher,  # TODO: what about for self/disable?
            train_data=train_data,
            start=epoch,
            copy_data=False,
            attach_optim_callbacks=True,
            fsdp_active=self.is_fsdp_enabled,
            metadata=self.metadata,
        )

    self.accelerator.wait_for_everyone()
    model = get_session_model()
    self.model_wrapped = self.model = model

    if self.recipe is None:
        logger.warning(
            "No training recipe was provided, finetuning will be run "
            "without event callbacks to LLM Compressor. To supply a recipe "
            "pass a yaml file or string to the `recipe` argument."
        )

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        torch.xpu.empty_cache()
    else:
        torch.cuda.empty_cache()

log_model_sparsification

log_model_sparsification()

Log the current model sparsification info including pruned and quantized states

Source code in llmcompressor/transformers/finetune/session_mixin.py
def log_model_sparsification(self):
    """
    Log the current model sparsification info including pruned and quantized states
    """
    sparsification_info = ModuleSparsificationInfo(self.model)

    logger.info(
        f"Sparsification info for {type(self.model).__name__}: "
        f"{sparsification_info.params_total} total params. "
    )
    sparsity_percent_formatted = "{:.2f}".format(
        sparsification_info.params_sparse_percent
    )
    logger.info(
        f"There are {sparsification_info.params_total} prunable "
        f"params which have {sparsity_percent_formatted}% "
        "avg sparsity."
    )

    quant_percent_formatted = "{:.2f}".format(
        sparsification_info.params_quantized_percent
    )
    logger.info(
        f"There are {sparsification_info.params_total} quantizable "
        f"params, with a quantization percentage of "
        f"{quant_percent_formatted}%."
    )

maybe_log_model_sparsification

maybe_log_model_sparsification()

Log info on model sparsity and quantization if possible. Only print logs on the main process, and avoid logging for quantized FSDP models

Source code in llmcompressor/transformers/finetune/session_mixin.py
def maybe_log_model_sparsification(self):
    """
    Log info on model sparsity and quantization if possible. Only print logs on the
    main process, and avoid logging for quantized FSDP models
    """
    with summon_full_params_context(self.model, offload_to_cpu=True):
        # offload to avoid OOM errors
        if not self.accelerator.is_main_process:
            # only calculate stats rank0 GPU
            return
        if self.is_fsdp_enabled and qat_active(self.model):
            # due to state dict changes we can't log sparsity info with quantized
            # models in FSDP
            return

        self.log_model_sparsification()

save_model

save_model(
    output_dir: str,
    _internal_call: bool = False,
    skip_sparsity_compression_stats: Optional[bool] = True,
)

Override of the save_model function and expects it to exist in the parent. Calls into super() to save the model and additionally saves any recipes that were used with the model within the model folder.

Parameters:

  • output_dir

    (str) –

    the path to save the recipes into

  • _internal_call

    (bool, default: False ) –

    True if this is an internal call from the trainer in super(). Called from self.save_model(output_dir, _internal_call=True) in transformers/trainer/Trainer::_save_checkpoint

Source code in llmcompressor/transformers/finetune/session_mixin.py
def save_model(
    self,
    output_dir: str,
    _internal_call: bool = False,
    skip_sparsity_compression_stats: Optional[bool] = True,
):
    """
    Override of the save_model function and expects it to exist in the parent.
    Calls into super() to save the model and additionally saves any recipes
    that were used with the model within the model folder.

    :param output_dir: the path to save the recipes into
    :param _internal_call: True if this is an internal call from
        the trainer in super(). Called from
        self.save_model(output_dir, _internal_call=True)
        in transformers/trainer/Trainer::_save_checkpoint

    """
    if active_session() is None:
        logger.warning(
            "No active session found, skipping saving of recipes and model."
        )
        return

    # knowledge distillation requires making wrappers transparent during
    if isinstance(self.model, KDModelWrapper):
        self.model.prepare_for_save()  # TODO: move to finalize

    # save checkpoint
    # note that skip_sparsity_compression_stats
    # is True by default to avoid high runtime cost
    self.save_state()
    if self.accelerator.is_main_process:
        processor = getattr(self, "processing_class", self.tokenizer)
        # TODO: need to port over all saving parameters so that all
        # checkpoints are saved in the same way
        save_checkpoint(
            output_dir,
            model=self.model,
            processor=processor,
            save_safetensors=self.args.save_safetensors,
            save_compressed=self.model_args.save_compressed,
            skip_sparsity_compression_stats=skip_sparsity_compression_stats,
        )
    self.accelerator.wait_for_everyone()

    if isinstance(self.model, KDModelWrapper):
        self.model.finish_save()

train

train(*args, stage: Optional[str] = None, **kwargs)

Run a sparsification training cycle. Runs initialization for the sparse session before calling super().train() and finalization of the session after.

Logs sparsification details for the trained model.

Parameters:

  • args

    positional args to pass to super().train()

  • stage

    (Optional[str], default: None ) –

    Optional stage of recipe to run, or None to run all stages

  • kwargs

    keyword args to pass to super().train()

Returns:

  • the output from super.train()

Source code in llmcompressor/transformers/finetune/session_mixin.py
def train(self, *args, stage: Optional[str] = None, **kwargs):
    """
    Run a sparsification training cycle. Runs initialization for the sparse session
    before calling super().train() and finalization of the session after.

    Logs sparsification details for the trained model.

    :param args: positional args to pass to super().train()
    :param stage: Optional stage of recipe to run, or None to run all stages
    :param kwargs: keyword args to pass to super().train()
    :return: the output from super.train()
    """

    # lifecycle
    checkpoint, epoch = self._calculate_checkpoint_info(kwargs)
    self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)

    # do not save checkpoints as compressed
    original_save_compressed = self.model_args.save_compressed
    self.model_args.save_compressed = False

    # train with accelerator
    self.accelerator.wait_for_everyone()
    output = super().train(*args, **kwargs)
    self.accelerator.wait_for_everyone()

    # restore original setting for saving final model
    self.model_args.save_compressed = original_save_compressed

    # lifecycle
    self.finalize_session()
    self.accelerator.wait_for_everyone()

    # log model sparsity
    self.maybe_log_model_sparsification()
    self.accelerator.wait_for_everyone()

    return output

training_step

training_step(
    model: Module,
    inputs: Dict[str, Union[Tensor, Any]],
    num_items_in_batch: Optional[int] = None,
) -> torch.Tensor

Overrides the Trainer's training step to trigger the batch_start callback to the modifiers, then calls the parent function.

Parameters:

  • model

    (Module) –

    the model to compute the loss for

  • inputs

    (Dict[str, Union[Tensor, Any]]) –

    the inputs to pass through the model for calculating the loss

Returns:

  • Tensor

    output of the model

Source code in llmcompressor/transformers/finetune/session_mixin.py
def training_step(
    self,
    model: torch.nn.Module,
    inputs: Dict[str, Union[torch.Tensor, Any]],
    num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
    """
    Overrides the Trainer's training step to trigger the batch_start callback to
    the modifiers, then calls the parent function.

    :param model: the model to compute the loss for
    :param inputs: the inputs to pass through the model for calculating the loss
    :return: output of the model
    """
    self._check_super_defined("training_step")

    callbacks.batch_start(batch_data=inputs, global_step=self.state.epoch)
    model_outputs = super().training_step(
        model=model, inputs=inputs, num_items_in_batch=num_items_in_batch
    )

    return model_outputs

TextGenerationDataset

TextGenerationDataset(
    dataset_args: DatasetArguments,
    split: str,
    processor: Processor,
)

Bases: RegistryMixin

Base class for text datasets. Applies the following transformations to a dataset in order to prepare the dataset to be loaded by a dataloader

  1. Load dataset from huggingface or local cache
  2. Preprocess dataset according to preprocess function or chat/dataset template
  3. Tokenize dataset using model tokenizer/processor
  4. Apply post processing such as grouping text and/or adding labels for finetuning

Parameters:

  • dataset_args

    (DatasetArguments) –

    configuration settings for dataset loading

  • split

    (str) –

    split from dataset to load, for instance test or train[:5%]

  • processor

    (Processor) –

    processor or tokenizer to use on dataset

Methods:

  • load_dataset

    Load the raw dataset from Hugging Face, using cached copy if available

  • map

    Wrapper function around Dataset.map and IterableDataset.map.

Attributes:

  • preprocess (Union[Callable[[LazyRow], Any], None]) –

    The function must return keys which correspond to processor/tokenizer kwargs,

Source code in llmcompressor/transformers/finetune/data/base.py
def __init__(
    self,
    dataset_args: DatasetArguments,
    split: str,
    processor: Processor,
):
    self.dataset_args = dataset_args
    self.split = split
    self.processor = processor

    # get tokenizer
    self.tokenizer = getattr(self.processor, "tokenizer", self.processor)

    if self.tokenizer is not None:
        # fill in pad token
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # configure sequence length
        max_seq_length = dataset_args.max_seq_length
        if dataset_args.max_seq_length > self.tokenizer.model_max_length:
            logger.warning(
                f"The max_seq_length passed ({max_seq_length}) is larger than "
                f"maximum length for model ({self.tokenizer.model_max_length}). "
                f"Using max_seq_length={self.tokenizer.model_max_length}."
            )
        self.max_seq_length = min(
            dataset_args.max_seq_length, self.tokenizer.model_max_length
        )

        # configure padding
        self.padding = (
            False
            if self.dataset_args.concatenate_data
            else "max_length"
            if self.dataset_args.pad_to_max_length
            else False
        )

    else:
        self.max_seq_length = None
        self.padding = False

preprocess cached property

preprocess: Union[Callable[[LazyRow], Any], None]

The function must return keys which correspond to processor/tokenizer kwargs, optionally including PROMPT_KEY

load_dataset

load_dataset()

Load the raw dataset from Hugging Face, using cached copy if available

Parameters:

  • cache_dir

    disk location to search for cached dataset

Returns:

  • the requested dataset

Source code in llmcompressor/transformers/finetune/data/base.py
def load_dataset(self):
    """
    Load the raw dataset from Hugging Face, using cached copy if available

    :param cache_dir: disk location to search for cached dataset
    :return: the requested dataset
    """
    if self.dataset_args.dataset_path is not None:
        if self.dataset_args.dvc_data_repository is not None:
            self.dataset_args.raw_kwargs["storage_options"] = {
                "url": self.dataset_args.dvc_data_repository
            }
            self.dataset_args.raw_kwargs["data_files"] = (
                self.dataset_args.dataset_path
            )
        else:
            self.dataset_args.raw_kwargs["data_files"] = (
                get_custom_datasets_from_path(
                    self.dataset_args.dataset_path,
                    self.dataset_args.dataset
                    if hasattr(self.dataset_args, "dataset")
                    else self.dataset_args.dataset_name,
                )
            )

    logger.debug(f"Loading dataset {self.dataset_args.dataset}")
    return get_raw_dataset(
        self.dataset_args,
        None,
        split=self.split,
        streaming=self.dataset_args.streaming,
        **self.dataset_args.raw_kwargs,
    )

map

map(
    dataset: Union[Dataset, IterableDataset],
    function: Callable[[Any], Any],
    **kwargs
) -> Union[Dataset, IterableDataset]

Wrapper function around Dataset.map and IterableDataset.map.

If the dataset is streaming (in the case of IterableDataset), non-applicable arguments are ignored and the dataset features are resolved

Source code in llmcompressor/transformers/finetune/data/base.py
def map(
    self,
    dataset: Union[Dataset, IterableDataset],
    function: Callable[[Any], Any],
    **kwargs,
) -> Union[Dataset, IterableDataset]:
    """
    Wrapper function around Dataset.map and IterableDataset.map.

    If the dataset is streaming (in the case of IterableDataset), non-applicable
    arguments are ignored and the dataset features are resolved
    """
    if isinstance(dataset, IterableDataset):
        # remove arguments that don't apply to streaming
        kwargs.pop("num_proc", None)
        kwargs.pop("load_from_cache_file", None)
        kwargs.pop("desc", None)
        kwargs.pop("keep_in_memory", None)

    dataset = dataset.map(function, **kwargs)

    if isinstance(dataset, IterableDataset):
        dataset = dataset._resolve_features()

    return dataset

is_model_ct_quantized_from_path

is_model_ct_quantized_from_path(path: str) -> bool

Determine if model from path is quantized based on the config

Parameters:

  • path

    (str) –

    path to the model or HF stub

Returns:

  • bool

    True if config contains quantization_config from the given path

Source code in llmcompressor/transformers/utils/helpers.py
def is_model_ct_quantized_from_path(path: str) -> bool:
    """
    Determine if model from path is quantized based
    on the config

    :param path: path to the model or HF stub
    :return: True if config contains quantization_config from the given path

    """
    config = AutoConfig.from_pretrained(path)
    if config is not None:
        if (
            hasattr(config, "quantization_config")
            and config.quantization_config["quant_method"] == "compressed-tensors"
        ):
            return True
    return False