Skip to content

llmcompressor.modifiers.distillation.utils.pytorch

Modules:

Classes:

KDModelWrapper

KDModelWrapper(
    student_model: Module,
    teacher_model: Module,
    wrappers: Dict[str, Any],
    comparison,
    fsdp_active: bool,
)

Bases: Module

Methods:

Source code in llmcompressor/modifiers/distillation/utils/pytorch/model_wrapper.py
def __init__(
    self,
    student_model: Module,
    teacher_model: Module,
    wrappers: Dict[str, Any],
    comparison,
    fsdp_active: bool,
):
    super(KDModelWrapper, self).__init__()

    self.student_model = student_model
    self.teacher_model = teacher_model
    self.wrappers = wrappers
    self.kd_comparison = comparison
    self._save_active = False
    self._fsdp_active = fsdp_active
    self.kd_enabled = False
    self.register_buffer(self.KD_LAST_COMPARISON, torch.zeros(1, device="cpu"))
    self._init_called = True  # make sure this is last property to be set

    def _clear_missing_keys(module, incompatible_keys):
        incompatible_keys.missing_keys.clear()

    self.register_load_state_dict_post_hook(_clear_missing_keys)

finish_save

finish_save()

Finish saving model

Source code in llmcompressor/modifiers/distillation/utils/pytorch/model_wrapper.py
def finish_save(self):
    """
    Finish saving model
    """
    self._save_active = False
    for student_wrapper, teacher_wrapper in self.wrappers.values():
        student_wrapper.finish_save()
        teacher_wrapper.finish_save()

prepare_for_save

prepare_for_save()

Prepare model structure to be saved, specifically self.named_modules

Source code in llmcompressor/modifiers/distillation/utils/pytorch/model_wrapper.py
def prepare_for_save(self):
    """
    Prepare model structure to be saved, specifically `self.named_modules`
    """
    self._save_active = True
    for student_wrapper, teacher_wrapper in self.wrappers.values():
        student_wrapper.prepare_for_save()
        teacher_wrapper.prepare_for_save()

KDModuleWrapper

KDModuleWrapper(
    layer: Module,
    hidden_size: Tuple,
    transforms: Optional[List[TransformFuncType]],
    fsdp_active: bool,
    offload_output: bool,
)

Bases: Module

Methods:

Source code in llmcompressor/modifiers/distillation/utils/pytorch/kd_wrapper.py
def __init__(
    self,
    layer: Module,
    hidden_size: Tuple,
    transforms: Optional[List[TransformFuncType]],
    fsdp_active: bool,
    offload_output: bool,
):
    super(KDModuleWrapper, self).__init__()

    self.layer = layer
    self._save_active = False
    self._fsdp_active = fsdp_active
    self.offload_output = offload_output
    self.kd_transforms = transforms
    self.kd_enabled = False
    self.register_buffer(
        self.KD_TRANSFORMED_BUFFER, torch.zeros(hidden_size, device="cpu")
    )
    self._init_called = True  # make sure this is last property to be set

    def _clear_missing_keys(module, incompatible_keys):
        incompatible_keys.missing_keys.clear()

    self.register_load_state_dict_post_hook(_clear_missing_keys)

finish_save

finish_save()

Finish saving model

Source code in llmcompressor/modifiers/distillation/utils/pytorch/kd_wrapper.py
def finish_save(self):
    """
    Finish saving model
    """
    self._save_active = False

prepare_for_save

prepare_for_save()

Prepare model structure to be saved, specifically self.named_modules

Source code in llmcompressor/modifiers/distillation/utils/pytorch/kd_wrapper.py
def prepare_for_save(self):
    """
    Prepare model structure to be saved, specifically `self.named_modules`
    """
    self._save_active = True