Skip to content

vllm.scalar_type

_SCALAR_TYPES_ID_MAP module-attribute

_SCALAR_TYPES_ID_MAP = {}

NanRepr

Bases: Enum

Source code in vllm/scalar_type.py
class NanRepr(Enum):
    NONE = 0  # nans are not supported
    IEEE_754 = 1  # nans are: Exp all 1s, mantissa not all 0s
    EXTD_RANGE_MAX_MIN = 2  # nans are: Exp all 1s, mantissa all 1s

EXTD_RANGE_MAX_MIN class-attribute instance-attribute

EXTD_RANGE_MAX_MIN = 2

IEEE_754 class-attribute instance-attribute

IEEE_754 = 1

NONE class-attribute instance-attribute

NONE = 0

ScalarType dataclass

ScalarType can represent a wide range of floating point and integer types, in particular it can be used to represent sub-byte data types (something that torch.dtype currently does not support). It is also capable of representing types with a bias, i.e.: stored_value = value + bias, this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias of 8). The implementation for this class can be found in csrc/core/scalar_type.hpp, these type signatures should be kept in sync with that file.

Source code in vllm/scalar_type.py
@dataclass(frozen=True)
class ScalarType:
    """
    ScalarType can represent a wide range of floating point and integer
    types, in particular it can be used to represent sub-byte data types
    (something that torch.dtype currently does not support). It is also
    capable of  representing types with a bias, i.e.:
      `stored_value = value + bias`,
    this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
    of 8). The implementation for this class can be found in
    csrc/core/scalar_type.hpp, these type signatures should be kept in sync
    with that file.
    """

    exponent: int
    """
    Number of bits in the exponent if this is a floating point type
    (zero if this an integer type)
    """

    mantissa: int
    """
    Number of bits in the mantissa if this is a floating point type,
    or the number bits representing an integer excluding the sign bit if
    this an integer type.
    """

    signed: bool
    "If the type is signed (i.e. has a sign bit)"

    bias: int
    """
    bias used to encode the values in this scalar type
    (value = stored_value - bias, default 0) for example if we store the
    type as an unsigned integer with a bias of 128 then the value 0 will be
    stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
    """

    _finite_values_only: bool = False
    """
    Private: if infs are supported, used `has_infs()` instead.
    """

    nan_repr: NanRepr = NanRepr.IEEE_754
    """
    How NaNs are represent in this scalar type, returns NanRepr value.
    (not applicable for integer types)
    """

    def _floating_point_max_int(self) -> int:
        assert (
            self.mantissa <= 52 and self.exponent <= 11
        ), f"Cannot represent max/min as a double for type {self.__str__()}"

        max_mantissa = (1 << self.mantissa) - 1
        if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
            max_mantissa = max_mantissa - 1

        max_exponent = (1 << self.exponent) - 2
        if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
                or self.nan_repr == NanRepr.NONE):
            assert (
                self.exponent < 11
            ), f"Cannot represent max/min as a double for type {self.__str__()}"
            max_exponent = max_exponent + 1

        # adjust the exponent to match that of a double
        # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
        # e is the exponent bits), there is some precedent for non-standard
        # biases, example `float8_e4m3b11fnuz` here:
        # https://github.com/jax-ml/ml_dtypes but to avoid premature over
        # complication we are just assuming the standard exponent bias until
        # there is a need to support non-standard biases
        exponent_bias = (1 << (self.exponent - 1)) - 1
        exponent_bias_double = (1 << 10) - 1  # double e = 11

        max_exponent_double = (max_exponent - exponent_bias +
                               exponent_bias_double)

        # shift the mantissa and exponent into the proper positions for an
        # IEEE double and bitwise-or them together.
        return (max_mantissa <<
                (52 - self.mantissa)) | (max_exponent_double << 52)

    def _floating_point_max(self) -> float:
        double_raw = self._floating_point_max_int()
        return struct.unpack('!d', struct.pack('!Q', double_raw))[0]

    def _raw_max(self) -> Union[int, float]:
        if self.is_floating_point():
            return self._floating_point_max()
        else:
            assert (self.size_bits < 64 or self.size_bits == 64
                    and self.is_signed()), "Cannot represent max as an int"
            return (1 << self.mantissa) - 1

    def _raw_min(self) -> Union[int, float]:
        if self.is_floating_point():
            assert self.is_signed(
            ), "We currently assume all floating point types are signed"
            sign_bit_double = 1 << 63

            max_raw = self._floating_point_max_int()
            min_raw = max_raw | sign_bit_double
            return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
        else:
            assert (not self.is_signed() or self.size_bits
                    <= 64), "Cannot represent min as a int64_t"

            if self.is_signed():
                return -(1 << (self.size_bits - 1))
            else:
                return 0

    @functools.cached_property
    def id(self) -> int:
        """
        Convert the ScalarType to an int which can be passed to pytorch custom
        ops. This layout of the int must be kept in sync with the C++
        ScalarType's from_id method.
        """
        val = 0
        offset = 0

        def or_and_advance(member, bit_width):
            nonlocal val
            nonlocal offset
            bit_mask = (1 << bit_width) - 1
            val = val | (int(member) & bit_mask) << offset
            offset = offset + bit_width

        or_and_advance(self.exponent, 8)
        or_and_advance(self.mantissa, 8)
        or_and_advance(self.signed, 1)
        or_and_advance(self.bias, 32)
        or_and_advance(self._finite_values_only, 1)
        or_and_advance(self.nan_repr.value, 8)

        assert offset <= 64, \
            f"ScalarType fields too big {offset} to fit into an int64"

        _SCALAR_TYPES_ID_MAP[val] = self

        return val

    @property
    def size_bits(self) -> int:
        return self.exponent + self.mantissa + int(self.signed)

    def min(self) -> Union[int, float]:
        """
        Min representable value for this scalar type.
        (accounting for bias if there is one)
        """
        return self._raw_min() - self.bias

    def max(self) -> Union[int, float]:
        """
        Max representable value for this scalar type.
        (accounting for bias if there is one)
        """
        return self._raw_max() - self.bias

    def is_signed(self) -> bool:
        """
        If the type is signed (i.e. has a sign bit), same as `signed`
        added for consistency with:
        https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
        """
        return self.signed

    def is_floating_point(self) -> bool:
        "If the type is a floating point type"
        return self.exponent != 0

    def is_integer(self) -> bool:
        "If the type is an integer type"
        return self.exponent == 0

    def has_bias(self) -> bool:
        "If the type has a non-zero bias"
        return self.bias != 0

    def has_infs(self) -> bool:
        "If the type is floating point and supports infinity"
        return not self._finite_values_only

    def has_nans(self) -> bool:
        return self.nan_repr != NanRepr.NONE.value

    def is_ieee_754(self) -> bool:
        """
        If the type is a floating point type that follows IEEE 754
        conventions
        """
        return self.nan_repr == NanRepr.IEEE_754.value and \
            not self._finite_values_only

    def __str__(self) -> str:
        """
        naming generally follows: https://github.com/jax-ml/ml_dtypes
        for floating point types (leading f) the scheme is:
        `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
        flags:
          - no-flags: means it follows IEEE 754 conventions
          - f: means finite values only (no infinities)
          - n: means nans are supported (non-standard encoding)
        for integer types the scheme is:
          `[u]int<size_bits>[b<bias>]`
          - if bias is not present it means its zero
        """
        if self.is_floating_point():
            ret = "float" + str(self.size_bits) + "_e" + str(
                self.exponent) + "m" + str(self.mantissa)

            if not self.is_ieee_754():
                if self._finite_values_only:
                    ret = ret + "f"
                if self.nan_repr != NanRepr.NONE:
                    ret = ret + "n"

            return ret
        else:
            ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
            if self.has_bias():
                ret = ret + "b" + str(self.bias)
            return ret

    def __repr__(self) -> str:
        return "ScalarType." + self.__str__()

    # __len__ needs to be defined (and has to throw TypeError) for pytorch's
    # opcheck to work.
    def __len__(self) -> int:
        raise TypeError

    #
    # Convenience Constructors
    #

    @classmethod
    def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
        "Create a signed integer scalar type (size_bits includes sign-bit)."
        ret = cls(0, size_bits - 1, True, bias if bias else 0)
        ret.id  # noqa B018: make sure the id is cached
        return ret

    @classmethod
    def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
        """Create a unsigned integer scalar type."""
        ret = cls(0, size_bits, False, bias if bias else 0)
        ret.id  # noqa B018: make sure the id is cached
        return ret

    @classmethod
    def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
        """
        Create a standard floating point type
        (i.e. follows IEEE 754 conventions).
        """
        assert (mantissa > 0 and exponent > 0)
        ret = cls(exponent, mantissa, True, 0)
        ret.id  # noqa B018: make sure the id is cached
        return ret

    @classmethod
    def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
               nan_repr: NanRepr) -> 'ScalarType':
        """
        Create a non-standard floating point type
        (i.e. does not follow IEEE 754 conventions).
        """
        assert (mantissa > 0 and exponent > 0)
        assert (nan_repr != NanRepr.IEEE_754), (
            "use `float_IEEE754` constructor for floating point types that "
            "follow IEEE 754 conventions")
        ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
        ret.id  # noqa B018: make sure the id is cached
        return ret

    @classmethod
    def from_id(cls, scalar_type_id: int):
        if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
            raise ValueError(
                f"scalar_type_id {scalar_type_id} doesn't exists.")
        return _SCALAR_TYPES_ID_MAP[scalar_type_id]

_finite_values_only class-attribute instance-attribute

_finite_values_only: bool = False

Private: if infs are supported, used has_infs() instead.

bias instance-attribute

bias: int

bias used to encode the values in this scalar type (value = stored_value - bias, default 0) for example if we store the type as an unsigned integer with a bias of 128 then the value 0 will be stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.

exponent instance-attribute

exponent: int

Number of bits in the exponent if this is a floating point type (zero if this an integer type)

id cached property

id: int

Convert the ScalarType to an int which can be passed to pytorch custom ops. This layout of the int must be kept in sync with the C++ ScalarType's from_id method.

mantissa instance-attribute

mantissa: int

Number of bits in the mantissa if this is a floating point type, or the number bits representing an integer excluding the sign bit if this an integer type.

nan_repr class-attribute instance-attribute

nan_repr: NanRepr = IEEE_754

How NaNs are represent in this scalar type, returns NanRepr value. (not applicable for integer types)

signed instance-attribute

signed: bool

If the type is signed (i.e. has a sign bit)

size_bits property

size_bits: int

__init__

__init__(
    exponent: int,
    mantissa: int,
    signed: bool,
    bias: int,
    _finite_values_only: bool = False,
    nan_repr: NanRepr = IEEE_754,
) -> None

__len__

__len__() -> int
Source code in vllm/scalar_type.py
def __len__(self) -> int:
    raise TypeError

__repr__

__repr__() -> str
Source code in vllm/scalar_type.py
def __repr__(self) -> str:
    return "ScalarType." + self.__str__()

__str__

__str__() -> str

naming generally follows: https://github.com/jax-ml/ml_dtypes for floating point types (leading f) the scheme is: float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags] flags: - no-flags: means it follows IEEE 754 conventions - f: means finite values only (no infinities) - n: means nans are supported (non-standard encoding) for integer types the scheme is: [u]int<size_bits>[b<bias>] - if bias is not present it means its zero

Source code in vllm/scalar_type.py
def __str__(self) -> str:
    """
    naming generally follows: https://github.com/jax-ml/ml_dtypes
    for floating point types (leading f) the scheme is:
    `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
    flags:
      - no-flags: means it follows IEEE 754 conventions
      - f: means finite values only (no infinities)
      - n: means nans are supported (non-standard encoding)
    for integer types the scheme is:
      `[u]int<size_bits>[b<bias>]`
      - if bias is not present it means its zero
    """
    if self.is_floating_point():
        ret = "float" + str(self.size_bits) + "_e" + str(
            self.exponent) + "m" + str(self.mantissa)

        if not self.is_ieee_754():
            if self._finite_values_only:
                ret = ret + "f"
            if self.nan_repr != NanRepr.NONE:
                ret = ret + "n"

        return ret
    else:
        ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
        if self.has_bias():
            ret = ret + "b" + str(self.bias)
        return ret

_floating_point_max

_floating_point_max() -> float
Source code in vllm/scalar_type.py
def _floating_point_max(self) -> float:
    double_raw = self._floating_point_max_int()
    return struct.unpack('!d', struct.pack('!Q', double_raw))[0]

_floating_point_max_int

_floating_point_max_int() -> int
Source code in vllm/scalar_type.py
def _floating_point_max_int(self) -> int:
    assert (
        self.mantissa <= 52 and self.exponent <= 11
    ), f"Cannot represent max/min as a double for type {self.__str__()}"

    max_mantissa = (1 << self.mantissa) - 1
    if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
        max_mantissa = max_mantissa - 1

    max_exponent = (1 << self.exponent) - 2
    if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
            or self.nan_repr == NanRepr.NONE):
        assert (
            self.exponent < 11
        ), f"Cannot represent max/min as a double for type {self.__str__()}"
        max_exponent = max_exponent + 1

    # adjust the exponent to match that of a double
    # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
    # e is the exponent bits), there is some precedent for non-standard
    # biases, example `float8_e4m3b11fnuz` here:
    # https://github.com/jax-ml/ml_dtypes but to avoid premature over
    # complication we are just assuming the standard exponent bias until
    # there is a need to support non-standard biases
    exponent_bias = (1 << (self.exponent - 1)) - 1
    exponent_bias_double = (1 << 10) - 1  # double e = 11

    max_exponent_double = (max_exponent - exponent_bias +
                           exponent_bias_double)

    # shift the mantissa and exponent into the proper positions for an
    # IEEE double and bitwise-or them together.
    return (max_mantissa <<
            (52 - self.mantissa)) | (max_exponent_double << 52)

_raw_max

_raw_max() -> Union[int, float]
Source code in vllm/scalar_type.py
def _raw_max(self) -> Union[int, float]:
    if self.is_floating_point():
        return self._floating_point_max()
    else:
        assert (self.size_bits < 64 or self.size_bits == 64
                and self.is_signed()), "Cannot represent max as an int"
        return (1 << self.mantissa) - 1

_raw_min

_raw_min() -> Union[int, float]
Source code in vllm/scalar_type.py
def _raw_min(self) -> Union[int, float]:
    if self.is_floating_point():
        assert self.is_signed(
        ), "We currently assume all floating point types are signed"
        sign_bit_double = 1 << 63

        max_raw = self._floating_point_max_int()
        min_raw = max_raw | sign_bit_double
        return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
    else:
        assert (not self.is_signed() or self.size_bits
                <= 64), "Cannot represent min as a int64_t"

        if self.is_signed():
            return -(1 << (self.size_bits - 1))
        else:
            return 0

float_ classmethod

float_(
    exponent: int,
    mantissa: int,
    finite_values_only: bool,
    nan_repr: NanRepr,
) -> ScalarType

Create a non-standard floating point type (i.e. does not follow IEEE 754 conventions).

Source code in vllm/scalar_type.py
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
           nan_repr: NanRepr) -> 'ScalarType':
    """
    Create a non-standard floating point type
    (i.e. does not follow IEEE 754 conventions).
    """
    assert (mantissa > 0 and exponent > 0)
    assert (nan_repr != NanRepr.IEEE_754), (
        "use `float_IEEE754` constructor for floating point types that "
        "follow IEEE 754 conventions")
    ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
    ret.id  # noqa B018: make sure the id is cached
    return ret

float_IEEE754 classmethod

float_IEEE754(exponent: int, mantissa: int) -> ScalarType

Create a standard floating point type (i.e. follows IEEE 754 conventions).

Source code in vllm/scalar_type.py
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
    """
    Create a standard floating point type
    (i.e. follows IEEE 754 conventions).
    """
    assert (mantissa > 0 and exponent > 0)
    ret = cls(exponent, mantissa, True, 0)
    ret.id  # noqa B018: make sure the id is cached
    return ret

from_id classmethod

from_id(scalar_type_id: int)
Source code in vllm/scalar_type.py
@classmethod
def from_id(cls, scalar_type_id: int):
    if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
        raise ValueError(
            f"scalar_type_id {scalar_type_id} doesn't exists.")
    return _SCALAR_TYPES_ID_MAP[scalar_type_id]

has_bias

has_bias() -> bool

If the type has a non-zero bias

Source code in vllm/scalar_type.py
def has_bias(self) -> bool:
    "If the type has a non-zero bias"
    return self.bias != 0

has_infs

has_infs() -> bool

If the type is floating point and supports infinity

Source code in vllm/scalar_type.py
def has_infs(self) -> bool:
    "If the type is floating point and supports infinity"
    return not self._finite_values_only

has_nans

has_nans() -> bool
Source code in vllm/scalar_type.py
def has_nans(self) -> bool:
    return self.nan_repr != NanRepr.NONE.value

int_ classmethod

int_(size_bits: int, bias: Optional[int]) -> ScalarType

Create a signed integer scalar type (size_bits includes sign-bit).

Source code in vllm/scalar_type.py
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
    "Create a signed integer scalar type (size_bits includes sign-bit)."
    ret = cls(0, size_bits - 1, True, bias if bias else 0)
    ret.id  # noqa B018: make sure the id is cached
    return ret

is_floating_point

is_floating_point() -> bool

If the type is a floating point type

Source code in vllm/scalar_type.py
def is_floating_point(self) -> bool:
    "If the type is a floating point type"
    return self.exponent != 0

is_ieee_754

is_ieee_754() -> bool

If the type is a floating point type that follows IEEE 754 conventions

Source code in vllm/scalar_type.py
def is_ieee_754(self) -> bool:
    """
    If the type is a floating point type that follows IEEE 754
    conventions
    """
    return self.nan_repr == NanRepr.IEEE_754.value and \
        not self._finite_values_only

is_integer

is_integer() -> bool

If the type is an integer type

Source code in vllm/scalar_type.py
def is_integer(self) -> bool:
    "If the type is an integer type"
    return self.exponent == 0

is_signed

is_signed() -> bool

If the type is signed (i.e. has a sign bit), same as signed added for consistency with: https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html

Source code in vllm/scalar_type.py
def is_signed(self) -> bool:
    """
    If the type is signed (i.e. has a sign bit), same as `signed`
    added for consistency with:
    https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
    """
    return self.signed

max

max() -> Union[int, float]

Max representable value for this scalar type. (accounting for bias if there is one)

Source code in vllm/scalar_type.py
def max(self) -> Union[int, float]:
    """
    Max representable value for this scalar type.
    (accounting for bias if there is one)
    """
    return self._raw_max() - self.bias

min

min() -> Union[int, float]

Min representable value for this scalar type. (accounting for bias if there is one)

Source code in vllm/scalar_type.py
def min(self) -> Union[int, float]:
    """
    Min representable value for this scalar type.
    (accounting for bias if there is one)
    """
    return self._raw_min() - self.bias

uint classmethod

uint(size_bits: int, bias: Optional[int]) -> ScalarType

Create a unsigned integer scalar type.

Source code in vllm/scalar_type.py
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
    """Create a unsigned integer scalar type."""
    ret = cls(0, size_bits, False, bias if bias else 0)
    ret.id  # noqa B018: make sure the id is cached
    return ret

scalar_types

Source code in vllm/scalar_type.py
class scalar_types:
    int4 = ScalarType.int_(4, None)
    uint4 = ScalarType.uint(4, None)
    int8 = ScalarType.int_(8, None)
    uint8 = ScalarType.uint(8, None)
    float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
    float8_e5m2 = ScalarType.float_IEEE754(5, 2)
    float16_e8m7 = ScalarType.float_IEEE754(8, 7)
    float16_e5m10 = ScalarType.float_IEEE754(5, 10)

    # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
    float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)

    # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
    float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)

    # "gptq" types
    uint2b2 = ScalarType.uint(2, 2)
    uint3b4 = ScalarType.uint(3, 4)
    uint4b8 = ScalarType.uint(4, 8)
    uint8b128 = ScalarType.uint(8, 128)

    # colloquial names
    bfloat16 = float16_e8m7
    float16 = float16_e5m10

bfloat16 class-attribute instance-attribute

bfloat16 = float16_e8m7

float16 class-attribute instance-attribute

float16 = float16_e5m10

float16_e5m10 class-attribute instance-attribute

float16_e5m10 = float_IEEE754(5, 10)

float16_e8m7 class-attribute instance-attribute

float16_e8m7 = float_IEEE754(8, 7)

float4_e2m1f class-attribute instance-attribute

float4_e2m1f = float_(2, 1, True, NONE)

float6_e3m2f class-attribute instance-attribute

float6_e3m2f = float_(3, 2, True, NONE)

float8_e4m3fn class-attribute instance-attribute

float8_e4m3fn = float_(4, 3, True, EXTD_RANGE_MAX_MIN)

float8_e5m2 class-attribute instance-attribute

float8_e5m2 = float_IEEE754(5, 2)

int4 class-attribute instance-attribute

int4 = int_(4, None)

int8 class-attribute instance-attribute

int8 = int_(8, None)

uint2b2 class-attribute instance-attribute

uint2b2 = uint(2, 2)

uint3b4 class-attribute instance-attribute

uint3b4 = uint(3, 4)

uint4 class-attribute instance-attribute

uint4 = uint(4, None)

uint4b8 class-attribute instance-attribute

uint4b8 = uint(4, 8)

uint8 class-attribute instance-attribute

uint8 = uint(8, None)

uint8b128 class-attribute instance-attribute

uint8b128 = uint(8, 128)