class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
- Bias: Same data type as original weights
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
quantized inside the kernel
"""
def __init__(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.weight_quant = weight_quant
self.input_quant = input_quant
self.static_input_scales = False # always dynamic per token
# Weight can be channel-wise (group_size=None) or group-wise
self.group_size = (
weight_quant.group_size if (weight_quant.group_size is not None) else -1
)
# Validate scheme: weights=W4 (channel or group),
# activations=dynamic TOKEN (A8)
# Must be dynamic per-token activations
if (
input_quant.strategy != QuantizationStrategy.TOKEN
or not input_quant.dynamic
):
raise ValueError(
"W4A8-int MoE needs dynamic per-token activation quantization."
)
if weight_quant.num_bits != 4:
raise ValueError("This method only supports 4-bit weights (num_bits=4).")
# Arm: check _dyn ops availability
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
_ = torch.ops.aten._dyn_quant_matmul_4bit
_ = torch.ops.aten._dyn_quant_pack_4bit_weight
except AttributeError as err:
raise RuntimeError(
f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops;
install a newer build."""
) from err
# Construct QuantKey for weights from QuantizationArgs
# W4A8 INT4: 4-bit weights (stored as int8), static quantization
if self.group_size == -1:
# Channel-wise quantization
group_shape = GroupShape(-1, 1)
scale_dtype = torch.float32
else:
# Group-wise quantization
group_shape = GroupShape(1, self.group_size)
scale_dtype = torch.bfloat16
weight_scale_desc = ScaleDesc(scale_dtype, static=True, group_shape=group_shape)
weight_key = QuantKey(torch.int8, weight_scale_desc, symmetric=True)
self.backend, self.experts_cls = select_w4a8_int8_moe_backend(
moe,
weight_key,
activation_key=None, # unquantized inputs
)
# ---- parameter creation ----
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Shapes per local rank (TP/EP):
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
# w2 : [E, H, I_local] int8
# Scales:
# channel-wise: group_size=-1 -> per-output-row, single scale per row
# group-wise : group_size=g ->
# per-output-row, (in_features/g) scales
E = num_experts
H = hidden_size
IN = intermediate_size_per_partition
g = self.group_size
# Per-row scale columns
def _n_scale_cols(in_features: int) -> int:
return 1 if g == -1 else (in_features // g)
# Register unpacked int4-as-int8 weights the loader will fill.
w13 = torch.nn.Parameter(
torch.empty(E, 2 * IN, H, dtype=torch.int8), requires_grad=False
)
set_weight_attrs(w13, extra_weight_attrs)
layer.register_parameter("w13_weight", w13)
w2 = torch.nn.Parameter(
torch.empty(E, H, IN, dtype=torch.int8), requires_grad=False
)
set_weight_attrs(w2, extra_weight_attrs)
layer.register_parameter("w2_weight", w2)
# Register scales
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
w13_s = torch.nn.Parameter(
torch.ones(E, 2 * IN, _n_scale_cols(H), dtype=scale_dtype),
requires_grad=False,
)
set_weight_attrs(
w13_s,
{"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs},
)
layer.register_parameter("w13_weight_scale", w13_s)
w2_s = torch.nn.Parameter(
torch.ones(E, H, _n_scale_cols(IN), dtype=scale_dtype), requires_grad=False
)
set_weight_attrs(
w2_s,
{"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs},
)
layer.register_parameter("w2_weight_scale", w2_s)
if self.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(E, 2 * IN, dtype=params_dtype), requires_grad=False
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
# Placeholders for packed weights (will be replaced after packing)
layer.register_parameter(
"w13_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False)
)
set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs)
layer.register_parameter(
"w2_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False)
)
set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs)
# dims for 4 bit fused matmuls
layer.w13_in_features = H
layer.w13_out_features = 2 * IN
layer.w2_in_features = IN
layer.w2_out_features = H
layer.group_size = g
# post-load packing to dyn-4bit KleidiAI kernel's format
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Use oracle to pack weights.
w13_packed, w2_packed, w13_weight_scale, w2_weight_scale, w13_bias, w2_bias = (
convert_to_w4a8_int8_moe_format(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
w13_weight_scale=layer.w13_weight_scale,
w2_weight_scale=layer.w2_weight_scale,
group_size=self.group_size,
w13_bias=layer.w13_bias if self.has_bias else None,
w2_bias=layer.w2_bias if self.has_bias else None,
)
)
# Register packed weights as parameters
replace_parameter(
layer,
"w13_weight_packed",
torch.nn.Parameter(w13_packed, requires_grad=False),
)
replace_parameter(
layer,
"w2_weight_packed",
torch.nn.Parameter(w2_packed, requires_grad=False),
)
replace_parameter(
layer, "w13_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False)
)
replace_parameter(
layer, "w2_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False)
)
replace_parameter(
layer,
"w13_weight_scale",
torch.nn.Parameter(w13_weight_scale, requires_grad=False),
)
replace_parameter(
layer,
"w2_weight_scale",
torch.nn.Parameter(w2_weight_scale, requires_grad=False),
)
if self.has_bias:
replace_parameter(
layer,
"w13_bias",
torch.nn.Parameter(w13_bias, requires_grad=False),
)
if self.has_bias:
replace_parameter(
layer,
"w2_bias",
torch.nn.Parameter(w2_bias, requires_grad=False),
)
quant_config = self.get_fused_moe_quant_config(layer)
assert quant_config is not None
assert self.experts_cls is not None
self.moe_kernel = make_w4a8_int8_moe_kernel(
moe_quant_config=quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
routing_tables=layer._expert_routing_tables(),
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# Determine block shape from group_size
# group_size=-1 means channel-wise: (-1, 1)
# group_size=N means group-wise: (1, N)
block_shape = (-1, 1) if self.group_size == -1 else (1, self.group_size)
return make_w4a8_int8_moe_quant_config(block_shape=block_shape)
def apply_monolithic(
self,
layer: RoutedExperts,
x: torch.Tensor,
router_logits: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: RoutedExperts,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts: SharedExperts | None,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts=shared_experts,
shared_experts_input=shared_experts_input,
)