Skip to content

vllm.model_executor.layers.quantization.quark.schemes

Modules:

Name Description
quark_ocp_mx
quark_scheme
quark_w8a8_fp8
quark_w8a8_int8

__all__ module-attribute

__all__ = [
    "QuarkScheme",
    "QuarkW8A8Fp8",
    "QuarkW8A8Int8",
    "QuarkOCP_MX",
]

QuarkOCP_MX

Bases: QuarkScheme

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
class QuarkOCP_MX(QuarkScheme):
    def __init__(
        self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
    ):
        self.out_dtype = torch.get_default_dtype()
        self.qscheme = "per_group"
        self.weight_quant_spec = weight_quant_spec
        self.input_quant_spec = input_quant_spec

        self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
        self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")

        self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
            self.input_dtype, self.weight_dtype
        )

        if self.weight_dtype == "mxfp4":
            self.packed_factor: Union[int, Fraction] = 2
            self.dequant_func = dequant_mxfp4
        else:
            self.packed_factor = Fraction(numerator=8, denominator=6)
            self.dequant_func = partial(
                dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
            )

        if self.input_dtype == "mxfp4":
            self.quant_dequant_func = quant_dequant_mxfp4
        else:
            self.quant_dequant_func = partial(
                quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
            )

        self.static_input_scales = not input_quant_spec.get("is_dynamic")

        if self.static_input_scales:
            raise NotImplementedError(
                "QuarkOCP_MX with static input scales is currently not "
                "implemented. Please open an issue."
            )

        # TODO: integrate (or test) mixed-precision kernel.
        self.emulate = not current_platform.supports_mx() or (
            self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
        )

        self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()

        if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
            # Currently need these kernels if not emulating
            raise NotImplementedError(
                f"{self.__class__.__name__} requires AITER to be installed "
                "for non-emulation mode! Please refer to "
                "https://github.com/ROCm/aiter for installation details."
            )

        if not current_platform.supports_mx():
            logger.warning_once(
                "The current platform does not support native MXFP4/MXFP6 "
                "computation. Simulated weight dequantization and activation "
                "QDQ (quantize and dequantize) will be used, with the linear "
                "layers computed in high precision."
            )

        if current_platform.supports_mx() and (
            self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
        ):
            logger.warning_once(
                "The current platform supports native MXFP4/MXFP6 "
                f"computation, but kernels for input_dtype={self.input_dtype} "
                f"and weight_dtype={self.weight_dtype} are not yet integrated "
                "in vLLM. Simulated weight dequantization and activation "
                "QDQ (quantize and dequantize) will be used, with the linear "
                "layers computed in high precision."
            )

    def get_packed_dim(self, dim: int, quant_dtype: str):
        if quant_dtype == "mxfp4":
            assert dim % 2 == 0
            return dim // 2
        elif quant_dtype in {"mxfp6_e3m2", "mxfp6_e2m3"}:
            # FP6 packs 4 * 6 = 24 bits on 3 bytes.
            assert (dim * 3) % 4 == 0
            return (dim * 3) // 4
        else:
            raise NotImplementedError(
                "Unsupported quant_dtype in QuarkOCP_MX.get_packed_dim, "
                f"got quant_dtype={quant_dtype}. Something is wrong, please "
                "open an issue."
            )

    @classmethod
    def get_min_capability(cls) -> int:
        return 70

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)

        if self.emulate:
            layer.weight_scale = torch.nn.Parameter(
                layer.weight_scale.data, requires_grad=False
            )
        else:
            if self.rocm_use_aiter_fp4_asm_gemm:
                # shuffle weight scale
                weight_scale_shuffle = layer.weight_scale.data
                sm, sn = weight_scale_shuffle.shape
                weight_scale_shuffle = weight_scale_shuffle.view(
                    sm // 32, 2, 16, sn // 8, 2, 4, 1
                )
                weight_scale_shuffle = weight_scale_shuffle.permute(
                    0, 3, 5, 2, 4, 1, 6
                ).contiguous()
                weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
                layer.weight_scale = torch.nn.Parameter(
                    weight_scale_shuffle, requires_grad=False
                )

                # shuffle weight
                weight_shuffle = layer.weight.data
                weight_shuffle = shuffle_weight(weight_shuffle, layout=(16, 16))
                layer.weight = torch.nn.Parameter(weight_shuffle, requires_grad=False)
            else:
                layer.weight_scale = torch.nn.Parameter(
                    layer.weight_scale.data.T.contiguous(), requires_grad=False
                )

    def create_weights(
        self,
        layer: torch.nn.Module,
        output_partition_sizes: list[int],
        input_size_per_partition: int,
        params_dtype: torch.dtype,
        weight_loader: Callable,
        **kwargs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes

        # WEIGHT
        weight = PackedvLLMParameter(
            data=torch.empty(
                output_size_per_partition,
                self.get_packed_dim(input_size_per_partition, self.weight_dtype),
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            packed_dim=1,
            packed_factor=self.packed_factor,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        weight_scale = GroupQuantScaleParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // OCP_MX_BLOCK_SIZE,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale", weight_scale)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if self.emulate:
            dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
            qdq_x = self.quant_dequant_func(x)
            return F.linear(qdq_x, dq_w, bias)
        else:
            return torch.ops.vllm.gemm_with_dynamic_quant(
                x,
                layer.weight,
                layer.weight_scale,
                self.rocm_use_aiter_fp4_asm_gemm,
                self.out_dtype,
            )

dequant_func instance-attribute

dequant_func = dequant_mxfp4

emulate instance-attribute

emulate = not supports_mx() or (
    input_dtype != "mxfp4" or weight_dtype != "mxfp4"
)

input_dtype instance-attribute

input_dtype = replace('fp', 'mxfp')

input_quant_spec instance-attribute

input_quant_spec = input_quant_spec

ocp_mx_scheme instance-attribute

ocp_mx_scheme = from_quant_dtype(input_dtype, weight_dtype)

out_dtype instance-attribute

out_dtype = get_default_dtype()

packed_factor instance-attribute

packed_factor: Union[int, Fraction] = 2

qscheme instance-attribute

qscheme = 'per_group'

quant_dequant_func instance-attribute

quant_dequant_func = quant_dequant_mxfp4

rocm_use_aiter_fp4_asm_gemm instance-attribute

rocm_use_aiter_fp4_asm_gemm = (
    is_rocm_aiter_fp4_asm_gemm_enabled()
)

static_input_scales instance-attribute

static_input_scales = not get('is_dynamic')

weight_dtype instance-attribute

weight_dtype = replace('fp', 'mxfp')

weight_quant_spec instance-attribute

weight_quant_spec = weight_quant_spec

__init__

__init__(
    weight_quant_spec: dict[str, Any],
    input_quant_spec: dict[str, Any],
)
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
def __init__(
    self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
):
    self.out_dtype = torch.get_default_dtype()
    self.qscheme = "per_group"
    self.weight_quant_spec = weight_quant_spec
    self.input_quant_spec = input_quant_spec

    self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
    self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")

    self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
        self.input_dtype, self.weight_dtype
    )

    if self.weight_dtype == "mxfp4":
        self.packed_factor: Union[int, Fraction] = 2
        self.dequant_func = dequant_mxfp4
    else:
        self.packed_factor = Fraction(numerator=8, denominator=6)
        self.dequant_func = partial(
            dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
        )

    if self.input_dtype == "mxfp4":
        self.quant_dequant_func = quant_dequant_mxfp4
    else:
        self.quant_dequant_func = partial(
            quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
        )

    self.static_input_scales = not input_quant_spec.get("is_dynamic")

    if self.static_input_scales:
        raise NotImplementedError(
            "QuarkOCP_MX with static input scales is currently not "
            "implemented. Please open an issue."
        )

    # TODO: integrate (or test) mixed-precision kernel.
    self.emulate = not current_platform.supports_mx() or (
        self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
    )

    self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()

    if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
        # Currently need these kernels if not emulating
        raise NotImplementedError(
            f"{self.__class__.__name__} requires AITER to be installed "
            "for non-emulation mode! Please refer to "
            "https://github.com/ROCm/aiter for installation details."
        )

    if not current_platform.supports_mx():
        logger.warning_once(
            "The current platform does not support native MXFP4/MXFP6 "
            "computation. Simulated weight dequantization and activation "
            "QDQ (quantize and dequantize) will be used, with the linear "
            "layers computed in high precision."
        )

    if current_platform.supports_mx() and (
        self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
    ):
        logger.warning_once(
            "The current platform supports native MXFP4/MXFP6 "
            f"computation, but kernels for input_dtype={self.input_dtype} "
            f"and weight_dtype={self.weight_dtype} are not yet integrated "
            "in vLLM. Simulated weight dequantization and activation "
            "QDQ (quantize and dequantize) will be used, with the linear "
            "layers computed in high precision."
        )

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if self.emulate:
        dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
        qdq_x = self.quant_dequant_func(x)
        return F.linear(qdq_x, dq_w, bias)
    else:
        return torch.ops.vllm.gemm_with_dynamic_quant(
            x,
            layer.weight,
            layer.weight_scale,
            self.rocm_use_aiter_fp4_asm_gemm,
            self.out_dtype,
        )

create_weights

create_weights(
    layer: Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
def create_weights(
    self,
    layer: torch.nn.Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: torch.dtype,
    weight_loader: Callable,
    **kwargs,
):
    output_size_per_partition = sum(output_partition_sizes)
    layer.logical_widths = output_partition_sizes

    # WEIGHT
    weight = PackedvLLMParameter(
        data=torch.empty(
            output_size_per_partition,
            self.get_packed_dim(input_size_per_partition, self.weight_dtype),
            dtype=torch.uint8,
        ),
        input_dim=1,
        output_dim=0,
        packed_dim=1,
        packed_factor=self.packed_factor,
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    weight_scale = GroupQuantScaleParameter(
        data=torch.empty(
            output_size_per_partition,
            input_size_per_partition // OCP_MX_BLOCK_SIZE,
            dtype=torch.uint8,
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight_scale", weight_scale)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
@classmethod
def get_min_capability(cls) -> int:
    return 70

get_packed_dim

get_packed_dim(dim: int, quant_dtype: str)
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
def get_packed_dim(self, dim: int, quant_dtype: str):
    if quant_dtype == "mxfp4":
        assert dim % 2 == 0
        return dim // 2
    elif quant_dtype in {"mxfp6_e3m2", "mxfp6_e2m3"}:
        # FP6 packs 4 * 6 = 24 bits on 3 bytes.
        assert (dim * 3) % 4 == 0
        return (dim * 3) // 4
    else:
        raise NotImplementedError(
            "Unsupported quant_dtype in QuarkOCP_MX.get_packed_dim, "
            f"got quant_dtype={quant_dtype}. Something is wrong, please "
            "open an issue."
        )

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)

    if self.emulate:
        layer.weight_scale = torch.nn.Parameter(
            layer.weight_scale.data, requires_grad=False
        )
    else:
        if self.rocm_use_aiter_fp4_asm_gemm:
            # shuffle weight scale
            weight_scale_shuffle = layer.weight_scale.data
            sm, sn = weight_scale_shuffle.shape
            weight_scale_shuffle = weight_scale_shuffle.view(
                sm // 32, 2, 16, sn // 8, 2, 4, 1
            )
            weight_scale_shuffle = weight_scale_shuffle.permute(
                0, 3, 5, 2, 4, 1, 6
            ).contiguous()
            weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
            layer.weight_scale = torch.nn.Parameter(
                weight_scale_shuffle, requires_grad=False
            )

            # shuffle weight
            weight_shuffle = layer.weight.data
            weight_shuffle = shuffle_weight(weight_shuffle, layout=(16, 16))
            layer.weight = torch.nn.Parameter(weight_shuffle, requires_grad=False)
        else:
            layer.weight_scale = torch.nn.Parameter(
                layer.weight_scale.data.T.contiguous(), requires_grad=False
            )

QuarkScheme

Bases: ABC

Abstract class used to describe the weight creation and forward pass of different quantization schemes supported by Quark.

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
class QuarkScheme(ABC):
    """
    Abstract class used to describe the weight creation and forward pass
    of different quantization schemes supported by Quark.
    """

    @classmethod
    @abstractmethod
    def get_min_capability(cls) -> int:
        """
        Get minimum device capability.
        """
        raise NotImplementedError

    @abstractmethod
    def create_weights(self, *args, **kwargs):
        """
        Weight creation for the particular scheme. Inputs to this function

        """
        raise NotImplementedError

    @abstractmethod
    def apply_weights(
        self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
    ):
        """
        Run the forward pass for the particular scheme. This is where
        scheme-specific dequant/quant steps/kernels should be applied.

        :param layer: torch.nn.Module with the registered weights and
            other parameters relevant to the particular scheme.
        :param x: input to the layer
        :param bias: bias parameter

        """
        raise NotImplementedError

    @abstractmethod
    def process_weights_after_loading(self, layer: torch.nn.Module):
        """
        Called after weight loading is complete for any cleanup that
        needs to occur.
        """
        raise NotImplementedError

apply_weights abstractmethod

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor]
)

Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied.

:param layer: torch.nn.Module with the registered weights and other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@abstractmethod
def apply_weights(
    self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
):
    """
    Run the forward pass for the particular scheme. This is where
    scheme-specific dequant/quant steps/kernels should be applied.

    :param layer: torch.nn.Module with the registered weights and
        other parameters relevant to the particular scheme.
    :param x: input to the layer
    :param bias: bias parameter

    """
    raise NotImplementedError

create_weights abstractmethod

create_weights(*args, **kwargs)

Weight creation for the particular scheme. Inputs to this function

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@abstractmethod
def create_weights(self, *args, **kwargs):
    """
    Weight creation for the particular scheme. Inputs to this function

    """
    raise NotImplementedError

get_min_capability abstractmethod classmethod

get_min_capability() -> int

Get minimum device capability.

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
    """
    Get minimum device capability.
    """
    raise NotImplementedError

process_weights_after_loading abstractmethod

process_weights_after_loading(layer: Module)

Called after weight loading is complete for any cleanup that needs to occur.

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
    """
    Called after weight loading is complete for any cleanup that
    needs to occur.
    """
    raise NotImplementedError

QuarkW8A8Fp8

Bases: QuarkScheme

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
class QuarkW8A8Fp8(QuarkScheme):
    def __init__(
        self, weight_config: dict[str, Any], input_config: Optional[dict[str, Any]]
    ):
        self.weight_qscheme = cast(str, weight_config.get("qscheme"))
        self.is_static_input_scheme: bool = False
        self.input_qscheme: Optional[str] = None
        if input_config is not None:
            self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
            self.input_qscheme = cast(str, input_config.get("qscheme"))

        per_token = (
            not self.is_static_input_scheme and self.input_qscheme == "per_channel"
        )
        self.act_quant_group_shape = (
            GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
        )
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=self.is_static_input_scheme,
            act_quant_group_shape=self.act_quant_group_shape,
        )
        self.out_dtype = torch.get_default_dtype()

    @classmethod
    def get_min_capability(cls) -> int:
        # lovelace and up
        return 89

    def process_weights_after_loading(self, layer) -> None:
        # If per tensor, when we have a fused module (e.g. QKV) with per
        # tensor scales (thus N scales being passed to the kernel),
        # requantize so we can always run per tensor
        if self.weight_qscheme == "per_tensor":
            if current_platform.is_fp8_fnuz():
                input_scale = getattr(layer, "input_scale", None)
                weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
                    weight=layer.weight,
                    weight_scale=layer.weight_scale,
                    input_scale=input_scale,
                )
                if input_scale is not None:
                    layer.input_scale = Parameter(input_scale, requires_grad=False)
            else:
                max_w_scale = layer.weight_scale
                weight = layer.weight

            max_w_scale, weight = requantize_with_max_scale(
                weight=weight,
                weight_scale=max_w_scale,
                logical_widths=layer.logical_widths,
            )

            layer.weight = Parameter(weight.t(), requires_grad=False)
            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

        # If channelwise, scales are already lined up, so just transpose.
        elif self.weight_qscheme == "per_channel":
            weight = layer.weight

            if current_platform.is_fp8_fnuz():
                input_scale = getattr(layer, "input_scale", None)
                weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
                    weight=weight,
                    weight_scale=layer.weight_scale,
                    input_scale=input_scale,
                )
                if input_scale is not None:
                    layer.input_scale = Parameter(input_scale, requires_grad=False)
            else:
                weight_scale = layer.weight_scale.data
            if self.act_quant_group_shape == GroupShape.PER_TOKEN:
                weight_scale = weight_scale.view(-1, 1)
            layer.weight = Parameter(weight.t(), requires_grad=False)
            # required by torch.compile to be torch.nn.Parameter
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)

        else:
            raise ValueError(f"Unknown quantization scheme {self.weight_qscheme}")

        # INPUT SCALE
        if self.is_static_input_scheme:
            layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
        else:
            layer.input_scale = None

    def create_weights(
        self,
        layer: torch.nn.Module,
        output_partition_sizes: list[int],
        input_size_per_partition: int,
        params_dtype: torch.dtype,
        weight_loader: Callable,
        **kwargs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes

        # WEIGHT
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=torch.float8_e4m3fn,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        # TODO: update create_xxx_parameter functions to return
        # the newly added parameters
        if self.weight_qscheme == "per_channel":
            weight_scale = ChannelQuantScaleParameter(
                data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
                output_dim=0,
                weight_loader=weight_loader,
            )
        else:
            assert self.weight_qscheme == "per_tensor"
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )

        # min requirement for fp8 kernels
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

        # INPUT SCALE
        if self.is_static_input_scheme:
            input_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
            input_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", input_scale)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )

act_quant_group_shape instance-attribute

act_quant_group_shape = (
    PER_TOKEN if per_token else PER_TENSOR
)

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(
    act_quant_static=is_static_input_scheme,
    act_quant_group_shape=act_quant_group_shape,
)

input_qscheme instance-attribute

input_qscheme: Optional[str] = None

is_static_input_scheme instance-attribute

is_static_input_scheme: bool = False

out_dtype instance-attribute

out_dtype = get_default_dtype()

weight_qscheme instance-attribute

weight_qscheme = cast(str, get('qscheme'))

__init__

__init__(
    weight_config: dict[str, Any],
    input_config: Optional[dict[str, Any]],
)
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
def __init__(
    self, weight_config: dict[str, Any], input_config: Optional[dict[str, Any]]
):
    self.weight_qscheme = cast(str, weight_config.get("qscheme"))
    self.is_static_input_scheme: bool = False
    self.input_qscheme: Optional[str] = None
    if input_config is not None:
        self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
        self.input_qscheme = cast(str, input_config.get("qscheme"))

    per_token = (
        not self.is_static_input_scheme and self.input_qscheme == "per_channel"
    )
    self.act_quant_group_shape = (
        GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
    )
    self.fp8_linear = Fp8LinearOp(
        act_quant_static=self.is_static_input_scheme,
        act_quant_group_shape=self.act_quant_group_shape,
    )
    self.out_dtype = torch.get_default_dtype()

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return self.fp8_linear.apply(
        input=x,
        weight=layer.weight,
        weight_scale=layer.weight_scale,
        out_dtype=self.out_dtype,
        input_scale=layer.input_scale,
        bias=bias,
    )

create_weights

create_weights(
    layer: Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
def create_weights(
    self,
    layer: torch.nn.Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: torch.dtype,
    weight_loader: Callable,
    **kwargs,
):
    output_size_per_partition = sum(output_partition_sizes)
    layer.logical_widths = output_partition_sizes

    # WEIGHT
    weight = ModelWeightParameter(
        data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=torch.float8_e4m3fn,
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    # TODO: update create_xxx_parameter functions to return
    # the newly added parameters
    if self.weight_qscheme == "per_channel":
        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
    else:
        assert self.weight_qscheme == "per_tensor"
        weight_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )

    # min requirement for fp8 kernels
    weight_scale[:] = torch.finfo(torch.float32).min
    layer.register_parameter("weight_scale", weight_scale)

    # INPUT SCALE
    if self.is_static_input_scheme:
        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        input_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("input_scale", input_scale)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
@classmethod
def get_min_capability(cls) -> int:
    # lovelace and up
    return 89

process_weights_after_loading

process_weights_after_loading(layer) -> None
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
def process_weights_after_loading(self, layer) -> None:
    # If per tensor, when we have a fused module (e.g. QKV) with per
    # tensor scales (thus N scales being passed to the kernel),
    # requantize so we can always run per tensor
    if self.weight_qscheme == "per_tensor":
        if current_platform.is_fp8_fnuz():
            input_scale = getattr(layer, "input_scale", None)
            weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                input_scale=input_scale,
            )
            if input_scale is not None:
                layer.input_scale = Parameter(input_scale, requires_grad=False)
        else:
            max_w_scale = layer.weight_scale
            weight = layer.weight

        max_w_scale, weight = requantize_with_max_scale(
            weight=weight,
            weight_scale=max_w_scale,
            logical_widths=layer.logical_widths,
        )

        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

    # If channelwise, scales are already lined up, so just transpose.
    elif self.weight_qscheme == "per_channel":
        weight = layer.weight

        if current_platform.is_fp8_fnuz():
            input_scale = getattr(layer, "input_scale", None)
            weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
                weight=weight,
                weight_scale=layer.weight_scale,
                input_scale=input_scale,
            )
            if input_scale is not None:
                layer.input_scale = Parameter(input_scale, requires_grad=False)
        else:
            weight_scale = layer.weight_scale.data
        if self.act_quant_group_shape == GroupShape.PER_TOKEN:
            weight_scale = weight_scale.view(-1, 1)
        layer.weight = Parameter(weight.t(), requires_grad=False)
        # required by torch.compile to be torch.nn.Parameter
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)

    else:
        raise ValueError(f"Unknown quantization scheme {self.weight_qscheme}")

    # INPUT SCALE
    if self.is_static_input_scheme:
        layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
    else:
        layer.input_scale = None

QuarkW8A8Int8

Bases: QuarkScheme

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
class QuarkW8A8Int8(QuarkScheme):
    _kernel_backends_being_used: set[str] = set()

    def __init__(
        self,
        qscheme: str,
        is_static_input_scheme: Optional[bool],
        input_symmetric: Optional[bool],
    ):
        self.qscheme = qscheme
        self.is_static_input_scheme = is_static_input_scheme
        self.input_symmetric = input_symmetric

    @classmethod
    def get_min_capability(cls) -> int:
        # turing and up
        return 75

    def create_weights(
        self,
        layer: torch.nn.Module,
        output_partition_sizes: list[int],
        input_size_per_partition: int,
        params_dtype: torch.dtype,
        weight_loader: Callable,
        **kwargs,
    ):
        layer.logical_widths = output_partition_sizes

        scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
            is_channelwise=(self.qscheme == "per_channel"),
            is_static_input_scheme=(self.is_static_input_scheme is True),
            input_symmetric=(self.input_symmetric is True),
        )

        kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)

        if kernel_type.__name__ not in self._kernel_backends_being_used:
            logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
            self._kernel_backends_being_used.add(kernel_type.__name__)

        # WEIGHT
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )

        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if self.qscheme == "per_channel":
            weight_scale = ChannelQuantScaleParameter(
                data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
                output_dim=0,
                weight_loader=weight_loader,
            )
            ChannelQuantZPParameter = ChannelQuantScaleParameter
            weight_zero_point = ChannelQuantZPParameter(
                data=torch.empty((sum(output_partition_sizes)), dtype=torch.int8),
                output_dim=0,
                weight_loader=weight_loader,
            )
        else:
            assert self.qscheme == "per_tensor"
            weight_scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
                weight_loader=weight_loader,
            )
            PerTensorZPParameter = PerTensorScaleParameter
            weight_zero_point = PerTensorZPParameter(
                data=torch.empty(len(output_partition_sizes), dtype=torch.int8),
                weight_loader=weight_loader,
            )
        layer.register_parameter("weight_scale", weight_scale)
        layer.register_parameter("weight_zero_point", weight_zero_point)

        # INPUT SCALE
        if self.is_static_input_scheme:
            input_scale = BasevLLMParameter(
                data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
            )
            layer.register_parameter("input_scale", input_scale)

            input_zero_point = BasevLLMParameter(
                data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
            )
            layer.register_parameter("input_zero_point", input_zero_point)

        self.kernel = kernel_type(
            c=scaled_mm_linear_kernel_config,
            w_q_param_name="weight",
            w_s_param_name="weight_scale",
            i_s_param_name="input_scale",
            i_zp_param_name="input_zero_point",
            azp_adj_param_name="azp_adj",
        )

    # Checkpoints are serialized in quark format, which is
    # different from the format the kernel may want. Handle repacking here.
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.register_parameter("weight_zero_point", None)
        delattr(layer, "weight_zero_point")
        if self.input_symmetric:
            layer.register_parameter("input_zero_point", None)
            delattr(layer, "input_zero_point")

        self.kernel.process_weights_after_loading(layer)

    def apply_weights(
        self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
    ) -> torch.Tensor:
        return self.kernel.apply_weights(layer, x, bias)

_kernel_backends_being_used class-attribute instance-attribute

_kernel_backends_being_used: set[str] = set()

input_symmetric instance-attribute

input_symmetric = input_symmetric

is_static_input_scheme instance-attribute

is_static_input_scheme = is_static_input_scheme

qscheme instance-attribute

qscheme = qscheme

__init__

__init__(
    qscheme: str,
    is_static_input_scheme: Optional[bool],
    input_symmetric: Optional[bool],
)
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
def __init__(
    self,
    qscheme: str,
    is_static_input_scheme: Optional[bool],
    input_symmetric: Optional[bool],
):
    self.qscheme = qscheme
    self.is_static_input_scheme = is_static_input_scheme
    self.input_symmetric = input_symmetric

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor]
) -> Tensor
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
def apply_weights(
    self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
    return self.kernel.apply_weights(layer, x, bias)

create_weights

create_weights(
    layer: Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: dtype,
    weight_loader: Callable,
    **kwargs,
)
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
def create_weights(
    self,
    layer: torch.nn.Module,
    output_partition_sizes: list[int],
    input_size_per_partition: int,
    params_dtype: torch.dtype,
    weight_loader: Callable,
    **kwargs,
):
    layer.logical_widths = output_partition_sizes

    scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
        is_channelwise=(self.qscheme == "per_channel"),
        is_static_input_scheme=(self.is_static_input_scheme is True),
        input_symmetric=(self.input_symmetric is True),
    )

    kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)

    if kernel_type.__name__ not in self._kernel_backends_being_used:
        logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
        self._kernel_backends_being_used.add(kernel_type.__name__)

    # WEIGHT
    weight = ModelWeightParameter(
        data=torch.empty(
            sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )

    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    if self.qscheme == "per_channel":
        weight_scale = ChannelQuantScaleParameter(
            data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
            output_dim=0,
            weight_loader=weight_loader,
        )
        ChannelQuantZPParameter = ChannelQuantScaleParameter
        weight_zero_point = ChannelQuantZPParameter(
            data=torch.empty((sum(output_partition_sizes)), dtype=torch.int8),
            output_dim=0,
            weight_loader=weight_loader,
        )
    else:
        assert self.qscheme == "per_tensor"
        weight_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        PerTensorZPParameter = PerTensorScaleParameter
        weight_zero_point = PerTensorZPParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.int8),
            weight_loader=weight_loader,
        )
    layer.register_parameter("weight_scale", weight_scale)
    layer.register_parameter("weight_zero_point", weight_zero_point)

    # INPUT SCALE
    if self.is_static_input_scheme:
        input_scale = BasevLLMParameter(
            data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
        )
        layer.register_parameter("input_scale", input_scale)

        input_zero_point = BasevLLMParameter(
            data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
        )
        layer.register_parameter("input_zero_point", input_zero_point)

    self.kernel = kernel_type(
        c=scaled_mm_linear_kernel_config,
        w_q_param_name="weight",
        w_s_param_name="weight_scale",
        i_s_param_name="input_scale",
        i_zp_param_name="input_zero_point",
        azp_adj_param_name="azp_adj",
    )

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
@classmethod
def get_min_capability(cls) -> int:
    # turing and up
    return 75

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    layer.register_parameter("weight_zero_point", None)
    delattr(layer, "weight_zero_point")
    if self.input_symmetric:
        layer.register_parameter("input_zero_point", None)
        delattr(layer, "input_zero_point")

    self.kernel.process_weights_after_loading(layer)