def _load_weights_mxfp4_dequantize_hpu(
self,
ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
use_ep = self.parallel_config.enable_expert_parallel
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size
# Use cdiv-based per-rank partitioning to match FusedMoE's bf16 param
# layout (which is what gets allocated here because the gpt_oss mxfp4
# quant config is bypassed in `_patched_normalize_quantization_config`).
# Block-aligned partitioning would over-/under-size the rank slice when
# `intermediate_size` is not divisible by `OCP_MX_BLOCK_SIZE * tp_size`
# (e.g. gpt-oss-120b: 2880 / (32*4) = 22.5).
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
local_intermediate_size = tp_rank_end - tp_rank_start
# For w2 the intermediate dim is the K (reduction) axis and mxfp4 scales
# are stored per `OCP_MX_BLOCK_SIZE` block along K. When the rank range
# is not block-aligned, expand outward to a block-aligned window for
# dequantization, then crop the dequantized result back to the rank's
# true range using `k_offset`.
k_block_start = tp_rank_start // OCP_MX_BLOCK_SIZE
k_block_end = cdiv(tp_rank_end, OCP_MX_BLOCK_SIZE)
k_offset = tp_rank_start - k_block_start * OCP_MX_BLOCK_SIZE
block_weight_dict = {}
for name, weight in weights:
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ".w13_weight_scale" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
if use_ep:
narrow_weight_scale = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight_scale = weight[:, 2 * tp_rank_start:2 * tp_rank_end, :]
narrow_weight_scale = narrow_weight_scale.contiguous()
# Read block weight
block_name = name.replace("weight_scale", "weight")
if block_name not in block_weight_dict:
raise ValueError(f"Expected block weight for {block_name} not found when processing {name}")
block_weight = block_weight_dict[block_name]
param = params_dict[block_name]
weight = convert_moe_packed_tensors(block_weight, narrow_weight_scale)
if use_ep:
param.copy_(weight)
else:
param[:, :2 * (tp_rank_end - tp_rank_start), :] = weight
del block_weight_dict[block_name]
loaded_params.add(name)
continue
elif ".w13_weight" in name:
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end, :, :]
narrow_weight = narrow_weight.contiguous()
block_weight_dict[name] = narrow_weight
loaded_params.add(name)
continue
elif ".w2_weight_scale" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight_scale = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight_scale = weight[..., k_block_start:k_block_end]
narrow_weight_scale = narrow_weight_scale.contiguous()
# Read block weight
block_name = name.replace("weight_scale", "weight")
if block_name not in block_weight_dict:
raise ValueError(f"Expected block weight for {block_name} not found when processing {name}")
block_weight = block_weight_dict[block_name]
param = params_dict[block_name]
weight = convert_moe_packed_tensors(block_weight, narrow_weight_scale)
if use_ep:
param.copy_(weight)
else:
# Crop block-aligned dequant output to the rank's true range.
param[:, :, :local_intermediate_size] = weight[..., k_offset:k_offset + local_intermediate_size]
del block_weight_dict[block_name]
loaded_params.add(name)
continue
elif ".w2_weight" in name:
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :, k_block_start:k_block_end, :]
narrow_weight = narrow_weight.contiguous()
block_weight_dict[name] = narrow_weight
loaded_params.add(name)
continue
elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.contiguous()
param = params_dict[name]
if use_ep:
param.copy_(narrow_weight)
else:
param[:, :2 * (tp_rank_end - tp_rank_start)] = narrow_weight
loaded_params.add(name)
continue
elif ".w2_bias" in name:
# Handle MLP down projection bias
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[name]
param.copy_(weight)
loaded_params.add(name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else:
# Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
return loaded_params