Skip to content

llmcompressor.transformers.finetune.data.base

Base classes for text generation dataset handling and processing.

This module provides the foundational TextGenerationDataset class with registry support for different dataset types. Handles dataset loading, tokenization, preprocessing, and text generation specific formatting for fine-tuning workflows.

Classes:

  • TextGenerationDataset

    Base class for text datasets. Applies the following transformations to a dataset

TextGenerationDataset

TextGenerationDataset(
    dataset_args: DatasetArguments,
    split: str,
    processor: Processor,
)

Bases: RegistryMixin

Base class for text datasets. Applies the following transformations to a dataset in order to prepare the dataset to be loaded by a dataloader

  1. Load dataset from huggingface or local cache
  2. Preprocess dataset according to preprocess function or chat/dataset template
  3. Tokenize dataset using model tokenizer/processor
  4. Apply post processing such as grouping text and/or adding labels for finetuning

Parameters:

  • dataset_args

    (DatasetArguments) –

    configuration settings for dataset loading

  • split

    (str) –

    split from dataset to load, for instance test or train[:5%]

  • processor

    (Processor) –

    processor or tokenizer to use on dataset

Methods:

  • load_dataset

    Load the raw dataset from Hugging Face, using cached copy if available

  • map

    Wrapper function around Dataset.map and IterableDataset.map.

Attributes:

  • preprocess (Union[Callable[[LazyRow], Any], None]) –

    The function must return keys which correspond to processor/tokenizer kwargs,

Source code in llmcompressor/transformers/finetune/data/base.py
def __init__(
    self,
    dataset_args: DatasetArguments,
    split: str,
    processor: Processor,
):
    self.dataset_args = dataset_args
    self.split = split
    self.processor = processor

    # get tokenizer
    self.tokenizer = getattr(self.processor, "tokenizer", self.processor)

    if self.tokenizer is not None:
        # fill in pad token
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # configure sequence length
        max_seq_length = dataset_args.max_seq_length
        if dataset_args.max_seq_length > self.tokenizer.model_max_length:
            logger.warning(
                f"The max_seq_length passed ({max_seq_length}) is larger than "
                f"maximum length for model ({self.tokenizer.model_max_length}). "
                f"Using max_seq_length={self.tokenizer.model_max_length}."
            )
        self.max_seq_length = min(
            dataset_args.max_seq_length, self.tokenizer.model_max_length
        )

        # configure padding
        self.padding = (
            False
            if self.dataset_args.concatenate_data
            else "max_length"
            if self.dataset_args.pad_to_max_length
            else False
        )

    else:
        self.max_seq_length = None
        self.padding = False

preprocess cached property

preprocess: Union[Callable[[LazyRow], Any], None]

The function must return keys which correspond to processor/tokenizer kwargs, optionally including PROMPT_KEY

load_dataset

load_dataset()

Load the raw dataset from Hugging Face, using cached copy if available

Parameters:

  • cache_dir

    disk location to search for cached dataset

Returns:

  • the requested dataset

Source code in llmcompressor/transformers/finetune/data/base.py
def load_dataset(self):
    """
    Load the raw dataset from Hugging Face, using cached copy if available

    :param cache_dir: disk location to search for cached dataset
    :return: the requested dataset
    """
    if self.dataset_args.dataset_path is not None:
        if self.dataset_args.dvc_data_repository is not None:
            self.dataset_args.raw_kwargs["storage_options"] = {
                "url": self.dataset_args.dvc_data_repository
            }
            self.dataset_args.raw_kwargs["data_files"] = (
                self.dataset_args.dataset_path
            )
        else:
            self.dataset_args.raw_kwargs["data_files"] = (
                get_custom_datasets_from_path(
                    self.dataset_args.dataset_path,
                    self.dataset_args.dataset
                    if hasattr(self.dataset_args, "dataset")
                    else self.dataset_args.dataset_name,
                )
            )

    logger.debug(f"Loading dataset {self.dataset_args.dataset}")
    return get_raw_dataset(
        self.dataset_args,
        None,
        split=self.split,
        streaming=self.dataset_args.streaming,
        **self.dataset_args.raw_kwargs,
    )

map

map(
    dataset: Union[Dataset, IterableDataset],
    function: Callable[[Any], Any],
    **kwargs
) -> Union[Dataset, IterableDataset]

Wrapper function around Dataset.map and IterableDataset.map.

If the dataset is streaming (in the case of IterableDataset), non-applicable arguments are ignored and the dataset features are resolved

Source code in llmcompressor/transformers/finetune/data/base.py
def map(
    self,
    dataset: Union[Dataset, IterableDataset],
    function: Callable[[Any], Any],
    **kwargs,
) -> Union[Dataset, IterableDataset]:
    """
    Wrapper function around Dataset.map and IterableDataset.map.

    If the dataset is streaming (in the case of IterableDataset), non-applicable
    arguments are ignored and the dataset features are resolved
    """
    if isinstance(dataset, IterableDataset):
        # remove arguments that don't apply to streaming
        kwargs.pop("num_proc", None)
        kwargs.pop("load_from_cache_file", None)
        kwargs.pop("desc", None)
        kwargs.pop("keep_in_memory", None)

    dataset = dataset.map(function, **kwargs)

    if isinstance(dataset, IterableDataset):
        dataset = dataset._resolve_features()

    return dataset