def _quant_dequant_mxfp6(
x: torch.Tensor,
quant_dtype: str,
scale_calculation_mode: str = "even",
) -> torch.Tensor:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale,
)
from quark.torch.quantization.utils import even_round, reshape_to_blocks
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use "
"MX-FP6 models. Please install it with `pip install "
"amd-quark`."
) from err
axis = -1
block_x = reshape_to_blocks(x, OCP_MX_BLOCK_SIZE, axis)
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
amax = amax.squeeze(-1)
# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if scale_calculation_mode != "even":
raise NotImplementedError(
f"Scale calculation mode {scale_calculation_mode} is not yet "
"supported in MX-FP6 quantization"
)
scale = even_round(amax, quant_dtype)
# Apply dequantize(quantize(x)).
x = fake_quantize_fp4_fp6_per_group_with_scale(
x,
scale.to(x.device),
axis=axis,
group_size=OCP_MX_BLOCK_SIZE,
quant_dtype=quant_dtype,
)
return x