Skip to content

vllm.model_executor.layers.quantization.utils.mxfp6_utils

_dequant_mxfp6

_dequant_mxfp6(
    x: Tensor,
    scale: Tensor,
    float_dtype: dtype,
    quant_dtype: str,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/utils/mxfp6_utils.py
def _dequant_mxfp6(
    x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
) -> torch.Tensor:
    try:
        from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
            dequantize_fp4_fp6_per_group,
        )
        from quark.torch.utils.pack import create_pack_method
    except ImportError as e:
        raise ImportError(
            "The package `amd-quark` is required to use "
            "MX-FP6 models. Please install it with `pip install "
            "amd-quark`."
        ) from e

    pack_method = create_pack_method(None, dtype=quant_dtype)
    unpacked_x = pack_method.unpack(x, reorder=False)

    scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype)

    # TODO: `dequantize_fp4_fp6_per_group` and `prepare_inputs_per_group`
    # always return fp32.
    return dequantize_fp4_fp6_per_group(
        unpacked_x,
        scale,
        axis=-1,
        group_size=OCP_MX_BLOCK_SIZE,
        quant_dtype=quant_dtype,
    ).to(float_dtype)

_dequant_mxfp6_fake

_dequant_mxfp6_fake(
    x: Tensor,
    scale: Tensor,
    float_dtype: dtype,
    quant_dtype: str,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/utils/mxfp6_utils.py
def _dequant_mxfp6_fake(
    x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
) -> torch.Tensor:
    assert (x.shape[-1] * 4) % 3 == 0
    return torch.empty(
        (*x.shape[:-1], (x.shape[-1] * 4) // 3), dtype=float_dtype, device=x.device
    )

_quant_dequant_mxfp6

_quant_dequant_mxfp6(
    x: Tensor,
    quant_dtype: str,
    scale_calculation_mode: str = "even",
) -> Tensor
Source code in vllm/model_executor/layers/quantization/utils/mxfp6_utils.py
def _quant_dequant_mxfp6(
    x: torch.Tensor,
    quant_dtype: str,
    scale_calculation_mode: str = "even",
) -> torch.Tensor:
    try:
        from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
            fake_quantize_fp4_fp6_per_group_with_scale,
        )
        from quark.torch.quantization.utils import even_round, reshape_to_blocks
    except ImportError as err:
        raise ImportError(
            "The package `amd-quark` is required to use "
            "MX-FP6 models. Please install it with `pip install "
            "amd-quark`."
        ) from err

    axis = -1
    block_x = reshape_to_blocks(x, OCP_MX_BLOCK_SIZE, axis)
    amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
    amax = amax.squeeze(-1)

    # TODO: there are other rounding strategies supported in quark and in the
    # config.json that we do not check for here!
    if scale_calculation_mode != "even":
        raise NotImplementedError(
            f"Scale calculation mode {scale_calculation_mode} is not yet "
            "supported in MX-FP6 quantization"
        )
    scale = even_round(amax, quant_dtype)

    # Apply dequantize(quantize(x)).
    x = fake_quantize_fp4_fp6_per_group_with_scale(
        x,
        scale.to(x.device),
        axis=axis,
        group_size=OCP_MX_BLOCK_SIZE,
        quant_dtype=quant_dtype,
    )

    return x

_quant_dequant_mxfp6_fake

_quant_dequant_mxfp6_fake(
    x: Tensor,
    quant_dtype: str,
    scale_calculation_mode: str = "even",
) -> Tensor
Source code in vllm/model_executor/layers/quantization/utils/mxfp6_utils.py
def _quant_dequant_mxfp6_fake(
    x: torch.Tensor,
    quant_dtype: str,
    scale_calculation_mode: str = "even",
) -> torch.Tensor:
    return torch.empty_like(x)

dequant_mxfp6

dequant_mxfp6(
    x: Tensor,
    scale: Tensor,
    float_dtype: dtype,
    quant_dtype: str,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/utils/mxfp6_utils.py
def dequant_mxfp6(
    x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
) -> torch.Tensor:
    return torch.ops.vllm.dequant_mxfp6(x, scale, float_dtype, quant_dtype)

quant_dequant_mxfp6

quant_dequant_mxfp6(
    x: Tensor,
    quant_dtype: str,
    scale_calculation_mode: str = "even",
) -> Tensor
Source code in vllm/model_executor/layers/quantization/utils/mxfp6_utils.py
def quant_dequant_mxfp6(
    x: torch.Tensor,
    quant_dtype: str,
    scale_calculation_mode: str = "even",
) -> torch.Tensor:
    return torch.ops.vllm.quant_dequant_mxfp6(x, quant_dtype, scale_calculation_mode)