Skip to content

llmcompressor.utils.fsdp.helpers

Functions:

  • get_fsdp_parent

    Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper

  • is_fsdp_model

    Check if a model instance is wrapped by FSDP

  • maybe_get_wrapped

    Given a model that may or may not have a distributed wrapper, return the underlying

  • set_wrapped_model

    Given a state with a model that may or may not have a distributed wrapper, set

get_fsdp_parent

get_fsdp_parent(
    layer_name: str, model: Module
) -> Optional[Module]

Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper is found just return None

:model: pytorch module to search through

Parameters:

  • layer_name

    (str) –

    layer name in model to get parent of

Returns:

  • Optional[Module]

    FSDP wrapped parent of layer_name if available, otherwise None

Source code in llmcompressor/utils/fsdp/helpers.py
def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]:
    """
    Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper
    is found just return None

    :param layer_name: layer name in model to get parent of
    :model: pytorch module to search through
    :return: FSDP wrapped parent of layer_name if available, otherwise None
    """
    if not is_fsdp_model(model):
        return None

    parent_name = layer_name
    parent = operator.attrgetter(parent_name)(model)
    while not isinstance(parent, FullyShardedDataParallel):
        if len(parent_name) == 0:  # we've reached the root module and its not FSDP
            # this should never get hit because we check for an FSDP root above
            # but while statements without a backup are too scary
            return None
        parent_name = ".".join(parent_name.split(".")[:-1])
        parent = operator.attrgetter(parent_name)(model)

    return parent

is_fsdp_model

is_fsdp_model(model: Module) -> bool

Check if a model instance is wrapped by FSDP

Parameters:

  • model

    (Module) –

    pytorch model to check

Returns:

  • bool

    True if module is wrapped, False otherwise

Source code in llmcompressor/utils/fsdp/helpers.py
def is_fsdp_model(model: Module) -> bool:
    """
    Check if a model instance is wrapped by FSDP

    :param model: pytorch model to check
    :return: True if module is wrapped, False otherwise
    """
    if not FullyShardedDataParallel:
        return False

    return isinstance(model, FullyShardedDataParallel)

maybe_get_wrapped

maybe_get_wrapped(model: Module) -> Module

Given a model that may or may not have a distributed wrapper, return the underlying wrapped model.

Parameters:

  • model

    (Module) –

    input model to get wrapped model from

Returns:

  • Module

    wrapped model

Source code in llmcompressor/utils/fsdp/helpers.py
def maybe_get_wrapped(model: Module) -> Module:
    """
    Given a model that may or may not have a distributed wrapper, return the underlying
    wrapped model.

    :param model: input model to get wrapped model from
    :returns: wrapped model
    """
    if is_fsdp_model(model=model):
        return model._fsdp_wrapped_module
    return model

set_wrapped_model

set_wrapped_model(state: State, wrapped_model: Module)

Given a state with a model that may or may not have a distributed wrapper, set the underlying wrapped model.

Parameters:

  • state

    (State) –

    state to update model of

  • updated_wrapped

    model to inject into input_model

Source code in llmcompressor/utils/fsdp/helpers.py
def set_wrapped_model(state: State, wrapped_model: Module):
    """
    Given a state with a model that may or may not have a distributed wrapper, set
    the underlying wrapped model.

    :param state: state to update model of
    :param updated_wrapped: model to inject into input_model
    """
    if is_fsdp_model(state.model):
        state.model._fsdp_wrapped_module = wrapped_model
    else:
        state.model = wrapped_model