Skip to content

vllm_gaudi.ops.hpu_modelopt

logger module-attribute

logger = init_logger(__name__)

HPUModelOptFp8Config

Bases: ModelOptFp8Config

Config class for ModelOpt FP8.

Source code in vllm_gaudi/ops/hpu_modelopt.py
class HPUModelOptFp8Config(ModelOptFp8Config):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        quant_method: str,
        is_checkpoint_fp8_serialized: bool,
        kv_cache_quant_method: str | None,
        exclude_modules: list[str],
    ) -> None:
        super().__init__(quant_method, is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)

        if self.quant_method != "FP8":
            raise ValueError("Unsupported ModelOpt FP8 quant_algo on Gaudi: "
                             f"{self.quant_method}. Supported: FP8 only.")

    def get_supported_act_dtypes(self) -> list[torch.dtype]:
        return [torch.bfloat16]

    def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]:
        # handle kv-cache first so we can focus only on weight quantization thereafter
        if isinstance(layer, Attention):
            return self.KVCacheMethodCls(self)

        # handle exclusion
        if self.is_layer_excluded(prefix):
            if isinstance(layer, LinearBase):
                return UnquantizedLinearMethod()
            return None

        # TODO: This special hard coded logic is not needed for quantized checkpoints
        # generated by ModelOpt >= 0.39.0 where they are handled natually by the
        # exclude_modules config. But need to keep them for loading quantized
        # checkpoints generated by older versions. Then check substring matching
        # for patterns not caught by exact match
        if "vision_tower" in prefix or "vision_model" in prefix:
            return UnquantizedLinearMethod()

        # now, the layer is quantized, handle it here
        if isinstance(layer, LinearBase):
            quant_method = HPUModelOptFp8LinearMethod(self)
            return quant_method
        elif isinstance(layer, FusedMoE):
            raise ValueError("FP8 modelopt quantization not yet supported on Gaudi")

        return None

__init__

__init__(
    quant_method: str,
    is_checkpoint_fp8_serialized: bool,
    kv_cache_quant_method: str | None,
    exclude_modules: list[str],
) -> None
Source code in vllm_gaudi/ops/hpu_modelopt.py
def __init__(
    self,
    quant_method: str,
    is_checkpoint_fp8_serialized: bool,
    kv_cache_quant_method: str | None,
    exclude_modules: list[str],
) -> None:
    super().__init__(quant_method, is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)

    if self.quant_method != "FP8":
        raise ValueError("Unsupported ModelOpt FP8 quant_algo on Gaudi: "
                         f"{self.quant_method}. Supported: FP8 only.")

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm_gaudi/ops/hpu_modelopt.py
def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]:
    # handle kv-cache first so we can focus only on weight quantization thereafter
    if isinstance(layer, Attention):
        return self.KVCacheMethodCls(self)

    # handle exclusion
    if self.is_layer_excluded(prefix):
        if isinstance(layer, LinearBase):
            return UnquantizedLinearMethod()
        return None

    # TODO: This special hard coded logic is not needed for quantized checkpoints
    # generated by ModelOpt >= 0.39.0 where they are handled natually by the
    # exclude_modules config. But need to keep them for loading quantized
    # checkpoints generated by older versions. Then check substring matching
    # for patterns not caught by exact match
    if "vision_tower" in prefix or "vision_model" in prefix:
        return UnquantizedLinearMethod()

    # now, the layer is quantized, handle it here
    if isinstance(layer, LinearBase):
        quant_method = HPUModelOptFp8LinearMethod(self)
        return quant_method
    elif isinstance(layer, FusedMoE):
        raise ValueError("FP8 modelopt quantization not yet supported on Gaudi")

    return None

get_supported_act_dtypes

get_supported_act_dtypes() -> list[dtype]
Source code in vllm_gaudi/ops/hpu_modelopt.py
def get_supported_act_dtypes(self) -> list[torch.dtype]:
    return [torch.bfloat16]

HPUModelOptFp8LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer static quantization on Gaudi. Supports loading FP8 checkpoints with static weight scale and activation scale. Future support might be added for dynamic scales.

Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config.

Source code in vllm_gaudi/ops/hpu_modelopt.py
class HPUModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization on Gaudi.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale. Future support might be added for dynamic
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn datatype
        Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        # Use V2 version of weight loader
        # See https://github.com/vllm-project/vllm/blob/releases/v0.11.2/vllm/model_executor/layers/linear.py#L493
        weight_loader = extra_weight_attrs.get("weight_loader")

        if layer and hasattr(layer, "weight_loader_v2"):
            weight_loader = layer.weight_loader_v2

        if hpu_ops.is_hpu_gaudi2:
            weight_loader = hpu_ops.gaudi_weight_wrapper(weight_loader)
        weight_loader = hpu_ops.synced_weight_loader(weight_loader)

        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype)
        weight = ModelWeightParameter(
            data=torch.empty(output_size_per_partition, input_size_per_partition, dtype=weight_dtype),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = hpu_ops.requantize_with_max_scale(layer.weight, layer.weight_scale,
                                                                    layer.logical_widths)
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        weight_scale = layer.weight_scale.transpose(0, 1) if layer.weight_scale.dim() > 1 else layer.weight_scale
        input_scale = getattr(layer, 'input_scale', None)

        # View input as 2D matrix for fp8 methods
        input_2d = x.view(-1, x.shape[-1])
        output_shape = [*x.shape[:-1], layer.weight.shape[1]]
        output = hpu_ops.apply_fp8_linear_hpu(input=input_2d,
                                              weight=layer.weight,
                                              weight_scale=weight_scale,
                                              input_scale=input_scale,
                                              bias=bias,
                                              trans_B=False)
        output = torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
        return output

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptFp8Config) -> None
Source code in vllm_gaudi/ops/hpu_modelopt.py
def __init__(self, quant_config: ModelOptFp8Config) -> None:
    self.quant_config = quant_config

apply

apply(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor
Source code in vllm_gaudi/ops/hpu_modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    weight_scale = layer.weight_scale.transpose(0, 1) if layer.weight_scale.dim() > 1 else layer.weight_scale
    input_scale = getattr(layer, 'input_scale', None)

    # View input as 2D matrix for fp8 methods
    input_2d = x.view(-1, x.shape[-1])
    output_shape = [*x.shape[:-1], layer.weight.shape[1]]
    output = hpu_ops.apply_fp8_linear_hpu(input=input_2d,
                                          weight=layer.weight,
                                          weight_scale=weight_scale,
                                          input_scale=input_scale,
                                          bias=bias,
                                          trans_B=False)
    output = torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
    return output

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm_gaudi/ops/hpu_modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    output_size_per_partition = sum(output_partition_sizes)
    # Use V2 version of weight loader
    # See https://github.com/vllm-project/vllm/blob/releases/v0.11.2/vllm/model_executor/layers/linear.py#L493
    weight_loader = extra_weight_attrs.get("weight_loader")

    if layer and hasattr(layer, "weight_loader_v2"):
        weight_loader = layer.weight_loader_v2

    if hpu_ops.is_hpu_gaudi2:
        weight_loader = hpu_ops.gaudi_weight_wrapper(weight_loader)
    weight_loader = hpu_ops.synced_weight_loader(weight_loader)

    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype)
    weight = ModelWeightParameter(
        data=torch.empty(output_size_per_partition, input_size_per_partition, dtype=weight_dtype),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight", weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALE
        weight_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)
        # INPUT SCALE
        scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )

        scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("input_scale", scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm_gaudi/ops/hpu_modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:
    weight = layer.weight
    max_w_scale = layer.weight_scale.max()
    if not (layer.weight_scale == layer.weight_scale[0]).all():
        max_w_scale, weight = hpu_ops.requantize_with_max_scale(layer.weight, layer.weight_scale,
                                                                layer.logical_widths)
    layer.weight = Parameter(weight.t(), requires_grad=False)
    layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
    layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)