Skip to content

llmcompressor.pytorch.model_load.helpers

Functions:

  • get_session_model

    :return: pytorch module stored by the active CompressionSession,

  • parse_dtype

    :param dtype_arg: dtype or string to parse

get_session_model

get_session_model() -> Optional[Module]

Returns:

  • Optional[Module]

    pytorch module stored by the active CompressionSession, or None if no session is active

Source code in src/llmcompressor/pytorch/model_load/helpers.py
def get_session_model() -> Optional[Module]:
    """
    :return: pytorch module stored by the active CompressionSession,
        or None if no session is active
    """
    session = active_session()
    if not session:
        return None

    active_model = session.state.model
    return active_model

parse_dtype

parse_dtype(dtype_arg: Union[str, dtype]) -> torch.dtype

Parameters:

  • dtype_arg (Union[str, dtype]) –

    dtype or string to parse

Returns:

  • dtype

    torch.dtype parsed from input string

Source code in src/llmcompressor/pytorch/model_load/helpers.py
def parse_dtype(dtype_arg: Union[str, torch.dtype]) -> torch.dtype:
    """
    :param dtype_arg: dtype or string to parse
    :return: torch.dtype parsed from input string
    """
    dtype_arg = str(dtype_arg)
    dtype = "auto"  # get precision from model by default
    if dtype_arg in ("half", "float16", "torch.float16"):
        dtype = torch.float16
    elif dtype_arg in ("torch.bfloat16", "bfloat16"):
        dtype = torch.bfloat16
    elif dtype_arg in ("full", "float32", "torch.float32"):
        dtype = torch.float32

    return dtype