Skip to content

vllm_gaudi.models.gptoss_mxfp4

_original_load_weights module-attribute

_original_load_weights = load_weights

_original_normalize_quantization_config module-attribute

_original_normalize_quantization_config = (
    _normalize_quantization_config
)

_load_weights_mxfp4_dequantize_hpu

_load_weights_mxfp4_dequantize_hpu(
    self,
    ep_rank_end: int,
    ep_rank_start: int,
    heads_per_rank: int,
    head_start: int,
    weights: Iterable[tuple[str, Tensor]],
    stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]
Source code in vllm_gaudi/models/gptoss_mxfp4.py
def _load_weights_mxfp4_dequantize_hpu(
    self,
    ep_rank_end: int,
    ep_rank_start: int,
    heads_per_rank: int,
    head_start: int,
    weights: Iterable[tuple[str, torch.Tensor]],
    stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    use_ep = self.parallel_config.enable_expert_parallel

    # In MoE, we need to flatten the tensor parallel size across the data
    # parallel size when EP is disabled.
    tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
        tp_size=get_tensor_model_parallel_world_size(),
        dp_size=get_dp_group().world_size,
        dp_rank=get_dp_group().rank_in_group,
        pcp_size=get_pcp_group().world_size,
        pcp_rank=get_pcp_group().rank_in_group,
    )
    intermediate_size = self.config.intermediate_size
    # Use cdiv-based per-rank partitioning to match FusedMoE's bf16 param
    # layout (which is what gets allocated here because the gpt_oss mxfp4
    # quant config is bypassed in `_patched_normalize_quantization_config`).
    # Block-aligned partitioning would over-/under-size the rank slice when
    # `intermediate_size` is not divisible by `OCP_MX_BLOCK_SIZE * tp_size`
    # (e.g. gpt-oss-120b: 2880 / (32*4) = 22.5).
    per_rank_intermediate_size = cdiv(intermediate_size, tp_size)

    # Calculate common slicing bounds for current rank
    tp_rank_start = tp_rank * per_rank_intermediate_size
    tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
    local_intermediate_size = tp_rank_end - tp_rank_start

    # For w2 the intermediate dim is the K (reduction) axis and mxfp4 scales
    # are stored per `OCP_MX_BLOCK_SIZE` block along K. When the rank range
    # is not block-aligned, expand outward to a block-aligned window for
    # dequantization, then crop the dequantized result back to the rank's
    # true range using `k_offset`.
    k_block_start = tp_rank_start // OCP_MX_BLOCK_SIZE
    k_block_end = cdiv(tp_rank_end, OCP_MX_BLOCK_SIZE)
    k_offset = tp_rank_start - k_block_start * OCP_MX_BLOCK_SIZE

    block_weight_dict = {}

    for name, weight in weights:
        # Skip layers on other devices.
        if is_pp_missing_parameter(name, self):
            continue

        if ".w13_weight_scale" in name:
            # Handle MLP gate and up projection weights
            # Extract gate and up projection parts
            if use_ep:
                narrow_weight_scale = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight_scale = weight[:, 2 * tp_rank_start:2 * tp_rank_end, :]

            narrow_weight_scale = narrow_weight_scale.contiguous()

            # Read block weight
            block_name = name.replace("weight_scale", "weight")
            if block_name not in block_weight_dict:
                raise ValueError(f"Expected block weight for {block_name} not found when processing {name}")
            block_weight = block_weight_dict[block_name]
            param = params_dict[block_name]

            weight = convert_moe_packed_tensors(block_weight, narrow_weight_scale)
            if use_ep:
                param.copy_(weight)
            else:
                param[:, :2 * (tp_rank_end - tp_rank_start), :] = weight
            del block_weight_dict[block_name]
            loaded_params.add(name)
            continue
        elif ".w13_weight" in name:
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end, :, :]
            narrow_weight = narrow_weight.contiguous()
            block_weight_dict[name] = narrow_weight
            loaded_params.add(name)
            continue
        elif ".w2_weight_scale" in name:
            # Handle MLP down projection weights
            if use_ep:
                narrow_weight_scale = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight_scale = weight[..., k_block_start:k_block_end]
            narrow_weight_scale = narrow_weight_scale.contiguous()

            # Read block weight
            block_name = name.replace("weight_scale", "weight")
            if block_name not in block_weight_dict:
                raise ValueError(f"Expected block weight for {block_name} not found when processing {name}")
            block_weight = block_weight_dict[block_name]
            param = params_dict[block_name]

            weight = convert_moe_packed_tensors(block_weight, narrow_weight_scale)
            if use_ep:
                param.copy_(weight)
            else:
                # Crop block-aligned dequant output to the rank's true range.
                param[:, :, :local_intermediate_size] = weight[..., k_offset:k_offset + local_intermediate_size]
            del block_weight_dict[block_name]
            loaded_params.add(name)
            continue
        elif ".w2_weight" in name:
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:, :, k_block_start:k_block_end, :]
            narrow_weight = narrow_weight.contiguous()
            block_weight_dict[name] = narrow_weight
            loaded_params.add(name)
            continue
        elif ".w13_bias" in name:
            # Handle MLP gate and up projection biases
            # Extract gate and up projection bias parts
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end]
            narrow_weight = narrow_weight.contiguous()

            param = params_dict[name]
            if use_ep:
                param.copy_(narrow_weight)
            else:
                param[:, :2 * (tp_rank_end - tp_rank_start)] = narrow_weight
            loaded_params.add(name)
            continue
        elif ".w2_bias" in name:
            # Handle MLP down projection bias
            if use_ep:
                weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                # (only load on rank 0 to avoid duplication)
                if tp_rank != 0:
                    weight.zero_()
            param = params_dict[name]
            param.copy_(weight)
            loaded_params.add(name)
            continue
        elif "sinks" in name:
            # Handle attention sinks (distributed across ranks)
            param = params_dict[name]
            narrow_weight = weight.narrow(0, head_start, heads_per_rank)
            param.data.copy_(narrow_weight)
            loaded_params.add(name)
            continue
        for param_name, weight_name, shard_id in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            if weight_loader == default_weight_loader:
                weight_loader(param, weight)
            else:
                weight_loader(param, weight, shard_id)
            break
        else:
            # Handle all other weights with potential renaming
            if name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, weight)
        loaded_params.add(name)
    return loaded_params

_patched_normalize_quantization_config

_patched_normalize_quantization_config(
    self, config: PretrainedConfig
)
Source code in vllm_gaudi/models/gptoss_mxfp4.py
def _patched_normalize_quantization_config(self, config: PretrainedConfig):
    # Skip mxfp4 quantization to use custom loading logic for gpt_oss
    if getattr(config, "model_type", None) == "gpt_oss":
        quant_cfg = getattr(config, "quantization_config", None)
        if quant_cfg is not None and quant_cfg.get("quant_method", "").lower() == "mxfp4":
            return None

    # For all other models, use the original vLLM implementation
    return _original_normalize_quantization_config(self, config)

convert_moe_packed_tensors

convert_moe_packed_tensors(
    blocks,
    scales,
    *,
    dtype: dtype = bfloat16,
    rows_per_chunk: int = 32768 * 1024,
) -> Tensor

Convert the mxfp4 weights, dequantize and make them compatible with the forward pass of GPT_OSS.

Source code in vllm_gaudi/models/gptoss_mxfp4.py
def convert_moe_packed_tensors(
    blocks,
    scales,
    *,
    dtype: torch.dtype = torch.bfloat16,
    # Large default chosen to process many rows per kernel launch and reduce overhead;
    # lower this if you need to limit peak memory usage.
    rows_per_chunk: int = 32768 * 1024,
) -> torch.Tensor:
    """
    Convert the mxfp4 weights, dequantize and make them compatible with the forward
    pass of GPT_OSS.
    """
    import math

    FP4_VALUES = [
        +0.0,
        +0.5,
        +1.0,
        +1.5,
        +2.0,
        +3.0,
        +4.0,
        +6.0,
        -0.0,
        -0.5,
        -1.0,
        -1.5,
        -2.0,
        -3.0,
        -4.0,
        -6.0,
    ]

    # MxFP4 stores the scale as an unsigned 8-bit exponent with a bias of 127
    # (i.e. values 0–255 represent exponents in the range -127…128). Subtract 127
    # to recover the signed exponent that torch.ldexp expects.
    scales = scales.to(torch.int32) - 127

    assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}"

    lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)

    *prefix_shape, G, B = blocks.shape
    rows_total = math.prod(prefix_shape) * G

    blocks = blocks.reshape(rows_total, B)
    scales = scales.reshape(rows_total, 1)

    out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)

    for r0 in range(0, rows_total, rows_per_chunk):
        r1 = min(r0 + rows_per_chunk, rows_total)

        blk = blocks[r0:r1]
        exp = scales[r0:r1]

        # nibble indices -> int64
        idx_lo = (blk & 0x0F).to(torch.long)
        idx_hi = (blk >> 4).to(torch.long)

        sub = out[r0:r1]
        sub[:, 0::2] = lut[idx_lo]
        sub[:, 1::2] = lut[idx_hi]

        torch.ldexp(sub, exp, out=sub)
        del idx_lo, idx_hi, blk, exp, sub

    out = out.reshape(*prefix_shape, G * B * 2).contiguous()
    del blocks, scales, lut
    return out

patched_load_weights

patched_load_weights(
    self, weights: Iterable[tuple[str, Tensor]]
) -> set[str]
Source code in vllm_gaudi/models/gptoss_mxfp4.py
def patched_load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    # Check if this is gpt_oss model with mxfp4 quantization
    quant_cfg = getattr(self.config, "quantization_config", None)
    quant_method = quant_cfg.get("quant_method") if quant_cfg else None

    # Only use custom loading for gpt_oss + mxfp4.
    # Newer vLLM normalizes the checkpoint's "mxfp4" to "gpt_oss_mxfp4" in
    # GptOssForCausalLMConfig.verify_and_update_model_config(), so accept both.
    if quant_method in ("mxfp4", "gpt_oss_mxfp4"):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

        return self._load_weights_mxfp4_dequantize_hpu(
            ep_rank_end,
            ep_rank_start,
            heads_per_rank,
            head_start,
            weights,
            stacked_params_mapping,
        )

    # For all other models, use the original vLLM implementation
    return _original_load_weights(self, weights)