class RocmDotScaledMxfp8LinearKernel(Mxfp8LinearKernel):
"""Native CDNA4 (gfx950) MXFP8 linear via Triton ``tl.dot_scaled``."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "not ROCm"
# supports_mx() == gfx95x (CDNA4 native microscaling hardware). On other
# archs dot_scaled would upcast to BF16, so the kernel selector falls
# through to the BF16 emulation (hipBLASLt) path instead.
if not current_platform.supports_mx():
return False, "native MX requires CDNA4 (gfx95x)"
return True, None
@classmethod
def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight.data # [N, K] fp8
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"Expected {MXFP8_SCALE_DTYPE} weight_scale, got "
f"{layer.weight_scale.dtype}."
)
out_shape = (*x.shape[:-1], layer.weight.shape[0])
x2d = x.reshape(-1, x.shape[-1])
if x2d.shape[-1] % 128 == 0:
out = _mxfp8_dot_scaled_linear(x2d, layer.weight, layer.weight_scale)
else:
# dot_scaled tiling needs K % 128 == 0; dequantize fallback otherwise.
w_bf16 = dequant_mxfp8_to_bf16(layer.weight, layer.weight_scale)
out = torch.nn.functional.linear(x2d, w_bf16).to(x.dtype)
out = out.reshape(out_shape)
if bias is not None:
out = out + bias
return out