def _load_weights_mxfp4_dequantize_hpu(
self,
ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
use_ep = self.parallel_config.enable_expert_parallel
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
per_rank_intermediate_size = per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
block_weight_dict = {}
for name, weight in weights:
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ".w13_weight_scale" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
if use_ep:
narrow_weight_scale = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight_scale = weight[:, 2 * tp_rank_start:2 * tp_rank_end, :]
narrow_weight_scale = narrow_weight_scale.contiguous()
# Read block weight
block_name = name.replace("weight_scale", "weight")
if block_name not in block_weight_dict:
raise ValueError(f"Expected block weight for {block_name} not found when processing {name}")
block_weight = block_weight_dict[block_name]
param = params_dict[block_name]
weight = convert_moe_packed_tensors(block_weight, narrow_weight_scale)
param[:, :2 * (tp_rank_end - tp_rank_start), :] = weight
del block_weight_dict[block_name]
loaded_params.add(name)
continue
elif ".w13_weight" in name:
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end, :, :]
narrow_weight = narrow_weight.contiguous()
block_weight_dict[name] = narrow_weight
loaded_params.add(name)
continue
elif ".w2_weight_scale" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight_scale = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight_scale = weight[..., tp_rank_start // OCP_MX_BLOCK_SIZE:tp_rank_end // OCP_MX_BLOCK_SIZE]
narrow_weight_scale = narrow_weight_scale.contiguous()
# Read block weight
block_name = name.replace("weight_scale", "weight")
if block_name not in block_weight_dict:
raise ValueError(f"Expected block weight for {block_name} not found when processing {name}")
block_weight = block_weight_dict[block_name]
param = params_dict[block_name]
weight = convert_moe_packed_tensors(block_weight, narrow_weight_scale)
param[:, :, :(tp_rank_end - tp_rank_start)] = weight
del block_weight_dict[block_name]
loaded_params.add(name)
continue
elif ".w2_weight" in name:
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :, tp_rank_start // OCP_MX_BLOCK_SIZE:tp_rank_end // OCP_MX_BLOCK_SIZE, :]
narrow_weight = narrow_weight.contiguous()
block_weight_dict[name] = narrow_weight
loaded_params.add(name)
continue
elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.contiguous()
param = params_dict[name]
param[:, :2 * (tp_rank_end - tp_rank_start)] = narrow_weight
loaded_params.add(name)
continue
elif ".w2_bias" in name:
# Handle MLP down projection bias
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[name]
param.copy_(weight)
loaded_params.add(name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else:
# Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
return loaded_params